mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
Add support for listing available models via CLI and HTTP endpoint
This commit is contained in:
46
cli.ts
46
cli.ts
@@ -15,12 +15,14 @@ Simple CLI tool for testing the local OpenAI-compatible API server.
|
|||||||
Options:
|
Options:
|
||||||
--model <model> Model to use (default: ${DEFAULT_MODEL})
|
--model <model> Model to use (default: ${DEFAULT_MODEL})
|
||||||
--prompt <prompt> The prompt to send (can also be provided as positional argument)
|
--prompt <prompt> The prompt to send (can also be provided as positional argument)
|
||||||
|
--list-models List all available models from the server
|
||||||
--help Show this help message
|
--help Show this help message
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
./cli.ts "What is the capital of France?"
|
./cli.ts "What is the capital of France?"
|
||||||
./cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
./cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||||
./cli.ts --prompt "Who was the 16th president of the United States?"
|
./cli.ts --prompt "Who was the 16th president of the United States?"
|
||||||
|
./cli.ts --list-models
|
||||||
|
|
||||||
The server should be running at http://localhost:8080
|
The server should be running at http://localhost:8080
|
||||||
Start it with: ./run_server.sh
|
Start it with: ./run_server.sh
|
||||||
@@ -39,6 +41,9 @@ const { values, positionals } = parseArgs({
|
|||||||
help: {
|
help: {
|
||||||
type: 'boolean',
|
type: 'boolean',
|
||||||
},
|
},
|
||||||
|
'list-models': {
|
||||||
|
type: 'boolean',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
strict: false,
|
strict: false,
|
||||||
allowPositionals: true,
|
allowPositionals: true,
|
||||||
@@ -67,6 +72,36 @@ async function requestLocalOpenAI(model: string, userPrompt: string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function listModels() {
|
||||||
|
const openai = new OpenAI({
|
||||||
|
baseURL: "http://localhost:8080/v1",
|
||||||
|
apiKey: "not used",
|
||||||
|
});
|
||||||
|
try {
|
||||||
|
const models = await openai.models.list();
|
||||||
|
console.log(`[INFO] Available models from http://localhost:8080/v1:`);
|
||||||
|
console.log("---");
|
||||||
|
|
||||||
|
if (models.data && models.data.length > 0) {
|
||||||
|
models.data.forEach((model, index) => {
|
||||||
|
console.log(`${index + 1}. ${model.id}`);
|
||||||
|
console.log(` Owner: ${model.owned_by}`);
|
||||||
|
console.log(` Created: ${new Date(model.created * 1000).toISOString()}`);
|
||||||
|
console.log("");
|
||||||
|
});
|
||||||
|
console.log(`Total: ${models.data.length} models available`);
|
||||||
|
} else {
|
||||||
|
console.log("No models found.");
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (e) {
|
||||||
|
console.error("[ERROR] Failed to fetch models from local OpenAI server:", e.message);
|
||||||
|
console.error("[HINT] Make sure the server is running at http://localhost:8080");
|
||||||
|
console.error("[HINT] Start it with: ./run_server.sh");
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async function main() {
|
async function main() {
|
||||||
// Show help if requested
|
// Show help if requested
|
||||||
if (values.help) {
|
if (values.help) {
|
||||||
@@ -74,6 +109,17 @@ async function main() {
|
|||||||
process.exit(0);
|
process.exit(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// List models if requested
|
||||||
|
if (values['list-models']) {
|
||||||
|
try {
|
||||||
|
await listModels();
|
||||||
|
process.exit(0);
|
||||||
|
} catch (error) {
|
||||||
|
console.error("\n[ERROR] Failed to list models:", error.message);
|
||||||
|
process.exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Get the prompt from either --prompt flag or positional argument
|
// Get the prompt from either --prompt flag or positional argument
|
||||||
const prompt = values.prompt || positionals[2]; // positionals[0] is 'bun', positionals[1] is 'client_cli.ts'
|
const prompt = values.prompt || positionals[2]; // positionals[0] is 'bun', positionals[1] is 'client_cli.ts'
|
||||||
|
|
||||||
|
@@ -191,4 +191,26 @@ pub struct Usage {
|
|||||||
pub prompt_tokens: usize,
|
pub prompt_tokens: usize,
|
||||||
pub completion_tokens: usize,
|
pub completion_tokens: usize,
|
||||||
pub total_tokens: usize,
|
pub total_tokens: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Model object representing an available model
|
||||||
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
|
pub struct Model {
|
||||||
|
/// The model identifier
|
||||||
|
pub id: String,
|
||||||
|
/// The object type, always "model"
|
||||||
|
pub object: String,
|
||||||
|
/// Unix timestamp of when the model was created
|
||||||
|
pub created: u64,
|
||||||
|
/// The organization that owns the model
|
||||||
|
pub owned_by: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Response for listing available models
|
||||||
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
|
pub struct ModelListResponse {
|
||||||
|
/// The object type, always "list"
|
||||||
|
pub object: String,
|
||||||
|
/// Array of available models
|
||||||
|
pub data: Vec<Model>,
|
||||||
}
|
}
|
@@ -2,7 +2,7 @@ use axum::{
|
|||||||
extract::State,
|
extract::State,
|
||||||
http::StatusCode,
|
http::StatusCode,
|
||||||
response::{sse::Event, sse::Sse, IntoResponse},
|
response::{sse::Event, sse::Sse, IntoResponse},
|
||||||
routing::post,
|
routing::{get, post},
|
||||||
Json, Router,
|
Json, Router,
|
||||||
};
|
};
|
||||||
use futures_util::stream::{self, Stream};
|
use futures_util::stream::{self, Stream};
|
||||||
@@ -16,9 +16,9 @@ use tokio::time;
|
|||||||
use tower_http::cors::{Any, CorsLayer};
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Usage};
|
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
|
||||||
use crate::text_generation::TextGeneration;
|
use crate::text_generation::TextGeneration;
|
||||||
use crate::{utilities_lib, Model, Which};
|
use crate::{utilities_lib, Model as GemmaModel, Which};
|
||||||
use either::Either;
|
use either::Either;
|
||||||
use hf_hub::api::sync::{Api, ApiError};
|
use hf_hub::api::sync::{Api, ApiError};
|
||||||
use hf_hub::{Repo, RepoType};
|
use hf_hub::{Repo, RepoType};
|
||||||
@@ -283,17 +283,17 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
|||||||
| Which::CodeInstruct7B => {
|
| Which::CodeInstruct7B => {
|
||||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
||||||
let model = Model1::new(args.use_flash_attn, &config, vb).unwrap();
|
let model = Model1::new(args.use_flash_attn, &config, vb).unwrap();
|
||||||
Model::V1(model)
|
GemmaModel::V1(model)
|
||||||
}
|
}
|
||||||
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
|
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
|
||||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
||||||
let model = Model2::new(args.use_flash_attn, &config, vb).unwrap();
|
let model = Model2::new(args.use_flash_attn, &config, vb).unwrap();
|
||||||
Model::V2(model)
|
GemmaModel::V2(model)
|
||||||
}
|
}
|
||||||
Which::BaseV3_1B | Which::InstructV3_1B => {
|
Which::BaseV3_1B | Which::InstructV3_1B => {
|
||||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_path).unwrap()).unwrap();
|
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_path).unwrap()).unwrap();
|
||||||
let model = Model3::new(args.use_flash_attn, &config, vb).unwrap();
|
let model = Model3::new(args.use_flash_attn, &config, vb).unwrap();
|
||||||
Model::V3(model)
|
GemmaModel::V3(model)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -580,6 +580,114 @@ async fn handle_streaming_request(
|
|||||||
Ok(Sse::new(stream))
|
Ok(Sse::new(stream))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Handler for GET /v1/models - returns list of available models
|
||||||
|
async fn list_models() -> Json<ModelListResponse> {
|
||||||
|
// Get all available model variants from the Which enum
|
||||||
|
let models = vec![
|
||||||
|
Model {
|
||||||
|
id: "gemma-2b".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002, // Using same timestamp as OpenAI example
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "gemma-7b".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "gemma-2b-it".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "gemma-7b-it".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "gemma-1.1-2b-it".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "gemma-1.1-7b-it".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "codegemma-2b".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "codegemma-7b".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "codegemma-2b-it".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "codegemma-7b-it".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "gemma-2-2b".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "gemma-2-2b-it".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "gemma-2-9b".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "gemma-2-9b-it".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "gemma-3-1b".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
Model {
|
||||||
|
id: "gemma-3-1b-it".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: "google".to_string(),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
Json(ModelListResponse {
|
||||||
|
object: "list".to_string(),
|
||||||
|
data: models,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// -------------------------
|
// -------------------------
|
||||||
// Router
|
// Router
|
||||||
// -------------------------
|
// -------------------------
|
||||||
@@ -593,6 +701,7 @@ pub fn create_router(app_state: AppState) -> Router {
|
|||||||
|
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
|
.route("/v1/models", get(list_models))
|
||||||
// .route("/v1/chat/completions/stream", post(chat_completions_stream))
|
// .route("/v1/chat/completions/stream", post(chat_completions_stream))
|
||||||
.layer(cors)
|
.layer(cors)
|
||||||
.with_state(app_state)
|
.with_state(app_state)
|
||||||
@@ -604,6 +713,35 @@ mod tests {
|
|||||||
use crate::openai_types::{Message, MessageContent};
|
use crate::openai_types::{Message, MessageContent};
|
||||||
use either::Either;
|
use either::Either;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_models_list_endpoint() {
|
||||||
|
println!("[DEBUG_LOG] Testing models list endpoint");
|
||||||
|
|
||||||
|
let response = list_models().await;
|
||||||
|
let models_response = response.0;
|
||||||
|
|
||||||
|
// Verify response structure
|
||||||
|
assert_eq!(models_response.object, "list");
|
||||||
|
assert_eq!(models_response.data.len(), 16);
|
||||||
|
|
||||||
|
// Verify some key models are present
|
||||||
|
let model_ids: Vec<String> = models_response.data.iter().map(|m| m.id.clone()).collect();
|
||||||
|
assert!(model_ids.contains(&"gemma-2b".to_string()));
|
||||||
|
assert!(model_ids.contains(&"gemma-7b".to_string()));
|
||||||
|
assert!(model_ids.contains(&"gemma-3-1b-it".to_string()));
|
||||||
|
assert!(model_ids.contains(&"codegemma-2b-it".to_string()));
|
||||||
|
|
||||||
|
// Verify model structure
|
||||||
|
for model in &models_response.data {
|
||||||
|
assert_eq!(model.object, "model");
|
||||||
|
assert_eq!(model.owned_by, "google");
|
||||||
|
assert_eq!(model.created, 1686935002);
|
||||||
|
assert!(!model.id.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("[DEBUG_LOG] Models list endpoint test passed - {} models available", models_response.data.len());
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_reproduce_tensor_shape_mismatch() {
|
async fn test_reproduce_tensor_shape_mismatch() {
|
||||||
// Create a test app state with Gemma 3 model (same as the failing request)
|
// Create a test app state with Gemma 3 model (same as the failing request)
|
||||||
|
Reference in New Issue
Block a user