run cargo fmt

This commit is contained in:
geoffsee
2025-09-04 13:45:25 -04:00
parent 1e02b12cda
commit c1c583faab
11 changed files with 241 additions and 170 deletions

View File

@@ -42,7 +42,11 @@ pub struct ModelMeta {
}
const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta {
ModelMeta { id, family, instruct }
ModelMeta {
id,
family,
instruct,
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]

View File

@@ -42,13 +42,13 @@ pub struct AppState {
pub llama_config: Option<LlamaInferenceConfig>,
}
impl Default for AppState {
fn default() -> Self {
// Configure a default model to prevent 503 errors from the chat-ui
// This can be overridden by environment variables if needed
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
let default_model_id =
std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
let gemma_config = GemmaInferenceConfig {
model: None,
..Default::default()
@@ -94,9 +94,6 @@ fn model_id_to_which(model_id: &str) -> Option<Which> {
}
}
fn normalize_model_id(model_id: &str) -> String {
model_id.to_lowercase().replace("_", "-")
}
@@ -157,7 +154,7 @@ pub async fn chat_completions_non_streaming_proxy(
// Use the model specified in the request
let model_id = request.model.clone();
let which_model = model_id_to_which(&model_id);
// Validate that the requested model is supported
let which_model = match which_model {
Some(model) => model,
@@ -204,19 +201,21 @@ pub async fn chat_completions_non_streaming_proxy(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Llama model", model_id) }
}))
})),
));
}
};
let mut config = LlamaInferenceConfig::new(llama_model);
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_llama_inference(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) }
}))
))?
run_llama_inference(config).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) }
})),
)
})?
} else {
// Create Gemma configuration dynamically
let gemma_model = match which_model {
@@ -241,23 +240,25 @@ pub async fn chat_completions_non_streaming_proxy(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
}))
})),
));
}
};
let mut config = GemmaInferenceConfig {
model: Some(gemma_model),
..Default::default()
};
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_gemma_api(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) }
}))
))?
run_gemma_api(config).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) }
})),
)
})?
};
// Collect all tokens from the stream
@@ -320,7 +321,7 @@ async fn handle_streaming_request(
// Use the model specified in the request
let model_id = request.model.clone();
let which_model = model_id_to_which(&model_id);
// Validate that the requested model is supported
let which_model = match which_model {
Some(model) => model,
@@ -397,7 +398,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Llama model", model_id) }
}))
})),
));
}
};
@@ -439,11 +440,11 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
}))
})),
));
}
};
let mut config = GemmaInferenceConfig {
model: Some(gemma_model),
..Default::default()
@@ -605,59 +606,66 @@ pub async fn list_models() -> Json<ModelListResponse> {
Which::Llama32_3BInstruct,
];
let mut models: Vec<Model> = which_variants
.into_iter()
.map(|which| {
let meta = which.meta();
let model_id = match which {
Which::Base2B => "gemma-2b",
Which::Base7B => "gemma-7b",
Which::Instruct2B => "gemma-2b-it",
Which::Instruct7B => "gemma-7b-it",
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
Which::CodeBase2B => "codegemma-2b",
Which::CodeBase7B => "codegemma-7b",
Which::CodeInstruct2B => "codegemma-2b-it",
Which::CodeInstruct7B => "codegemma-7b-it",
Which::BaseV2_2B => "gemma-2-2b",
Which::InstructV2_2B => "gemma-2-2b-it",
Which::BaseV2_9B => "gemma-2-9b",
Which::InstructV2_9B => "gemma-2-9b-it",
Which::BaseV3_1B => "gemma-3-1b",
Which::InstructV3_1B => "gemma-3-1b-it",
Which::Llama32_1B => "llama-3.2-1b",
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
Which::Llama32_3B => "llama-3.2-3b",
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
};
let owned_by = if meta.id.starts_with("google/") {
"google"
} else if meta.id.starts_with("meta-llama/") {
"meta"
} else {
"unknown"
};
let mut models: Vec<Model> = which_variants.into_iter().map(|which| {
let meta = which.meta();
let model_id = match which {
Which::Base2B => "gemma-2b",
Which::Base7B => "gemma-7b",
Which::Instruct2B => "gemma-2b-it",
Which::Instruct7B => "gemma-7b-it",
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
Which::CodeBase2B => "codegemma-2b",
Which::CodeBase7B => "codegemma-7b",
Which::CodeInstruct2B => "codegemma-2b-it",
Which::CodeInstruct7B => "codegemma-7b-it",
Which::BaseV2_2B => "gemma-2-2b",
Which::InstructV2_2B => "gemma-2-2b-it",
Which::BaseV2_9B => "gemma-2-9b",
Which::InstructV2_9B => "gemma-2-9b-it",
Which::BaseV3_1B => "gemma-3-1b",
Which::InstructV3_1B => "gemma-3-1b-it",
Which::Llama32_1B => "llama-3.2-1b",
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
Which::Llama32_3B => "llama-3.2-3b",
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
};
let owned_by = if meta.id.starts_with("google/") {
"google"
} else if meta.id.starts_with("meta-llama/") {
"meta"
} else {
"unknown"
};
Model {
id: model_id.to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: owned_by.to_string(),
}
}).collect();
Model {
id: model_id.to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: owned_by.to_string(),
}
})
.collect();
// Get embeddings models and convert them to inference Model format
let embeddings_response = models_list().await;
let embeddings_models: Vec<Model> = embeddings_response.0.data.into_iter().map(|embedding_model| {
Model {
let embeddings_models: Vec<Model> = embeddings_response
.0
.data
.into_iter()
.map(|embedding_model| Model {
id: embedding_model.id,
object: embedding_model.object,
created: 1686935002,
owned_by: format!("{} - {}", embedding_model.owned_by, embedding_model.description),
}
}).collect();
owned_by: format!(
"{} - {}",
embedding_model.owned_by, embedding_model.description
),
})
.collect();
// Add embeddings models to the main models list
models.extend(embeddings_models);