mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
Refactor apply_cached_repeat_penalty
for optimized caching and reuse, add extensive unit tests, and integrate special handling for gemma-specific models.
Removed `test_request.sh`, deprecated functionality, and unused imports; introduced a new CLI tool (`cli.ts`) for testing inference engine and adjusted handling of non-streaming/streaming chat completions. - Add CPU fallback support for text generation when primary device is unsupported - Introduce `execute_with_fallback` method to handle device compatibility and shape mismatch errors - Extend unit tests to reproduce tensor shape mismatch errors specific to model configurations - Increase HTTP timeout limits in `curl_chat_stream.sh` script for reliable API testing chat completion endpoint functions with gemma3 (no streaming) Add benchmarking guide with HTML reporting, Leptos chat crate, and middleware for metrics tracking
This commit is contained in:
@@ -23,3 +23,4 @@ tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
rand = "0.8.5"
|
||||
async-openai = "0.28.3"
|
||||
once_cell = "1.19.0"
|
||||
|
@@ -1,14 +1,30 @@
|
||||
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
||||
use axum::{
|
||||
response::Json as ResponseJson, routing::{get, post},
|
||||
response::Json as ResponseJson, routing::{post},
|
||||
Json,
|
||||
Router,
|
||||
};
|
||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use once_cell::sync::Lazy;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing;
|
||||
|
||||
// Persistent model instance (singleton pattern)
|
||||
static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
|
||||
tracing::info!("Initializing persistent embedding model (singleton)");
|
||||
let model_start_time = std::time::Instant::now();
|
||||
|
||||
let model = TextEmbedding::try_new(
|
||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true)
|
||||
)
|
||||
.expect("Failed to initialize persistent embedding model");
|
||||
|
||||
let model_init_time = model_start_time.elapsed();
|
||||
tracing::info!("Persistent embedding model initialized in {:.2?}", model_init_time);
|
||||
|
||||
model
|
||||
});
|
||||
|
||||
pub async fn root() -> &'static str {
|
||||
"Hello, World!"
|
||||
}
|
||||
@@ -16,13 +32,21 @@ pub async fn root() -> &'static str {
|
||||
pub async fn embeddings_create(
|
||||
Json(payload): Json<CreateEmbeddingRequest>,
|
||||
) -> ResponseJson<serde_json::Value> {
|
||||
let model = TextEmbedding::try_new(
|
||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true)
|
||||
)
|
||||
.expect("Failed to initialize model");
|
||||
|
||||
// Start timing the entire process
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
// Phase 1: Access persistent model instance
|
||||
let model_start_time = std::time::Instant::now();
|
||||
|
||||
// Access the lazy-initialized persistent model instance
|
||||
// This will only initialize the model on the first request
|
||||
let model_access_time = model_start_time.elapsed();
|
||||
tracing::debug!("Persistent model access completed in {:.2?}", model_access_time);
|
||||
|
||||
// Phase 2: Process input
|
||||
let input_start_time = std::time::Instant::now();
|
||||
|
||||
let embedding_input = payload.input;
|
||||
|
||||
let texts_from_embedding_input = match embedding_input {
|
||||
EmbeddingInput::String(text) => vec![text],
|
||||
EmbeddingInput::StringArray(texts) => texts,
|
||||
@@ -33,10 +57,25 @@ pub async fn embeddings_create(
|
||||
panic!("Array of integer arrays not supported for text embeddings");
|
||||
}
|
||||
};
|
||||
|
||||
let embeddings = model
|
||||
|
||||
let input_processing_time = input_start_time.elapsed();
|
||||
tracing::debug!("Input processing completed in {:.2?}", input_processing_time);
|
||||
|
||||
// Phase 3: Generate embeddings
|
||||
let embedding_start_time = std::time::Instant::now();
|
||||
|
||||
let embeddings = EMBEDDING_MODEL
|
||||
.embed(texts_from_embedding_input, None)
|
||||
.expect("failed to embed document");
|
||||
|
||||
let embedding_generation_time = embedding_start_time.elapsed();
|
||||
tracing::info!("Embedding generation completed in {:.2?}", embedding_generation_time);
|
||||
|
||||
// Memory usage estimation (approximate)
|
||||
let embedding_size_bytes = embeddings.iter()
|
||||
.map(|e| e.len() * std::mem::size_of::<f32>())
|
||||
.sum::<usize>();
|
||||
tracing::debug!("Embedding size: {:.2} MB", embedding_size_bytes as f64 / 1024.0 / 1024.0);
|
||||
|
||||
// Only log detailed embedding information at trace level to reduce log volume
|
||||
tracing::trace!("Embeddings length: {}", embeddings.len());
|
||||
@@ -50,6 +89,9 @@ pub async fn embeddings_create(
|
||||
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
||||
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count);
|
||||
|
||||
// Phase 4: Post-process embeddings
|
||||
let postprocessing_start_time = std::time::Instant::now();
|
||||
|
||||
// Create the final embedding
|
||||
let final_embedding = {
|
||||
// Check if the embedding is all zeros
|
||||
@@ -92,12 +134,18 @@ pub async fn embeddings_create(
|
||||
padded_embedding
|
||||
}
|
||||
};
|
||||
|
||||
let postprocessing_time = postprocessing_start_time.elapsed();
|
||||
tracing::debug!("Embedding post-processing completed in {:.2?}", postprocessing_time);
|
||||
|
||||
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
||||
|
||||
// Log the first 10 values of the final embedding at trace level
|
||||
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]);
|
||||
|
||||
// Phase 5: Prepare response
|
||||
let response_start_time = std::time::Instant::now();
|
||||
|
||||
// Return a response that matches the OpenAI API format
|
||||
let response = serde_json::json!({
|
||||
"object": "list",
|
||||
@@ -114,12 +162,25 @@ pub async fn embeddings_create(
|
||||
"total_tokens": 0
|
||||
}
|
||||
});
|
||||
|
||||
let response_time = response_start_time.elapsed();
|
||||
tracing::debug!("Response preparation completed in {:.2?}", response_time);
|
||||
|
||||
// Log total time and breakdown
|
||||
let total_time = start_time.elapsed();
|
||||
tracing::info!(
|
||||
"Embeddings request completed in {:.2?} (model_access: {:.2?}, embedding: {:.2?}, postprocessing: {:.2?})",
|
||||
total_time,
|
||||
model_access_time,
|
||||
embedding_generation_time,
|
||||
postprocessing_time
|
||||
);
|
||||
|
||||
ResponseJson(response)
|
||||
}
|
||||
|
||||
pub fn create_embeddings_router() -> Router {
|
||||
Router::new()
|
||||
.route("/", get(root))
|
||||
.route("/v1/embeddings", post(embeddings_create))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
}
|
@@ -124,7 +124,6 @@ async fn embeddings_create(
|
||||
|
||||
fn create_app() -> Router {
|
||||
Router::new()
|
||||
.route("/", get(root))
|
||||
.route("/v1/embeddings", post(embeddings_create))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
}
|
||||
|
Reference in New Issue
Block a user