fixes issue with model selection

This commit is contained in:
geoffsee
2025-09-04 13:42:30 -04:00
parent ff55d882c7
commit 1e02b12cda
6 changed files with 209 additions and 63 deletions

View File

@@ -365,7 +365,7 @@ fn ChatPage() -> impl IntoView {
// State for available models and selected model
let available_models = RwSignal::new(Vec::<ModelInfo>::new());
let selected_model = RwSignal::new(String::from("gemma-3-1b-it")); // Default model
let selected_model = RwSignal::new(String::from("")); // Default model
// State for streaming response
let streaming_content = RwSignal::new(String::new());
@@ -382,6 +382,7 @@ fn ChatPage() -> impl IntoView {
match fetch_models().await {
Ok(models) => {
available_models.set(models);
selected_model.set(String::from("gemma-3-1b-it"));
}
Err(error) => {
console::log_1(&format!("Failed to fetch models: {}", error).into());

View File

@@ -7,6 +7,7 @@ use axum::{
};
use futures_util::stream::{self, Stream};
use std::convert::Infallible;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tokio_stream::wrappers::UnboundedReceiverStream;
@@ -20,7 +21,7 @@ use crate::openai_types::{
use crate::Which;
use either::Either;
use embeddings_engine::models_list;
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
use gemma_runner::{run_gemma_api, GemmaInferenceConfig, WhichModel};
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
use serde_json::Value;
// -------------------------
@@ -35,12 +36,13 @@ pub enum ModelType {
#[derive(Clone)]
pub struct AppState {
pub model_type: ModelType,
pub model_type: Option<ModelType>,
pub model_id: String,
pub gemma_config: Option<GemmaInferenceConfig>,
pub llama_config: Option<LlamaInferenceConfig>,
}
impl Default for AppState {
fn default() -> Self {
// Configure a default model to prevent 503 errors from the chat-ui
@@ -48,12 +50,12 @@ impl Default for AppState {
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
let gemma_config = GemmaInferenceConfig {
model: gemma_runner::WhichModel::InstructV3_1B,
model: None,
..Default::default()
};
Self {
model_type: ModelType::Gemma,
model_type: None,
model_id: default_model_id,
gemma_config: Some(gemma_config),
llama_config: None,
@@ -84,7 +86,9 @@ fn model_id_to_which(model_id: &str) -> Option<Which> {
"gemma-2-9b-it" => Some(Which::InstructV2_9B),
"gemma-3-1b" => Some(Which::BaseV3_1B),
"gemma-3-1b-it" => Some(Which::InstructV3_1B),
"llama-3.2-1b" => Some(Which::Llama32_1B),
"llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct),
"llama-3.2-3b" => Some(Which::Llama32_3B),
"llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct),
_ => None,
}
@@ -190,7 +194,21 @@ pub async fn chat_completions_non_streaming_proxy(
// Get streaming receiver based on model type
let rx = if which_model.is_llama_model() {
// Create Llama configuration dynamically
let mut config = LlamaInferenceConfig::default();
let llama_model = match which_model {
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
_ => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Llama model", model_id) }
}))
));
}
};
let mut config = LlamaInferenceConfig::new(llama_model);
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_llama_inference(config).map_err(|e| (
@@ -201,14 +219,35 @@ pub async fn chat_completions_non_streaming_proxy(
))?
} else {
// Create Gemma configuration dynamically
let gemma_model = if which_model.is_v3_model() {
gemma_runner::WhichModel::InstructV3_1B
} else {
gemma_runner::WhichModel::InstructV3_1B // Default fallback
let gemma_model = match which_model {
Which::Base2B => gemma_runner::WhichModel::Base2B,
Which::Base7B => gemma_runner::WhichModel::Base7B,
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
_ => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
}))
));
}
};
let mut config = GemmaInferenceConfig {
model: gemma_model,
model: Some(gemma_model),
..Default::default()
};
config.prompt = prompt.clone();
@@ -348,7 +387,21 @@ async fn handle_streaming_request(
// Get streaming receiver based on model type
let model_rx = if which_model.is_llama_model() {
// Create Llama configuration dynamically
let mut config = LlamaInferenceConfig::default();
let llama_model = match which_model {
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
_ => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Llama model", model_id) }
}))
));
}
};
let mut config = LlamaInferenceConfig::new(llama_model);
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
match run_llama_inference(config) {
@@ -364,14 +417,35 @@ async fn handle_streaming_request(
}
} else {
// Create Gemma configuration dynamically
let gemma_model = if which_model.is_v3_model() {
gemma_runner::WhichModel::InstructV3_1B
} else {
gemma_runner::WhichModel::InstructV3_1B // Default fallback
let gemma_model = match which_model {
Which::Base2B => gemma_runner::WhichModel::Base2B,
Which::Base7B => gemma_runner::WhichModel::Base7B,
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
_ => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
}))
));
}
};
let mut config = GemmaInferenceConfig {
model: gemma_model,
model: Some(gemma_model),
..Default::default()
};
config.prompt = prompt.clone();