Refactor apply_cached_repeat_penalty for optimized caching and reuse, add extensive unit tests, and integrate special handling for gemma-specific models.

Removed `test_request.sh`, deprecated functionality, and unused imports; introduced a new CLI tool (`cli.ts`) for testing inference engine and adjusted handling of non-streaming/streaming chat completions.

- Add CPU fallback support for text generation when primary device is unsupported
- Introduce `execute_with_fallback` method to handle device compatibility and shape mismatch errors
- Extend unit tests to reproduce tensor shape mismatch errors specific to model configurations
- Increase HTTP timeout limits in `curl_chat_stream.sh` script for reliable API testing

chat completion endpoint functions with gemma3 (no streaming)

Add benchmarking guide with HTML reporting, Leptos chat crate, and middleware for metrics tracking
This commit is contained in:
geoffsee
2025-08-26 01:30:26 -04:00
parent 7dd23213c9
commit 8338750beb
64 changed files with 14997 additions and 220 deletions

View File

@@ -23,3 +23,4 @@ tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
rand = "0.8.5"
async-openai = "0.28.3"
once_cell = "1.19.0"

View File

@@ -1,14 +1,30 @@
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
use axum::{
response::Json as ResponseJson, routing::{get, post},
response::Json as ResponseJson, routing::{post},
Json,
Router,
};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use serde::{Deserialize, Serialize};
use once_cell::sync::Lazy;
use tower_http::trace::TraceLayer;
use tracing;
// Persistent model instance (singleton pattern)
static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
tracing::info!("Initializing persistent embedding model (singleton)");
let model_start_time = std::time::Instant::now();
let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true)
)
.expect("Failed to initialize persistent embedding model");
let model_init_time = model_start_time.elapsed();
tracing::info!("Persistent embedding model initialized in {:.2?}", model_init_time);
model
});
pub async fn root() -> &'static str {
"Hello, World!"
}
@@ -16,13 +32,21 @@ pub async fn root() -> &'static str {
pub async fn embeddings_create(
Json(payload): Json<CreateEmbeddingRequest>,
) -> ResponseJson<serde_json::Value> {
let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true)
)
.expect("Failed to initialize model");
// Start timing the entire process
let start_time = std::time::Instant::now();
// Phase 1: Access persistent model instance
let model_start_time = std::time::Instant::now();
// Access the lazy-initialized persistent model instance
// This will only initialize the model on the first request
let model_access_time = model_start_time.elapsed();
tracing::debug!("Persistent model access completed in {:.2?}", model_access_time);
// Phase 2: Process input
let input_start_time = std::time::Instant::now();
let embedding_input = payload.input;
let texts_from_embedding_input = match embedding_input {
EmbeddingInput::String(text) => vec![text],
EmbeddingInput::StringArray(texts) => texts,
@@ -33,10 +57,25 @@ pub async fn embeddings_create(
panic!("Array of integer arrays not supported for text embeddings");
}
};
let embeddings = model
let input_processing_time = input_start_time.elapsed();
tracing::debug!("Input processing completed in {:.2?}", input_processing_time);
// Phase 3: Generate embeddings
let embedding_start_time = std::time::Instant::now();
let embeddings = EMBEDDING_MODEL
.embed(texts_from_embedding_input, None)
.expect("failed to embed document");
let embedding_generation_time = embedding_start_time.elapsed();
tracing::info!("Embedding generation completed in {:.2?}", embedding_generation_time);
// Memory usage estimation (approximate)
let embedding_size_bytes = embeddings.iter()
.map(|e| e.len() * std::mem::size_of::<f32>())
.sum::<usize>();
tracing::debug!("Embedding size: {:.2} MB", embedding_size_bytes as f64 / 1024.0 / 1024.0);
// Only log detailed embedding information at trace level to reduce log volume
tracing::trace!("Embeddings length: {}", embeddings.len());
@@ -50,6 +89,9 @@ pub async fn embeddings_create(
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count);
// Phase 4: Post-process embeddings
let postprocessing_start_time = std::time::Instant::now();
// Create the final embedding
let final_embedding = {
// Check if the embedding is all zeros
@@ -92,12 +134,18 @@ pub async fn embeddings_create(
padded_embedding
}
};
let postprocessing_time = postprocessing_start_time.elapsed();
tracing::debug!("Embedding post-processing completed in {:.2?}", postprocessing_time);
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
// Log the first 10 values of the final embedding at trace level
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]);
// Phase 5: Prepare response
let response_start_time = std::time::Instant::now();
// Return a response that matches the OpenAI API format
let response = serde_json::json!({
"object": "list",
@@ -114,12 +162,25 @@ pub async fn embeddings_create(
"total_tokens": 0
}
});
let response_time = response_start_time.elapsed();
tracing::debug!("Response preparation completed in {:.2?}", response_time);
// Log total time and breakdown
let total_time = start_time.elapsed();
tracing::info!(
"Embeddings request completed in {:.2?} (model_access: {:.2?}, embedding: {:.2?}, postprocessing: {:.2?})",
total_time,
model_access_time,
embedding_generation_time,
postprocessing_time
);
ResponseJson(response)
}
pub fn create_embeddings_router() -> Router {
Router::new()
.route("/", get(root))
.route("/v1/embeddings", post(embeddings_create))
.layer(TraceLayer::new_for_http())
}

View File

@@ -124,7 +124,6 @@ async fn embeddings_create(
fn create_app() -> Router {
Router::new()
.route("/", get(root))
.route("/v1/embeddings", post(embeddings_create))
.layer(TraceLayer::new_for_http())
}