mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
Add support for listing available models via CLI and HTTP endpoint
This commit is contained in:
46
cli.ts
46
cli.ts
@@ -15,12 +15,14 @@ Simple CLI tool for testing the local OpenAI-compatible API server.
|
||||
Options:
|
||||
--model <model> Model to use (default: ${DEFAULT_MODEL})
|
||||
--prompt <prompt> The prompt to send (can also be provided as positional argument)
|
||||
--list-models List all available models from the server
|
||||
--help Show this help message
|
||||
|
||||
Examples:
|
||||
./cli.ts "What is the capital of France?"
|
||||
./cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||
./cli.ts --prompt "Who was the 16th president of the United States?"
|
||||
./cli.ts --list-models
|
||||
|
||||
The server should be running at http://localhost:8080
|
||||
Start it with: ./run_server.sh
|
||||
@@ -39,6 +41,9 @@ const { values, positionals } = parseArgs({
|
||||
help: {
|
||||
type: 'boolean',
|
||||
},
|
||||
'list-models': {
|
||||
type: 'boolean',
|
||||
},
|
||||
},
|
||||
strict: false,
|
||||
allowPositionals: true,
|
||||
@@ -67,6 +72,36 @@ async function requestLocalOpenAI(model: string, userPrompt: string) {
|
||||
}
|
||||
}
|
||||
|
||||
async function listModels() {
|
||||
const openai = new OpenAI({
|
||||
baseURL: "http://localhost:8080/v1",
|
||||
apiKey: "not used",
|
||||
});
|
||||
try {
|
||||
const models = await openai.models.list();
|
||||
console.log(`[INFO] Available models from http://localhost:8080/v1:`);
|
||||
console.log("---");
|
||||
|
||||
if (models.data && models.data.length > 0) {
|
||||
models.data.forEach((model, index) => {
|
||||
console.log(`${index + 1}. ${model.id}`);
|
||||
console.log(` Owner: ${model.owned_by}`);
|
||||
console.log(` Created: ${new Date(model.created * 1000).toISOString()}`);
|
||||
console.log("");
|
||||
});
|
||||
console.log(`Total: ${models.data.length} models available`);
|
||||
} else {
|
||||
console.log("No models found.");
|
||||
}
|
||||
|
||||
} catch (e) {
|
||||
console.error("[ERROR] Failed to fetch models from local OpenAI server:", e.message);
|
||||
console.error("[HINT] Make sure the server is running at http://localhost:8080");
|
||||
console.error("[HINT] Start it with: ./run_server.sh");
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
async function main() {
|
||||
// Show help if requested
|
||||
if (values.help) {
|
||||
@@ -74,6 +109,17 @@ async function main() {
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
// List models if requested
|
||||
if (values['list-models']) {
|
||||
try {
|
||||
await listModels();
|
||||
process.exit(0);
|
||||
} catch (error) {
|
||||
console.error("\n[ERROR] Failed to list models:", error.message);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Get the prompt from either --prompt flag or positional argument
|
||||
const prompt = values.prompt || positionals[2]; // positionals[0] is 'bun', positionals[1] is 'client_cli.ts'
|
||||
|
||||
|
@@ -192,3 +192,25 @@ pub struct Usage {
|
||||
pub completion_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
/// Model object representing an available model
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct Model {
|
||||
/// The model identifier
|
||||
pub id: String,
|
||||
/// The object type, always "model"
|
||||
pub object: String,
|
||||
/// Unix timestamp of when the model was created
|
||||
pub created: u64,
|
||||
/// The organization that owns the model
|
||||
pub owned_by: String,
|
||||
}
|
||||
|
||||
/// Response for listing available models
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct ModelListResponse {
|
||||
/// The object type, always "list"
|
||||
pub object: String,
|
||||
/// Array of available models
|
||||
pub data: Vec<Model>,
|
||||
}
|
@@ -2,7 +2,7 @@ use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{sse::Event, sse::Sse, IntoResponse},
|
||||
routing::post,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use futures_util::stream::{self, Stream};
|
||||
@@ -16,9 +16,9 @@ use tokio::time;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Usage};
|
||||
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
|
||||
use crate::text_generation::TextGeneration;
|
||||
use crate::{utilities_lib, Model, Which};
|
||||
use crate::{utilities_lib, Model as GemmaModel, Which};
|
||||
use either::Either;
|
||||
use hf_hub::api::sync::{Api, ApiError};
|
||||
use hf_hub::{Repo, RepoType};
|
||||
@@ -283,17 +283,17 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
||||
| Which::CodeInstruct7B => {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
||||
let model = Model1::new(args.use_flash_attn, &config, vb).unwrap();
|
||||
Model::V1(model)
|
||||
GemmaModel::V1(model)
|
||||
}
|
||||
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
||||
let model = Model2::new(args.use_flash_attn, &config, vb).unwrap();
|
||||
Model::V2(model)
|
||||
GemmaModel::V2(model)
|
||||
}
|
||||
Which::BaseV3_1B | Which::InstructV3_1B => {
|
||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_path).unwrap()).unwrap();
|
||||
let model = Model3::new(args.use_flash_attn, &config, vb).unwrap();
|
||||
Model::V3(model)
|
||||
GemmaModel::V3(model)
|
||||
}
|
||||
};
|
||||
|
||||
@@ -580,6 +580,114 @@ async fn handle_streaming_request(
|
||||
Ok(Sse::new(stream))
|
||||
}
|
||||
|
||||
/// Handler for GET /v1/models - returns list of available models
|
||||
async fn list_models() -> Json<ModelListResponse> {
|
||||
// Get all available model variants from the Which enum
|
||||
let models = vec![
|
||||
Model {
|
||||
id: "gemma-2b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002, // Using same timestamp as OpenAI example
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-7b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-2b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-7b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-1.1-2b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-1.1-7b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "codegemma-2b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "codegemma-7b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "codegemma-2b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "codegemma-7b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-2-2b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-2-2b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-2-9b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-2-9b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-3-1b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-3-1b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
Json(ModelListResponse {
|
||||
object: "list".to_string(),
|
||||
data: models,
|
||||
})
|
||||
}
|
||||
|
||||
// -------------------------
|
||||
// Router
|
||||
// -------------------------
|
||||
@@ -593,6 +701,7 @@ pub fn create_router(app_state: AppState) -> Router {
|
||||
|
||||
Router::new()
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/v1/models", get(list_models))
|
||||
// .route("/v1/chat/completions/stream", post(chat_completions_stream))
|
||||
.layer(cors)
|
||||
.with_state(app_state)
|
||||
@@ -604,6 +713,35 @@ mod tests {
|
||||
use crate::openai_types::{Message, MessageContent};
|
||||
use either::Either;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_models_list_endpoint() {
|
||||
println!("[DEBUG_LOG] Testing models list endpoint");
|
||||
|
||||
let response = list_models().await;
|
||||
let models_response = response.0;
|
||||
|
||||
// Verify response structure
|
||||
assert_eq!(models_response.object, "list");
|
||||
assert_eq!(models_response.data.len(), 16);
|
||||
|
||||
// Verify some key models are present
|
||||
let model_ids: Vec<String> = models_response.data.iter().map(|m| m.id.clone()).collect();
|
||||
assert!(model_ids.contains(&"gemma-2b".to_string()));
|
||||
assert!(model_ids.contains(&"gemma-7b".to_string()));
|
||||
assert!(model_ids.contains(&"gemma-3-1b-it".to_string()));
|
||||
assert!(model_ids.contains(&"codegemma-2b-it".to_string()));
|
||||
|
||||
// Verify model structure
|
||||
for model in &models_response.data {
|
||||
assert_eq!(model.object, "model");
|
||||
assert_eq!(model.owned_by, "google");
|
||||
assert_eq!(model.created, 1686935002);
|
||||
assert!(!model.id.is_empty());
|
||||
}
|
||||
|
||||
println!("[DEBUG_LOG] Models list endpoint test passed - {} models available", models_response.data.len());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reproduce_tensor_shape_mismatch() {
|
||||
// Create a test app state with Gemma 3 model (same as the failing request)
|
||||
|
Reference in New Issue
Block a user