mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
run cargo fmt
This commit is contained in:
@@ -42,7 +42,11 @@ pub struct ModelMeta {
|
||||
}
|
||||
|
||||
const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta {
|
||||
ModelMeta { id, family, instruct }
|
||||
ModelMeta {
|
||||
id,
|
||||
family,
|
||||
instruct,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
|
@@ -42,13 +42,13 @@ pub struct AppState {
|
||||
pub llama_config: Option<LlamaInferenceConfig>,
|
||||
}
|
||||
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
// Configure a default model to prevent 503 errors from the chat-ui
|
||||
// This can be overridden by environment variables if needed
|
||||
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
|
||||
|
||||
let default_model_id =
|
||||
std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
|
||||
|
||||
let gemma_config = GemmaInferenceConfig {
|
||||
model: None,
|
||||
..Default::default()
|
||||
@@ -94,9 +94,6 @@ fn model_id_to_which(model_id: &str) -> Option<Which> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
fn normalize_model_id(model_id: &str) -> String {
|
||||
model_id.to_lowercase().replace("_", "-")
|
||||
}
|
||||
@@ -157,7 +154,7 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
// Use the model specified in the request
|
||||
let model_id = request.model.clone();
|
||||
let which_model = model_id_to_which(&model_id);
|
||||
|
||||
|
||||
// Validate that the requested model is supported
|
||||
let which_model = match which_model {
|
||||
Some(model) => model,
|
||||
@@ -204,19 +201,21 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Llama model", model_id) }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
let mut config = LlamaInferenceConfig::new(llama_model);
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
run_llama_inference(config).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
})),
|
||||
)
|
||||
})?
|
||||
} else {
|
||||
// Create Gemma configuration dynamically
|
||||
let gemma_model = match which_model {
|
||||
@@ -241,23 +240,25 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
let mut config = GemmaInferenceConfig {
|
||||
model: Some(gemma_model),
|
||||
..Default::default()
|
||||
};
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_gemma_api(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
run_gemma_api(config).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
})),
|
||||
)
|
||||
})?
|
||||
};
|
||||
|
||||
// Collect all tokens from the stream
|
||||
@@ -320,7 +321,7 @@ async fn handle_streaming_request(
|
||||
// Use the model specified in the request
|
||||
let model_id = request.model.clone();
|
||||
let which_model = model_id_to_which(&model_id);
|
||||
|
||||
|
||||
// Validate that the requested model is supported
|
||||
let which_model = match which_model {
|
||||
Some(model) => model,
|
||||
@@ -397,7 +398,7 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Llama model", model_id) }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
@@ -439,11 +440,11 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
let mut config = GemmaInferenceConfig {
|
||||
model: Some(gemma_model),
|
||||
..Default::default()
|
||||
@@ -605,59 +606,66 @@ pub async fn list_models() -> Json<ModelListResponse> {
|
||||
Which::Llama32_3BInstruct,
|
||||
];
|
||||
|
||||
let mut models: Vec<Model> = which_variants
|
||||
.into_iter()
|
||||
.map(|which| {
|
||||
let meta = which.meta();
|
||||
let model_id = match which {
|
||||
Which::Base2B => "gemma-2b",
|
||||
Which::Base7B => "gemma-7b",
|
||||
Which::Instruct2B => "gemma-2b-it",
|
||||
Which::Instruct7B => "gemma-7b-it",
|
||||
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
|
||||
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
|
||||
Which::CodeBase2B => "codegemma-2b",
|
||||
Which::CodeBase7B => "codegemma-7b",
|
||||
Which::CodeInstruct2B => "codegemma-2b-it",
|
||||
Which::CodeInstruct7B => "codegemma-7b-it",
|
||||
Which::BaseV2_2B => "gemma-2-2b",
|
||||
Which::InstructV2_2B => "gemma-2-2b-it",
|
||||
Which::BaseV2_9B => "gemma-2-9b",
|
||||
Which::InstructV2_9B => "gemma-2-9b-it",
|
||||
Which::BaseV3_1B => "gemma-3-1b",
|
||||
Which::InstructV3_1B => "gemma-3-1b-it",
|
||||
Which::Llama32_1B => "llama-3.2-1b",
|
||||
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
|
||||
Which::Llama32_3B => "llama-3.2-3b",
|
||||
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
|
||||
};
|
||||
|
||||
let owned_by = if meta.id.starts_with("google/") {
|
||||
"google"
|
||||
} else if meta.id.starts_with("meta-llama/") {
|
||||
"meta"
|
||||
} else {
|
||||
"unknown"
|
||||
};
|
||||
|
||||
let mut models: Vec<Model> = which_variants.into_iter().map(|which| {
|
||||
let meta = which.meta();
|
||||
let model_id = match which {
|
||||
Which::Base2B => "gemma-2b",
|
||||
Which::Base7B => "gemma-7b",
|
||||
Which::Instruct2B => "gemma-2b-it",
|
||||
Which::Instruct7B => "gemma-7b-it",
|
||||
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
|
||||
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
|
||||
Which::CodeBase2B => "codegemma-2b",
|
||||
Which::CodeBase7B => "codegemma-7b",
|
||||
Which::CodeInstruct2B => "codegemma-2b-it",
|
||||
Which::CodeInstruct7B => "codegemma-7b-it",
|
||||
Which::BaseV2_2B => "gemma-2-2b",
|
||||
Which::InstructV2_2B => "gemma-2-2b-it",
|
||||
Which::BaseV2_9B => "gemma-2-9b",
|
||||
Which::InstructV2_9B => "gemma-2-9b-it",
|
||||
Which::BaseV3_1B => "gemma-3-1b",
|
||||
Which::InstructV3_1B => "gemma-3-1b-it",
|
||||
Which::Llama32_1B => "llama-3.2-1b",
|
||||
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
|
||||
Which::Llama32_3B => "llama-3.2-3b",
|
||||
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
|
||||
};
|
||||
|
||||
let owned_by = if meta.id.starts_with("google/") {
|
||||
"google"
|
||||
} else if meta.id.starts_with("meta-llama/") {
|
||||
"meta"
|
||||
} else {
|
||||
"unknown"
|
||||
};
|
||||
|
||||
Model {
|
||||
id: model_id.to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: owned_by.to_string(),
|
||||
}
|
||||
}).collect();
|
||||
Model {
|
||||
id: model_id.to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: owned_by.to_string(),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Get embeddings models and convert them to inference Model format
|
||||
let embeddings_response = models_list().await;
|
||||
let embeddings_models: Vec<Model> = embeddings_response.0.data.into_iter().map(|embedding_model| {
|
||||
Model {
|
||||
let embeddings_models: Vec<Model> = embeddings_response
|
||||
.0
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|embedding_model| Model {
|
||||
id: embedding_model.id,
|
||||
object: embedding_model.object,
|
||||
created: 1686935002,
|
||||
owned_by: format!("{} - {}", embedding_model.owned_by, embedding_model.description),
|
||||
}
|
||||
}).collect();
|
||||
owned_by: format!(
|
||||
"{} - {}",
|
||||
embedding_model.owned_by, embedding_model.description
|
||||
),
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Add embeddings models to the main models list
|
||||
models.extend(embeddings_models);
|
||||
|
Reference in New Issue
Block a user