mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
chat client only displays available models
This commit is contained in:
@@ -42,13 +42,18 @@ pub struct AppState {
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
// Configure a default model to prevent 503 errors from the chat-ui
|
||||
// This can be overridden by environment variables if needed
|
||||
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
|
||||
|
||||
let gemma_config = GemmaInferenceConfig {
|
||||
model: gemma_runner::WhichModel::InstructV3_1B,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
Self {
|
||||
model_type: ModelType::Gemma,
|
||||
model_id: "gemma-3-1b-it".to_string(),
|
||||
model_id: default_model_id,
|
||||
gemma_config: Some(gemma_config),
|
||||
llama_config: None,
|
||||
}
|
||||
@@ -59,6 +64,34 @@ impl Default for AppState {
|
||||
// Helper functions
|
||||
// -------------------------
|
||||
|
||||
fn model_id_to_which(model_id: &str) -> Option<Which> {
|
||||
let normalized = normalize_model_id(model_id);
|
||||
match normalized.as_str() {
|
||||
"gemma-2b" => Some(Which::Base2B),
|
||||
"gemma-7b" => Some(Which::Base7B),
|
||||
"gemma-2b-it" => Some(Which::Instruct2B),
|
||||
"gemma-7b-it" => Some(Which::Instruct7B),
|
||||
"gemma-1.1-2b-it" => Some(Which::InstructV1_1_2B),
|
||||
"gemma-1.1-7b-it" => Some(Which::InstructV1_1_7B),
|
||||
"codegemma-2b" => Some(Which::CodeBase2B),
|
||||
"codegemma-7b" => Some(Which::CodeBase7B),
|
||||
"codegemma-2b-it" => Some(Which::CodeInstruct2B),
|
||||
"codegemma-7b-it" => Some(Which::CodeInstruct7B),
|
||||
"gemma-2-2b" => Some(Which::BaseV2_2B),
|
||||
"gemma-2-2b-it" => Some(Which::InstructV2_2B),
|
||||
"gemma-2-9b" => Some(Which::BaseV2_9B),
|
||||
"gemma-2-9b-it" => Some(Which::InstructV2_9B),
|
||||
"gemma-3-1b" => Some(Which::BaseV3_1B),
|
||||
"gemma-3-1b-it" => Some(Which::InstructV3_1B),
|
||||
"llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct),
|
||||
"llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
fn normalize_model_id(model_id: &str) -> String {
|
||||
model_id.to_lowercase().replace("_", "-")
|
||||
}
|
||||
@@ -116,90 +149,76 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
state: AppState,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||
// Enforce model selection behavior: reject if a different model is requested
|
||||
let configured_model = state.model_id.clone();
|
||||
let requested_model = request.model.clone();
|
||||
if requested_model.to_lowercase() != "default" {
|
||||
let normalized_requested = normalize_model_id(&requested_model);
|
||||
let normalized_configured = normalize_model_id(&configured_model);
|
||||
if normalized_requested != normalized_configured {
|
||||
// Use the model specified in the request
|
||||
let model_id = request.model.clone();
|
||||
let which_model = model_id_to_which(&model_id);
|
||||
|
||||
// Validate that the requested model is supported
|
||||
let which_model = match which_model {
|
||||
Some(model) => model,
|
||||
None => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!(
|
||||
"Requested model '{}' is not available. This server is running '{}' only.",
|
||||
requested_model, configured_model
|
||||
),
|
||||
"type": "model_mismatch"
|
||||
"message": format!("Unsupported model: {}", model_id),
|
||||
"type": "model_not_supported"
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let model_id = state.model_id.clone();
|
||||
};
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
|
||||
// Build prompt based on model type
|
||||
let prompt = match state.model_type {
|
||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||
ModelType::Llama => {
|
||||
// For Llama, just use the last user message for now
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
let prompt = if which_model.is_llama_model() {
|
||||
// For Llama, just use the last user message for now
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
build_gemma_prompt(&request.messages)
|
||||
};
|
||||
|
||||
// Get streaming receiver based on model type
|
||||
let rx =
|
||||
match state.model_type {
|
||||
ModelType::Gemma => {
|
||||
if let Some(mut config) = state.gemma_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_gemma_api(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
ModelType::Llama => {
|
||||
if let Some(mut config) = state.llama_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
let rx = if which_model.is_llama_model() {
|
||||
// Create Llama configuration dynamically
|
||||
let mut config = LlamaInferenceConfig::default();
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
} else {
|
||||
// Create Gemma configuration dynamically
|
||||
let gemma_model = if which_model.is_v3_model() {
|
||||
gemma_runner::WhichModel::InstructV3_1B
|
||||
} else {
|
||||
gemma_runner::WhichModel::InstructV3_1B // Default fallback
|
||||
};
|
||||
|
||||
let mut config = GemmaInferenceConfig {
|
||||
model: gemma_model,
|
||||
..Default::default()
|
||||
};
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_gemma_api(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
};
|
||||
|
||||
// Collect all tokens from the stream
|
||||
let mut completion = String::new();
|
||||
@@ -258,27 +277,25 @@ async fn handle_streaming_request(
|
||||
state: AppState,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
|
||||
// Validate requested model vs configured model
|
||||
let configured_model = state.model_id.clone();
|
||||
let requested_model = request.model.clone();
|
||||
if requested_model.to_lowercase() != "default" {
|
||||
let normalized_requested = normalize_model_id(&requested_model);
|
||||
let normalized_configured = normalize_model_id(&configured_model);
|
||||
if normalized_requested != normalized_configured {
|
||||
// Use the model specified in the request
|
||||
let model_id = request.model.clone();
|
||||
let which_model = model_id_to_which(&model_id);
|
||||
|
||||
// Validate that the requested model is supported
|
||||
let which_model = match which_model {
|
||||
Some(model) => model,
|
||||
None => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!(
|
||||
"Requested model '{}' is not available. This server is running '{}' only.",
|
||||
requested_model, configured_model
|
||||
),
|
||||
"type": "model_mismatch"
|
||||
"message": format!("Unsupported model: {}", model_id),
|
||||
"type": "model_not_supported"
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Generate a unique ID and metadata
|
||||
let response_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', ""));
|
||||
@@ -286,24 +303,22 @@ async fn handle_streaming_request(
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
let model_id = state.model_id.clone();
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
|
||||
// Build prompt based on model type
|
||||
let prompt = match state.model_type {
|
||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||
ModelType::Llama => {
|
||||
// For Llama, just use the last user message for now
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
let prompt = if which_model.is_llama_model() {
|
||||
// For Llama, just use the last user message for now
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
build_gemma_prompt(&request.messages)
|
||||
};
|
||||
tracing::debug!("Formatted prompt: {}", prompt);
|
||||
|
||||
@@ -330,51 +345,43 @@ async fn handle_streaming_request(
|
||||
}
|
||||
|
||||
// Get streaming receiver based on model type
|
||||
let model_rx = match state.model_type {
|
||||
ModelType::Gemma => {
|
||||
if let Some(mut config) = state.gemma_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
match run_gemma_api(config) {
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let model_rx = if which_model.is_llama_model() {
|
||||
// Create Llama configuration dynamically
|
||||
let mut config = LlamaInferenceConfig::default();
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
match run_llama_inference(config) {
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
ModelType::Llama => {
|
||||
if let Some(mut config) = state.llama_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
match run_llama_inference(config) {
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
} else {
|
||||
// Create Gemma configuration dynamically
|
||||
let gemma_model = if which_model.is_v3_model() {
|
||||
gemma_runner::WhichModel::InstructV3_1B
|
||||
} else {
|
||||
gemma_runner::WhichModel::InstructV3_1B // Default fallback
|
||||
};
|
||||
|
||||
let mut config = GemmaInferenceConfig {
|
||||
model: gemma_model,
|
||||
..Default::default()
|
||||
};
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
match run_gemma_api(config) {
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
@@ -500,172 +507,69 @@ pub fn create_router(app_state: AppState) -> Router {
|
||||
/// Handler for GET /v1/models - returns list of available models
|
||||
pub async fn list_models() -> Json<ModelListResponse> {
|
||||
// Get all available model variants from the Which enum
|
||||
let models = vec![
|
||||
// Gemma models
|
||||
let which_variants = vec![
|
||||
Which::Base2B,
|
||||
Which::Base7B,
|
||||
Which::Instruct2B,
|
||||
Which::Instruct7B,
|
||||
Which::InstructV1_1_2B,
|
||||
Which::InstructV1_1_7B,
|
||||
Which::CodeBase2B,
|
||||
Which::CodeBase7B,
|
||||
Which::CodeInstruct2B,
|
||||
Which::CodeInstruct7B,
|
||||
Which::BaseV2_2B,
|
||||
Which::InstructV2_2B,
|
||||
Which::BaseV2_9B,
|
||||
Which::InstructV2_9B,
|
||||
Which::BaseV3_1B,
|
||||
Which::InstructV3_1B,
|
||||
Which::Llama32_1B,
|
||||
Which::Llama32_1BInstruct,
|
||||
Which::Llama32_3B,
|
||||
Which::Llama32_3BInstruct,
|
||||
];
|
||||
|
||||
let models: Vec<Model> = which_variants.into_iter().map(|which| {
|
||||
let meta = which.meta();
|
||||
let model_id = match which {
|
||||
Which::Base2B => "gemma-2b",
|
||||
Which::Base7B => "gemma-7b",
|
||||
Which::Instruct2B => "gemma-2b-it",
|
||||
Which::Instruct7B => "gemma-7b-it",
|
||||
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
|
||||
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
|
||||
Which::CodeBase2B => "codegemma-2b",
|
||||
Which::CodeBase7B => "codegemma-7b",
|
||||
Which::CodeInstruct2B => "codegemma-2b-it",
|
||||
Which::CodeInstruct7B => "codegemma-7b-it",
|
||||
Which::BaseV2_2B => "gemma-2-2b",
|
||||
Which::InstructV2_2B => "gemma-2-2b-it",
|
||||
Which::BaseV2_9B => "gemma-2-9b",
|
||||
Which::InstructV2_9B => "gemma-2-9b-it",
|
||||
Which::BaseV3_1B => "gemma-3-1b",
|
||||
Which::InstructV3_1B => "gemma-3-1b-it",
|
||||
Which::Llama32_1B => "llama-3.2-1b",
|
||||
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
|
||||
Which::Llama32_3B => "llama-3.2-3b",
|
||||
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
|
||||
};
|
||||
|
||||
let owned_by = if meta.id.starts_with("google/") {
|
||||
"google"
|
||||
} else if meta.id.starts_with("meta-llama/") {
|
||||
"meta"
|
||||
} else {
|
||||
"unknown"
|
||||
};
|
||||
|
||||
Model {
|
||||
id: "gemma-2b".to_string(),
|
||||
id: model_id.to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002, // Using same timestamp as OpenAI example
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-7b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-2b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-7b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-1.1-2b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-1.1-7b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "codegemma-2b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "codegemma-7b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "codegemma-2b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "codegemma-7b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-2-2b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-2-2b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-2-9b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-2-9b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-3-1b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-3-1b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
// Llama models
|
||||
Model {
|
||||
id: "llama-3.2-1b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "llama-3.2-1b-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "llama-3.2-3b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "llama-3.2-3b-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-135m".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-135m-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-360m".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-360m-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-1.7b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-1.7b-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "tinyllama-1.1b-chat".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "tinyllama".to_string(),
|
||||
},
|
||||
];
|
||||
owned_by: owned_by.to_string(),
|
||||
}
|
||||
}).collect();
|
||||
|
||||
Json(ModelListResponse {
|
||||
object: "list".to_string(),
|
||||
|
Reference in New Issue
Block a user