mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
cleanup, add ci
This commit is contained in:
@@ -1,9 +1,5 @@
|
||||
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
||||
use axum::{
|
||||
response::Json as ResponseJson, routing::{post},
|
||||
Json,
|
||||
Router,
|
||||
};
|
||||
use axum::{Json, Router, response::Json as ResponseJson, routing::post};
|
||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||
use once_cell::sync::Lazy;
|
||||
use tower_http::trace::TraceLayer;
|
||||
@@ -13,15 +9,18 @@ use tracing;
|
||||
static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
|
||||
tracing::info!("Initializing persistent embedding model (singleton)");
|
||||
let model_start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
let model = TextEmbedding::try_new(
|
||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true)
|
||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
|
||||
)
|
||||
.expect("Failed to initialize persistent embedding model");
|
||||
|
||||
.expect("Failed to initialize persistent embedding model");
|
||||
|
||||
let model_init_time = model_start_time.elapsed();
|
||||
tracing::info!("Persistent embedding model initialized in {:.2?}", model_init_time);
|
||||
|
||||
tracing::info!(
|
||||
"Persistent embedding model initialized in {:.2?}",
|
||||
model_init_time
|
||||
);
|
||||
|
||||
model
|
||||
});
|
||||
|
||||
@@ -30,18 +29,21 @@ pub async fn embeddings_create(
|
||||
) -> ResponseJson<serde_json::Value> {
|
||||
// Start timing the entire process
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
// Phase 1: Access persistent model instance
|
||||
let model_start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
// Access the lazy-initialized persistent model instance
|
||||
// This will only initialize the model on the first request
|
||||
let model_access_time = model_start_time.elapsed();
|
||||
tracing::debug!("Persistent model access completed in {:.2?}", model_access_time);
|
||||
|
||||
tracing::debug!(
|
||||
"Persistent model access completed in {:.2?}",
|
||||
model_access_time
|
||||
);
|
||||
|
||||
// Phase 2: Process input
|
||||
let input_start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
let embedding_input = payload.input;
|
||||
let texts_from_embedding_input = match embedding_input {
|
||||
EmbeddingInput::String(text) => vec![text],
|
||||
@@ -53,41 +55,58 @@ pub async fn embeddings_create(
|
||||
panic!("Array of integer arrays not supported for text embeddings");
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
let input_processing_time = input_start_time.elapsed();
|
||||
tracing::debug!("Input processing completed in {:.2?}", input_processing_time);
|
||||
|
||||
tracing::debug!(
|
||||
"Input processing completed in {:.2?}",
|
||||
input_processing_time
|
||||
);
|
||||
|
||||
// Phase 3: Generate embeddings
|
||||
let embedding_start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
let embeddings = EMBEDDING_MODEL
|
||||
.embed(texts_from_embedding_input, None)
|
||||
.expect("failed to embed document");
|
||||
|
||||
|
||||
let embedding_generation_time = embedding_start_time.elapsed();
|
||||
tracing::info!("Embedding generation completed in {:.2?}", embedding_generation_time);
|
||||
|
||||
tracing::info!(
|
||||
"Embedding generation completed in {:.2?}",
|
||||
embedding_generation_time
|
||||
);
|
||||
|
||||
// Memory usage estimation (approximate)
|
||||
let embedding_size_bytes = embeddings.iter()
|
||||
let embedding_size_bytes = embeddings
|
||||
.iter()
|
||||
.map(|e| e.len() * std::mem::size_of::<f32>())
|
||||
.sum::<usize>();
|
||||
tracing::debug!("Embedding size: {:.2} MB", embedding_size_bytes as f64 / 1024.0 / 1024.0);
|
||||
tracing::debug!(
|
||||
"Embedding size: {:.2} MB",
|
||||
embedding_size_bytes as f64 / 1024.0 / 1024.0
|
||||
);
|
||||
|
||||
// Only log detailed embedding information at trace level to reduce log volume
|
||||
tracing::trace!("Embeddings length: {}", embeddings.len());
|
||||
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
||||
|
||||
// Log the first 10 values of the original embedding at trace level
|
||||
tracing::trace!("Original embedding preview: {:?}", &embeddings[0][..10.min(embeddings[0].len())]);
|
||||
tracing::trace!(
|
||||
"Original embedding preview: {:?}",
|
||||
&embeddings[0][..10.min(embeddings[0].len())]
|
||||
);
|
||||
|
||||
// Check if there are any NaN or zero values in the original embedding
|
||||
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
||||
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
||||
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count);
|
||||
tracing::trace!(
|
||||
"Original embedding stats: NaN count={}, zero count={}",
|
||||
nan_count,
|
||||
zero_count
|
||||
);
|
||||
|
||||
// Phase 4: Post-process embeddings
|
||||
let postprocessing_start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
// Create the final embedding
|
||||
let final_embedding = {
|
||||
// Check if the embedding is all zeros
|
||||
@@ -110,6 +129,8 @@ pub async fn embeddings_create(
|
||||
|
||||
// Normalize the random embedding
|
||||
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for i in 0..random_embedding.len() {
|
||||
random_embedding[i] /= norm;
|
||||
}
|
||||
@@ -123,25 +144,35 @@ pub async fn embeddings_create(
|
||||
let target_dimension = 768;
|
||||
if padded_embedding.len() < target_dimension {
|
||||
let padding_needed = target_dimension - padded_embedding.len();
|
||||
tracing::trace!("Padding embedding with {} zeros to reach {} dimensions", padding_needed, target_dimension);
|
||||
tracing::trace!(
|
||||
"Padding embedding with {} zeros to reach {} dimensions",
|
||||
padding_needed,
|
||||
target_dimension
|
||||
);
|
||||
padded_embedding.extend(vec![0.0; padding_needed]);
|
||||
}
|
||||
|
||||
padded_embedding
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
let postprocessing_time = postprocessing_start_time.elapsed();
|
||||
tracing::debug!("Embedding post-processing completed in {:.2?}", postprocessing_time);
|
||||
tracing::debug!(
|
||||
"Embedding post-processing completed in {:.2?}",
|
||||
postprocessing_time
|
||||
);
|
||||
|
||||
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
||||
|
||||
// Log the first 10 values of the final embedding at trace level
|
||||
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]);
|
||||
tracing::trace!(
|
||||
"Final embedding preview: {:?}",
|
||||
&final_embedding[..10.min(final_embedding.len())]
|
||||
);
|
||||
|
||||
// Phase 5: Prepare response
|
||||
let response_start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
// Return a response that matches the OpenAI API format
|
||||
let response = serde_json::json!({
|
||||
"object": "list",
|
||||
@@ -158,10 +189,10 @@ pub async fn embeddings_create(
|
||||
"total_tokens": 0
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
let response_time = response_start_time.elapsed();
|
||||
tracing::debug!("Response preparation completed in {:.2?}", response_time);
|
||||
|
||||
|
||||
// Log total time and breakdown
|
||||
let total_time = start_time.elapsed();
|
||||
tracing::info!(
|
||||
@@ -171,7 +202,7 @@ pub async fn embeddings_create(
|
||||
embedding_generation_time,
|
||||
postprocessing_time
|
||||
);
|
||||
|
||||
|
||||
ResponseJson(response)
|
||||
}
|
||||
|
||||
@@ -179,4 +210,4 @@ pub fn create_embeddings_router() -> Router {
|
||||
Router::new()
|
||||
.route("/v1/embeddings", post(embeddings_create))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
}
|
||||
}
|
||||
|
@@ -1,8 +1,8 @@
|
||||
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
||||
use axum::{
|
||||
response::Json as ResponseJson, routing::{get, post},
|
||||
Json,
|
||||
Router,
|
||||
Json, Router,
|
||||
response::Json as ResponseJson,
|
||||
routing::{get, post},
|
||||
};
|
||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -13,19 +13,17 @@ use tracing;
|
||||
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
|
||||
const DEFAULT_SERVER_PORT: &str = "8080";
|
||||
|
||||
|
||||
async fn embeddings_create(
|
||||
Json(payload): Json<CreateEmbeddingRequest>,
|
||||
) -> ResponseJson<serde_json::Value> {
|
||||
let model = TextEmbedding::try_new(
|
||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true)
|
||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
|
||||
)
|
||||
.expect("Failed to initialize model");
|
||||
|
||||
let embedding_input = payload.input;
|
||||
|
||||
let embedding_input = payload.input;
|
||||
|
||||
let texts_from_embedding_input = match embedding_input {
|
||||
let texts_from_embedding_input = match embedding_input {
|
||||
EmbeddingInput::String(text) => vec![text],
|
||||
EmbeddingInput::StringArray(texts) => texts,
|
||||
EmbeddingInput::IntegerArray(_) => {
|
||||
@@ -45,12 +43,19 @@ async fn embeddings_create(
|
||||
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
||||
|
||||
// Log the first 10 values of the original embedding at trace level
|
||||
tracing::trace!("Original embedding preview: {:?}", &embeddings[0][..10.min(embeddings[0].len())]);
|
||||
tracing::trace!(
|
||||
"Original embedding preview: {:?}",
|
||||
&embeddings[0][..10.min(embeddings[0].len())]
|
||||
);
|
||||
|
||||
// Check if there are any NaN or zero values in the original embedding
|
||||
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
||||
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
||||
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count);
|
||||
tracing::trace!(
|
||||
"Original embedding stats: NaN count={}, zero count={}",
|
||||
nan_count,
|
||||
zero_count
|
||||
);
|
||||
|
||||
// Create the final embedding
|
||||
let final_embedding = {
|
||||
@@ -87,7 +92,11 @@ async fn embeddings_create(
|
||||
let target_dimension = 768;
|
||||
if padded_embedding.len() < target_dimension {
|
||||
let padding_needed = target_dimension - padded_embedding.len();
|
||||
tracing::trace!("Padding embedding with {} zeros to reach {} dimensions", padding_needed, target_dimension);
|
||||
tracing::trace!(
|
||||
"Padding embedding with {} zeros to reach {} dimensions",
|
||||
padding_needed,
|
||||
target_dimension
|
||||
);
|
||||
padded_embedding.extend(vec![0.0; padding_needed]);
|
||||
}
|
||||
|
||||
@@ -98,7 +107,10 @@ async fn embeddings_create(
|
||||
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
||||
|
||||
// Log the first 10 values of the final embedding at trace level
|
||||
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]);
|
||||
tracing::trace!(
|
||||
"Final embedding preview: {:?}",
|
||||
&final_embedding[..10.min(final_embedding.len())]
|
||||
);
|
||||
|
||||
// Return a response that matches the OpenAI API format
|
||||
let response = serde_json::json!({
|
||||
@@ -120,7 +132,7 @@ async fn embeddings_create(
|
||||
}
|
||||
|
||||
fn create_app() -> Router {
|
||||
Router::new()
|
||||
Router::new()
|
||||
.route("/v1/embeddings", post(embeddings_create))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
}
|
||||
@@ -143,21 +155,21 @@ async fn main() {
|
||||
.init();
|
||||
let app = create_app();
|
||||
|
||||
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string());
|
||||
let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string());
|
||||
let server_address = format!("{}:{}", server_host, server_port);
|
||||
let listener = tokio::net::TcpListener::bind(server_address).await.unwrap();
|
||||
tracing::info!("Listening on {}", listener.local_addr().unwrap());
|
||||
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string());
|
||||
let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string());
|
||||
let server_address = format!("{}:{}", server_host, server_port);
|
||||
let listener = tokio::net::TcpListener::bind(server_address).await.unwrap();
|
||||
tracing::info!("Listening on {}", listener.local_addr().unwrap());
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::body::to_bytes;
|
||||
use axum::body::Body;
|
||||
use axum::http::StatusCode;
|
||||
use tower::ServiceExt;
|
||||
use super::*;
|
||||
use axum::body::Body;
|
||||
use axum::body::to_bytes;
|
||||
use axum::http::StatusCode;
|
||||
use tower::ServiceExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_embeddings_create() {
|
||||
@@ -168,11 +180,13 @@ mod tests {
|
||||
|
||||
let body = CreateEmbeddingRequest {
|
||||
model: "nomic-text-embed".to_string(),
|
||||
input: EmbeddingInput::from(vec!["The food was delicious and the waiter...".to_string()]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: Some(768),
|
||||
};
|
||||
input: EmbeddingInput::from(vec![
|
||||
"The food was delicious and the waiter...".to_string(),
|
||||
]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: Some(768),
|
||||
};
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
|
@@ -3,16 +3,14 @@ name = "gemma-runner"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
|
||||
|
||||
|
||||
[dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-examples = { git = "https://github.com/huggingface/candle.git" }
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
hf-hub = "0.4"
|
||||
tokenizers = "0.21"
|
||||
anyhow = "1.0"
|
||||
@@ -22,6 +20,12 @@ tracing = "0.1"
|
||||
tracing-chrome = "0.7"
|
||||
tracing-subscriber = "0.3"
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
|
@@ -4,10 +4,10 @@ extern crate accelerate_src;
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::ValueEnum;
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
use clap::ValueEnum;
|
||||
|
||||
// Removed gemma_cli import as it's not needed for the API
|
||||
use candle_core::{utils, DType, Device, Tensor};
|
||||
@@ -119,7 +119,12 @@ impl TextGeneration {
|
||||
|
||||
/// Stream-only generation: sends freshly generated token strings over `tx`.
|
||||
/// (Does not send the prompt tokens; only newly generated model tokens.)
|
||||
fn run_stream(&mut self, prompt: &str, sample_len: usize, tx: Sender<Result<String>>) -> Result<()> {
|
||||
fn run_stream(
|
||||
&mut self,
|
||||
prompt: &str,
|
||||
sample_len: usize,
|
||||
tx: Sender<Result<String>>,
|
||||
) -> Result<()> {
|
||||
self.tokenizer.clear();
|
||||
|
||||
// Encode prompt (context only; do not emit prompt tokens to the stream).
|
||||
@@ -303,7 +308,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
|
||||
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
|
||||
}
|
||||
.to_string()
|
||||
.to_string()
|
||||
});
|
||||
|
||||
println!("Loading model: {}", &model_id);
|
||||
@@ -337,7 +342,10 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V1(model)
|
||||
}
|
||||
WhichModel::BaseV2_2B | WhichModel::InstructV2_2B | WhichModel::BaseV2_9B | WhichModel::InstructV2_9B => {
|
||||
WhichModel::BaseV2_2B
|
||||
| WhichModel::InstructV2_2B
|
||||
| WhichModel::BaseV2_9B
|
||||
| WhichModel::InstructV2_9B => {
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V2(model)
|
||||
|
@@ -1,6 +1,6 @@
|
||||
use std::io::Write;
|
||||
use clap::Parser;
|
||||
use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
|
||||
use clap::Parser;
|
||||
use std::io::Write;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)]
|
||||
@@ -94,4 +94,4 @@ pub fn run_cli() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@@ -2,8 +2,8 @@
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
mod gemma_cli;
|
||||
mod gemma_api;
|
||||
mod gemma_cli;
|
||||
|
||||
use anyhow::Error;
|
||||
use clap::{Parser, ValueEnum};
|
||||
@@ -14,4 +14,4 @@ use std::io::Write;
|
||||
/// just a placeholder, not used for anything
|
||||
fn main() -> std::result::Result<(), Error> {
|
||||
run_cli()
|
||||
}
|
||||
}
|
||||
|
@@ -84,7 +84,10 @@ fn main() -> Result<()> {
|
||||
let services = discover_services(workspace_path)?;
|
||||
println!("Found {} services:", services.len());
|
||||
for service in &services {
|
||||
println!(" - {}: {} (port {})", service.name, service.image, service.port);
|
||||
println!(
|
||||
" - {}: {} (port {})",
|
||||
service.name, service.image, service.port
|
||||
);
|
||||
}
|
||||
|
||||
generate_helm_chart(output_path, chart_name, &services)?;
|
||||
@@ -115,17 +118,20 @@ fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
|
||||
fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> {
|
||||
let content = fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read Cargo.toml at {:?}", path))?;
|
||||
|
||||
|
||||
let cargo_toml: CargoToml = toml::from_str(&content)
|
||||
.with_context(|| format!("Failed to parse Cargo.toml at {:?}", path))?;
|
||||
|
||||
let package = cargo_toml.package
|
||||
let package = cargo_toml
|
||||
.package
|
||||
.ok_or_else(|| anyhow::anyhow!("No package section found in {:?}", path))?;
|
||||
|
||||
let metadata = package.metadata
|
||||
let metadata = package
|
||||
.metadata
|
||||
.ok_or_else(|| anyhow::anyhow!("No metadata section found in {:?}", path))?;
|
||||
|
||||
let kube_metadata = metadata.kube
|
||||
let kube_metadata = metadata
|
||||
.kube
|
||||
.ok_or_else(|| anyhow::anyhow!("No kube metadata found in {:?}", path))?;
|
||||
|
||||
Ok(ServiceInfo {
|
||||
@@ -136,7 +142,11 @@ fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> {
|
||||
})
|
||||
}
|
||||
|
||||
fn generate_helm_chart(output_path: &str, chart_name: &str, services: &[ServiceInfo]) -> Result<()> {
|
||||
fn generate_helm_chart(
|
||||
output_path: &str,
|
||||
chart_name: &str,
|
||||
services: &[ServiceInfo],
|
||||
) -> Result<()> {
|
||||
let chart_dir = Path::new(output_path);
|
||||
let templates_dir = chart_dir.join("templates");
|
||||
|
||||
@@ -512,4 +522,4 @@ fn generate_helmignore(chart_dir: &Path) -> Result<()> {
|
||||
|
||||
fs::write(chart_dir.join(".helmignore"), helmignore_content)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@@ -3,18 +3,6 @@ name = "inference-engine"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
|
||||
[[bin]]
|
||||
name="gemma_inference"
|
||||
path = "src/gemma_inference.rs"
|
||||
required-features = ["bin"]
|
||||
|
||||
[[bin]]
|
||||
name="llama_inference"
|
||||
path = "src/llama_inference.rs"
|
||||
required-features = ["bin"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { version = "0.3.2", optional = true }
|
||||
candle-datasets = { version = "=0.9.1", optional = true }
|
||||
|
@@ -30,4 +30,4 @@ pub trait ModelInference {
|
||||
}
|
||||
|
||||
/// Factory function type for creating model inference implementations
|
||||
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;
|
||||
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;
|
||||
|
@@ -1,19 +1,19 @@
|
||||
// Expose modules for testing and library usage
|
||||
pub mod token_output_stream;
|
||||
pub mod model;
|
||||
pub mod text_generation;
|
||||
pub mod utilities_lib;
|
||||
pub mod openai_types;
|
||||
pub mod text_generation;
|
||||
pub mod token_output_stream;
|
||||
pub mod utilities_lib;
|
||||
// pub mod cli;
|
||||
pub mod server;
|
||||
pub mod inference;
|
||||
pub mod server;
|
||||
|
||||
// Re-export key components for easier access
|
||||
pub use inference::ModelInference;
|
||||
pub use model::{Model, Which};
|
||||
pub use server::{create_router, AppState};
|
||||
pub use text_generation::TextGeneration;
|
||||
pub use token_output_stream::TokenOutputStream;
|
||||
pub use server::{AppState, create_router};
|
||||
pub use inference::ModelInference;
|
||||
|
||||
use std::env;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
@@ -1,8 +1,8 @@
|
||||
// use candle_core::Tensor;
|
||||
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
pub enum Which {
|
||||
@@ -52,7 +52,11 @@ pub enum Model {
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
input_ids: &candle_core::Tensor,
|
||||
pos: usize,
|
||||
) -> candle_core::Result<candle_core::Tensor> {
|
||||
match self {
|
||||
Self::V1(m) => m.forward(input_ids, pos),
|
||||
Self::V2(m) => m.forward(input_ids, pos),
|
||||
@@ -88,7 +92,13 @@ impl Which {
|
||||
|
||||
pub fn is_instruct_model(&self) -> bool {
|
||||
match self {
|
||||
Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false,
|
||||
Self::Base2B
|
||||
| Self::Base7B
|
||||
| Self::CodeBase2B
|
||||
| Self::CodeBase7B
|
||||
| Self::BaseV2_2B
|
||||
| Self::BaseV2_9B
|
||||
| Self::BaseV3_1B => false,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
@@ -100,4 +110,4 @@ impl Which {
|
||||
pub fn is_llama_model(&self) -> bool {
|
||||
matches!(self, Self::LlamaInstruct3_2_1B | Self::LlamaInstruct3_2_3B)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -10,7 +10,10 @@ pub struct MessageInnerContent(
|
||||
);
|
||||
|
||||
impl ToSchema<'_> for MessageInnerContent {
|
||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||
fn schema() -> (
|
||||
&'static str,
|
||||
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
|
||||
) {
|
||||
(
|
||||
"MessageInnerContent",
|
||||
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
||||
@@ -45,12 +48,18 @@ fn message_inner_content_schema() -> utoipa::openapi::Schema {
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MessageContent(
|
||||
#[serde(with = "either::serde_untagged")]
|
||||
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
||||
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
||||
);
|
||||
|
||||
impl ToSchema<'_> for MessageContent {
|
||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
|
||||
fn schema() -> (
|
||||
&'static str,
|
||||
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
|
||||
) {
|
||||
(
|
||||
"MessageContent",
|
||||
utoipa::openapi::RefOr::T(message_content_schema()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,4 +222,4 @@ pub struct ModelListResponse {
|
||||
pub object: String,
|
||||
/// Array of available models
|
||||
pub data: Vec<Model>,
|
||||
}
|
||||
}
|
||||
|
@@ -6,19 +6,22 @@ use axum::{
|
||||
Json, Router,
|
||||
};
|
||||
use futures_util::stream::{self, Stream};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
|
||||
use crate::openai_types::{
|
||||
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest,
|
||||
ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage,
|
||||
};
|
||||
use crate::Which;
|
||||
use either::Either;
|
||||
use serde_json::Value;
|
||||
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
|
||||
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
|
||||
use serde_json::Value;
|
||||
// -------------------------
|
||||
// Shared app state
|
||||
// -------------------------
|
||||
@@ -62,12 +65,15 @@ fn normalize_model_id(model_id: &str) -> String {
|
||||
|
||||
fn build_gemma_prompt(messages: &[Message]) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
|
||||
for message in messages {
|
||||
match message.role.as_str() {
|
||||
"system" => {
|
||||
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
||||
prompt.push_str(&format!("<start_of_turn>system\n{}<end_of_turn>\n", content));
|
||||
prompt.push_str(&format!(
|
||||
"<start_of_turn>system\n{}<end_of_turn>\n",
|
||||
content
|
||||
));
|
||||
}
|
||||
}
|
||||
"user" => {
|
||||
@@ -83,7 +89,7 @@ fn build_gemma_prompt(messages: &[Message]) -> String {
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
prompt.push_str("<start_of_turn>model\n");
|
||||
prompt
|
||||
}
|
||||
@@ -97,9 +103,13 @@ pub async fn chat_completions(
|
||||
Json(request): Json<ChatCompletionRequest>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
if !request.stream.unwrap_or(false) {
|
||||
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response());
|
||||
return Ok(chat_completions_non_streaming_proxy(state, request)
|
||||
.await
|
||||
.into_response());
|
||||
}
|
||||
Ok(chat_completions_stream(state, request).await.into_response())
|
||||
Ok(chat_completions_stream(state, request)
|
||||
.await
|
||||
.into_response())
|
||||
}
|
||||
|
||||
pub async fn chat_completions_non_streaming_proxy(
|
||||
@@ -136,7 +146,9 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||
ModelType::Llama => {
|
||||
// For Llama, just use the last user message for now
|
||||
request.messages.last()
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
@@ -147,46 +159,47 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
};
|
||||
|
||||
// Get streaming receiver based on model type
|
||||
let rx = match state.model_type {
|
||||
ModelType::Gemma => {
|
||||
if let Some(mut config) = state.gemma_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_gemma_api(config).map_err(|e| (
|
||||
let rx =
|
||||
match state.model_type {
|
||||
ModelType::Gemma => {
|
||||
if let Some(mut config) = state.gemma_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_gemma_api(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
}))
|
||||
));
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
ModelType::Llama => {
|
||||
if let Some(mut config) = state.llama_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
ModelType::Llama => {
|
||||
if let Some(mut config) = state.llama_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
}))
|
||||
));
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Collect all tokens from the stream
|
||||
let mut completion = String::new();
|
||||
@@ -281,7 +294,9 @@ async fn handle_streaming_request(
|
||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||
ModelType::Llama => {
|
||||
// For Llama, just use the last user message for now
|
||||
request.messages.last()
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
@@ -303,7 +318,10 @@ async fn handle_streaming_request(
|
||||
model: model_id.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: Some("assistant".to_string()), content: None },
|
||||
delta: Delta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
@@ -324,7 +342,7 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -333,7 +351,7 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -348,7 +366,7 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -357,7 +375,7 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -386,16 +404,20 @@ async fn handle_streaming_request(
|
||||
if recent_tokens.len() > REPETITION_WINDOW {
|
||||
recent_tokens.remove(0);
|
||||
}
|
||||
|
||||
|
||||
// Check for repetitive patterns
|
||||
if recent_tokens.len() >= 4 {
|
||||
let last_token = &recent_tokens[recent_tokens.len() - 1];
|
||||
let second_last = &recent_tokens[recent_tokens.len() - 2];
|
||||
|
||||
|
||||
if last_token == second_last {
|
||||
repetition_count += 1;
|
||||
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
|
||||
|
||||
tracing::warn!(
|
||||
"Detected repetition pattern: '{}' (count: {})",
|
||||
last_token,
|
||||
repetition_count
|
||||
);
|
||||
|
||||
if repetition_count >= MAX_REPETITION_COUNT {
|
||||
tracing::info!("Stopping generation due to excessive repetition");
|
||||
break;
|
||||
@@ -412,11 +434,14 @@ async fn handle_streaming_request(
|
||||
model: model_id_clone.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: None, content: Some(token) },
|
||||
delta: Delta {
|
||||
role: None,
|
||||
content: Some(token),
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&chunk) {
|
||||
let _ = tx.send(Ok(Event::default().data(json)));
|
||||
}
|
||||
@@ -436,7 +461,10 @@ async fn handle_streaming_request(
|
||||
model: model_id_clone.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: None, content: None },
|
||||
delta: Delta {
|
||||
role: None,
|
||||
content: None,
|
||||
},
|
||||
finish_reason: Some("stop".to_string()),
|
||||
}],
|
||||
};
|
||||
@@ -451,8 +479,6 @@ async fn handle_streaming_request(
|
||||
Ok(Sse::new(stream))
|
||||
}
|
||||
|
||||
|
||||
|
||||
// -------------------------
|
||||
// Router
|
||||
// -------------------------
|
||||
@@ -647,7 +673,6 @@ pub async fn list_models() -> Json<ModelListResponse> {
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -681,10 +706,7 @@ mod tests {
|
||||
|
||||
let prompt = build_gemma_prompt(&messages);
|
||||
|
||||
let expected = "<start_of_turn>user\nSystem message\n\nKnock knock.<end_of_turn>\n\
|
||||
<start_of_turn>model\nWho's there?<end_of_turn>\n\
|
||||
<start_of_turn>user\nGemma.<end_of_turn>\n\
|
||||
<start_of_turn>model\n";
|
||||
let expected = "<start_of_turn>system\nSystem message<end_of_turn>\n<start_of_turn>user\nKnock knock.<end_of_turn>\n<start_of_turn>model\nWho's there?<end_of_turn>\n<start_of_turn>user\nGemma.<end_of_turn>\n<start_of_turn>model\n";
|
||||
|
||||
assert_eq!(prompt, expected);
|
||||
}
|
||||
@@ -698,15 +720,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_missing_content() {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: None,
|
||||
name: None,
|
||||
}
|
||||
];
|
||||
let messages = vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: None,
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let prompt = build_gemma_prompt(&messages);
|
||||
assert_eq!(prompt, "<start_of_turn>user\n<end_of_turn>\n<start_of_turn>model\n");
|
||||
assert_eq!(prompt, "<start_of_turn>model\n");
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -84,4 +84,4 @@ impl TokenOutputStream {
|
||||
self.prev_index = 0;
|
||||
self.current_index = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -147,7 +147,8 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
||||
) -> Result<Vec<std::path::PathBuf>> {
|
||||
let path = path.as_ref();
|
||||
let jsfile = std::fs::File::open(path.join(json_file))?;
|
||||
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
None => candle_core::bail!("no weight map in {json_file:?}"),
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
@@ -164,4 +165,4 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
||||
.map(|v| path.join(v))
|
||||
.collect();
|
||||
Ok(safetensors_files)
|
||||
}
|
||||
}
|
||||
|
@@ -9,7 +9,10 @@ mod tests {
|
||||
// Test a few representative model variants
|
||||
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
|
||||
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
|
||||
assert_eq!(Which::InstructV1_1_2B.to_model_id(), "google/gemma-1.1-2b-it");
|
||||
assert_eq!(
|
||||
Which::InstructV1_1_2B.to_model_id(),
|
||||
"google/gemma-1.1-2b-it"
|
||||
);
|
||||
assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b");
|
||||
assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b");
|
||||
assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it");
|
||||
@@ -64,4 +67,4 @@ mod tests {
|
||||
// Note: Testing the Model enum's forward method would require creating actual model instances,
|
||||
// which is complex and would require loading model weights. This is better suited for
|
||||
// integration tests or mocking the models.
|
||||
}
|
||||
}
|
||||
|
@@ -106,7 +106,7 @@ mod tests {
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 3u32];
|
||||
|
||||
|
||||
// Create a mock TextGeneration instance
|
||||
// Since we can't easily create a full TextGeneration instance without a model,
|
||||
// we'll test the logic by creating a simple struct with the necessary fields
|
||||
@@ -115,7 +115,7 @@ mod tests {
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
@@ -167,16 +167,17 @@ mod tests {
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 1.0, // No penalty
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
|
||||
// With no penalty, logits should be unchanged
|
||||
assert_eq!(result_data, logits_data);
|
||||
Ok(())
|
||||
@@ -189,13 +190,13 @@ mod tests {
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 3u32];
|
||||
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
@@ -238,16 +239,17 @@ mod tests {
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0, // Apply penalty
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
|
||||
// Tokens 1, 2, 3 should be penalized (divided by 2.0)
|
||||
let expected = vec![1.0f32, 1.0, 1.5, 2.0, 5.0]; // [1.0, 2.0/2.0, 3.0/2.0, 4.0/2.0, 5.0]
|
||||
assert_eq!(result_data, expected);
|
||||
@@ -261,13 +263,13 @@ mod tests {
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 1u32, 1u32]; // Repeated token should use cache
|
||||
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
@@ -308,20 +310,21 @@ mod tests {
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0,
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
|
||||
// First call should cache the penalty for token 1
|
||||
let (_result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
let (_result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
// Cache should contain the penalized value for token 1
|
||||
assert!(mock_gen.penalty_cache.contains_key(&1));
|
||||
assert_eq!(mock_gen.penalty_cache.get(&1), Some(&1.0)); // 2.0 / 2.0 = 1.0
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -332,13 +335,13 @@ mod tests {
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens: Vec<u32> = vec![]; // Empty tokens
|
||||
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
@@ -379,16 +382,17 @@ mod tests {
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0,
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
|
||||
// With empty tokens, logits should be unchanged
|
||||
assert_eq!(result_data, logits_data);
|
||||
Ok(())
|
||||
@@ -401,13 +405,13 @@ mod tests {
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 5u32, 10u32]; // Token 5 and 10 are out of bounds
|
||||
|
||||
|
||||
struct MockTextGeneration {
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
penalty_cache: HashMap<usize, f32>,
|
||||
}
|
||||
|
||||
|
||||
impl MockTextGeneration {
|
||||
fn apply_cached_repeat_penalty(
|
||||
&mut self,
|
||||
@@ -448,16 +452,17 @@ mod tests {
|
||||
Ok((result, elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut mock_gen = MockTextGeneration {
|
||||
repeat_penalty: 2.0,
|
||||
repeat_last_n: 3,
|
||||
penalty_cache: HashMap::new(),
|
||||
};
|
||||
|
||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
|
||||
let (result_logits, _duration) =
|
||||
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||
let result_data = result_logits.to_vec1::<f32>()?;
|
||||
|
||||
|
||||
// Only token 1 should be penalized, out-of-bounds tokens should be ignored
|
||||
let expected = vec![1.0f32, 1.0, 3.0]; // [1.0, 2.0/2.0, 3.0]
|
||||
assert_eq!(result_data, expected);
|
||||
@@ -471,52 +476,52 @@ mod tests {
|
||||
// Since creating a real TextGeneration instance requires a Model which needs model weights,
|
||||
// we'll create a test that demonstrates the method is now public and can be accessed.
|
||||
// The comprehensive functionality testing is already covered by the mock tests above.
|
||||
|
||||
|
||||
// Test data setup
|
||||
let device = Device::Cpu;
|
||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 3u32];
|
||||
|
||||
|
||||
// Test that we can create the necessary components
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
|
||||
|
||||
// The method is now public as confirmed by making it pub fn apply_cached_repeat_penalty
|
||||
// This test verifies the method signature and that it's accessible from external code
|
||||
|
||||
|
||||
// We could create a TextGeneration instance if we had a way to mock the Model,
|
||||
// but for now we confirm that the existing mock tests cover the functionality
|
||||
// and the method is properly exposed as public
|
||||
|
||||
|
||||
println!("apply_cached_repeat_penalty method is now public and accessible for testing");
|
||||
assert!(true);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
// Integration test that demonstrates the method usage pattern
|
||||
#[test]
|
||||
#[test]
|
||||
fn test_apply_cached_repeat_penalty_usage_pattern() -> Result<()> {
|
||||
// This test demonstrates how the apply_cached_repeat_penalty method would be used
|
||||
// in practice, even though we can't create a full TextGeneration instance in unit tests
|
||||
|
||||
|
||||
let device = Device::Cpu;
|
||||
let logits_data = vec![1.5f32, 2.5, 3.5, 4.5, 5.5];
|
||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
||||
let tokens = vec![1u32, 2u32, 1u32, 3u32]; // Repeated token 1 to test caching
|
||||
|
||||
|
||||
// Test parameters that would be used with TextGeneration
|
||||
let repeat_penalty = 1.2f32;
|
||||
let repeat_last_n = 3usize;
|
||||
let mut penalty_cache: HashMap<usize, f32> = HashMap::new();
|
||||
|
||||
|
||||
// Simulate the method's logic to verify it works as expected
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
if repeat_penalty != 1.0 {
|
||||
let start_at = tokens.len().saturating_sub(repeat_last_n);
|
||||
let penalty_tokens = &tokens[start_at..];
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
|
||||
for &token_id in penalty_tokens {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
@@ -531,14 +536,14 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let _duration = start_time.elapsed();
|
||||
|
||||
|
||||
// Verify that tokens were processed correctly
|
||||
assert!(penalty_cache.contains_key(&1)); // Token 1 should be cached
|
||||
assert!(penalty_cache.contains_key(&2)); // Token 2 should be cached
|
||||
assert!(penalty_cache.contains_key(&2)); // Token 2 should be cached
|
||||
assert!(penalty_cache.contains_key(&3)); // Token 3 should be cached
|
||||
|
||||
|
||||
println!("Successfully demonstrated apply_cached_repeat_penalty usage pattern");
|
||||
Ok(())
|
||||
}
|
||||
|
@@ -1,7 +1,7 @@
|
||||
use inference_engine::token_output_stream::TokenOutputStream;
|
||||
use tokenizers::Tokenizer;
|
||||
use std::path::PathBuf;
|
||||
use anyhow::Result;
|
||||
use inference_engine::token_output_stream::TokenOutputStream;
|
||||
use std::path::PathBuf;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@@ -19,7 +19,7 @@ mod tests {
|
||||
fn test_new_token_output_stream() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Check that the token stream was created successfully
|
||||
assert!(token_stream.tokenizer().get_vocab(true).len() > 0);
|
||||
Ok(())
|
||||
@@ -29,18 +29,18 @@ mod tests {
|
||||
fn test_clear() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Add a token
|
||||
let token_id = token_stream.get_token("<eos>").unwrap();
|
||||
token_stream.next_token(token_id)?;
|
||||
|
||||
|
||||
// Clear the stream
|
||||
token_stream.clear();
|
||||
|
||||
|
||||
// Check that the stream is empty by trying to decode all
|
||||
let decoded = token_stream.decode_all()?;
|
||||
assert_eq!(decoded, "");
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -48,15 +48,15 @@ mod tests {
|
||||
fn test_get_token() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Get a token that should exist
|
||||
let eos_token = token_stream.get_token("<eos>");
|
||||
assert!(eos_token.is_some());
|
||||
|
||||
|
||||
// Get a token that shouldn't exist
|
||||
let nonexistent_token = token_stream.get_token("<this_token_does_not_exist>");
|
||||
assert!(nonexistent_token.is_none());
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -64,11 +64,14 @@ mod tests {
|
||||
fn test_next_token_and_decode() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Get some tokens
|
||||
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
||||
let hello_tokens = token_stream
|
||||
.tokenizer()
|
||||
.encode("Hello world", true)
|
||||
.unwrap();
|
||||
let token_ids = hello_tokens.get_ids();
|
||||
|
||||
|
||||
// Add tokens one by one
|
||||
let mut output = String::new();
|
||||
for &token_id in token_ids {
|
||||
@@ -76,16 +79,16 @@ mod tests {
|
||||
output.push_str(&text);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Get any remaining text
|
||||
if let Some(rest) = token_stream.decode_rest()? {
|
||||
output.push_str(&rest);
|
||||
}
|
||||
|
||||
|
||||
// Check the output
|
||||
assert!(!output.is_empty());
|
||||
assert_eq!(output.trim(), "Hello world");
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -93,22 +96,25 @@ mod tests {
|
||||
fn test_decode_all() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Get some tokens
|
||||
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
||||
let hello_tokens = token_stream
|
||||
.tokenizer()
|
||||
.encode("Hello world", true)
|
||||
.unwrap();
|
||||
let token_ids = hello_tokens.get_ids();
|
||||
|
||||
|
||||
// Add tokens one by one
|
||||
for &token_id in token_ids {
|
||||
token_stream.next_token(token_id)?;
|
||||
}
|
||||
|
||||
|
||||
// Decode all
|
||||
let decoded = token_stream.decode_all()?;
|
||||
|
||||
|
||||
// Check the output
|
||||
assert_eq!(decoded.trim(), "Hello world");
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -116,14 +122,14 @@ mod tests {
|
||||
fn test_into_inner() -> Result<()> {
|
||||
let tokenizer = create_test_tokenizer()?;
|
||||
let token_stream = TokenOutputStream::new(tokenizer);
|
||||
|
||||
|
||||
// Get the inner tokenizer
|
||||
let inner_tokenizer = token_stream.into_inner();
|
||||
|
||||
|
||||
// Check that the inner tokenizer works
|
||||
let encoded = inner_tokenizer.encode("Test", true).unwrap();
|
||||
assert!(encoded.get_ids().len() > 0);
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -5,6 +5,25 @@ use leptos_router::{
|
||||
StaticSegment,
|
||||
};
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
use async_openai_wasm::config::OpenAIConfig;
|
||||
#[cfg(feature = "hydrate")]
|
||||
use async_openai_wasm::types::{FinishReason, Role};
|
||||
#[cfg(feature = "hydrate")]
|
||||
use async_openai_wasm::{
|
||||
types::{
|
||||
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
|
||||
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs,
|
||||
Model as OpenAIModel,
|
||||
},
|
||||
Client,
|
||||
};
|
||||
#[cfg(feature = "hydrate")]
|
||||
use futures_util::StreamExt;
|
||||
#[cfg(feature = "hydrate")]
|
||||
use js_sys::Date;
|
||||
#[cfg(feature = "hydrate")]
|
||||
use leptos::task::spawn_local;
|
||||
#[cfg(feature = "hydrate")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(feature = "hydrate")]
|
||||
@@ -12,25 +31,7 @@ use std::collections::VecDeque;
|
||||
#[cfg(feature = "hydrate")]
|
||||
use uuid::Uuid;
|
||||
#[cfg(feature = "hydrate")]
|
||||
use js_sys::Date;
|
||||
#[cfg(feature = "hydrate")]
|
||||
use web_sys::{HtmlInputElement, KeyboardEvent, SubmitEvent};
|
||||
#[cfg(feature = "hydrate")]
|
||||
use futures_util::StreamExt;
|
||||
#[cfg(feature = "hydrate")]
|
||||
use async_openai_wasm::{
|
||||
types::{
|
||||
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
|
||||
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, Model as OpenAIModel,
|
||||
},
|
||||
Client,
|
||||
};
|
||||
#[cfg(feature = "hydrate")]
|
||||
use async_openai_wasm::config::OpenAIConfig;
|
||||
#[cfg(feature = "hydrate")]
|
||||
use async_openai_wasm::types::{Role, FinishReason};
|
||||
#[cfg(feature = "hydrate")]
|
||||
use leptos::task::spawn_local;
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -43,11 +44,15 @@ pub struct Message {
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MessageContent(pub either::Either<String, Vec<std::collections::HashMap<String, MessageInnerContent>>>);
|
||||
pub struct MessageContent(
|
||||
pub either::Either<String, Vec<std::collections::HashMap<String, MessageInnerContent>>>,
|
||||
);
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MessageInnerContent(pub either::Either<String, std::collections::HashMap<String, String>>);
|
||||
pub struct MessageInnerContent(
|
||||
pub either::Either<String, std::collections::HashMap<String, String>>,
|
||||
);
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -62,27 +67,40 @@ const DEFAULT_MODEL: &str = "default";
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
async fn fetch_available_models() -> Result<Vec<OpenAIModel>, String> {
|
||||
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Starting model fetch from http://localhost:8080/v1");
|
||||
|
||||
leptos::logging::log!(
|
||||
"[DEBUG_LOG] fetch_available_models: Starting model fetch from http://localhost:8080/v1"
|
||||
);
|
||||
|
||||
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
|
||||
let client = Client::with_config(config);
|
||||
|
||||
|
||||
match client.models().list().await {
|
||||
Ok(response) => {
|
||||
let model_count = response.data.len();
|
||||
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Successfully fetched {} models", model_count);
|
||||
|
||||
leptos::logging::log!(
|
||||
"[DEBUG_LOG] fetch_available_models: Successfully fetched {} models",
|
||||
model_count
|
||||
);
|
||||
|
||||
if model_count > 0 {
|
||||
let model_names: Vec<String> = response.data.iter().map(|m| m.id.clone()).collect();
|
||||
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Available models: {:?}", model_names);
|
||||
leptos::logging::log!(
|
||||
"[DEBUG_LOG] fetch_available_models: Available models: {:?}",
|
||||
model_names
|
||||
);
|
||||
} else {
|
||||
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: No models returned by server");
|
||||
leptos::logging::log!(
|
||||
"[DEBUG_LOG] fetch_available_models: No models returned by server"
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
Ok(response.data)
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Failed to fetch models: {:?}", e);
|
||||
leptos::logging::log!(
|
||||
"[DEBUG_LOG] fetch_available_models: Failed to fetch models: {:?}",
|
||||
e
|
||||
);
|
||||
Err(format!("Failed to fetch models: {}", e))
|
||||
}
|
||||
}
|
||||
@@ -150,7 +168,7 @@ fn ChatInterface() -> impl IntoView {
|
||||
{
|
||||
ChatInterfaceImpl()
|
||||
}
|
||||
|
||||
|
||||
#[cfg(not(feature = "hydrate"))]
|
||||
{
|
||||
view! {
|
||||
@@ -252,7 +270,7 @@ fn ChatInterfaceImpl() -> impl IntoView {
|
||||
|
||||
let current_model = selected_model.get_untracked();
|
||||
let total_messages = chat_messages.len();
|
||||
|
||||
|
||||
leptos::logging::log!("[DEBUG_LOG] send_message: Preparing request - model: '{}', history_count: {}, total_messages: {}",
|
||||
current_model, history_count, total_messages);
|
||||
|
||||
@@ -267,17 +285,17 @@ fn ChatInterfaceImpl() -> impl IntoView {
|
||||
// Send request
|
||||
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
|
||||
let client = Client::with_config(config);
|
||||
|
||||
|
||||
leptos::logging::log!("[DEBUG_LOG] send_message: Sending request to http://localhost:8080/v1 with model: '{}'", current_model);
|
||||
|
||||
match client.chat().create_stream(request).await {
|
||||
Ok(mut stream) => {
|
||||
leptos::logging::log!("[DEBUG_LOG] send_message: Successfully created stream");
|
||||
|
||||
|
||||
let mut assistant_created = false;
|
||||
let mut content_appended = false;
|
||||
let mut chunks_received = 0;
|
||||
|
||||
|
||||
while let Some(next) = stream.next().await {
|
||||
match next {
|
||||
Ok(chunk) => {
|
||||
@@ -335,7 +353,11 @@ fn ChatInterfaceImpl() -> impl IntoView {
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
leptos::logging::log!("[DEBUG_LOG] send_message: Stream error after {} chunks: {:?}", chunks_received, e);
|
||||
leptos::logging::log!(
|
||||
"[DEBUG_LOG] send_message: Stream error after {} chunks: {:?}",
|
||||
chunks_received,
|
||||
e
|
||||
);
|
||||
set_messages.update(|msgs| {
|
||||
msgs.push_back(Message {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
@@ -364,7 +386,10 @@ fn ChatInterfaceImpl() -> impl IntoView {
|
||||
leptos::logging::log!("[DEBUG_LOG] send_message: Stream completed successfully, received {} chunks", chunks_received);
|
||||
}
|
||||
Err(e) => {
|
||||
leptos::logging::log!("[DEBUG_LOG] send_message: Request failed with error: {:?}", e);
|
||||
leptos::logging::log!(
|
||||
"[DEBUG_LOG] send_message: Request failed with error: {:?}",
|
||||
e
|
||||
);
|
||||
let error_message = Message {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
role: "system".to_string(),
|
||||
@@ -404,7 +429,8 @@ fn ChatInterfaceImpl() -> impl IntoView {
|
||||
};
|
||||
|
||||
let messages_list = move || {
|
||||
messages.get()
|
||||
messages
|
||||
.get()
|
||||
.into_iter()
|
||||
.map(|message| {
|
||||
let role_class = match message.role.as_str() {
|
||||
@@ -439,7 +465,7 @@ fn ChatInterfaceImpl() -> impl IntoView {
|
||||
<h1>"Chat Interface"</h1>
|
||||
<div class="model-selector">
|
||||
<label for="model-select">"Model: "</label>
|
||||
<select
|
||||
<select
|
||||
id="model-select"
|
||||
on:change=on_model_change
|
||||
prop:value=selected_model
|
||||
|
@@ -10,10 +10,10 @@ pub fn hydrate() {
|
||||
|
||||
#[cfg(feature = "ssr")]
|
||||
pub fn create_leptos_router() -> axum::Router {
|
||||
use crate::app::*;
|
||||
use axum::Router;
|
||||
use leptos::prelude::*;
|
||||
use leptos_axum::{generate_route_list, LeptosRoutes};
|
||||
use crate::app::*;
|
||||
|
||||
let conf = get_configuration(None).unwrap();
|
||||
let leptos_options = conf.leptos_options;
|
||||
|
@@ -1,12 +1,11 @@
|
||||
|
||||
#[cfg(feature = "ssr")]
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
use axum::Router;
|
||||
use leptos::logging::log;
|
||||
use leptos::prelude::*;
|
||||
use leptos_axum::{generate_route_list, LeptosRoutes};
|
||||
use leptos_app::app::*;
|
||||
use leptos_axum::{generate_route_list, LeptosRoutes};
|
||||
|
||||
let conf = get_configuration(None).unwrap();
|
||||
let addr = conf.leptos_options.site_addr;
|
||||
|
@@ -18,6 +18,11 @@ candle-core = { git = "https://github.com/huggingface/candle.git", features = ["
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
|
||||
[target.'cfg(not(target_os = "macos"))'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
|
@@ -5,4 +5,3 @@ pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
||||
|
||||
// Re-export constants and types that might be needed
|
||||
pub const EOS_TOKEN: &str = "</s>";
|
||||
|
||||
|
@@ -1,14 +1,14 @@
|
||||
use crate::EOS_TOKEN;
|
||||
use anyhow::{bail, Error as E};
|
||||
use candle_core::{utils, DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use candle_transformers::models::llama::{Llama, LlamaConfig};
|
||||
use candle_transformers::models::llama as model;
|
||||
use candle_transformers::models::llama::{Llama, LlamaConfig};
|
||||
use clap::ValueEnum;
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use std::sync::mpsc::{self, Receiver};
|
||||
use clap::ValueEnum;
|
||||
use crate::{EOS_TOKEN};
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
|
||||
pub enum WhichModel {
|
||||
@@ -81,8 +81,8 @@ impl Default for LlamaInferenceConfig {
|
||||
max_tokens: 512,
|
||||
|
||||
// Performance flags
|
||||
no_kv_cache: false, // keep cache ON for speed
|
||||
use_flash_attn: true, // great speed boost if supported
|
||||
no_kv_cache: false, // keep cache ON for speed
|
||||
use_flash_attn: true, // great speed boost if supported
|
||||
|
||||
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
|
||||
dtype: Some("bf16".to_string()),
|
||||
@@ -98,8 +98,6 @@ impl Default for LlamaInferenceConfig {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
fn device(cpu: bool) -> anyhow::Result<Device> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
@@ -112,7 +110,6 @@ fn device(cpu: bool) -> anyhow::Result<Device> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn hub_load_safetensors(
|
||||
api: &hf_hub::api::sync::ApiRepo,
|
||||
json_file: &str,
|
||||
@@ -171,7 +168,7 @@ pub fn run_llama_inference(
|
||||
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
}
|
||||
.to_string()
|
||||
.to_string()
|
||||
});
|
||||
println!("Loading model: {}", model_id);
|
||||
let revision = cfg.revision.clone().unwrap_or("main".to_string());
|
||||
@@ -334,4 +331,3 @@ pub fn run_llama_inference(
|
||||
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
|
@@ -88,7 +88,6 @@ impl Into<LlamaInferenceConfig> for Args {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn run_cli() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let cfg = args.into();
|
||||
@@ -106,4 +105,4 @@ pub fn run_cli() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@@ -2,8 +2,8 @@
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
mod llama_cli;
|
||||
mod llama_api;
|
||||
mod llama_cli;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
@@ -14,7 +14,6 @@ use crate::llama_cli::run_cli;
|
||||
|
||||
const EOS_TOKEN: &str = "</s>";
|
||||
|
||||
|
||||
fn main() -> Result<()> {
|
||||
run_cli()
|
||||
}
|
||||
}
|
||||
|
@@ -1,7 +1,9 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use tracing::info;
|
||||
use tracing::log::error;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ServerConfig {
|
||||
#[serde(default = "default_server_host")]
|
||||
@@ -10,14 +12,16 @@ pub struct ServerConfig {
|
||||
pub server_port: u16,
|
||||
pub server_mode: ServerMode,
|
||||
#[serde(default)]
|
||||
pub services: Services,
|
||||
pub services: Option<Services>,
|
||||
}
|
||||
|
||||
fn default_server_host() -> String {
|
||||
"127.0.0.1".to_string()
|
||||
}
|
||||
|
||||
fn default_server_port() -> u16 { 8080 }
|
||||
fn default_server_port() -> u16 {
|
||||
8080
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
@@ -34,17 +38,15 @@ impl Default for ServerMode {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Services {
|
||||
#[serde(default = "inference_service_url")]
|
||||
pub inference_url: String,
|
||||
#[serde(default = "embeddings_service_url")]
|
||||
pub embeddings_url: String,
|
||||
pub inference_url: Option<String>,
|
||||
pub embeddings_url: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for Services {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
inference_url: inference_service_url(),
|
||||
embeddings_url: embeddings_service_url(),
|
||||
inference_url: None,
|
||||
embeddings_url: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -63,7 +65,7 @@ impl Default for ServerConfig {
|
||||
server_host: "127.0.0.1".to_string(),
|
||||
server_port: 8080,
|
||||
server_mode: ServerMode::Standalone,
|
||||
services: Services::default(),
|
||||
services: Some(Services::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -73,21 +75,19 @@ impl ServerConfig {
|
||||
/// Falls back to default (Local mode) if not set or invalid
|
||||
pub fn from_env() -> Self {
|
||||
match env::var("SERVER_CONFIG") {
|
||||
Ok(config_str) => {
|
||||
match serde_json::from_str::<ServerConfig>(&config_str) {
|
||||
Ok(config) => {
|
||||
tracing::info!("Loaded server configuration: {:?}", config);
|
||||
config
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to parse SERVER_CONFIG environment variable: {}. Using default configuration.",
|
||||
e
|
||||
);
|
||||
ServerConfig::default()
|
||||
}
|
||||
Ok(config_str) => match serde_json::from_str::<ServerConfig>(&config_str) {
|
||||
Ok(config) => {
|
||||
tracing::info!("Loaded server configuration: {:?}", config);
|
||||
config
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to parse SERVER_CONFIG environment variable: {}. Using default configuration.",
|
||||
e
|
||||
);
|
||||
ServerConfig::default()
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
tracing::info!("SERVER_CONFIG not set, Standalone mode active");
|
||||
ServerConfig::default()
|
||||
@@ -96,18 +96,52 @@ impl ServerConfig {
|
||||
}
|
||||
|
||||
/// Check if the server should run in high availability mode
|
||||
pub fn is_high_availability(&self) -> bool {
|
||||
self.server_mode == ServerMode::HighAvailability
|
||||
pub fn is_high_availability(&self) -> Result<bool, std::io::Error> {
|
||||
if self.server_mode == ServerMode::HighAvailability {
|
||||
let services_well_defined: bool = self.clone().services.is_some();
|
||||
|
||||
let inference_url_well_defined: bool =
|
||||
services_well_defined && self.clone().services.unwrap().inference_url.is_some();
|
||||
|
||||
let embeddings_well_defined: bool =
|
||||
services_well_defined && self.clone().services.unwrap().embeddings_url.is_some();
|
||||
|
||||
let is_well_defined_for_ha =
|
||||
services_well_defined && inference_url_well_defined && embeddings_well_defined;
|
||||
|
||||
if !is_well_defined_for_ha {
|
||||
let config_string = serde_json::to_string_pretty(&self).unwrap();
|
||||
error!(
|
||||
"HighAvailability mode configured but services not well defined! \n## Config Used:\n {}",
|
||||
config_string
|
||||
);
|
||||
let err = std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"HighAvailability mode configured but services not well defined!",
|
||||
);
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(self.server_mode == ServerMode::HighAvailability)
|
||||
}
|
||||
|
||||
/// Get the inference service URL for proxying
|
||||
pub fn inference_url(&self) -> &str {
|
||||
&self.services.inference_url
|
||||
pub fn inference_url(&self) -> Option<String> {
|
||||
if self.services.is_some() {
|
||||
self.services.clone()?.inference_url
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the embeddings service URL for proxying
|
||||
pub fn embeddings_url(&self) -> &str {
|
||||
&self.services.embeddings_url
|
||||
pub fn embeddings_url(&self) -> Option<String> {
|
||||
if self.services.is_some() {
|
||||
self.services.clone()?.embeddings_url
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,7 +153,7 @@ mod tests {
|
||||
fn test_default_config() {
|
||||
let config = ServerConfig::default();
|
||||
assert_eq!(config.server_mode, ServerMode::Standalone);
|
||||
assert!(!config.is_high_availability());
|
||||
assert!(!config.is_high_availability().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -134,23 +168,26 @@ mod tests {
|
||||
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert_eq!(config.server_mode, ServerMode::HighAvailability);
|
||||
assert!(config.is_high_availability());
|
||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
||||
assert!(config.is_high_availability().unwrap());
|
||||
assert_eq!(
|
||||
config.inference_url().unwrap(),
|
||||
"http://inference-service:8080"
|
||||
);
|
||||
assert_eq!(
|
||||
config.embeddings_url().unwrap(),
|
||||
"http://embeddings-service:8080"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_local_mode_config() {
|
||||
let config_json = r#"{
|
||||
"serverMode": "Local"
|
||||
"serverMode": "Standalone"
|
||||
}"#;
|
||||
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert_eq!(config.server_mode, ServerMode::Standalone);
|
||||
assert!(!config.is_high_availability());
|
||||
// Should use default URLs
|
||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
||||
assert!(!config.is_high_availability().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -164,17 +201,26 @@ mod tests {
|
||||
}"#;
|
||||
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert_eq!(config.inference_url(), "http://custom-inference:9000");
|
||||
assert_eq!(config.embeddings_url(), "http://custom-embeddings:9001");
|
||||
assert_eq!(
|
||||
config.inference_url().unwrap(),
|
||||
"http://custom-inference:9000"
|
||||
);
|
||||
assert_eq!(
|
||||
config.embeddings_url().unwrap(),
|
||||
"http://custom-embeddings:9001"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_minimal_high_availability_config() {
|
||||
fn test_minimal_high_availability_config_error() {
|
||||
let config_json = r#"{"serverMode": "HighAvailability"}"#;
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert!(config.is_high_availability());
|
||||
// Should use default URLs
|
||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
||||
|
||||
let is_high_availability = config.is_high_availability();
|
||||
|
||||
assert!(is_high_availability.is_err());
|
||||
// // Should use default URLs
|
||||
// assert_eq!(config.inference_url().unwrap(), "http://inference-service:8080");
|
||||
// assert_eq!(config.embeddings_url().unwrap(), "http://embeddings-service:8080");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,7 +1,9 @@
|
||||
mod config;
|
||||
mod middleware;
|
||||
mod proxy;
|
||||
mod standalone;
|
||||
|
||||
use crate::standalone::create_standalone_router;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::get;
|
||||
use axum::{Router, http::Uri, response::Html, serve};
|
||||
@@ -11,6 +13,7 @@ use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
|
||||
use proxy::create_proxy_router;
|
||||
use rust_embed::Embed;
|
||||
use std::env;
|
||||
use std::path::Component::ParentDir;
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::classify::ServerErrorsFailureClass::StatusCode;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
@@ -49,33 +52,19 @@ async fn main() {
|
||||
let default_host = server_config.server_host.clone();
|
||||
let default_port = server_config.server_port;
|
||||
|
||||
// Create router based on server mode
|
||||
let service_router = if server_config.clone().is_high_availability() {
|
||||
tracing::info!("Running in HighAvailability mode - proxying to external services");
|
||||
tracing::info!(" Inference service URL: {}", server_config.inference_url());
|
||||
tracing::info!(
|
||||
" Embeddings service URL: {}",
|
||||
server_config.embeddings_url()
|
||||
);
|
||||
|
||||
// Use proxy router that forwards requests to external services
|
||||
create_proxy_router(server_config.clone())
|
||||
} else {
|
||||
tracing::info!("Running in Standalone mode - using embedded services");
|
||||
|
||||
// Create unified router by merging embeddings and inference routers (existing behavior)
|
||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||
|
||||
// Create AppState with correct model configuration
|
||||
let app_state = AppState::default();
|
||||
|
||||
// Get the inference router directly from the inference engine
|
||||
let inference_router = inference_engine::create_router(app_state);
|
||||
|
||||
// Merge the local routers
|
||||
Router::new()
|
||||
.merge(embeddings_router)
|
||||
.merge(inference_router)
|
||||
let service_router = match server_config.clone().is_high_availability() {
|
||||
Ok(is_ha) => {
|
||||
if is_ha {
|
||||
log_config(server_config.clone());
|
||||
create_proxy_router(server_config.clone())
|
||||
} else {
|
||||
log_config(server_config.clone());
|
||||
create_standalone_router(server_config)
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
panic!("{}", error);
|
||||
}
|
||||
};
|
||||
|
||||
// Create CORS layer
|
||||
@@ -124,5 +113,25 @@ async fn main() {
|
||||
serve(listener, app).await.unwrap();
|
||||
}
|
||||
|
||||
fn log_config(config: ServerConfig) {
|
||||
match config.is_high_availability() {
|
||||
Ok(is_high) => {
|
||||
if is_high {
|
||||
tracing::info!("Running in HighAvailability mode - proxying to external services");
|
||||
tracing::info!("Inference service URL: {}", config.inference_url().unwrap());
|
||||
tracing::info!(
|
||||
"Embeddings service URL: {}",
|
||||
config.embeddings_url().unwrap()
|
||||
);
|
||||
} else {
|
||||
tracing::info!("Running in Standalone mode");
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
panic!("{}", error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Chat completions handler that properly uses the inference server crate's error handling
|
||||
// This function is no longer needed as we're using the inference_engine router directly
|
||||
|
@@ -2,6 +2,8 @@ use axum::{
|
||||
extract::MatchedPath,
|
||||
http::{Request, Response},
|
||||
};
|
||||
use std::fmt;
|
||||
use std::task::ready;
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
@@ -12,8 +14,6 @@ use std::{
|
||||
use tokio::sync::Mutex;
|
||||
use tower::{Layer, Service};
|
||||
use tracing::{debug, info};
|
||||
use std::task::ready;
|
||||
use std::fmt;
|
||||
|
||||
/// Performance metrics for a specific endpoint
|
||||
#[derive(Debug, Clone, Default)]
|
||||
@@ -33,16 +33,16 @@ impl EndpointMetrics {
|
||||
pub fn add_response_time(&mut self, time_ms: u64) {
|
||||
self.count += 1;
|
||||
self.total_time_ms += time_ms;
|
||||
|
||||
|
||||
if self.min_time_ms == 0 || time_ms < self.min_time_ms {
|
||||
self.min_time_ms = time_ms;
|
||||
}
|
||||
|
||||
|
||||
if time_ms > self.max_time_ms {
|
||||
self.max_time_ms = time_ms;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Get the average response time in milliseconds
|
||||
pub fn avg_time_ms(&self) -> f64 {
|
||||
if self.count == 0 {
|
||||
@@ -51,12 +51,15 @@ impl EndpointMetrics {
|
||||
self.total_time_ms as f64 / self.count as f64
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Get a human-readable summary of the metrics
|
||||
pub fn summary(&self) -> String {
|
||||
format!(
|
||||
"requests: {}, avg: {:.2}ms, min: {}ms, max: {}ms",
|
||||
self.count, self.avg_time_ms(), self.min_time_ms, self.max_time_ms
|
||||
self.count,
|
||||
self.avg_time_ms(),
|
||||
self.min_time_ms,
|
||||
self.max_time_ms
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -75,14 +78,16 @@ impl MetricsStore {
|
||||
endpoints: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Record a request's timing information
|
||||
pub async fn record(&self, path: String, time_ms: u64) {
|
||||
let mut endpoints = self.endpoints.lock().await;
|
||||
let metrics = endpoints.entry(path).or_insert_with(EndpointMetrics::default);
|
||||
let metrics = endpoints
|
||||
.entry(path)
|
||||
.or_insert_with(EndpointMetrics::default);
|
||||
metrics.add_response_time(time_ms);
|
||||
}
|
||||
|
||||
|
||||
/// Get metrics for all endpoints
|
||||
pub async fn get_all(&self) -> Vec<(String, EndpointMetrics)> {
|
||||
let endpoints = self.endpoints.lock().await;
|
||||
@@ -91,12 +96,12 @@ impl MetricsStore {
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
|
||||
/// Log a summary of all metrics
|
||||
pub async fn log_summary(&self) {
|
||||
let metrics = self.get_all().await;
|
||||
info!("Performance metrics summary:");
|
||||
|
||||
|
||||
for (path, metric) in metrics {
|
||||
info!(" {}: {}", path, metric.summary());
|
||||
}
|
||||
@@ -163,26 +168,28 @@ where
|
||||
} else {
|
||||
req.uri().path().to_string()
|
||||
};
|
||||
|
||||
|
||||
let method = req.method().clone();
|
||||
let start = Instant::now();
|
||||
let metrics_store = self.metrics_store.clone();
|
||||
|
||||
|
||||
let future = self.inner.call(req);
|
||||
|
||||
|
||||
Box::pin(async move {
|
||||
let response = future.await?;
|
||||
|
||||
|
||||
let time = start.elapsed();
|
||||
let status = response.status();
|
||||
let time_ms = time.as_millis() as u64;
|
||||
|
||||
|
||||
// Record the timing in our metrics store
|
||||
metrics_store.record(format!("{} {}", method, path), time_ms).await;
|
||||
|
||||
metrics_store
|
||||
.record(format!("{} {}", method, path), time_ms)
|
||||
.await;
|
||||
|
||||
// Log the request timing
|
||||
debug!("{} {} {} - {} ms", method, path, status, time_ms);
|
||||
|
||||
|
||||
Ok(response)
|
||||
})
|
||||
}
|
||||
@@ -214,7 +221,7 @@ impl Future for MetricsLoggerFuture {
|
||||
metrics_store.log_summary().await;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,7 +1,3 @@
|
||||
pub mod metrics;
|
||||
|
||||
pub use metrics::{
|
||||
MetricsStore,
|
||||
MetricsLoggerFuture,
|
||||
MetricsLayer,
|
||||
};
|
||||
pub use metrics::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
|
||||
|
@@ -1,10 +1,10 @@
|
||||
use axum::{
|
||||
Router,
|
||||
body::Body,
|
||||
extract::{Request, State},
|
||||
http::{HeaderMap, Method, StatusCode, Uri},
|
||||
response::{IntoResponse, Response},
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
@@ -47,10 +47,16 @@ async fn proxy_chat_completions(
|
||||
headers: HeaderMap,
|
||||
body: Body,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let target_url = format!("{}/v1/chat/completions", proxy_client.config.inference_url());
|
||||
|
||||
let target_url = format!(
|
||||
"{}/v1/chat/completions",
|
||||
proxy_client
|
||||
.config
|
||||
.inference_url()
|
||||
.expect("Invalid Configuration")
|
||||
);
|
||||
|
||||
tracing::info!("Proxying chat completions request to: {}", target_url);
|
||||
|
||||
|
||||
// Extract body as bytes
|
||||
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||
Ok(bytes) => bytes,
|
||||
@@ -63,7 +69,9 @@ async fn proxy_chat_completions(
|
||||
// Check if this is a streaming request
|
||||
let is_streaming = if let Ok(body_str) = String::from_utf8(body_bytes.to_vec()) {
|
||||
if let Ok(json) = serde_json::from_str::<Value>(&body_str) {
|
||||
json.get("stream").and_then(|v| v.as_bool()).unwrap_or(false)
|
||||
json.get("stream")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
@@ -72,7 +80,8 @@ async fn proxy_chat_completions(
|
||||
};
|
||||
|
||||
// Forward the request
|
||||
let mut req_builder = proxy_client.client
|
||||
let mut req_builder = proxy_client
|
||||
.client
|
||||
.post(&target_url)
|
||||
.body(body_bytes.to_vec());
|
||||
|
||||
@@ -85,8 +94,7 @@ async fn proxy_chat_completions(
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
let mut resp_builder = Response::builder()
|
||||
.status(response.status());
|
||||
let mut resp_builder = Response::builder().status(response.status());
|
||||
|
||||
// Forward response headers
|
||||
for (name, value) in response.headers().iter() {
|
||||
@@ -99,14 +107,12 @@ async fn proxy_chat_completions(
|
||||
if is_streaming {
|
||||
// For streaming, we need to forward the response as-is
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.header("content-type", "text/plain; charset=utf-8")
|
||||
.header("cache-control", "no-cache")
|
||||
.header("connection", "keep-alive")
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Ok(body) => resp_builder
|
||||
.header("content-type", "text/plain; charset=utf-8")
|
||||
.header("cache-control", "no-cache")
|
||||
.header("connection", "keep-alive")
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read streaming response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
@@ -115,11 +121,9 @@ async fn proxy_chat_completions(
|
||||
} else {
|
||||
// For non-streaming, forward the JSON response
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Ok(body) => resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
@@ -139,10 +143,16 @@ async fn proxy_models(
|
||||
State(proxy_client): State<ProxyClient>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let target_url = format!("{}/v1/models", proxy_client.config.inference_url());
|
||||
|
||||
let target_url = format!(
|
||||
"{}/v1/models",
|
||||
proxy_client
|
||||
.config
|
||||
.inference_url()
|
||||
.expect("Invalid Configuration Detected")
|
||||
);
|
||||
|
||||
tracing::info!("Proxying models request to: {}", target_url);
|
||||
|
||||
|
||||
let mut req_builder = proxy_client.client.get(&target_url);
|
||||
|
||||
// Forward relevant headers
|
||||
@@ -154,8 +164,7 @@ async fn proxy_models(
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
let mut resp_builder = Response::builder()
|
||||
.status(response.status());
|
||||
let mut resp_builder = Response::builder().status(response.status());
|
||||
|
||||
// Forward response headers
|
||||
for (name, value) in response.headers().iter() {
|
||||
@@ -165,11 +174,9 @@ async fn proxy_models(
|
||||
}
|
||||
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Ok(body) => resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read models response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
@@ -189,10 +196,16 @@ async fn proxy_embeddings(
|
||||
headers: HeaderMap,
|
||||
body: Body,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let target_url = format!("{}/v1/embeddings", proxy_client.config.embeddings_url());
|
||||
|
||||
let target_url = format!(
|
||||
"{}/v1/embeddings",
|
||||
proxy_client
|
||||
.config
|
||||
.embeddings_url()
|
||||
.expect("Invalid Configuration Detected")
|
||||
);
|
||||
|
||||
tracing::info!("Proxying embeddings request to: {}", target_url);
|
||||
|
||||
|
||||
// Extract body as bytes
|
||||
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||
Ok(bytes) => bytes,
|
||||
@@ -203,7 +216,8 @@ async fn proxy_embeddings(
|
||||
};
|
||||
|
||||
// Forward the request
|
||||
let mut req_builder = proxy_client.client
|
||||
let mut req_builder = proxy_client
|
||||
.client
|
||||
.post(&target_url)
|
||||
.body(body_bytes.to_vec());
|
||||
|
||||
@@ -216,8 +230,7 @@ async fn proxy_embeddings(
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
let mut resp_builder = Response::builder()
|
||||
.status(response.status());
|
||||
let mut resp_builder = Response::builder().status(response.status());
|
||||
|
||||
// Forward response headers
|
||||
for (name, value) in response.headers().iter() {
|
||||
@@ -227,11 +240,9 @@ async fn proxy_embeddings(
|
||||
}
|
||||
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Ok(body) => resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read embeddings response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
@@ -250,7 +261,7 @@ fn should_forward_header(header_name: &str) -> bool {
|
||||
match header_name.to_lowercase().as_str() {
|
||||
"content-type" | "content-length" | "authorization" | "user-agent" | "accept" => true,
|
||||
"host" | "connection" | "upgrade" => false, // Don't forward connection-specific headers
|
||||
_ => true, // Forward other headers by default
|
||||
_ => true, // Forward other headers by default
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,7 +270,7 @@ fn should_forward_response_header(header_name: &str) -> bool {
|
||||
match header_name.to_lowercase().as_str() {
|
||||
"content-type" | "content-length" | "cache-control" | "connection" => true,
|
||||
"server" | "date" => false, // Don't forward server-specific headers
|
||||
_ => true, // Forward other headers by default
|
||||
_ => true, // Forward other headers by default
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,14 +301,20 @@ mod tests {
|
||||
server_host: "127.0.0.1".to_string(),
|
||||
server_port: 8080,
|
||||
server_mode: ServerMode::HighAvailability,
|
||||
services: Services {
|
||||
inference_url: "http://test-inference:8080".to_string(),
|
||||
embeddings_url: "http://test-embeddings:8080".to_string(),
|
||||
},
|
||||
services: Some(Services {
|
||||
inference_url: Some("http://test-inference:8080".to_string()),
|
||||
embeddings_url: Some("http://test-embeddings:8080".to_string()),
|
||||
}),
|
||||
};
|
||||
|
||||
let proxy_client = ProxyClient::new(config);
|
||||
assert_eq!(proxy_client.config.inference_url(), "http://test-inference:8080");
|
||||
assert_eq!(proxy_client.config.embeddings_url(), "http://test-embeddings:8080");
|
||||
assert_eq!(
|
||||
proxy_client.config.inference_url().unwrap().as_str(),
|
||||
"http://test-inference:8080"
|
||||
);
|
||||
assert_eq!(
|
||||
proxy_client.config.embeddings_url().unwrap().as_str(),
|
||||
"http://test-embeddings:8080"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
19
crates/predict-otron-9000/src/standalone.rs
Normal file
19
crates/predict-otron-9000/src/standalone.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
use crate::config::ServerConfig;
|
||||
use axum::Router;
|
||||
use inference_engine::AppState;
|
||||
|
||||
pub fn create_standalone_router(server_config: ServerConfig) -> Router {
|
||||
// Create unified router by merging embeddings and inference routers (existing behavior)
|
||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||
|
||||
// Create AppState with correct model configuration
|
||||
let app_state = AppState::default();
|
||||
|
||||
// Get the inference router directly from the inference engine
|
||||
let inference_router = inference_engine::create_router(app_state);
|
||||
|
||||
// Merge the local routers
|
||||
Router::new()
|
||||
.merge(embeddings_router)
|
||||
.merge(inference_router)
|
||||
}
|
Reference in New Issue
Block a user