Add support for listing available models via CLI and HTTP endpoint

This commit is contained in:
geoffsee
2025-08-27 16:35:08 -04:00
parent 432c04d9df
commit 9e28e259ad
3 changed files with 212 additions and 6 deletions

46
cli.ts
View File

@@ -15,12 +15,14 @@ Simple CLI tool for testing the local OpenAI-compatible API server.
Options:
--model <model> Model to use (default: ${DEFAULT_MODEL})
--prompt <prompt> The prompt to send (can also be provided as positional argument)
--list-models List all available models from the server
--help Show this help message
Examples:
./cli.ts "What is the capital of France?"
./cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
./cli.ts --prompt "Who was the 16th president of the United States?"
./cli.ts --list-models
The server should be running at http://localhost:8080
Start it with: ./run_server.sh
@@ -39,6 +41,9 @@ const { values, positionals } = parseArgs({
help: {
type: 'boolean',
},
'list-models': {
type: 'boolean',
},
},
strict: false,
allowPositionals: true,
@@ -67,6 +72,36 @@ async function requestLocalOpenAI(model: string, userPrompt: string) {
}
}
async function listModels() {
const openai = new OpenAI({
baseURL: "http://localhost:8080/v1",
apiKey: "not used",
});
try {
const models = await openai.models.list();
console.log(`[INFO] Available models from http://localhost:8080/v1:`);
console.log("---");
if (models.data && models.data.length > 0) {
models.data.forEach((model, index) => {
console.log(`${index + 1}. ${model.id}`);
console.log(` Owner: ${model.owned_by}`);
console.log(` Created: ${new Date(model.created * 1000).toISOString()}`);
console.log("");
});
console.log(`Total: ${models.data.length} models available`);
} else {
console.log("No models found.");
}
} catch (e) {
console.error("[ERROR] Failed to fetch models from local OpenAI server:", e.message);
console.error("[HINT] Make sure the server is running at http://localhost:8080");
console.error("[HINT] Start it with: ./run_server.sh");
throw e;
}
}
async function main() {
// Show help if requested
if (values.help) {
@@ -74,6 +109,17 @@ async function main() {
process.exit(0);
}
// List models if requested
if (values['list-models']) {
try {
await listModels();
process.exit(0);
} catch (error) {
console.error("\n[ERROR] Failed to list models:", error.message);
process.exit(1);
}
}
// Get the prompt from either --prompt flag or positional argument
const prompt = values.prompt || positionals[2]; // positionals[0] is 'bun', positionals[1] is 'client_cli.ts'

View File

@@ -191,4 +191,26 @@ pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
/// Model object representing an available model
#[derive(Debug, Serialize, ToSchema)]
pub struct Model {
/// The model identifier
pub id: String,
/// The object type, always "model"
pub object: String,
/// Unix timestamp of when the model was created
pub created: u64,
/// The organization that owns the model
pub owned_by: String,
}
/// Response for listing available models
#[derive(Debug, Serialize, ToSchema)]
pub struct ModelListResponse {
/// The object type, always "list"
pub object: String,
/// Array of available models
pub data: Vec<Model>,
}

View File

@@ -2,7 +2,7 @@ use axum::{
extract::State,
http::StatusCode,
response::{sse::Event, sse::Sse, IntoResponse},
routing::post,
routing::{get, post},
Json, Router,
};
use futures_util::stream::{self, Stream};
@@ -16,9 +16,9 @@ use tokio::time;
use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid;
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Usage};
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
use crate::text_generation::TextGeneration;
use crate::{utilities_lib, Model, Which};
use crate::{utilities_lib, Model as GemmaModel, Which};
use either::Either;
use hf_hub::api::sync::{Api, ApiError};
use hf_hub::{Repo, RepoType};
@@ -283,17 +283,17 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
| Which::CodeInstruct7B => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
let model = Model1::new(args.use_flash_attn, &config, vb).unwrap();
Model::V1(model)
GemmaModel::V1(model)
}
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
let model = Model2::new(args.use_flash_attn, &config, vb).unwrap();
Model::V2(model)
GemmaModel::V2(model)
}
Which::BaseV3_1B | Which::InstructV3_1B => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_path).unwrap()).unwrap();
let model = Model3::new(args.use_flash_attn, &config, vb).unwrap();
Model::V3(model)
GemmaModel::V3(model)
}
};
@@ -580,6 +580,114 @@ async fn handle_streaming_request(
Ok(Sse::new(stream))
}
/// Handler for GET /v1/models - returns list of available models
async fn list_models() -> Json<ModelListResponse> {
// Get all available model variants from the Which enum
let models = vec![
Model {
id: "gemma-2b".to_string(),
object: "model".to_string(),
created: 1686935002, // Using same timestamp as OpenAI example
owned_by: "google".to_string(),
},
Model {
id: "gemma-7b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-7b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-1.1-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-1.1-7b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-2b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-7b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-7b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-2b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-9b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-9b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-3-1b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-3-1b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
];
Json(ModelListResponse {
object: "list".to_string(),
data: models,
})
}
// -------------------------
// Router
// -------------------------
@@ -593,6 +701,7 @@ pub fn create_router(app_state: AppState) -> Router {
Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/models", get(list_models))
// .route("/v1/chat/completions/stream", post(chat_completions_stream))
.layer(cors)
.with_state(app_state)
@@ -604,6 +713,35 @@ mod tests {
use crate::openai_types::{Message, MessageContent};
use either::Either;
#[tokio::test]
async fn test_models_list_endpoint() {
println!("[DEBUG_LOG] Testing models list endpoint");
let response = list_models().await;
let models_response = response.0;
// Verify response structure
assert_eq!(models_response.object, "list");
assert_eq!(models_response.data.len(), 16);
// Verify some key models are present
let model_ids: Vec<String> = models_response.data.iter().map(|m| m.id.clone()).collect();
assert!(model_ids.contains(&"gemma-2b".to_string()));
assert!(model_ids.contains(&"gemma-7b".to_string()));
assert!(model_ids.contains(&"gemma-3-1b-it".to_string()));
assert!(model_ids.contains(&"codegemma-2b-it".to_string()));
// Verify model structure
for model in &models_response.data {
assert_eq!(model.object, "model");
assert_eq!(model.owned_by, "google");
assert_eq!(model.created, 1686935002);
assert!(!model.id.is_empty());
}
println!("[DEBUG_LOG] Models list endpoint test passed - {} models available", models_response.data.len());
}
#[tokio::test]
async fn test_reproduce_tensor_shape_mismatch() {
// Create a test app state with Gemma 3 model (same as the failing request)