From c1c583faab806e5ddcbd5896f35d99c85071ac5e Mon Sep 17 00:00:00 2001 From: geoffsee <> Date: Thu, 4 Sep 2025 13:45:25 -0400 Subject: [PATCH] run cargo fmt --- .github/workflows/ci.yml | 2 +- crates/chat-ui/src/app.rs | 51 +++---- crates/embeddings-engine/src/lib.rs | 161 ++++++++++++++-------- crates/embeddings-engine/src/main.rs | 10 +- crates/inference-engine/src/model.rs | 6 +- crates/inference-engine/src/server.rs | 150 ++++++++++---------- integration/gemma-runner/src/gemma_api.rs | 12 +- integration/helm-chart-tool/src/main.rs | 4 +- integration/llama-runner/src/llama_api.rs | 2 +- integration/utils/src/imagenet.rs | 2 +- integration/utils/src/lib.rs | 11 +- 11 files changed, 241 insertions(+), 170 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 80b00be..2aa4f62 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -44,7 +44,7 @@ jobs: - name: Clippy shell: bash - run: cargo clippy --all-targets + run: cargo clippy --all - name: Tests shell: bash diff --git a/crates/chat-ui/src/app.rs b/crates/chat-ui/src/app.rs index 0b1ad11..8d8b2bf 100644 --- a/crates/chat-ui/src/app.rs +++ b/crates/chat-ui/src/app.rs @@ -194,57 +194,57 @@ pub fn send_chat_completion_stream( ) { use wasm_bindgen::prelude::*; use wasm_bindgen::JsCast; - + let request = ChatRequest { model, messages, max_tokens: Some(1024), stream: Some(true), }; - + // We need to send a POST request but EventSource only supports GET // So we'll use fetch with a readable stream instead let window = web_sys::window().unwrap(); let request_json = serde_json::to_string(&request).unwrap(); - + let opts = web_sys::RequestInit::new(); opts.set_method("POST"); opts.set_body(&JsValue::from_str(&request_json)); - + let headers = web_sys::Headers::new().unwrap(); headers.set("Content-Type", "application/json").unwrap(); headers.set("Accept", "text/event-stream").unwrap(); opts.set_headers(&headers); - + let request = web_sys::Request::new_with_str_and_init("/v1/chat/completions", &opts).unwrap(); - + let promise = window.fetch_with_request(&request); - + wasm_bindgen_futures::spawn_local(async move { match wasm_bindgen_futures::JsFuture::from(promise).await { Ok(resp_value) => { let resp: web_sys::Response = resp_value.dyn_into().unwrap(); - + if !resp.ok() { on_error(format!("Server error: {}", resp.status())); return; } - + let body = resp.body(); if body.is_none() { on_error("No response body".to_string()); return; } - + let reader = body .unwrap() .get_reader() .dyn_into::() .unwrap(); - + let decoder = web_sys::TextDecoder::new().unwrap(); let mut buffer = String::new(); - + loop { match wasm_bindgen_futures::JsFuture::from(reader.read()).await { Ok(result) => { @@ -252,24 +252,25 @@ pub fn send_chat_completion_stream( .unwrap() .as_bool() .unwrap_or(false); - + if done { break; } - - let value = js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap(); + + let value = + js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap(); let array = js_sys::Uint8Array::new(&value); let mut bytes = vec![0; array.length() as usize]; array.copy_to(&mut bytes); let text = decoder.decode_with_u8_array(&bytes).unwrap(); - + buffer.push_str(&text); - + // Process complete SSE events from buffer while let Some(event_end) = buffer.find("\n\n") { let event = buffer[..event_end].to_string(); buffer = buffer[event_end + 2..].to_string(); - + // Parse SSE event for line in event.lines() { if let Some(data) = line.strip_prefix("data: ") { @@ -277,9 +278,11 @@ pub fn send_chat_completion_stream( on_complete(); return; } - + // Parse JSON chunk - if let Ok(chunk) = serde_json::from_str::(data) { + if let Ok(chunk) = + serde_json::from_str::(data) + { if let Some(choice) = chunk.choices.first() { if let Some(content) = &choice.delta.content { on_chunk(content.clone()); @@ -296,7 +299,7 @@ pub fn send_chat_completion_stream( } } } - + on_complete(); } Err(e) => { @@ -366,11 +369,11 @@ fn ChatPage() -> impl IntoView { // State for available models and selected model let available_models = RwSignal::new(Vec::::new()); let selected_model = RwSignal::new(String::from("")); // Default model - + // State for streaming response let streaming_content = RwSignal::new(String::new()); let is_streaming = RwSignal::new(false); - + // State for streaming mode toggle let use_streaming = RwSignal::new(true); // Default to streaming @@ -424,7 +427,7 @@ fn ChatPage() -> impl IntoView { // Clear streaming content and set streaming flag streaming_content.set(String::new()); is_streaming.set(true); - + // Use streaming API send_chat_completion_stream( current_messages, diff --git a/crates/embeddings-engine/src/lib.rs b/crates/embeddings-engine/src/lib.rs index 2433214..4fe5f9a 100644 --- a/crates/embeddings-engine/src/lib.rs +++ b/crates/embeddings-engine/src/lib.rs @@ -1,5 +1,10 @@ use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput}; -use axum::{Json, Router, response::Json as ResponseJson, routing::{get, post}, http::StatusCode}; +use axum::{ + Json, Router, + http::StatusCode, + response::Json as ResponseJson, + routing::{get, post}, +}; use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; use once_cell::sync::Lazy; use serde::Serialize; @@ -9,9 +14,8 @@ use tower_http::trace::TraceLayer; use tracing; // Cache for multiple embedding models -static MODEL_CACHE: Lazy>>> = Lazy::new(|| { - RwLock::new(HashMap::new()) -}); +static MODEL_CACHE: Lazy>>> = + Lazy::new(|| RwLock::new(HashMap::new())); #[derive(Serialize)] pub struct ModelInfo { @@ -32,11 +36,19 @@ pub struct ModelsResponse { fn parse_embedding_model(model_name: &str) -> Result { match model_name { // Sentence Transformers models - "sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => Ok(EmbeddingModel::AllMiniLML6V2), - "sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => Ok(EmbeddingModel::AllMiniLML6V2Q), - "sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => Ok(EmbeddingModel::AllMiniLML12V2), - "sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => Ok(EmbeddingModel::AllMiniLML12V2Q), - + "sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => { + Ok(EmbeddingModel::AllMiniLML6V2) + } + "sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => { + Ok(EmbeddingModel::AllMiniLML6V2Q) + } + "sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => { + Ok(EmbeddingModel::AllMiniLML12V2) + } + "sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => { + Ok(EmbeddingModel::AllMiniLML12V2Q) + } + // BGE models "BAAI/bge-base-en-v1.5" | "bge-base-en-v1.5" => Ok(EmbeddingModel::BGEBaseENV15), "BAAI/bge-base-en-v1.5-q" | "bge-base-en-v1.5-q" => Ok(EmbeddingModel::BGEBaseENV15Q), @@ -46,41 +58,68 @@ fn parse_embedding_model(model_name: &str) -> Result { "BAAI/bge-small-en-v1.5-q" | "bge-small-en-v1.5-q" => Ok(EmbeddingModel::BGESmallENV15Q), "BAAI/bge-small-zh-v1.5" | "bge-small-zh-v1.5" => Ok(EmbeddingModel::BGESmallZHV15), "BAAI/bge-large-zh-v1.5" | "bge-large-zh-v1.5" => Ok(EmbeddingModel::BGELargeZHV15), - + // Nomic models - "nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => Ok(EmbeddingModel::NomicEmbedTextV1), - "nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => Ok(EmbeddingModel::NomicEmbedTextV15), - "nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => Ok(EmbeddingModel::NomicEmbedTextV15Q), - + "nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => { + Ok(EmbeddingModel::NomicEmbedTextV1) + } + "nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => { + Ok(EmbeddingModel::NomicEmbedTextV15) + } + "nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => { + Ok(EmbeddingModel::NomicEmbedTextV15Q) + } + // Paraphrase models - "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2), - "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q" | "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q), - "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" | "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2), - + "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" + | "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2), + "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q" + | "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q), + "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + | "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2), + // ModernBert - "lightonai/modernbert-embed-large" | "modernbert-embed-large" => Ok(EmbeddingModel::ModernBertEmbedLarge), - + "lightonai/modernbert-embed-large" | "modernbert-embed-large" => { + Ok(EmbeddingModel::ModernBertEmbedLarge) + } + // Multilingual E5 models - "intfloat/multilingual-e5-small" | "multilingual-e5-small" => Ok(EmbeddingModel::MultilingualE5Small), - "intfloat/multilingual-e5-base" | "multilingual-e5-base" => Ok(EmbeddingModel::MultilingualE5Base), - "intfloat/multilingual-e5-large" | "multilingual-e5-large" => Ok(EmbeddingModel::MultilingualE5Large), - + "intfloat/multilingual-e5-small" | "multilingual-e5-small" => { + Ok(EmbeddingModel::MultilingualE5Small) + } + "intfloat/multilingual-e5-base" | "multilingual-e5-base" => { + Ok(EmbeddingModel::MultilingualE5Base) + } + "intfloat/multilingual-e5-large" | "multilingual-e5-large" => { + Ok(EmbeddingModel::MultilingualE5Large) + } + // Mixedbread models - "mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => Ok(EmbeddingModel::MxbaiEmbedLargeV1), - "mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => Ok(EmbeddingModel::MxbaiEmbedLargeV1Q), - + "mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => { + Ok(EmbeddingModel::MxbaiEmbedLargeV1) + } + "mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => { + Ok(EmbeddingModel::MxbaiEmbedLargeV1Q) + } + // GTE models "Alibaba-NLP/gte-base-en-v1.5" | "gte-base-en-v1.5" => Ok(EmbeddingModel::GTEBaseENV15), - "Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => Ok(EmbeddingModel::GTEBaseENV15Q), + "Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => { + Ok(EmbeddingModel::GTEBaseENV15Q) + } "Alibaba-NLP/gte-large-en-v1.5" | "gte-large-en-v1.5" => Ok(EmbeddingModel::GTELargeENV15), - "Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => Ok(EmbeddingModel::GTELargeENV15Q), - + "Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => { + Ok(EmbeddingModel::GTELargeENV15Q) + } + // CLIP model "Qdrant/clip-ViT-B-32-text" | "clip-vit-b-32" => Ok(EmbeddingModel::ClipVitB32), - + // Jina model - "jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode), - + "jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => { + Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode) + } + _ => Err(format!("Unsupported embedding model: {}", model_name)), } } @@ -95,7 +134,9 @@ fn get_model_dimensions(model: &EmbeddingModel) -> usize { EmbeddingModel::BGESmallENV15 | EmbeddingModel::BGESmallENV15Q => 384, EmbeddingModel::BGESmallZHV15 => 512, EmbeddingModel::BGELargeZHV15 => 1024, - EmbeddingModel::NomicEmbedTextV1 | EmbeddingModel::NomicEmbedTextV15 | EmbeddingModel::NomicEmbedTextV15Q => 768, + EmbeddingModel::NomicEmbedTextV1 + | EmbeddingModel::NomicEmbedTextV15 + | EmbeddingModel::NomicEmbedTextV15Q => 768, EmbeddingModel::ParaphraseMLMiniLML12V2 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => 384, EmbeddingModel::ParaphraseMLMpnetBaseV2 => 768, EmbeddingModel::ModernBertEmbedLarge => 1024, @@ -114,37 +155,41 @@ fn get_model_dimensions(model: &EmbeddingModel) -> usize { fn get_or_create_model(embedding_model: EmbeddingModel) -> Result, String> { // First try to get from cache (read lock) { - let cache = MODEL_CACHE.read().map_err(|e| format!("Failed to acquire read lock: {}", e))?; + let cache = MODEL_CACHE + .read() + .map_err(|e| format!("Failed to acquire read lock: {}", e))?; if let Some(model) = cache.get(&embedding_model) { tracing::debug!("Using cached model: {:?}", embedding_model); return Ok(Arc::clone(model)); } } - + // Model not in cache, create it (write lock) - let mut cache = MODEL_CACHE.write().map_err(|e| format!("Failed to acquire write lock: {}", e))?; - + let mut cache = MODEL_CACHE + .write() + .map_err(|e| format!("Failed to acquire write lock: {}", e))?; + // Double-check after acquiring write lock if let Some(model) = cache.get(&embedding_model) { tracing::debug!("Using cached model (double-check): {:?}", embedding_model); return Ok(Arc::clone(model)); } - + tracing::info!("Initializing new embedding model: {:?}", embedding_model); let model_start_time = std::time::Instant::now(); - + let model = TextEmbedding::try_new( InitOptions::new(embedding_model.clone()).with_show_download_progress(true), ) .map_err(|e| format!("Failed to initialize model {:?}: {}", embedding_model, e))?; - + let model_init_time = model_start_time.elapsed(); tracing::info!( "Embedding model {:?} initialized in {:.2?}", embedding_model, model_init_time ); - + let model_arc = Arc::new(model); cache.insert(embedding_model.clone(), Arc::clone(&model_arc)); Ok(model_arc) @@ -158,7 +203,7 @@ pub async fn embeddings_create( // Phase 1: Parse and get the embedding model let model_start_time = std::time::Instant::now(); - + let embedding_model = match parse_embedding_model(&payload.model) { Ok(model) => model, Err(e) => { @@ -166,15 +211,18 @@ pub async fn embeddings_create( return Err((StatusCode::BAD_REQUEST, format!("Invalid model: {}", e))); } }; - + let model = match get_or_create_model(embedding_model.clone()) { Ok(model) => model, Err(e) => { tracing::error!("Failed to get/create model: {}", e); - return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("Model initialization failed: {}", e))); + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Model initialization failed: {}", e), + )); } }; - + let model_access_time = model_start_time.elapsed(); tracing::debug!( "Model access/creation completed in {:.2?}", @@ -205,12 +253,13 @@ pub async fn embeddings_create( // Phase 3: Generate embeddings let embedding_start_time = std::time::Instant::now(); - let embeddings = model - .embed(texts_from_embedding_input, None) - .map_err(|e| { - tracing::error!("Failed to generate embeddings: {}", e); - (StatusCode::INTERNAL_SERVER_ERROR, format!("Embedding generation failed: {}", e)) - })?; + let embeddings = model.embed(texts_from_embedding_input, None).map_err(|e| { + tracing::error!("Failed to generate embeddings: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Embedding generation failed: {}", e), + ) + })?; let embedding_generation_time = embedding_start_time.elapsed(); tracing::info!( @@ -287,7 +336,7 @@ pub async fn embeddings_create( // Use the actual model dimensions instead of hardcoded 768 let actual_dimensions = padded_embedding.len(); let expected_dimensions = get_model_dimensions(&embedding_model); - + if actual_dimensions != expected_dimensions { tracing::warn!( "Model {:?} produced {} dimensions but expected {}", @@ -455,7 +504,8 @@ pub async fn models_list() -> ResponseJson { id: "nomic-ai/nomic-embed-text-v1.5-q".to_string(), object: "model".to_string(), owned_by: "nomic-ai".to_string(), - description: "Quantized v1.5 release of the 8192 context length english model".to_string(), + description: "Quantized v1.5 release of the 8192 context length english model" + .to_string(), dimensions: 768, }, ModelInfo { @@ -476,7 +526,8 @@ pub async fn models_list() -> ResponseJson { id: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2".to_string(), object: "model".to_string(), owned_by: "sentence-transformers".to_string(), - description: "Sentence-transformers model for tasks like clustering or semantic search".to_string(), + description: "Sentence-transformers model for tasks like clustering or semantic search" + .to_string(), dimensions: 768, }, ModelInfo { diff --git a/crates/embeddings-engine/src/main.rs b/crates/embeddings-engine/src/main.rs index 44135aa..c429562 100644 --- a/crates/embeddings-engine/src/main.rs +++ b/crates/embeddings-engine/src/main.rs @@ -18,12 +18,10 @@ async fn embeddings_create( ) -> Result, axum::response::Response> { match embeddings_engine::embeddings_create(Json(payload)).await { Ok(response) => Ok(response), - Err((status_code, message)) => { - Err(axum::response::Response::builder() - .status(status_code) - .body(axum::body::Body::from(message)) - .unwrap()) - } + Err((status_code, message)) => Err(axum::response::Response::builder() + .status(status_code) + .body(axum::body::Body::from(message)) + .unwrap()), } } diff --git a/crates/inference-engine/src/model.rs b/crates/inference-engine/src/model.rs index 89270ff..e8af7d7 100644 --- a/crates/inference-engine/src/model.rs +++ b/crates/inference-engine/src/model.rs @@ -42,7 +42,11 @@ pub struct ModelMeta { } const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta { - ModelMeta { id, family, instruct } + ModelMeta { + id, + family, + instruct, + } } #[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] diff --git a/crates/inference-engine/src/server.rs b/crates/inference-engine/src/server.rs index bd2e91e..a7c0f77 100644 --- a/crates/inference-engine/src/server.rs +++ b/crates/inference-engine/src/server.rs @@ -42,13 +42,13 @@ pub struct AppState { pub llama_config: Option, } - impl Default for AppState { fn default() -> Self { // Configure a default model to prevent 503 errors from the chat-ui // This can be overridden by environment variables if needed - let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string()); - + let default_model_id = + std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string()); + let gemma_config = GemmaInferenceConfig { model: None, ..Default::default() @@ -94,9 +94,6 @@ fn model_id_to_which(model_id: &str) -> Option { } } - - - fn normalize_model_id(model_id: &str) -> String { model_id.to_lowercase().replace("_", "-") } @@ -157,7 +154,7 @@ pub async fn chat_completions_non_streaming_proxy( // Use the model specified in the request let model_id = request.model.clone(); let which_model = model_id_to_which(&model_id); - + // Validate that the requested model is supported let which_model = match which_model { Some(model) => model, @@ -204,19 +201,21 @@ pub async fn chat_completions_non_streaming_proxy( StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": { "message": format!("Model {} is not a Llama model", model_id) } - })) + })), )); } }; let mut config = LlamaInferenceConfig::new(llama_model); config.prompt = prompt.clone(); config.max_tokens = max_tokens; - run_llama_inference(config).map_err(|e| ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": { "message": format!("Error initializing Llama model: {}", e) } - })) - ))? + run_llama_inference(config).map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": { "message": format!("Error initializing Llama model: {}", e) } + })), + ) + })? } else { // Create Gemma configuration dynamically let gemma_model = match which_model { @@ -241,23 +240,25 @@ pub async fn chat_completions_non_streaming_proxy( StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": { "message": format!("Model {} is not a Gemma model", model_id) } - })) + })), )); } }; - + let mut config = GemmaInferenceConfig { model: Some(gemma_model), ..Default::default() }; config.prompt = prompt.clone(); config.max_tokens = max_tokens; - run_gemma_api(config).map_err(|e| ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": { "message": format!("Error initializing Gemma model: {}", e) } - })) - ))? + run_gemma_api(config).map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": { "message": format!("Error initializing Gemma model: {}", e) } + })), + ) + })? }; // Collect all tokens from the stream @@ -320,7 +321,7 @@ async fn handle_streaming_request( // Use the model specified in the request let model_id = request.model.clone(); let which_model = model_id_to_which(&model_id); - + // Validate that the requested model is supported let which_model = match which_model { Some(model) => model, @@ -397,7 +398,7 @@ async fn handle_streaming_request( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": { "message": format!("Model {} is not a Llama model", model_id) } - })) + })), )); } }; @@ -439,11 +440,11 @@ async fn handle_streaming_request( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": { "message": format!("Model {} is not a Gemma model", model_id) } - })) + })), )); } }; - + let mut config = GemmaInferenceConfig { model: Some(gemma_model), ..Default::default() @@ -605,59 +606,66 @@ pub async fn list_models() -> Json { Which::Llama32_3BInstruct, ]; + let mut models: Vec = which_variants + .into_iter() + .map(|which| { + let meta = which.meta(); + let model_id = match which { + Which::Base2B => "gemma-2b", + Which::Base7B => "gemma-7b", + Which::Instruct2B => "gemma-2b-it", + Which::Instruct7B => "gemma-7b-it", + Which::InstructV1_1_2B => "gemma-1.1-2b-it", + Which::InstructV1_1_7B => "gemma-1.1-7b-it", + Which::CodeBase2B => "codegemma-2b", + Which::CodeBase7B => "codegemma-7b", + Which::CodeInstruct2B => "codegemma-2b-it", + Which::CodeInstruct7B => "codegemma-7b-it", + Which::BaseV2_2B => "gemma-2-2b", + Which::InstructV2_2B => "gemma-2-2b-it", + Which::BaseV2_9B => "gemma-2-9b", + Which::InstructV2_9B => "gemma-2-9b-it", + Which::BaseV3_1B => "gemma-3-1b", + Which::InstructV3_1B => "gemma-3-1b-it", + Which::Llama32_1B => "llama-3.2-1b", + Which::Llama32_1BInstruct => "llama-3.2-1b-instruct", + Which::Llama32_3B => "llama-3.2-3b", + Which::Llama32_3BInstruct => "llama-3.2-3b-instruct", + }; + let owned_by = if meta.id.starts_with("google/") { + "google" + } else if meta.id.starts_with("meta-llama/") { + "meta" + } else { + "unknown" + }; - let mut models: Vec = which_variants.into_iter().map(|which| { - let meta = which.meta(); - let model_id = match which { - Which::Base2B => "gemma-2b", - Which::Base7B => "gemma-7b", - Which::Instruct2B => "gemma-2b-it", - Which::Instruct7B => "gemma-7b-it", - Which::InstructV1_1_2B => "gemma-1.1-2b-it", - Which::InstructV1_1_7B => "gemma-1.1-7b-it", - Which::CodeBase2B => "codegemma-2b", - Which::CodeBase7B => "codegemma-7b", - Which::CodeInstruct2B => "codegemma-2b-it", - Which::CodeInstruct7B => "codegemma-7b-it", - Which::BaseV2_2B => "gemma-2-2b", - Which::InstructV2_2B => "gemma-2-2b-it", - Which::BaseV2_9B => "gemma-2-9b", - Which::InstructV2_9B => "gemma-2-9b-it", - Which::BaseV3_1B => "gemma-3-1b", - Which::InstructV3_1B => "gemma-3-1b-it", - Which::Llama32_1B => "llama-3.2-1b", - Which::Llama32_1BInstruct => "llama-3.2-1b-instruct", - Which::Llama32_3B => "llama-3.2-3b", - Which::Llama32_3BInstruct => "llama-3.2-3b-instruct", - }; - - let owned_by = if meta.id.starts_with("google/") { - "google" - } else if meta.id.starts_with("meta-llama/") { - "meta" - } else { - "unknown" - }; - - Model { - id: model_id.to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: owned_by.to_string(), - } - }).collect(); + Model { + id: model_id.to_string(), + object: "model".to_string(), + created: 1686935002, + owned_by: owned_by.to_string(), + } + }) + .collect(); // Get embeddings models and convert them to inference Model format let embeddings_response = models_list().await; - let embeddings_models: Vec = embeddings_response.0.data.into_iter().map(|embedding_model| { - Model { + let embeddings_models: Vec = embeddings_response + .0 + .data + .into_iter() + .map(|embedding_model| Model { id: embedding_model.id, object: embedding_model.object, created: 1686935002, - owned_by: format!("{} - {}", embedding_model.owned_by, embedding_model.description), - } - }).collect(); + owned_by: format!( + "{} - {}", + embedding_model.owned_by, embedding_model.description + ), + }) + .collect(); // Add embeddings models to the main models list models.extend(embeddings_models); diff --git a/integration/gemma-runner/src/gemma_api.rs b/integration/gemma-runner/src/gemma_api.rs index 4dfd7d9..52d82aa 100644 --- a/integration/gemma-runner/src/gemma_api.rs +++ b/integration/gemma-runner/src/gemma_api.rs @@ -1,4 +1,3 @@ - use anyhow::{Error as E, Result}; use candle_transformers::models::gemma::{Config as Config1, Model as Model1}; use candle_transformers::models::gemma2::{Config as Config2, Model as Model2}; @@ -11,13 +10,13 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use std::io::Write; +use std::fmt; +use std::str::FromStr; use std::sync::mpsc::{self, Receiver, Sender}; use std::thread; use tokenizers::Tokenizer; use utils::hub_load_safetensors; use utils::token_output_stream::TokenOutputStream; -use std::str::FromStr; -use std::fmt; #[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] pub enum WhichModel { @@ -367,7 +366,9 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result vec![repo.get("model.safetensors")?], + Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => { + vec![repo.get("model.safetensors")?] + } _ => hub_load_safetensors(&repo, "model.safetensors.index.json")?, }; println!("Retrieved files in {:?}", start.elapsed()); @@ -396,7 +397,8 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result { // default to V2 model + | None => { + // default to V2 model let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let model = Model2::new(cfg.use_flash_attn, &config, vb)?; Model::V2(model) diff --git a/integration/helm-chart-tool/src/main.rs b/integration/helm-chart-tool/src/main.rs index 888bb5a..16809b9 100644 --- a/integration/helm-chart-tool/src/main.rs +++ b/integration/helm-chart-tool/src/main.rs @@ -105,7 +105,9 @@ fn discover_services(workspace_path: &str) -> Result> { .into_iter() .filter_map(|e| e.ok()) { - if entry.file_name() == "Cargo.toml" && entry.path() != workspace_root.join("../../../Cargo.toml") { + if entry.file_name() == "Cargo.toml" + && entry.path() != workspace_root.join("../../../Cargo.toml") + { if let Ok(service_info) = parse_cargo_toml(entry.path()) { services.push(service_info); } diff --git a/integration/llama-runner/src/llama_api.rs b/integration/llama-runner/src/llama_api.rs index 41aacd8..24c04f1 100644 --- a/integration/llama-runner/src/llama_api.rs +++ b/integration/llama-runner/src/llama_api.rs @@ -102,7 +102,7 @@ impl Default for LlamaInferenceConfig { max_tokens: 512, // Performance flags - no_kv_cache: false, // keep cache ON for speed + no_kv_cache: false, // keep cache ON for speed use_flash_attn: false, // great speed boost if supported // Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed. diff --git a/integration/utils/src/imagenet.rs b/integration/utils/src/imagenet.rs index 3dcb312..e514b05 100644 --- a/integration/utils/src/imagenet.rs +++ b/integration/utils/src/imagenet.rs @@ -1,5 +1,5 @@ -use candle_transformers::models::mimi::candle; use candle_core::{Device, Result, Tensor}; +use candle_transformers::models::mimi::candle; pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406]; pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225]; diff --git a/integration/utils/src/lib.rs b/integration/utils/src/lib.rs index 3b13714..c1f8919 100644 --- a/integration/utils/src/lib.rs +++ b/integration/utils/src/lib.rs @@ -8,8 +8,10 @@ pub mod coco_classes; pub mod imagenet; pub mod token_output_stream; pub mod wav; -use candle_core::{Device, Tensor, utils::{cuda_is_available, metal_is_available}}; - +use candle_core::{ + utils::{cuda_is_available, metal_is_available}, + Device, Tensor, +}; pub fn device(cpu: bool) -> Result { if cpu { @@ -126,7 +128,7 @@ pub fn hub_load_safetensors( repo.get(v) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) }) - .collect::, std::io::Error, >>()?; + .collect::, std::io::Error>>()?; Ok(safetensors_files) } @@ -136,7 +138,8 @@ pub fn hub_load_local_safetensors>( ) -> Result, anyhow::Error> { let path = path.as_ref(); let jsfile = std::fs::File::open(path.join(json_file))?; - let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?; + let json: serde_json::Value = + serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?; let weight_map = match json.get("weight_map") { None => anyhow::bail!("no weight map in {json_file:?}"), Some(serde_json::Value::Object(map)) => map,