Add support for listing available models via CLI and HTTP endpoint

This commit is contained in:
geoffsee
2025-08-27 16:35:08 -04:00
parent 432c04d9df
commit 9e28e259ad
3 changed files with 212 additions and 6 deletions

View File

@@ -2,7 +2,7 @@ use axum::{
extract::State,
http::StatusCode,
response::{sse::Event, sse::Sse, IntoResponse},
routing::post,
routing::{get, post},
Json, Router,
};
use futures_util::stream::{self, Stream};
@@ -16,9 +16,9 @@ use tokio::time;
use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid;
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Usage};
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
use crate::text_generation::TextGeneration;
use crate::{utilities_lib, Model, Which};
use crate::{utilities_lib, Model as GemmaModel, Which};
use either::Either;
use hf_hub::api::sync::{Api, ApiError};
use hf_hub::{Repo, RepoType};
@@ -283,17 +283,17 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
| Which::CodeInstruct7B => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
let model = Model1::new(args.use_flash_attn, &config, vb).unwrap();
Model::V1(model)
GemmaModel::V1(model)
}
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
let model = Model2::new(args.use_flash_attn, &config, vb).unwrap();
Model::V2(model)
GemmaModel::V2(model)
}
Which::BaseV3_1B | Which::InstructV3_1B => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_path).unwrap()).unwrap();
let model = Model3::new(args.use_flash_attn, &config, vb).unwrap();
Model::V3(model)
GemmaModel::V3(model)
}
};
@@ -580,6 +580,114 @@ async fn handle_streaming_request(
Ok(Sse::new(stream))
}
/// Handler for GET /v1/models - returns list of available models
async fn list_models() -> Json<ModelListResponse> {
// Get all available model variants from the Which enum
let models = vec![
Model {
id: "gemma-2b".to_string(),
object: "model".to_string(),
created: 1686935002, // Using same timestamp as OpenAI example
owned_by: "google".to_string(),
},
Model {
id: "gemma-7b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-7b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-1.1-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-1.1-7b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-2b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-7b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-7b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-2b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-9b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-9b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-3-1b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-3-1b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
];
Json(ModelListResponse {
object: "list".to_string(),
data: models,
})
}
// -------------------------
// Router
// -------------------------
@@ -593,6 +701,7 @@ pub fn create_router(app_state: AppState) -> Router {
Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/models", get(list_models))
// .route("/v1/chat/completions/stream", post(chat_completions_stream))
.layer(cors)
.with_state(app_state)
@@ -604,6 +713,35 @@ mod tests {
use crate::openai_types::{Message, MessageContent};
use either::Either;
#[tokio::test]
async fn test_models_list_endpoint() {
println!("[DEBUG_LOG] Testing models list endpoint");
let response = list_models().await;
let models_response = response.0;
// Verify response structure
assert_eq!(models_response.object, "list");
assert_eq!(models_response.data.len(), 16);
// Verify some key models are present
let model_ids: Vec<String> = models_response.data.iter().map(|m| m.id.clone()).collect();
assert!(model_ids.contains(&"gemma-2b".to_string()));
assert!(model_ids.contains(&"gemma-7b".to_string()));
assert!(model_ids.contains(&"gemma-3-1b-it".to_string()));
assert!(model_ids.contains(&"codegemma-2b-it".to_string()));
// Verify model structure
for model in &models_response.data {
assert_eq!(model.object, "model");
assert_eq!(model.owned_by, "google");
assert_eq!(model.created, 1686935002);
assert!(!model.id.is_empty());
}
println!("[DEBUG_LOG] Models list endpoint test passed - {} models available", models_response.data.len());
}
#[tokio::test]
async fn test_reproduce_tensor_shape_mismatch() {
// Create a test app state with Gemma 3 model (same as the failing request)