chat client only displays available models

This commit is contained in:
geoffsee
2025-09-01 22:29:54 -04:00
parent 545e0c9831
commit 2deecb5e51
20 changed files with 3314 additions and 484 deletions

View File

@@ -42,13 +42,18 @@ pub struct AppState {
impl Default for AppState {
fn default() -> Self {
// Configure a default model to prevent 503 errors from the chat-ui
// This can be overridden by environment variables if needed
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
let gemma_config = GemmaInferenceConfig {
model: gemma_runner::WhichModel::InstructV3_1B,
..Default::default()
};
Self {
model_type: ModelType::Gemma,
model_id: "gemma-3-1b-it".to_string(),
model_id: default_model_id,
gemma_config: Some(gemma_config),
llama_config: None,
}
@@ -59,6 +64,34 @@ impl Default for AppState {
// Helper functions
// -------------------------
fn model_id_to_which(model_id: &str) -> Option<Which> {
let normalized = normalize_model_id(model_id);
match normalized.as_str() {
"gemma-2b" => Some(Which::Base2B),
"gemma-7b" => Some(Which::Base7B),
"gemma-2b-it" => Some(Which::Instruct2B),
"gemma-7b-it" => Some(Which::Instruct7B),
"gemma-1.1-2b-it" => Some(Which::InstructV1_1_2B),
"gemma-1.1-7b-it" => Some(Which::InstructV1_1_7B),
"codegemma-2b" => Some(Which::CodeBase2B),
"codegemma-7b" => Some(Which::CodeBase7B),
"codegemma-2b-it" => Some(Which::CodeInstruct2B),
"codegemma-7b-it" => Some(Which::CodeInstruct7B),
"gemma-2-2b" => Some(Which::BaseV2_2B),
"gemma-2-2b-it" => Some(Which::InstructV2_2B),
"gemma-2-9b" => Some(Which::BaseV2_9B),
"gemma-2-9b-it" => Some(Which::InstructV2_9B),
"gemma-3-1b" => Some(Which::BaseV3_1B),
"gemma-3-1b-it" => Some(Which::InstructV3_1B),
"llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct),
"llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct),
_ => None,
}
}
fn normalize_model_id(model_id: &str) -> String {
model_id.to_lowercase().replace("_", "-")
}
@@ -116,90 +149,76 @@ pub async fn chat_completions_non_streaming_proxy(
state: AppState,
request: ChatCompletionRequest,
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
// Enforce model selection behavior: reject if a different model is requested
let configured_model = state.model_id.clone();
let requested_model = request.model.clone();
if requested_model.to_lowercase() != "default" {
let normalized_requested = normalize_model_id(&requested_model);
let normalized_configured = normalize_model_id(&configured_model);
if normalized_requested != normalized_configured {
// Use the model specified in the request
let model_id = request.model.clone();
let which_model = model_id_to_which(&model_id);
// Validate that the requested model is supported
let which_model = match which_model {
Some(model) => model,
None => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": {
"message": format!(
"Requested model '{}' is not available. This server is running '{}' only.",
requested_model, configured_model
),
"type": "model_mismatch"
"message": format!("Unsupported model: {}", model_id),
"type": "model_not_supported"
}
})),
));
}
}
let model_id = state.model_id.clone();
};
let max_tokens = request.max_tokens.unwrap_or(1000);
// Build prompt based on model type
let prompt = match state.model_type {
ModelType::Gemma => build_gemma_prompt(&request.messages),
ModelType::Llama => {
// For Llama, just use the last user message for now
request
.messages
.last()
.and_then(|m| m.content.as_ref())
.and_then(|c| match c {
MessageContent(Either::Left(text)) => Some(text.clone()),
_ => None,
})
.unwrap_or_default()
}
let prompt = if which_model.is_llama_model() {
// For Llama, just use the last user message for now
request
.messages
.last()
.and_then(|m| m.content.as_ref())
.and_then(|c| match c {
MessageContent(Either::Left(text)) => Some(text.clone()),
_ => None,
})
.unwrap_or_default()
} else {
build_gemma_prompt(&request.messages)
};
// Get streaming receiver based on model type
let rx =
match state.model_type {
ModelType::Gemma => {
if let Some(mut config) = state.gemma_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_gemma_api(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) }
}))
))?
} else {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Gemma configuration not available" }
})),
));
}
}
ModelType::Llama => {
if let Some(mut config) = state.llama_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_llama_inference(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) }
}))
))?
} else {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Llama configuration not available" }
})),
));
}
}
let rx = if which_model.is_llama_model() {
// Create Llama configuration dynamically
let mut config = LlamaInferenceConfig::default();
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_llama_inference(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) }
}))
))?
} else {
// Create Gemma configuration dynamically
let gemma_model = if which_model.is_v3_model() {
gemma_runner::WhichModel::InstructV3_1B
} else {
gemma_runner::WhichModel::InstructV3_1B // Default fallback
};
let mut config = GemmaInferenceConfig {
model: gemma_model,
..Default::default()
};
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_gemma_api(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) }
}))
))?
};
// Collect all tokens from the stream
let mut completion = String::new();
@@ -258,27 +277,25 @@ async fn handle_streaming_request(
state: AppState,
request: ChatCompletionRequest,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
// Validate requested model vs configured model
let configured_model = state.model_id.clone();
let requested_model = request.model.clone();
if requested_model.to_lowercase() != "default" {
let normalized_requested = normalize_model_id(&requested_model);
let normalized_configured = normalize_model_id(&configured_model);
if normalized_requested != normalized_configured {
// Use the model specified in the request
let model_id = request.model.clone();
let which_model = model_id_to_which(&model_id);
// Validate that the requested model is supported
let which_model = match which_model {
Some(model) => model,
None => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": {
"message": format!(
"Requested model '{}' is not available. This server is running '{}' only.",
requested_model, configured_model
),
"type": "model_mismatch"
"message": format!("Unsupported model: {}", model_id),
"type": "model_not_supported"
}
})),
));
}
}
};
// Generate a unique ID and metadata
let response_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', ""));
@@ -286,24 +303,22 @@ async fn handle_streaming_request(
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let model_id = state.model_id.clone();
let max_tokens = request.max_tokens.unwrap_or(1000);
// Build prompt based on model type
let prompt = match state.model_type {
ModelType::Gemma => build_gemma_prompt(&request.messages),
ModelType::Llama => {
// For Llama, just use the last user message for now
request
.messages
.last()
.and_then(|m| m.content.as_ref())
.and_then(|c| match c {
MessageContent(Either::Left(text)) => Some(text.clone()),
_ => None,
})
.unwrap_or_default()
}
let prompt = if which_model.is_llama_model() {
// For Llama, just use the last user message for now
request
.messages
.last()
.and_then(|m| m.content.as_ref())
.and_then(|c| match c {
MessageContent(Either::Left(text)) => Some(text.clone()),
_ => None,
})
.unwrap_or_default()
} else {
build_gemma_prompt(&request.messages)
};
tracing::debug!("Formatted prompt: {}", prompt);
@@ -330,51 +345,43 @@ async fn handle_streaming_request(
}
// Get streaming receiver based on model type
let model_rx = match state.model_type {
ModelType::Gemma => {
if let Some(mut config) = state.gemma_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
match run_gemma_api(config) {
Ok(rx) => rx,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) }
})),
));
}
}
} else {
let model_rx = if which_model.is_llama_model() {
// Create Llama configuration dynamically
let mut config = LlamaInferenceConfig::default();
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
match run_llama_inference(config) {
Ok(rx) => rx,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Gemma configuration not available" }
"error": { "message": format!("Error initializing Llama model: {}", e) }
})),
));
}
}
ModelType::Llama => {
if let Some(mut config) = state.llama_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
match run_llama_inference(config) {
Ok(rx) => rx,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) }
})),
));
}
}
} else {
} else {
// Create Gemma configuration dynamically
let gemma_model = if which_model.is_v3_model() {
gemma_runner::WhichModel::InstructV3_1B
} else {
gemma_runner::WhichModel::InstructV3_1B // Default fallback
};
let mut config = GemmaInferenceConfig {
model: gemma_model,
..Default::default()
};
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
match run_gemma_api(config) {
Ok(rx) => rx,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Llama configuration not available" }
"error": { "message": format!("Error initializing Gemma model: {}", e) }
})),
));
}
@@ -500,172 +507,69 @@ pub fn create_router(app_state: AppState) -> Router {
/// Handler for GET /v1/models - returns list of available models
pub async fn list_models() -> Json<ModelListResponse> {
// Get all available model variants from the Which enum
let models = vec![
// Gemma models
let which_variants = vec![
Which::Base2B,
Which::Base7B,
Which::Instruct2B,
Which::Instruct7B,
Which::InstructV1_1_2B,
Which::InstructV1_1_7B,
Which::CodeBase2B,
Which::CodeBase7B,
Which::CodeInstruct2B,
Which::CodeInstruct7B,
Which::BaseV2_2B,
Which::InstructV2_2B,
Which::BaseV2_9B,
Which::InstructV2_9B,
Which::BaseV3_1B,
Which::InstructV3_1B,
Which::Llama32_1B,
Which::Llama32_1BInstruct,
Which::Llama32_3B,
Which::Llama32_3BInstruct,
];
let models: Vec<Model> = which_variants.into_iter().map(|which| {
let meta = which.meta();
let model_id = match which {
Which::Base2B => "gemma-2b",
Which::Base7B => "gemma-7b",
Which::Instruct2B => "gemma-2b-it",
Which::Instruct7B => "gemma-7b-it",
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
Which::CodeBase2B => "codegemma-2b",
Which::CodeBase7B => "codegemma-7b",
Which::CodeInstruct2B => "codegemma-2b-it",
Which::CodeInstruct7B => "codegemma-7b-it",
Which::BaseV2_2B => "gemma-2-2b",
Which::InstructV2_2B => "gemma-2-2b-it",
Which::BaseV2_9B => "gemma-2-9b",
Which::InstructV2_9B => "gemma-2-9b-it",
Which::BaseV3_1B => "gemma-3-1b",
Which::InstructV3_1B => "gemma-3-1b-it",
Which::Llama32_1B => "llama-3.2-1b",
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
Which::Llama32_3B => "llama-3.2-3b",
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
};
let owned_by = if meta.id.starts_with("google/") {
"google"
} else if meta.id.starts_with("meta-llama/") {
"meta"
} else {
"unknown"
};
Model {
id: "gemma-2b".to_string(),
id: model_id.to_string(),
object: "model".to_string(),
created: 1686935002, // Using same timestamp as OpenAI example
owned_by: "google".to_string(),
},
Model {
id: "gemma-7b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-7b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-1.1-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-1.1-7b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-2b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-7b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-7b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-2b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-9b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-9b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-3-1b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-3-1b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
// Llama models
Model {
id: "llama-3.2-1b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "meta".to_string(),
},
Model {
id: "llama-3.2-1b-instruct".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "meta".to_string(),
},
Model {
id: "llama-3.2-3b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "meta".to_string(),
},
Model {
id: "llama-3.2-3b-instruct".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "meta".to_string(),
},
Model {
id: "smollm2-135m".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "smollm2-135m-instruct".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "smollm2-360m".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "smollm2-360m-instruct".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "smollm2-1.7b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "smollm2-1.7b-instruct".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "tinyllama-1.1b-chat".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "tinyllama".to_string(),
},
];
owned_by: owned_by.to_string(),
}
}).collect();
Json(ModelListResponse {
object: "list".to_string(),