mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
222 lines
7.7 KiB
Rust
222 lines
7.7 KiB
Rust
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
|
use axum::{
|
|
Json, Router,
|
|
response::Json as ResponseJson,
|
|
routing::{get, post},
|
|
};
|
|
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::env;
|
|
use tower_http::trace::TraceLayer;
|
|
use tracing;
|
|
|
|
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
|
|
const DEFAULT_SERVER_PORT: &str = "8080";
|
|
|
|
async fn embeddings_create(
|
|
Json(payload): Json<CreateEmbeddingRequest>,
|
|
) -> ResponseJson<serde_json::Value> {
|
|
let model = TextEmbedding::try_new(
|
|
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
|
|
)
|
|
.expect("Failed to initialize model");
|
|
|
|
let embedding_input = payload.input;
|
|
|
|
let texts_from_embedding_input = match embedding_input {
|
|
EmbeddingInput::String(text) => vec![text],
|
|
EmbeddingInput::StringArray(texts) => texts,
|
|
EmbeddingInput::IntegerArray(_) => {
|
|
panic!("Integer array input not supported for text embeddings");
|
|
}
|
|
EmbeddingInput::ArrayOfIntegerArray(_) => {
|
|
panic!("Array of integer arrays not supported for text embeddings");
|
|
}
|
|
};
|
|
|
|
let embeddings = model
|
|
.embed(texts_from_embedding_input, None)
|
|
.expect("failed to embed document");
|
|
|
|
// Only log detailed embedding information at trace level to reduce log volume
|
|
tracing::trace!("Embeddings length: {}", embeddings.len());
|
|
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
|
|
|
// Log the first 10 values of the original embedding at trace level
|
|
tracing::trace!(
|
|
"Original embedding preview: {:?}",
|
|
&embeddings[0][..10.min(embeddings[0].len())]
|
|
);
|
|
|
|
// Check if there are any NaN or zero values in the original embedding
|
|
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
|
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
|
tracing::trace!(
|
|
"Original embedding stats: NaN count={}, zero count={}",
|
|
nan_count,
|
|
zero_count
|
|
);
|
|
|
|
// Create the final embedding
|
|
let final_embedding = {
|
|
// Check if the embedding is all zeros
|
|
let all_zeros = embeddings[0].iter().all(|&x| x == 0.0);
|
|
if all_zeros {
|
|
tracing::warn!("Embedding is all zeros. Generating random non-zero embedding.");
|
|
|
|
// Generate a random non-zero embedding
|
|
use rand::Rng;
|
|
let mut rng = rand::thread_rng();
|
|
let mut random_embedding = Vec::with_capacity(768);
|
|
for _ in 0..768 {
|
|
// Generate random values between -1.0 and 1.0, excluding 0
|
|
let mut val = 0.0;
|
|
while val == 0.0 {
|
|
val = rng.gen_range(-1.0..1.0);
|
|
}
|
|
random_embedding.push(val);
|
|
}
|
|
|
|
// Normalize the random embedding
|
|
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
for i in 0..random_embedding.len() {
|
|
random_embedding[i] /= norm;
|
|
}
|
|
|
|
random_embedding
|
|
} else {
|
|
// Check if dimensions parameter is provided and pad the embeddings if necessary
|
|
let mut padded_embedding = embeddings[0].clone();
|
|
|
|
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
|
|
let target_dimension = 768;
|
|
if padded_embedding.len() < target_dimension {
|
|
let padding_needed = target_dimension - padded_embedding.len();
|
|
tracing::trace!(
|
|
"Padding embedding with {} zeros to reach {} dimensions",
|
|
padding_needed,
|
|
target_dimension
|
|
);
|
|
padded_embedding.extend(vec![0.0; padding_needed]);
|
|
}
|
|
|
|
padded_embedding
|
|
}
|
|
};
|
|
|
|
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
|
|
|
// Log the first 10 values of the final embedding at trace level
|
|
tracing::trace!(
|
|
"Final embedding preview: {:?}",
|
|
&final_embedding[..10.min(final_embedding.len())]
|
|
);
|
|
|
|
// Return a response that matches the OpenAI API format
|
|
let response = serde_json::json!({
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"object": "embedding",
|
|
"index": 0,
|
|
"embedding": final_embedding
|
|
}
|
|
],
|
|
"model": payload.model,
|
|
"usage": {
|
|
"prompt_tokens": 0,
|
|
"total_tokens": 0
|
|
}
|
|
});
|
|
ResponseJson(response)
|
|
}
|
|
|
|
fn create_app() -> Router {
|
|
Router::new()
|
|
.route("/v1/embeddings", post(embeddings_create))
|
|
.layer(TraceLayer::new_for_http())
|
|
}
|
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
|
#[tokio::main]
|
|
async fn main() {
|
|
tracing_subscriber::registry()
|
|
.with(
|
|
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
|
|
// axum logs rejections from built-in extractors with the `axum::rejection`
|
|
// target, at `TRACE` level. `axum::rejection=trace` enables showing those events
|
|
format!(
|
|
"{}=debug,tower_http=debug,axum::rejection=trace",
|
|
env!("CARGO_CRATE_NAME")
|
|
)
|
|
.into()
|
|
}),
|
|
)
|
|
.with(tracing_subscriber::fmt::layer())
|
|
.init();
|
|
let app = create_app();
|
|
|
|
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string());
|
|
let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string());
|
|
let server_address = format!("{}:{}", server_host, server_port);
|
|
let listener = tokio::net::TcpListener::bind(server_address).await.unwrap();
|
|
tracing::info!("Listening on {}", listener.local_addr().unwrap());
|
|
axum::serve(listener, app).await.unwrap();
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use axum::body::Body;
|
|
use axum::body::to_bytes;
|
|
use axum::http::StatusCode;
|
|
use tower::ServiceExt;
|
|
|
|
#[tokio::test]
|
|
async fn test_embeddings_create() {
|
|
// Start a test server
|
|
let app = create_app();
|
|
|
|
// Use the OpenAI client with our test server
|
|
|
|
let body = CreateEmbeddingRequest {
|
|
model: "nomic-text-embed".to_string(),
|
|
input: EmbeddingInput::from(vec![
|
|
"The food was delicious and the waiter...".to_string(),
|
|
]),
|
|
encoding_format: None,
|
|
user: None,
|
|
dimensions: Some(768),
|
|
};
|
|
|
|
let response = app
|
|
.oneshot(
|
|
axum::http::Request::builder()
|
|
.method(axum::http::Method::POST)
|
|
.uri("/v1/embeddings")
|
|
.header("content-type", "application/json")
|
|
.body(Body::from(serde_json::to_string(&body).unwrap()))
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
assert_eq!(response.status(), StatusCode::OK);
|
|
let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
|
|
|
|
let response_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
|
|
|
assert_eq!(response_json["object"], "list");
|
|
assert!(response_json["data"].is_array());
|
|
assert_eq!(response_json["data"].as_array().unwrap().len(), 1);
|
|
assert_eq!(response_json["model"], "nomic-text-embed");
|
|
|
|
let embedding_obj = &response_json["data"][0];
|
|
assert_eq!(embedding_obj["object"], "embedding");
|
|
assert_eq!(embedding_obj["index"], 0);
|
|
assert!(embedding_obj["embedding"].is_array());
|
|
|
|
let embedding = embedding_obj["embedding"].as_array().unwrap();
|
|
assert_eq!(embedding.len(), 768);
|
|
}
|
|
}
|