mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
cleanup, add ci
This commit is contained in:
@@ -1,9 +1,5 @@
|
||||
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
||||
use axum::{
|
||||
response::Json as ResponseJson, routing::{post},
|
||||
Json,
|
||||
Router,
|
||||
};
|
||||
use axum::{Json, Router, response::Json as ResponseJson, routing::post};
|
||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||
use once_cell::sync::Lazy;
|
||||
use tower_http::trace::TraceLayer;
|
||||
@@ -13,15 +9,18 @@ use tracing;
|
||||
static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
|
||||
tracing::info!("Initializing persistent embedding model (singleton)");
|
||||
let model_start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
let model = TextEmbedding::try_new(
|
||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true)
|
||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
|
||||
)
|
||||
.expect("Failed to initialize persistent embedding model");
|
||||
|
||||
.expect("Failed to initialize persistent embedding model");
|
||||
|
||||
let model_init_time = model_start_time.elapsed();
|
||||
tracing::info!("Persistent embedding model initialized in {:.2?}", model_init_time);
|
||||
|
||||
tracing::info!(
|
||||
"Persistent embedding model initialized in {:.2?}",
|
||||
model_init_time
|
||||
);
|
||||
|
||||
model
|
||||
});
|
||||
|
||||
@@ -30,18 +29,21 @@ pub async fn embeddings_create(
|
||||
) -> ResponseJson<serde_json::Value> {
|
||||
// Start timing the entire process
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
// Phase 1: Access persistent model instance
|
||||
let model_start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
// Access the lazy-initialized persistent model instance
|
||||
// This will only initialize the model on the first request
|
||||
let model_access_time = model_start_time.elapsed();
|
||||
tracing::debug!("Persistent model access completed in {:.2?}", model_access_time);
|
||||
|
||||
tracing::debug!(
|
||||
"Persistent model access completed in {:.2?}",
|
||||
model_access_time
|
||||
);
|
||||
|
||||
// Phase 2: Process input
|
||||
let input_start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
let embedding_input = payload.input;
|
||||
let texts_from_embedding_input = match embedding_input {
|
||||
EmbeddingInput::String(text) => vec![text],
|
||||
@@ -53,41 +55,58 @@ pub async fn embeddings_create(
|
||||
panic!("Array of integer arrays not supported for text embeddings");
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
let input_processing_time = input_start_time.elapsed();
|
||||
tracing::debug!("Input processing completed in {:.2?}", input_processing_time);
|
||||
|
||||
tracing::debug!(
|
||||
"Input processing completed in {:.2?}",
|
||||
input_processing_time
|
||||
);
|
||||
|
||||
// Phase 3: Generate embeddings
|
||||
let embedding_start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
let embeddings = EMBEDDING_MODEL
|
||||
.embed(texts_from_embedding_input, None)
|
||||
.expect("failed to embed document");
|
||||
|
||||
|
||||
let embedding_generation_time = embedding_start_time.elapsed();
|
||||
tracing::info!("Embedding generation completed in {:.2?}", embedding_generation_time);
|
||||
|
||||
tracing::info!(
|
||||
"Embedding generation completed in {:.2?}",
|
||||
embedding_generation_time
|
||||
);
|
||||
|
||||
// Memory usage estimation (approximate)
|
||||
let embedding_size_bytes = embeddings.iter()
|
||||
let embedding_size_bytes = embeddings
|
||||
.iter()
|
||||
.map(|e| e.len() * std::mem::size_of::<f32>())
|
||||
.sum::<usize>();
|
||||
tracing::debug!("Embedding size: {:.2} MB", embedding_size_bytes as f64 / 1024.0 / 1024.0);
|
||||
tracing::debug!(
|
||||
"Embedding size: {:.2} MB",
|
||||
embedding_size_bytes as f64 / 1024.0 / 1024.0
|
||||
);
|
||||
|
||||
// Only log detailed embedding information at trace level to reduce log volume
|
||||
tracing::trace!("Embeddings length: {}", embeddings.len());
|
||||
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
||||
|
||||
// Log the first 10 values of the original embedding at trace level
|
||||
tracing::trace!("Original embedding preview: {:?}", &embeddings[0][..10.min(embeddings[0].len())]);
|
||||
tracing::trace!(
|
||||
"Original embedding preview: {:?}",
|
||||
&embeddings[0][..10.min(embeddings[0].len())]
|
||||
);
|
||||
|
||||
// Check if there are any NaN or zero values in the original embedding
|
||||
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
||||
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
||||
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count);
|
||||
tracing::trace!(
|
||||
"Original embedding stats: NaN count={}, zero count={}",
|
||||
nan_count,
|
||||
zero_count
|
||||
);
|
||||
|
||||
// Phase 4: Post-process embeddings
|
||||
let postprocessing_start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
// Create the final embedding
|
||||
let final_embedding = {
|
||||
// Check if the embedding is all zeros
|
||||
@@ -110,6 +129,8 @@ pub async fn embeddings_create(
|
||||
|
||||
// Normalize the random embedding
|
||||
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in 0..random_embedding.len() {
|
||||
random_embedding[i] /= norm;
|
||||
}
|
||||
@@ -123,25 +144,35 @@ pub async fn embeddings_create(
|
||||
let target_dimension = 768;
|
||||
if padded_embedding.len() < target_dimension {
|
||||
let padding_needed = target_dimension - padded_embedding.len();
|
||||
tracing::trace!("Padding embedding with {} zeros to reach {} dimensions", padding_needed, target_dimension);
|
||||
tracing::trace!(
|
||||
"Padding embedding with {} zeros to reach {} dimensions",
|
||||
padding_needed,
|
||||
target_dimension
|
||||
);
|
||||
padded_embedding.extend(vec![0.0; padding_needed]);
|
||||
}
|
||||
|
||||
padded_embedding
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
let postprocessing_time = postprocessing_start_time.elapsed();
|
||||
tracing::debug!("Embedding post-processing completed in {:.2?}", postprocessing_time);
|
||||
tracing::debug!(
|
||||
"Embedding post-processing completed in {:.2?}",
|
||||
postprocessing_time
|
||||
);
|
||||
|
||||
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
||||
|
||||
// Log the first 10 values of the final embedding at trace level
|
||||
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]);
|
||||
tracing::trace!(
|
||||
"Final embedding preview: {:?}",
|
||||
&final_embedding[..10.min(final_embedding.len())]
|
||||
);
|
||||
|
||||
// Phase 5: Prepare response
|
||||
let response_start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
// Return a response that matches the OpenAI API format
|
||||
let response = serde_json::json!({
|
||||
"object": "list",
|
||||
@@ -158,10 +189,10 @@ pub async fn embeddings_create(
|
||||
"total_tokens": 0
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
let response_time = response_start_time.elapsed();
|
||||
tracing::debug!("Response preparation completed in {:.2?}", response_time);
|
||||
|
||||
|
||||
// Log total time and breakdown
|
||||
let total_time = start_time.elapsed();
|
||||
tracing::info!(
|
||||
@@ -171,7 +202,7 @@ pub async fn embeddings_create(
|
||||
embedding_generation_time,
|
||||
postprocessing_time
|
||||
);
|
||||
|
||||
|
||||
ResponseJson(response)
|
||||
}
|
||||
|
||||
@@ -179,4 +210,4 @@ pub fn create_embeddings_router() -> Router {
|
||||
Router::new()
|
||||
.route("/v1/embeddings", post(embeddings_create))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
}
|
||||
}
|
||||
|
@@ -1,8 +1,8 @@
|
||||
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
||||
use axum::{
|
||||
response::Json as ResponseJson, routing::{get, post},
|
||||
Json,
|
||||
Router,
|
||||
Json, Router,
|
||||
response::Json as ResponseJson,
|
||||
routing::{get, post},
|
||||
};
|
||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -13,19 +13,17 @@ use tracing;
|
||||
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
|
||||
const DEFAULT_SERVER_PORT: &str = "8080";
|
||||
|
||||
|
||||
async fn embeddings_create(
|
||||
Json(payload): Json<CreateEmbeddingRequest>,
|
||||
) -> ResponseJson<serde_json::Value> {
|
||||
let model = TextEmbedding::try_new(
|
||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true)
|
||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
|
||||
)
|
||||
.expect("Failed to initialize model");
|
||||
|
||||
let embedding_input = payload.input;
|
||||
|
||||
let embedding_input = payload.input;
|
||||
|
||||
let texts_from_embedding_input = match embedding_input {
|
||||
let texts_from_embedding_input = match embedding_input {
|
||||
EmbeddingInput::String(text) => vec![text],
|
||||
EmbeddingInput::StringArray(texts) => texts,
|
||||
EmbeddingInput::IntegerArray(_) => {
|
||||
@@ -45,12 +43,19 @@ async fn embeddings_create(
|
||||
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
||||
|
||||
// Log the first 10 values of the original embedding at trace level
|
||||
tracing::trace!("Original embedding preview: {:?}", &embeddings[0][..10.min(embeddings[0].len())]);
|
||||
tracing::trace!(
|
||||
"Original embedding preview: {:?}",
|
||||
&embeddings[0][..10.min(embeddings[0].len())]
|
||||
);
|
||||
|
||||
// Check if there are any NaN or zero values in the original embedding
|
||||
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
||||
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
||||
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count);
|
||||
tracing::trace!(
|
||||
"Original embedding stats: NaN count={}, zero count={}",
|
||||
nan_count,
|
||||
zero_count
|
||||
);
|
||||
|
||||
// Create the final embedding
|
||||
let final_embedding = {
|
||||
@@ -87,7 +92,11 @@ async fn embeddings_create(
|
||||
let target_dimension = 768;
|
||||
if padded_embedding.len() < target_dimension {
|
||||
let padding_needed = target_dimension - padded_embedding.len();
|
||||
tracing::trace!("Padding embedding with {} zeros to reach {} dimensions", padding_needed, target_dimension);
|
||||
tracing::trace!(
|
||||
"Padding embedding with {} zeros to reach {} dimensions",
|
||||
padding_needed,
|
||||
target_dimension
|
||||
);
|
||||
padded_embedding.extend(vec![0.0; padding_needed]);
|
||||
}
|
||||
|
||||
@@ -98,7 +107,10 @@ async fn embeddings_create(
|
||||
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
||||
|
||||
// Log the first 10 values of the final embedding at trace level
|
||||
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]);
|
||||
tracing::trace!(
|
||||
"Final embedding preview: {:?}",
|
||||
&final_embedding[..10.min(final_embedding.len())]
|
||||
);
|
||||
|
||||
// Return a response that matches the OpenAI API format
|
||||
let response = serde_json::json!({
|
||||
@@ -120,7 +132,7 @@ async fn embeddings_create(
|
||||
}
|
||||
|
||||
fn create_app() -> Router {
|
||||
Router::new()
|
||||
Router::new()
|
||||
.route("/v1/embeddings", post(embeddings_create))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
}
|
||||
@@ -143,21 +155,21 @@ async fn main() {
|
||||
.init();
|
||||
let app = create_app();
|
||||
|
||||
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string());
|
||||
let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string());
|
||||
let server_address = format!("{}:{}", server_host, server_port);
|
||||
let listener = tokio::net::TcpListener::bind(server_address).await.unwrap();
|
||||
tracing::info!("Listening on {}", listener.local_addr().unwrap());
|
||||
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string());
|
||||
let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string());
|
||||
let server_address = format!("{}:{}", server_host, server_port);
|
||||
let listener = tokio::net::TcpListener::bind(server_address).await.unwrap();
|
||||
tracing::info!("Listening on {}", listener.local_addr().unwrap());
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::body::to_bytes;
|
||||
use axum::body::Body;
|
||||
use axum::http::StatusCode;
|
||||
use tower::ServiceExt;
|
||||
use super::*;
|
||||
use axum::body::Body;
|
||||
use axum::body::to_bytes;
|
||||
use axum::http::StatusCode;
|
||||
use tower::ServiceExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_embeddings_create() {
|
||||
@@ -168,11 +180,13 @@ mod tests {
|
||||
|
||||
let body = CreateEmbeddingRequest {
|
||||
model: "nomic-text-embed".to_string(),
|
||||
input: EmbeddingInput::from(vec!["The food was delicious and the waiter...".to_string()]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: Some(768),
|
||||
};
|
||||
input: EmbeddingInput::from(vec![
|
||||
"The food was delicious and the waiter...".to_string(),
|
||||
]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: Some(768),
|
||||
};
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
|
Reference in New Issue
Block a user