mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
fixes issue with model selection
This commit is contained in:
@@ -7,6 +7,7 @@ use axum::{
|
||||
};
|
||||
use futures_util::stream::{self, Stream};
|
||||
use std::convert::Infallible;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
@@ -20,7 +21,7 @@ use crate::openai_types::{
|
||||
use crate::Which;
|
||||
use either::Either;
|
||||
use embeddings_engine::models_list;
|
||||
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
|
||||
use gemma_runner::{run_gemma_api, GemmaInferenceConfig, WhichModel};
|
||||
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
|
||||
use serde_json::Value;
|
||||
// -------------------------
|
||||
@@ -35,12 +36,13 @@ pub enum ModelType {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub model_type: ModelType,
|
||||
pub model_type: Option<ModelType>,
|
||||
pub model_id: String,
|
||||
pub gemma_config: Option<GemmaInferenceConfig>,
|
||||
pub llama_config: Option<LlamaInferenceConfig>,
|
||||
}
|
||||
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
// Configure a default model to prevent 503 errors from the chat-ui
|
||||
@@ -48,12 +50,12 @@ impl Default for AppState {
|
||||
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
|
||||
|
||||
let gemma_config = GemmaInferenceConfig {
|
||||
model: gemma_runner::WhichModel::InstructV3_1B,
|
||||
model: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
Self {
|
||||
model_type: ModelType::Gemma,
|
||||
model_type: None,
|
||||
model_id: default_model_id,
|
||||
gemma_config: Some(gemma_config),
|
||||
llama_config: None,
|
||||
@@ -84,7 +86,9 @@ fn model_id_to_which(model_id: &str) -> Option<Which> {
|
||||
"gemma-2-9b-it" => Some(Which::InstructV2_9B),
|
||||
"gemma-3-1b" => Some(Which::BaseV3_1B),
|
||||
"gemma-3-1b-it" => Some(Which::InstructV3_1B),
|
||||
"llama-3.2-1b" => Some(Which::Llama32_1B),
|
||||
"llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct),
|
||||
"llama-3.2-3b" => Some(Which::Llama32_3B),
|
||||
"llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct),
|
||||
_ => None,
|
||||
}
|
||||
@@ -190,7 +194,21 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
// Get streaming receiver based on model type
|
||||
let rx = if which_model.is_llama_model() {
|
||||
// Create Llama configuration dynamically
|
||||
let mut config = LlamaInferenceConfig::default();
|
||||
let llama_model = match which_model {
|
||||
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
|
||||
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
|
||||
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
|
||||
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Llama model", model_id) }
|
||||
}))
|
||||
));
|
||||
}
|
||||
};
|
||||
let mut config = LlamaInferenceConfig::new(llama_model);
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
@@ -201,14 +219,35 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
))?
|
||||
} else {
|
||||
// Create Gemma configuration dynamically
|
||||
let gemma_model = if which_model.is_v3_model() {
|
||||
gemma_runner::WhichModel::InstructV3_1B
|
||||
} else {
|
||||
gemma_runner::WhichModel::InstructV3_1B // Default fallback
|
||||
let gemma_model = match which_model {
|
||||
Which::Base2B => gemma_runner::WhichModel::Base2B,
|
||||
Which::Base7B => gemma_runner::WhichModel::Base7B,
|
||||
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
|
||||
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
|
||||
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
|
||||
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
|
||||
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
|
||||
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
|
||||
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
|
||||
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
|
||||
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
|
||||
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
|
||||
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
|
||||
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
|
||||
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
|
||||
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
|
||||
}))
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut config = GemmaInferenceConfig {
|
||||
model: gemma_model,
|
||||
model: Some(gemma_model),
|
||||
..Default::default()
|
||||
};
|
||||
config.prompt = prompt.clone();
|
||||
@@ -348,7 +387,21 @@ async fn handle_streaming_request(
|
||||
// Get streaming receiver based on model type
|
||||
let model_rx = if which_model.is_llama_model() {
|
||||
// Create Llama configuration dynamically
|
||||
let mut config = LlamaInferenceConfig::default();
|
||||
let llama_model = match which_model {
|
||||
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
|
||||
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
|
||||
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
|
||||
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Llama model", model_id) }
|
||||
}))
|
||||
));
|
||||
}
|
||||
};
|
||||
let mut config = LlamaInferenceConfig::new(llama_model);
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
match run_llama_inference(config) {
|
||||
@@ -364,14 +417,35 @@ async fn handle_streaming_request(
|
||||
}
|
||||
} else {
|
||||
// Create Gemma configuration dynamically
|
||||
let gemma_model = if which_model.is_v3_model() {
|
||||
gemma_runner::WhichModel::InstructV3_1B
|
||||
} else {
|
||||
gemma_runner::WhichModel::InstructV3_1B // Default fallback
|
||||
let gemma_model = match which_model {
|
||||
Which::Base2B => gemma_runner::WhichModel::Base2B,
|
||||
Which::Base7B => gemma_runner::WhichModel::Base7B,
|
||||
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
|
||||
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
|
||||
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
|
||||
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
|
||||
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
|
||||
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
|
||||
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
|
||||
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
|
||||
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
|
||||
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
|
||||
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
|
||||
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
|
||||
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
|
||||
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
|
||||
}))
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut config = GemmaInferenceConfig {
|
||||
model: gemma_model,
|
||||
model: Some(gemma_model),
|
||||
..Default::default()
|
||||
};
|
||||
config.prompt = prompt.clone();
|
||||
|
Reference in New Issue
Block a user