diff --git a/cli.ts b/cli.ts index 3ec546f..4169701 100755 --- a/cli.ts +++ b/cli.ts @@ -15,12 +15,14 @@ Simple CLI tool for testing the local OpenAI-compatible API server. Options: --model Model to use (default: ${DEFAULT_MODEL}) --prompt The prompt to send (can also be provided as positional argument) + --list-models List all available models from the server --help Show this help message Examples: ./cli.ts "What is the capital of France?" ./cli.ts --model gemma-3-1b-it --prompt "Hello, world!" ./cli.ts --prompt "Who was the 16th president of the United States?" + ./cli.ts --list-models The server should be running at http://localhost:8080 Start it with: ./run_server.sh @@ -39,6 +41,9 @@ const { values, positionals } = parseArgs({ help: { type: 'boolean', }, + 'list-models': { + type: 'boolean', + }, }, strict: false, allowPositionals: true, @@ -67,6 +72,36 @@ async function requestLocalOpenAI(model: string, userPrompt: string) { } } +async function listModels() { + const openai = new OpenAI({ + baseURL: "http://localhost:8080/v1", + apiKey: "not used", + }); + try { + const models = await openai.models.list(); + console.log(`[INFO] Available models from http://localhost:8080/v1:`); + console.log("---"); + + if (models.data && models.data.length > 0) { + models.data.forEach((model, index) => { + console.log(`${index + 1}. ${model.id}`); + console.log(` Owner: ${model.owned_by}`); + console.log(` Created: ${new Date(model.created * 1000).toISOString()}`); + console.log(""); + }); + console.log(`Total: ${models.data.length} models available`); + } else { + console.log("No models found."); + } + + } catch (e) { + console.error("[ERROR] Failed to fetch models from local OpenAI server:", e.message); + console.error("[HINT] Make sure the server is running at http://localhost:8080"); + console.error("[HINT] Start it with: ./run_server.sh"); + throw e; + } +} + async function main() { // Show help if requested if (values.help) { @@ -74,6 +109,17 @@ async function main() { process.exit(0); } + // List models if requested + if (values['list-models']) { + try { + await listModels(); + process.exit(0); + } catch (error) { + console.error("\n[ERROR] Failed to list models:", error.message); + process.exit(1); + } + } + // Get the prompt from either --prompt flag or positional argument const prompt = values.prompt || positionals[2]; // positionals[0] is 'bun', positionals[1] is 'client_cli.ts' diff --git a/crates/inference-engine/src/openai_types.rs b/crates/inference-engine/src/openai_types.rs index 7ec8c76..d42540b 100644 --- a/crates/inference-engine/src/openai_types.rs +++ b/crates/inference-engine/src/openai_types.rs @@ -191,4 +191,26 @@ pub struct Usage { pub prompt_tokens: usize, pub completion_tokens: usize, pub total_tokens: usize, +} + +/// Model object representing an available model +#[derive(Debug, Serialize, ToSchema)] +pub struct Model { + /// The model identifier + pub id: String, + /// The object type, always "model" + pub object: String, + /// Unix timestamp of when the model was created + pub created: u64, + /// The organization that owns the model + pub owned_by: String, +} + +/// Response for listing available models +#[derive(Debug, Serialize, ToSchema)] +pub struct ModelListResponse { + /// The object type, always "list" + pub object: String, + /// Array of available models + pub data: Vec, } \ No newline at end of file diff --git a/crates/inference-engine/src/server.rs b/crates/inference-engine/src/server.rs index 768b6eb..7ee76ae 100644 --- a/crates/inference-engine/src/server.rs +++ b/crates/inference-engine/src/server.rs @@ -2,7 +2,7 @@ use axum::{ extract::State, http::StatusCode, response::{sse::Event, sse::Sse, IntoResponse}, - routing::post, + routing::{get, post}, Json, Router, }; use futures_util::stream::{self, Stream}; @@ -16,9 +16,9 @@ use tokio::time; use tower_http::cors::{Any, CorsLayer}; use uuid::Uuid; -use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Usage}; +use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage}; use crate::text_generation::TextGeneration; -use crate::{utilities_lib, Model, Which}; +use crate::{utilities_lib, Model as GemmaModel, Which}; use either::Either; use hf_hub::api::sync::{Api, ApiError}; use hf_hub::{Repo, RepoType}; @@ -283,17 +283,17 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration { | Which::CodeInstruct7B => { let config: Config1 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap(); let model = Model1::new(args.use_flash_attn, &config, vb).unwrap(); - Model::V1(model) + GemmaModel::V1(model) } Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => { let config: Config2 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap(); let model = Model2::new(args.use_flash_attn, &config, vb).unwrap(); - Model::V2(model) + GemmaModel::V2(model) } Which::BaseV3_1B | Which::InstructV3_1B => { let config: Config3 = serde_json::from_reader(std::fs::File::open(config_path).unwrap()).unwrap(); let model = Model3::new(args.use_flash_attn, &config, vb).unwrap(); - Model::V3(model) + GemmaModel::V3(model) } }; @@ -580,6 +580,114 @@ async fn handle_streaming_request( Ok(Sse::new(stream)) } +/// Handler for GET /v1/models - returns list of available models +async fn list_models() -> Json { + // Get all available model variants from the Which enum + let models = vec![ + Model { + id: "gemma-2b".to_string(), + object: "model".to_string(), + created: 1686935002, // Using same timestamp as OpenAI example + owned_by: "google".to_string(), + }, + Model { + id: "gemma-7b".to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: "google".to_string(), + }, + Model { + id: "gemma-2b-it".to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: "google".to_string(), + }, + Model { + id: "gemma-7b-it".to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: "google".to_string(), + }, + Model { + id: "gemma-1.1-2b-it".to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: "google".to_string(), + }, + Model { + id: "gemma-1.1-7b-it".to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: "google".to_string(), + }, + Model { + id: "codegemma-2b".to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: "google".to_string(), + }, + Model { + id: "codegemma-7b".to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: "google".to_string(), + }, + Model { + id: "codegemma-2b-it".to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: "google".to_string(), + }, + Model { + id: "codegemma-7b-it".to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: "google".to_string(), + }, + Model { + id: "gemma-2-2b".to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: "google".to_string(), + }, + Model { + id: "gemma-2-2b-it".to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: "google".to_string(), + }, + Model { + id: "gemma-2-9b".to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: "google".to_string(), + }, + Model { + id: "gemma-2-9b-it".to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: "google".to_string(), + }, + Model { + id: "gemma-3-1b".to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: "google".to_string(), + }, + Model { + id: "gemma-3-1b-it".to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: "google".to_string(), + }, + ]; + + Json(ModelListResponse { + object: "list".to_string(), + data: models, + }) +} + // ------------------------- // Router // ------------------------- @@ -593,6 +701,7 @@ pub fn create_router(app_state: AppState) -> Router { Router::new() .route("/v1/chat/completions", post(chat_completions)) + .route("/v1/models", get(list_models)) // .route("/v1/chat/completions/stream", post(chat_completions_stream)) .layer(cors) .with_state(app_state) @@ -604,6 +713,35 @@ mod tests { use crate::openai_types::{Message, MessageContent}; use either::Either; + #[tokio::test] + async fn test_models_list_endpoint() { + println!("[DEBUG_LOG] Testing models list endpoint"); + + let response = list_models().await; + let models_response = response.0; + + // Verify response structure + assert_eq!(models_response.object, "list"); + assert_eq!(models_response.data.len(), 16); + + // Verify some key models are present + let model_ids: Vec = models_response.data.iter().map(|m| m.id.clone()).collect(); + assert!(model_ids.contains(&"gemma-2b".to_string())); + assert!(model_ids.contains(&"gemma-7b".to_string())); + assert!(model_ids.contains(&"gemma-3-1b-it".to_string())); + assert!(model_ids.contains(&"codegemma-2b-it".to_string())); + + // Verify model structure + for model in &models_response.data { + assert_eq!(model.object, "model"); + assert_eq!(model.owned_by, "google"); + assert_eq!(model.created, 1686935002); + assert!(!model.id.is_empty()); + } + + println!("[DEBUG_LOG] Models list endpoint test passed - {} models available", models_response.data.len()); + } + #[tokio::test] async fn test_reproduce_tensor_shape_mismatch() { // Create a test app state with Gemma 3 model (same as the failing request)