mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
- Refactored build_pipeline
usage to ensure pipeline arguments are cloned.
- Introduced `reset_state` for clearing cached state between requests. - Enhanced chat UI with model selector and dynamic model fetching. - Improved error logging and detailed debug messages for chat request flows. - Added fresh instantiation of `TextGeneration` to prevent tensor shape mismatches.
This commit is contained in:
@@ -36,15 +36,18 @@ use serde_json::Value;
|
|||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub text_generation: Arc<Mutex<TextGeneration>>,
|
pub text_generation: Arc<Mutex<TextGeneration>>,
|
||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
|
// Store build args to recreate TextGeneration when needed
|
||||||
|
pub build_args: PipelineArgs,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for AppState {
|
impl Default for AppState {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
let args = PipelineArgs::default();
|
let args = PipelineArgs::default();
|
||||||
let text_generation = build_pipeline(args);
|
let text_generation = build_pipeline(args.clone());
|
||||||
Self {
|
Self {
|
||||||
text_generation: Arc::new(Mutex::new(text_generation)),
|
text_generation: Arc::new(Mutex::new(text_generation)),
|
||||||
model_id: String::new(),
|
model_id: String::new(),
|
||||||
|
build_args: args,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -318,7 +321,7 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
|||||||
pub async fn chat_completions(
|
pub async fn chat_completions(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Json(request): Json<ChatCompletionRequest>,
|
Json(request): Json<ChatCompletionRequest>,
|
||||||
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
|
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||||
// If streaming was requested, this function shouldn't be called
|
// If streaming was requested, this function shouldn't be called
|
||||||
// A separate route handles streaming requests
|
// A separate route handles streaming requests
|
||||||
if !request.stream.unwrap_or(false) {
|
if !request.stream.unwrap_or(false) {
|
||||||
@@ -357,7 +360,11 @@ pub async fn chat_completions_non_streaming_proxy(state: AppState, request: Chat
|
|||||||
// Generate
|
// Generate
|
||||||
let mut output = Vec::new();
|
let mut output = Vec::new();
|
||||||
{
|
{
|
||||||
|
// Recreate TextGeneration instance to ensure completely fresh state
|
||||||
|
// This prevents KV cache persistence that causes tensor shape mismatches
|
||||||
|
let fresh_text_gen = build_pipeline(state.build_args.clone());
|
||||||
let mut text_gen = state.text_generation.lock().await;
|
let mut text_gen = state.text_generation.lock().await;
|
||||||
|
*text_gen = fresh_text_gen;
|
||||||
|
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||||
@@ -456,7 +463,12 @@ async fn handle_streaming_request(
|
|||||||
// Generate text using existing buffer-based approach
|
// Generate text using existing buffer-based approach
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
{
|
{
|
||||||
|
// Recreate TextGeneration instance to ensure completely fresh state
|
||||||
|
// This prevents KV cache persistence that causes tensor shape mismatches
|
||||||
|
let fresh_text_gen = build_pipeline(state.build_args.clone());
|
||||||
let mut text_gen = state.text_generation.lock().await;
|
let mut text_gen = state.text_generation.lock().await;
|
||||||
|
*text_gen = fresh_text_gen;
|
||||||
|
|
||||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||||
|
|
||||||
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
|
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
|
||||||
@@ -752,10 +764,11 @@ mod tests {
|
|||||||
println!("[DEBUG_LOG] Creating pipeline with model: {}", args.model_id);
|
println!("[DEBUG_LOG] Creating pipeline with model: {}", args.model_id);
|
||||||
|
|
||||||
// This should reproduce the same conditions as the curl script
|
// This should reproduce the same conditions as the curl script
|
||||||
let text_generation = build_pipeline(args);
|
let text_generation = build_pipeline(args.clone());
|
||||||
let app_state = AppState {
|
let app_state = AppState {
|
||||||
text_generation: Arc::new(Mutex::new(text_generation)),
|
text_generation: Arc::new(Mutex::new(text_generation)),
|
||||||
model_id: "gemma-3-1b-it".to_string(),
|
model_id: "gemma-3-1b-it".to_string(),
|
||||||
|
build_args: args,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create the same request as the curl script
|
// Create the same request as the curl script
|
||||||
|
@@ -117,6 +117,16 @@ impl TextGeneration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset method to clear state between requests
|
||||||
|
pub fn reset_state(&mut self) {
|
||||||
|
// Reset the primary device flag so we try the primary device first for each new request
|
||||||
|
if !self.device.is_cpu() {
|
||||||
|
self.try_primary_device = true;
|
||||||
|
}
|
||||||
|
// Clear the penalty cache to avoid stale cached values from previous requests
|
||||||
|
self.penalty_cache.clear();
|
||||||
|
}
|
||||||
|
|
||||||
// Helper method to apply repeat penalty with caching for optimization
|
// Helper method to apply repeat penalty with caching for optimization
|
||||||
pub fn apply_cached_repeat_penalty(
|
pub fn apply_cached_repeat_penalty(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
@@ -34,6 +34,7 @@ web-sys = { version = "0.3", features = [
|
|||||||
"Element",
|
"Element",
|
||||||
"HtmlElement",
|
"HtmlElement",
|
||||||
"HtmlInputElement",
|
"HtmlInputElement",
|
||||||
|
"HtmlSelectElement",
|
||||||
"HtmlTextAreaElement",
|
"HtmlTextAreaElement",
|
||||||
"Event",
|
"Event",
|
||||||
"EventTarget",
|
"EventTarget",
|
||||||
|
@@ -10,12 +10,12 @@ use futures_util::StreamExt;
|
|||||||
use async_openai_wasm::{
|
use async_openai_wasm::{
|
||||||
types::{
|
types::{
|
||||||
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
|
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
|
||||||
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs,
|
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, Model as OpenAIModel,
|
||||||
},
|
},
|
||||||
Client,
|
Client,
|
||||||
};
|
};
|
||||||
use async_openai_wasm::config::OpenAIConfig;
|
use async_openai_wasm::config::OpenAIConfig;
|
||||||
use async_openai_wasm::types::ChatCompletionResponseStream;
|
use async_openai_wasm::types::{ChatCompletionResponseStream, Model};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
@@ -89,6 +89,43 @@ pub fn App() -> impl IntoView {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn fetch_available_models() -> Result<Vec<OpenAIModel>, String> {
|
||||||
|
log::info!("[DEBUG_LOG] fetch_available_models: Starting model fetch from http://localhost:8080/v1");
|
||||||
|
|
||||||
|
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
|
||||||
|
let client = Client::with_config(config);
|
||||||
|
|
||||||
|
match client.models().list().await {
|
||||||
|
Ok(response) => {
|
||||||
|
let model_count = response.data.len();
|
||||||
|
log::info!("[DEBUG_LOG] fetch_available_models: Successfully fetched {} models", model_count);
|
||||||
|
|
||||||
|
if model_count > 0 {
|
||||||
|
let model_names: Vec<String> = response.data.iter().map(|m| m.id.clone()).collect();
|
||||||
|
log::debug!("[DEBUG_LOG] fetch_available_models: Available models: {:?}", model_names);
|
||||||
|
} else {
|
||||||
|
log::warn!("[DEBUG_LOG] fetch_available_models: No models returned by server");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(response.data)
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("[DEBUG_LOG] fetch_available_models: Failed to fetch models: {:?}", e);
|
||||||
|
|
||||||
|
let error_details = format!("{:?}", e);
|
||||||
|
if error_details.contains("400") || error_details.contains("Bad Request") {
|
||||||
|
log::error!("[DEBUG_LOG] fetch_available_models: HTTP 400 - Server rejected models request");
|
||||||
|
} else if error_details.contains("404") || error_details.contains("Not Found") {
|
||||||
|
log::error!("[DEBUG_LOG] fetch_available_models: HTTP 404 - Models endpoint not found");
|
||||||
|
} else if error_details.contains("Connection") || error_details.contains("connection") {
|
||||||
|
log::error!("[DEBUG_LOG] fetch_available_models: Connection error - server may be down");
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(format!("Failed to fetch models: {}", e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn send_chat_request(chat_request: ChatRequest) -> ChatCompletionResponseStream {
|
async fn send_chat_request(chat_request: ChatRequest) -> ChatCompletionResponseStream {
|
||||||
let config = OpenAIConfig::new().with_api_base("http://localhost:8080".to_string());
|
let config = OpenAIConfig::new().with_api_base("http://localhost:8080".to_string());
|
||||||
let client = Client::with_config(config);
|
let client = Client::with_config(config);
|
||||||
@@ -168,19 +205,47 @@ async fn send_chat_request(chat_request: ChatRequest) -> ChatCompletionResponseS
|
|||||||
// Err("leptos-chat chat request only supported on wasm32 target".to_string())
|
// Err("leptos-chat chat request only supported on wasm32 target".to_string())
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
const DEFAULT_MODEL: &str = "gemma-2b-it";
|
||||||
|
|
||||||
#[component]
|
#[component]
|
||||||
fn ChatInterface() -> impl IntoView {
|
fn ChatInterface() -> impl IntoView {
|
||||||
let (messages, set_messages) = create_signal::<VecDeque<Message>>(VecDeque::new());
|
let (messages, set_messages) = create_signal::<VecDeque<Message>>(VecDeque::new());
|
||||||
let (input_value, set_input_value) = create_signal(String::new());
|
let (input_value, set_input_value) = create_signal(String::new());
|
||||||
let (is_loading, set_is_loading) = create_signal(false);
|
let (is_loading, set_is_loading) = create_signal(false);
|
||||||
|
let (available_models, set_available_models) = create_signal::<Vec<OpenAIModel>>(Vec::new());
|
||||||
|
let (selected_model, set_selected_model) = create_signal(DEFAULT_MODEL.to_string());
|
||||||
|
let (models_loading, set_models_loading) = create_signal(false);
|
||||||
|
|
||||||
|
// Fetch models on component initialization
|
||||||
|
create_effect(move |_| {
|
||||||
|
spawn_local(async move {
|
||||||
|
set_models_loading.set(true);
|
||||||
|
match fetch_available_models().await {
|
||||||
|
Ok(models) => {
|
||||||
|
set_available_models.set(models);
|
||||||
|
set_models_loading.set(false);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("Failed to fetch models: {}", e);
|
||||||
|
// Set a default model if fetching fails
|
||||||
|
set_available_models.set(vec![]);
|
||||||
|
set_models_loading.set(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
let send_message = create_action(move |content: &String| {
|
let send_message = create_action(move |content: &String| {
|
||||||
let content = content.clone();
|
let content = content.clone();
|
||||||
async move {
|
async move {
|
||||||
if content.trim().is_empty() {
|
if content.trim().is_empty() {
|
||||||
|
log::debug!("[DEBUG_LOG] send_message: Empty content, skipping");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log::info!("[DEBUG_LOG] send_message: Starting message send process");
|
||||||
|
log::debug!("[DEBUG_LOG] send_message: User message content length: {}", content.len());
|
||||||
|
|
||||||
set_is_loading.set(true);
|
set_is_loading.set(true);
|
||||||
|
|
||||||
// Add user message to chat
|
// Add user message to chat
|
||||||
@@ -204,7 +269,8 @@ fn ChatInterface() -> impl IntoView {
|
|||||||
chat_messages.push(system_message.into());
|
chat_messages.push(system_message.into());
|
||||||
|
|
||||||
// Add history messages
|
// Add history messages
|
||||||
messages.with(|msgs| {
|
let history_count = messages.with_untracked(|msgs| {
|
||||||
|
let count = msgs.len();
|
||||||
for msg in msgs.iter() {
|
for msg in msgs.iter() {
|
||||||
let message = ChatCompletionRequestUserMessageArgs::default()
|
let message = ChatCompletionRequestUserMessageArgs::default()
|
||||||
.content(msg.content.clone())
|
.content(msg.content.clone())
|
||||||
@@ -212,6 +278,7 @@ fn ChatInterface() -> impl IntoView {
|
|||||||
.expect("failed to build message");
|
.expect("failed to build message");
|
||||||
chat_messages.push(message.into());
|
chat_messages.push(message.into());
|
||||||
}
|
}
|
||||||
|
count
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add current user message
|
// Add current user message
|
||||||
@@ -221,20 +288,37 @@ fn ChatInterface() -> impl IntoView {
|
|||||||
.expect("failed to build user message");
|
.expect("failed to build user message");
|
||||||
chat_messages.push(message.into());
|
chat_messages.push(message.into());
|
||||||
|
|
||||||
|
let current_model = selected_model.get_untracked();
|
||||||
|
let total_messages = chat_messages.len();
|
||||||
|
|
||||||
|
log::info!("[DEBUG_LOG] send_message: Preparing request - model: '{}', history_count: {}, total_messages: {}",
|
||||||
|
current_model, history_count, total_messages);
|
||||||
|
|
||||||
let request = CreateChatCompletionRequestArgs::default()
|
let request = CreateChatCompletionRequestArgs::default()
|
||||||
.model("gemma-2b-it")
|
.model(current_model.as_str())
|
||||||
.max_tokens(512u32)
|
.max_tokens(512u32)
|
||||||
.messages(chat_messages)
|
.messages(chat_messages)
|
||||||
.stream(true) // ensure server streams
|
.stream(true) // ensure server streams
|
||||||
.build()
|
.build()
|
||||||
.expect("failed to build request");
|
.expect("failed to build request");
|
||||||
|
|
||||||
|
// Log request details for debugging server issues
|
||||||
|
log::info!("[DEBUG_LOG] send_message: Request configuration - model: '{}', max_tokens: 512, stream: true, messages_count: {}",
|
||||||
|
current_model, total_messages);
|
||||||
|
log::debug!("[DEBUG_LOG] send_message: Request details - history_messages: {}, system_messages: 1, user_messages: {}",
|
||||||
|
history_count, total_messages - history_count - 1);
|
||||||
|
|
||||||
// Send request
|
// Send request
|
||||||
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
|
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
|
||||||
let client = Client::with_config(config);
|
let client = Client::with_config(config);
|
||||||
|
|
||||||
|
log::info!("[DEBUG_LOG] send_message: Sending request to http://localhost:8080/v1 with model: '{}'", current_model);
|
||||||
|
|
||||||
|
|
||||||
match client.chat().create_stream(request).await {
|
match client.chat().create_stream(request).await {
|
||||||
Ok(mut stream) => {
|
Ok(mut stream) => {
|
||||||
|
log::info!("[DEBUG_LOG] send_message: Successfully created stream, starting to receive response");
|
||||||
|
|
||||||
// Insert a placeholder assistant message to append into
|
// Insert a placeholder assistant message to append into
|
||||||
let assistant_id = Uuid::new_v4().to_string();
|
let assistant_id = Uuid::new_v4().to_string();
|
||||||
set_messages.update(|msgs| {
|
set_messages.update(|msgs| {
|
||||||
@@ -246,10 +330,12 @@ fn ChatInterface() -> impl IntoView {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let mut chunks_received = 0;
|
||||||
// Stream loop: append deltas to the last message
|
// Stream loop: append deltas to the last message
|
||||||
while let Some(next) = stream.next().await {
|
while let Some(next) = stream.next().await {
|
||||||
match next {
|
match next {
|
||||||
Ok(chunk) => {
|
Ok(chunk) => {
|
||||||
|
chunks_received += 1;
|
||||||
// Try to pull out the content delta in a tolerant way.
|
// Try to pull out the content delta in a tolerant way.
|
||||||
// async-openai 0.28.x stream chunk usually looks like:
|
// async-openai 0.28.x stream chunk usually looks like:
|
||||||
// choices[0].delta.content: Option<String>
|
// choices[0].delta.content: Option<String>
|
||||||
@@ -281,12 +367,13 @@ fn ChatInterface() -> impl IntoView {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::error!("Stream error: {:?}", e);
|
log::error!("[DEBUG_LOG] send_message: Stream error after {} chunks: {:?}", chunks_received, e);
|
||||||
|
log::error!("[DEBUG_LOG] send_message: Stream error details - model: '{}', chunks_received: {}", current_model, chunks_received);
|
||||||
set_messages.update(|msgs| {
|
set_messages.update(|msgs| {
|
||||||
msgs.push_back(Message {
|
msgs.push_back(Message {
|
||||||
id: Uuid::new_v4().to_string(),
|
id: Uuid::new_v4().to_string(),
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: format!("Stream error: {e}"),
|
content: format!("Stream error after {} chunks: {}", chunks_received, e),
|
||||||
timestamp: Date::now(),
|
timestamp: Date::now(),
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -294,13 +381,39 @@ fn ChatInterface() -> impl IntoView {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
log::info!("[DEBUG_LOG] send_message: Stream completed successfully, received {} chunks", chunks_received);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
log::error!("Failed to send request: {:?}", e);
|
// Detailed error logging for different types of errors
|
||||||
|
log::error!("[DEBUG_LOG] send_message: Request failed with error: {:?}", e);
|
||||||
|
log::error!("[DEBUG_LOG] send_message: Request context - model: '{}', total_messages: {}, endpoint: http://localhost:8080/v1",
|
||||||
|
current_model, total_messages);
|
||||||
|
|
||||||
|
// Try to extract more specific error information
|
||||||
|
let error_details = format!("{:?}", e);
|
||||||
|
let user_message = if error_details.contains("400") || error_details.contains("Bad Request") {
|
||||||
|
log::error!("[DEBUG_LOG] send_message: HTTP 400 Bad Request detected - possible issues:");
|
||||||
|
log::error!("[DEBUG_LOG] send_message: - Invalid model name: '{}'", current_model);
|
||||||
|
log::error!("[DEBUG_LOG] send_message: - Invalid message format or content");
|
||||||
|
log::error!("[DEBUG_LOG] send_message: - Server configuration issue");
|
||||||
|
format!("Error: HTTP 400 Bad Request - Check model '{}' and message format. See console for details.", current_model)
|
||||||
|
} else if error_details.contains("404") || error_details.contains("Not Found") {
|
||||||
|
log::error!("[DEBUG_LOG] send_message: HTTP 404 Not Found - server endpoint may be incorrect");
|
||||||
|
"Error: HTTP 404 Not Found - Server endpoint not found".to_string()
|
||||||
|
} else if error_details.contains("500") || error_details.contains("Internal Server Error") {
|
||||||
|
log::error!("[DEBUG_LOG] send_message: HTTP 500 Internal Server Error - server-side issue");
|
||||||
|
"Error: HTTP 500 Internal Server Error - Server problem".to_string()
|
||||||
|
} else if error_details.contains("Connection") || error_details.contains("connection") {
|
||||||
|
log::error!("[DEBUG_LOG] send_message: Connection error - server may be down");
|
||||||
|
"Error: Cannot connect to server at http://localhost:8080".to_string()
|
||||||
|
} else {
|
||||||
|
format!("Error: Request failed - {}", e)
|
||||||
|
};
|
||||||
|
|
||||||
let error_message = Message {
|
let error_message = Message {
|
||||||
id: Uuid::new_v4().to_string(),
|
id: Uuid::new_v4().to_string(),
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: "Error: Failed to connect to server".to_string(),
|
content: user_message,
|
||||||
timestamp: Date::now(),
|
timestamp: Date::now(),
|
||||||
};
|
};
|
||||||
set_messages.update(|msgs| msgs.push_back(error_message));
|
set_messages.update(|msgs| msgs.push_back(error_message));
|
||||||
@@ -330,6 +443,11 @@ fn ChatInterface() -> impl IntoView {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let on_model_change = move |ev| {
|
||||||
|
let select = event_target::<web_sys::HtmlSelectElement>(&ev);
|
||||||
|
set_selected_model.set(select.value());
|
||||||
|
};
|
||||||
|
|
||||||
let messages_list = move || {
|
let messages_list = move || {
|
||||||
messages.get()
|
messages.get()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@@ -364,6 +482,36 @@ fn ChatInterface() -> impl IntoView {
|
|||||||
view! {
|
view! {
|
||||||
<div class="chat-container">
|
<div class="chat-container">
|
||||||
<h1>"Chat Interface"</h1>
|
<h1>"Chat Interface"</h1>
|
||||||
|
<div class="model-selector">
|
||||||
|
<label for="model-select">"Model: "</label>
|
||||||
|
<select
|
||||||
|
id="model-select"
|
||||||
|
on:change=on_model_change
|
||||||
|
prop:value=selected_model
|
||||||
|
prop:disabled=models_loading
|
||||||
|
>
|
||||||
|
{move || {
|
||||||
|
if models_loading.get() {
|
||||||
|
view! {
|
||||||
|
<option value="">"Loading models..."</option>
|
||||||
|
}.into_view()
|
||||||
|
} else {
|
||||||
|
let models = available_models.get();
|
||||||
|
if models.is_empty() {
|
||||||
|
view! {
|
||||||
|
<option selected=true value="gemma-3b-it">"gemma-3b-it (default)"</option>
|
||||||
|
}.into_view()
|
||||||
|
} else {
|
||||||
|
models.into_iter().map(|model| {
|
||||||
|
view! {
|
||||||
|
<option value=model.id.clone() selected={model.id == DEFAULT_MODEL}>{model.id}</option>
|
||||||
|
}
|
||||||
|
}).collect_view()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
<div class="messages-container">
|
<div class="messages-container">
|
||||||
{messages_list}
|
{messages_list}
|
||||||
{loading_indicator}
|
{loading_indicator}
|
||||||
@@ -390,203 +538,6 @@ fn ChatInterface() -> impl IntoView {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// #[component]
|
|
||||||
// fn ChatInterface() -> impl IntoView {
|
|
||||||
// let (messages, set_messages) = create_signal::<VecDeque<Message>>(VecDeque::new());
|
|
||||||
// let (input_value, set_input_value) = create_signal(String::new());
|
|
||||||
// let (is_loading, set_is_loading) = create_signal(false);
|
|
||||||
//
|
|
||||||
// let send_message = create_action(move |content: &String| {
|
|
||||||
// let content = content.clone();
|
|
||||||
// async move {
|
|
||||||
// if content.trim().is_empty() {
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// set_is_loading.set(true);
|
|
||||||
//
|
|
||||||
// // Add user message to chat
|
|
||||||
// let user_message = Message {
|
|
||||||
// id: Uuid::new_v4().to_string(),
|
|
||||||
// role: "user".to_string(),
|
|
||||||
// content: content.clone(),
|
|
||||||
// timestamp: Date::now(),
|
|
||||||
// };
|
|
||||||
//
|
|
||||||
// set_messages.update(|msgs| msgs.push_back(user_message.clone()));
|
|
||||||
// set_input_value.set(String::new());
|
|
||||||
//
|
|
||||||
// let mut chat_messages = Vec::new();
|
|
||||||
//
|
|
||||||
// // Add system message
|
|
||||||
// let system_message = ChatCompletionRequestSystemMessageArgs::default()
|
|
||||||
// .content("You are a helpful assistant.")
|
|
||||||
// .build()
|
|
||||||
// .expect("failed to build system message");
|
|
||||||
// chat_messages.push(system_message.into());
|
|
||||||
//
|
|
||||||
// // Add history messages
|
|
||||||
// messages.with(|msgs| {
|
|
||||||
// for msg in msgs.iter() {
|
|
||||||
// let message = ChatCompletionRequestUserMessageArgs::default()
|
|
||||||
// .content(msg.content.clone().into())
|
|
||||||
// .build()
|
|
||||||
// .expect("failed to build message");
|
|
||||||
// chat_messages.push(message.into());
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
//
|
|
||||||
// // Add current user message
|
|
||||||
// let message = ChatCompletionRequestUserMessageArgs::default()
|
|
||||||
// .content(user_message.content.clone().into())
|
|
||||||
// .build()
|
|
||||||
// .expect("failed to build user message");
|
|
||||||
// chat_messages.push(message.into());
|
|
||||||
//
|
|
||||||
// let request = CreateChatCompletionRequestArgs::default()
|
|
||||||
// .model("gemma-2b-it")
|
|
||||||
// .max_tokens(512u32)
|
|
||||||
// .messages(chat_messages)
|
|
||||||
// .build()
|
|
||||||
// .expect("failed to build request");
|
|
||||||
//
|
|
||||||
// // Send request
|
|
||||||
// let config = OpenAIConfig::new().with_api_base("http://localhost:8080".to_string());
|
|
||||||
// let client = Client::with_config(config);
|
|
||||||
//
|
|
||||||
// match client
|
|
||||||
// .chat()
|
|
||||||
// .create_stream(request)
|
|
||||||
// .await
|
|
||||||
// {
|
|
||||||
// Ok(chat_response) => {
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// // if let Some(choice) = chat_response {
|
|
||||||
// // // Extract content from the message
|
|
||||||
// // let content_text = match &choice.message.content {
|
|
||||||
// // Some(message_content) => {
|
|
||||||
// // match &message_content.0 {
|
|
||||||
// // either::Either::Left(text) => text.clone(),
|
|
||||||
// // either::Either::Right(_) => "Complex content not supported".to_string(),
|
|
||||||
// // }
|
|
||||||
// // }
|
|
||||||
// // None => "No content provided".to_string(),
|
|
||||||
// // };
|
|
||||||
// //
|
|
||||||
// // let assistant_message = Message {
|
|
||||||
// // id: Uuid::new_v4().to_string(),
|
|
||||||
// // role: "assistant".to_string(),
|
|
||||||
// // content: content_text,
|
|
||||||
// // timestamp: Date::now(),
|
|
||||||
// // };
|
|
||||||
// // set_messages.update(|msgs| msgs.push_back(assistant_message));
|
|
||||||
// //
|
|
||||||
// //
|
|
||||||
// //
|
|
||||||
// // // Log token usage information
|
|
||||||
// // log::debug!("Token usage - Prompt: {}, Completion: {}, Total: {}",
|
|
||||||
// // chat_response.usage.prompt_tokens,
|
|
||||||
// // chat_response.usage.completion_tokens,
|
|
||||||
// // chat_response.usage.total_tokens);
|
|
||||||
// // }
|
|
||||||
// }
|
|
||||||
// Err(e) => {
|
|
||||||
// log::error!("Failed to send request: {:?}", e);
|
|
||||||
// let error_message = Message {
|
|
||||||
// id: Uuid::new_v4().to_string(),
|
|
||||||
// role: "system".to_string(),
|
|
||||||
// content: "Error: Failed to connect to server".to_string(),
|
|
||||||
// timestamp: Date::now(),
|
|
||||||
// };
|
|
||||||
// set_messages.update(|msgs| msgs.push_back(error_message));
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// set_is_loading.set(false);
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
//
|
|
||||||
// let on_input = move |ev| {
|
|
||||||
// let input = event_target::<HtmlInputElement>(&ev);
|
|
||||||
// set_input_value.set(input.value());
|
|
||||||
// };
|
|
||||||
//
|
|
||||||
// let on_submit = move |ev: SubmitEvent| {
|
|
||||||
// ev.prevent_default();
|
|
||||||
// let content = input_value.get();
|
|
||||||
// send_message.dispatch(content);
|
|
||||||
// };
|
|
||||||
//
|
|
||||||
// let on_keypress = move |ev: KeyboardEvent| {
|
|
||||||
// if ev.key() == "Enter" && !ev.shift_key() {
|
|
||||||
// ev.prevent_default();
|
|
||||||
// let content = input_value.get();
|
|
||||||
// send_message.dispatch(content);
|
|
||||||
// }
|
|
||||||
// };
|
|
||||||
//
|
|
||||||
// let messages_list = move || {
|
|
||||||
// messages.get()
|
|
||||||
// .into_iter()
|
|
||||||
// .map(|message| {
|
|
||||||
// let role_class = match message.role.as_str() {
|
|
||||||
// "user" => "user-message",
|
|
||||||
// "assistant" => "assistant-message",
|
|
||||||
// _ => "system-message",
|
|
||||||
// };
|
|
||||||
//
|
|
||||||
// view! {
|
|
||||||
// <div class=format!("message {}", role_class)>
|
|
||||||
// <div class="message-role">{message.role}</div>
|
|
||||||
// <div class="message-content">{message.content}</div>
|
|
||||||
// </div>
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
// .collect_view()
|
|
||||||
// };
|
|
||||||
//
|
|
||||||
// let loading_indicator = move || {
|
|
||||||
// is_loading.get().then(|| {
|
|
||||||
// view! {
|
|
||||||
// <div class="message assistant-message">
|
|
||||||
// <div class="message-role">"assistant"</div>
|
|
||||||
// <div class="message-content">"Thinking..."</div>
|
|
||||||
// </div>
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
// };
|
|
||||||
//
|
|
||||||
// view! {
|
|
||||||
// <div class="chat-container">
|
|
||||||
// <h1>"Chat Interface"</h1>
|
|
||||||
// <div class="messages-container">
|
|
||||||
// {messages_list}
|
|
||||||
// {loading_indicator}
|
|
||||||
// </div>
|
|
||||||
// <form class="input-form" on:submit=on_submit>
|
|
||||||
// <input
|
|
||||||
// type="text"
|
|
||||||
// class="message-input"
|
|
||||||
// placeholder="Type your message here..."
|
|
||||||
// prop:value=input_value
|
|
||||||
// on:input=on_input
|
|
||||||
// on:keypress=on_keypress
|
|
||||||
// prop:disabled=is_loading
|
|
||||||
// />
|
|
||||||
// <button
|
|
||||||
// type="submit"
|
|
||||||
// class="send-button"
|
|
||||||
// prop:disabled=move || is_loading.get() || input_value.get().trim().is_empty()
|
|
||||||
// >
|
|
||||||
// "Send"
|
|
||||||
// </button>
|
|
||||||
// </form>
|
|
||||||
// </div>
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
#[wasm_bindgen::prelude::wasm_bindgen(start)]
|
#[wasm_bindgen::prelude::wasm_bindgen(start)]
|
||||||
pub fn main() {
|
pub fn main() {
|
||||||
// Set up error handling and logging for WebAssembly
|
// Set up error handling and logging for WebAssembly
|
||||||
|
@@ -53,10 +53,11 @@ async fn main() {
|
|||||||
pipeline_args.model_id = "google/gemma-3-1b-it".to_string();
|
pipeline_args.model_id = "google/gemma-3-1b-it".to_string();
|
||||||
pipeline_args.which = Which::InstructV3_1B;
|
pipeline_args.which = Which::InstructV3_1B;
|
||||||
|
|
||||||
let text_generation = build_pipeline(pipeline_args);
|
let text_generation = build_pipeline(pipeline_args.clone());
|
||||||
let app_state = AppState {
|
let app_state = AppState {
|
||||||
text_generation: std::sync::Arc::new(tokio::sync::Mutex::new(text_generation)),
|
text_generation: std::sync::Arc::new(tokio::sync::Mutex::new(text_generation)),
|
||||||
model_id: "google/gemma-3-1b-it".to_string(),
|
model_id: "google/gemma-3-1b-it".to_string(),
|
||||||
|
build_args: pipeline_args,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get the inference router directly from the inference engine
|
// Get the inference router directly from the inference engine
|
||||||
|
Reference in New Issue
Block a user