run cargo fmt

This commit is contained in:
geoffsee
2025-09-04 13:45:25 -04:00
parent 1e02b12cda
commit c1c583faab
11 changed files with 241 additions and 170 deletions

View File

@@ -1,5 +1,10 @@
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
use axum::{Json, Router, response::Json as ResponseJson, routing::{get, post}, http::StatusCode};
use axum::{
Json, Router,
http::StatusCode,
response::Json as ResponseJson,
routing::{get, post},
};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use once_cell::sync::Lazy;
use serde::Serialize;
@@ -9,9 +14,8 @@ use tower_http::trace::TraceLayer;
use tracing;
// Cache for multiple embedding models
static MODEL_CACHE: Lazy<RwLock<HashMap<EmbeddingModel, Arc<TextEmbedding>>>> = Lazy::new(|| {
RwLock::new(HashMap::new())
});
static MODEL_CACHE: Lazy<RwLock<HashMap<EmbeddingModel, Arc<TextEmbedding>>>> =
Lazy::new(|| RwLock::new(HashMap::new()));
#[derive(Serialize)]
pub struct ModelInfo {
@@ -32,11 +36,19 @@ pub struct ModelsResponse {
fn parse_embedding_model(model_name: &str) -> Result<EmbeddingModel, String> {
match model_name {
// Sentence Transformers models
"sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => Ok(EmbeddingModel::AllMiniLML6V2),
"sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => Ok(EmbeddingModel::AllMiniLML6V2Q),
"sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => Ok(EmbeddingModel::AllMiniLML12V2),
"sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => Ok(EmbeddingModel::AllMiniLML12V2Q),
"sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => {
Ok(EmbeddingModel::AllMiniLML6V2)
}
"sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => {
Ok(EmbeddingModel::AllMiniLML6V2Q)
}
"sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => {
Ok(EmbeddingModel::AllMiniLML12V2)
}
"sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => {
Ok(EmbeddingModel::AllMiniLML12V2Q)
}
// BGE models
"BAAI/bge-base-en-v1.5" | "bge-base-en-v1.5" => Ok(EmbeddingModel::BGEBaseENV15),
"BAAI/bge-base-en-v1.5-q" | "bge-base-en-v1.5-q" => Ok(EmbeddingModel::BGEBaseENV15Q),
@@ -46,41 +58,68 @@ fn parse_embedding_model(model_name: &str) -> Result<EmbeddingModel, String> {
"BAAI/bge-small-en-v1.5-q" | "bge-small-en-v1.5-q" => Ok(EmbeddingModel::BGESmallENV15Q),
"BAAI/bge-small-zh-v1.5" | "bge-small-zh-v1.5" => Ok(EmbeddingModel::BGESmallZHV15),
"BAAI/bge-large-zh-v1.5" | "bge-large-zh-v1.5" => Ok(EmbeddingModel::BGELargeZHV15),
// Nomic models
"nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => Ok(EmbeddingModel::NomicEmbedTextV1),
"nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => Ok(EmbeddingModel::NomicEmbedTextV15),
"nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => Ok(EmbeddingModel::NomicEmbedTextV15Q),
"nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => {
Ok(EmbeddingModel::NomicEmbedTextV1)
}
"nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => {
Ok(EmbeddingModel::NomicEmbedTextV15)
}
"nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => {
Ok(EmbeddingModel::NomicEmbedTextV15Q)
}
// Paraphrase models
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2),
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q" | "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q),
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2" | "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2),
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
| "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2),
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q"
| "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q),
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
| "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2),
// ModernBert
"lightonai/modernbert-embed-large" | "modernbert-embed-large" => Ok(EmbeddingModel::ModernBertEmbedLarge),
"lightonai/modernbert-embed-large" | "modernbert-embed-large" => {
Ok(EmbeddingModel::ModernBertEmbedLarge)
}
// Multilingual E5 models
"intfloat/multilingual-e5-small" | "multilingual-e5-small" => Ok(EmbeddingModel::MultilingualE5Small),
"intfloat/multilingual-e5-base" | "multilingual-e5-base" => Ok(EmbeddingModel::MultilingualE5Base),
"intfloat/multilingual-e5-large" | "multilingual-e5-large" => Ok(EmbeddingModel::MultilingualE5Large),
"intfloat/multilingual-e5-small" | "multilingual-e5-small" => {
Ok(EmbeddingModel::MultilingualE5Small)
}
"intfloat/multilingual-e5-base" | "multilingual-e5-base" => {
Ok(EmbeddingModel::MultilingualE5Base)
}
"intfloat/multilingual-e5-large" | "multilingual-e5-large" => {
Ok(EmbeddingModel::MultilingualE5Large)
}
// Mixedbread models
"mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => Ok(EmbeddingModel::MxbaiEmbedLargeV1),
"mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => Ok(EmbeddingModel::MxbaiEmbedLargeV1Q),
"mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => {
Ok(EmbeddingModel::MxbaiEmbedLargeV1)
}
"mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => {
Ok(EmbeddingModel::MxbaiEmbedLargeV1Q)
}
// GTE models
"Alibaba-NLP/gte-base-en-v1.5" | "gte-base-en-v1.5" => Ok(EmbeddingModel::GTEBaseENV15),
"Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => Ok(EmbeddingModel::GTEBaseENV15Q),
"Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => {
Ok(EmbeddingModel::GTEBaseENV15Q)
}
"Alibaba-NLP/gte-large-en-v1.5" | "gte-large-en-v1.5" => Ok(EmbeddingModel::GTELargeENV15),
"Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => Ok(EmbeddingModel::GTELargeENV15Q),
"Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => {
Ok(EmbeddingModel::GTELargeENV15Q)
}
// CLIP model
"Qdrant/clip-ViT-B-32-text" | "clip-vit-b-32" => Ok(EmbeddingModel::ClipVitB32),
// Jina model
"jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode),
"jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => {
Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode)
}
_ => Err(format!("Unsupported embedding model: {}", model_name)),
}
}
@@ -95,7 +134,9 @@ fn get_model_dimensions(model: &EmbeddingModel) -> usize {
EmbeddingModel::BGESmallENV15 | EmbeddingModel::BGESmallENV15Q => 384,
EmbeddingModel::BGESmallZHV15 => 512,
EmbeddingModel::BGELargeZHV15 => 1024,
EmbeddingModel::NomicEmbedTextV1 | EmbeddingModel::NomicEmbedTextV15 | EmbeddingModel::NomicEmbedTextV15Q => 768,
EmbeddingModel::NomicEmbedTextV1
| EmbeddingModel::NomicEmbedTextV15
| EmbeddingModel::NomicEmbedTextV15Q => 768,
EmbeddingModel::ParaphraseMLMiniLML12V2 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => 384,
EmbeddingModel::ParaphraseMLMpnetBaseV2 => 768,
EmbeddingModel::ModernBertEmbedLarge => 1024,
@@ -114,37 +155,41 @@ fn get_model_dimensions(model: &EmbeddingModel) -> usize {
fn get_or_create_model(embedding_model: EmbeddingModel) -> Result<Arc<TextEmbedding>, String> {
// First try to get from cache (read lock)
{
let cache = MODEL_CACHE.read().map_err(|e| format!("Failed to acquire read lock: {}", e))?;
let cache = MODEL_CACHE
.read()
.map_err(|e| format!("Failed to acquire read lock: {}", e))?;
if let Some(model) = cache.get(&embedding_model) {
tracing::debug!("Using cached model: {:?}", embedding_model);
return Ok(Arc::clone(model));
}
}
// Model not in cache, create it (write lock)
let mut cache = MODEL_CACHE.write().map_err(|e| format!("Failed to acquire write lock: {}", e))?;
let mut cache = MODEL_CACHE
.write()
.map_err(|e| format!("Failed to acquire write lock: {}", e))?;
// Double-check after acquiring write lock
if let Some(model) = cache.get(&embedding_model) {
tracing::debug!("Using cached model (double-check): {:?}", embedding_model);
return Ok(Arc::clone(model));
}
tracing::info!("Initializing new embedding model: {:?}", embedding_model);
let model_start_time = std::time::Instant::now();
let model = TextEmbedding::try_new(
InitOptions::new(embedding_model.clone()).with_show_download_progress(true),
)
.map_err(|e| format!("Failed to initialize model {:?}: {}", embedding_model, e))?;
let model_init_time = model_start_time.elapsed();
tracing::info!(
"Embedding model {:?} initialized in {:.2?}",
embedding_model,
model_init_time
);
let model_arc = Arc::new(model);
cache.insert(embedding_model.clone(), Arc::clone(&model_arc));
Ok(model_arc)
@@ -158,7 +203,7 @@ pub async fn embeddings_create(
// Phase 1: Parse and get the embedding model
let model_start_time = std::time::Instant::now();
let embedding_model = match parse_embedding_model(&payload.model) {
Ok(model) => model,
Err(e) => {
@@ -166,15 +211,18 @@ pub async fn embeddings_create(
return Err((StatusCode::BAD_REQUEST, format!("Invalid model: {}", e)));
}
};
let model = match get_or_create_model(embedding_model.clone()) {
Ok(model) => model,
Err(e) => {
tracing::error!("Failed to get/create model: {}", e);
return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("Model initialization failed: {}", e)));
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Model initialization failed: {}", e),
));
}
};
let model_access_time = model_start_time.elapsed();
tracing::debug!(
"Model access/creation completed in {:.2?}",
@@ -205,12 +253,13 @@ pub async fn embeddings_create(
// Phase 3: Generate embeddings
let embedding_start_time = std::time::Instant::now();
let embeddings = model
.embed(texts_from_embedding_input, None)
.map_err(|e| {
tracing::error!("Failed to generate embeddings: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, format!("Embedding generation failed: {}", e))
})?;
let embeddings = model.embed(texts_from_embedding_input, None).map_err(|e| {
tracing::error!("Failed to generate embeddings: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Embedding generation failed: {}", e),
)
})?;
let embedding_generation_time = embedding_start_time.elapsed();
tracing::info!(
@@ -287,7 +336,7 @@ pub async fn embeddings_create(
// Use the actual model dimensions instead of hardcoded 768
let actual_dimensions = padded_embedding.len();
let expected_dimensions = get_model_dimensions(&embedding_model);
if actual_dimensions != expected_dimensions {
tracing::warn!(
"Model {:?} produced {} dimensions but expected {}",
@@ -455,7 +504,8 @@ pub async fn models_list() -> ResponseJson<ModelsResponse> {
id: "nomic-ai/nomic-embed-text-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "Quantized v1.5 release of the 8192 context length english model".to_string(),
description: "Quantized v1.5 release of the 8192 context length english model"
.to_string(),
dimensions: 768,
},
ModelInfo {
@@ -476,7 +526,8 @@ pub async fn models_list() -> ResponseJson<ModelsResponse> {
id: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence-transformers model for tasks like clustering or semantic search".to_string(),
description: "Sentence-transformers model for tasks like clustering or semantic search"
.to_string(),
dimensions: 768,
},
ModelInfo {

View File

@@ -18,12 +18,10 @@ async fn embeddings_create(
) -> Result<ResponseJson<serde_json::Value>, axum::response::Response> {
match embeddings_engine::embeddings_create(Json(payload)).await {
Ok(response) => Ok(response),
Err((status_code, message)) => {
Err(axum::response::Response::builder()
.status(status_code)
.body(axum::body::Body::from(message))
.unwrap())
}
Err((status_code, message)) => Err(axum::response::Response::builder()
.status(status_code)
.body(axum::body::Body::from(message))
.unwrap()),
}
}