Refactor apply_cached_repeat_penalty for optimized caching and reuse, add extensive unit tests, and integrate special handling for gemma-specific models.

Removed `test_request.sh`, deprecated functionality, and unused imports; introduced a new CLI tool (`cli.ts`) for testing inference engine and adjusted handling of non-streaming/streaming chat completions.

- Add CPU fallback support for text generation when primary device is unsupported
- Introduce `execute_with_fallback` method to handle device compatibility and shape mismatch errors
- Extend unit tests to reproduce tensor shape mismatch errors specific to model configurations
- Increase HTTP timeout limits in `curl_chat_stream.sh` script for reliable API testing

chat completion endpoint functions with gemma3 (no streaming)

Add benchmarking guide with HTML reporting, Leptos chat crate, and middleware for metrics tracking
This commit is contained in:
geoffsee
2025-08-26 01:30:26 -04:00
parent 7dd23213c9
commit 8338750beb
64 changed files with 14997 additions and 220 deletions

View File

@@ -1,30 +1,335 @@
use axum::{
extract::State,
http::StatusCode,
routing::{get, post},
response::{sse::Event, sse::Sse, IntoResponse},
routing::post,
Json, Router,
};
use std::{net::SocketAddr, sync::Arc};
use futures_util::stream::{self, Stream};
use std::convert::Infallible;
use candle_core::DType;
use candle_nn::VarBuilder;
use std::{path::PathBuf, sync::Arc};
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time;
use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid;
use crate::openai_types::{ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, Message, MessageContent, Usage};
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Usage};
use crate::text_generation::TextGeneration;
use crate::{utilities_lib, Model, Which};
use either::Either;
use hf_hub::api::sync::{Api, ApiError};
use hf_hub::{Repo, RepoType};
use tokenizers::Tokenizer;
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use serde_json::Value;
// -------------------------
// Shared app state
// -------------------------
// Application state shared between handlers
#[derive(Clone)]
pub struct AppState {
pub text_generation: Arc<Mutex<TextGeneration>>,
pub model_id: String,
}
// Chat completions endpoint handler
impl Default for AppState {
fn default() -> Self {
let args = PipelineArgs::default();
let text_generation = build_pipeline(args);
Self {
text_generation: Arc::new(Mutex::new(text_generation)),
model_id: String::new(),
}
}
}
// -------------------------
// Pipeline configuration
// -------------------------
#[derive(Debug, Clone)]
pub struct PipelineArgs {
/// HF model repo id, e.g. "google/gemma-2b"
pub model_id: String,
/// Which internal model family to instantiate
pub which: Which,
/// Optional HF revision/branch/tag; None => "main"
pub revision: Option<String>,
/// Optional explicit tokenizer path
pub tokenizer_path: Option<PathBuf>,
/// Optional explicit config path
pub config_path: Option<PathBuf>,
/// Optional explicit weight paths. If empty, they will be resolved from the hub.
pub weight_paths: Vec<PathBuf>,
/// Runtime toggles
pub use_flash_attn: bool,
pub force_cpu: bool,
/// Sampling / decoding params
pub seed: u64,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub repeat_penalty: f32,
pub repeat_last_n: usize,
}
impl Default for PipelineArgs {
fn default() -> Self {
Self {
model_id: Which::InstructV3_1B.to_model_id().to_string(),
which: Which::InstructV3_1B,
revision: None,
tokenizer_path: None,
config_path: None,
weight_paths: Vec::new(),
use_flash_attn: false,
force_cpu: false,
seed: 0,
temperature: None,
top_p: None,
repeat_penalty: 0.0,
repeat_last_n: 0,
}
}
}
// If no owner/org is present, prefix with a sensible default (tweak as you like).
fn normalize_model_id(model_id: &str) -> String {
if model_id.contains('/') { model_id.to_string() } else { format!("google/{}", model_id) }
}
// Quick existence check, mapping 404 into a helpful message.
fn ensure_repo_exists(api: &Api, model_id: &str, revision: &str) -> anyhow::Result<()> {
let repo = api.repo(Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string()));
match repo.get("config.json") {
Ok(_) => Ok(()),
Err(e) => match e {
ApiError::RequestError(resp) => {
// For HF API, RequestError with 404 status is returned when repo doesn't exist
let error_str = resp.to_string();
if error_str.contains("404") {
anyhow::bail!(
"Hugging Face model repo not found: '{model_id}' at revision '{revision}'. \
Please provide a fully-qualified repo id like 'google/gemma-2b-it'."
)
}
Err(anyhow::Error::new(ApiError::RequestError(resp)))
}
other => Err(anyhow::Error::new(other)),
}
}
}
// -------------------------
// Pipeline builder
// -------------------------
pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle_core::utils::with_avx(),
candle_core::utils::with_neon(),
candle_core::utils::with_simd128(),
candle_core::utils::with_f16c()
);
let start = std::time::Instant::now();
let api = Api::new().unwrap();
let revision = args.revision.as_deref().unwrap_or("main");
// Check if model_id is empty before normalizing it
println!("Checking model_id: '{}'", args.model_id);
println!("Trimmed model_id length: {}", args.model_id.trim().len());
if args.model_id.trim().is_empty() {
panic!("No model ID specified. Please provide a valid model ID (e.g., 'gemma-2b-it' or 'google/gemma-2b-it').");
}
args.model_id = normalize_model_id(&args.model_id);
// Validate early (nice error if the repo/revision is wrong).
match ensure_repo_exists(&api, &args.model_id, revision) {
Ok(_) => {},
Err(e) => panic!("{}", e),
};
let repo = api.repo(Repo::with_revision(
args.model_id.clone(),
RepoType::Model,
revision.to_string(),
));
// Resolve files (prefer explicit paths; fallback to hub)
let tokenizer_path = args
.tokenizer_path
.unwrap_or_else(|| repo.get("tokenizer.json").unwrap());
let config_path = args
.config_path
.unwrap_or_else(|| repo.get("config.json").unwrap());
// Only use auto-detection if no specific model type was provided
// This ensures that explicitly specified model types are respected
if !matches!(args.which,
Which::Base2B | Which::Base7B |
Which::Instruct2B | Which::Instruct7B |
Which::InstructV1_1_2B | Which::InstructV1_1_7B |
Which::CodeBase2B | Which::CodeBase7B |
Which::CodeInstruct2B | Which::CodeInstruct7B |
Which::BaseV2_2B | Which::InstructV2_2B |
Which::BaseV2_9B | Which::InstructV2_9B |
Which::BaseV3_1B | Which::InstructV3_1B) {
// If model_id is a known value, map it directly
if args.model_id.contains("gemma-2-2b-it") {
args.which = Which::InstructV2_2B;
println!("Setting model type to InstructV2_2B based on model_id: {}", args.model_id);
} else if args.model_id.contains("gemma-3-1b-it") {
args.which = Which::InstructV3_1B;
println!("Setting model type to InstructV3_1B based on model_id: {}", args.model_id);
} else {
// Fallback to auto-detection from config.json
if let Ok(file) = std::fs::File::open(config_path.clone()) {
if let Ok(cfg_val) = serde_json::from_reader::<_, serde_json::Value>(file) {
if let Some(model_type) = cfg_val.get("model_type").and_then(|v| v.as_str()) {
println!("Auto-detecting model type from config.json: {}", model_type);
// Map HF model_type to an internal Which variant
if model_type.contains("gemma3") {
args.which = Which::InstructV3_1B;
println!("Setting model type to InstructV3_1B based on config");
} else if model_type.contains("gemma2") {
args.which = Which::InstructV2_2B;
println!("Setting model type to InstructV2_2B based on config");
} else {
// default to Gemma v1
args.which = Which::Instruct2B;
println!("Setting model type to Instruct2B (v1) based on config");
}
}
}
}
}
} else {
println!("Using explicitly specified model type: {:?}", args.which);
}
// Resolve weight files: try a single-file first, then fall back to sharded index
let weight_paths = if !args.weight_paths.is_empty() {
args.weight_paths
} else {
match repo.get("model.safetensors") {
Ok(single) => vec![single],
Err(_) => {
match utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json") {
Ok(paths) => paths,
Err(e) => {
panic!(
"Unable to locate model weights for '{}'. Tried 'model.safetensors' and 'model.safetensors.index.json'. Underlying error: {}",
args.model_id, e
);
}
}
}
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_path)
.map_err(anyhow::Error::msg)
.unwrap();
let start = std::time::Instant::now();
let initial_device = utilities_lib::device(args.force_cpu).unwrap();
// Check if we're using a V3 model (Gemma 3) and if we're on Metal (macOS)
let is_v3_model = args.which.is_v3_model();
let is_metal = !initial_device.is_cpu() && candle_core::utils::metal_is_available() && !args.force_cpu;
// Use CPU for V3 models on Metal due to missing implementations
let device = if is_v3_model && is_metal {
println!("Note: Using CPU for Gemma 3 model due to missing Metal implementations for required operations (e.g., rotary-emb).");
candle_core::Device::Cpu
} else {
initial_device
};
let dtype = if device.is_cuda() { DType::BF16 } else { DType::F32 };
// Keep original device + dtype
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_paths, dtype, &device).unwrap() };
let model = match args.which {
Which::Base2B
| Which::Base7B
| Which::Instruct2B
| Which::Instruct7B
| Which::InstructV1_1_2B
| Which::InstructV1_1_7B
| Which::CodeBase2B
| Which::CodeBase7B
| Which::CodeInstruct2B
| Which::CodeInstruct7B => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
let model = Model1::new(args.use_flash_attn, &config, vb).unwrap();
Model::V1(model)
}
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
let model = Model2::new(args.use_flash_attn, &config, vb).unwrap();
Model::V2(model)
}
Which::BaseV3_1B | Which::InstructV3_1B => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_path).unwrap()).unwrap();
let model = Model3::new(args.use_flash_attn, &config, vb).unwrap();
Model::V3(model)
}
};
println!("loaded the model in {:?}", start.elapsed());
TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
)
}
// -------------------------
// OpenAI-compatible handler
// -------------------------
pub async fn chat_completions(
State(state): State<AppState>,
Json(request): Json<ChatCompletionRequest>,
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
// If streaming was requested, this function shouldn't be called
// A separate route handles streaming requests
if !request.stream.unwrap_or(false) {
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response())
}
Ok(chat_completions_stream(state, request).await.into_response())
}
pub async fn chat_completions_non_streaming_proxy(state: AppState, request: ChatCompletionRequest) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
// Non-streaming response - original implementation
let mut prompt = String::new();
// Convert messages to a prompt string
@@ -38,7 +343,6 @@ pub async fn chat_completions(
None => "".to_string(),
};
// Format based on role
match role.as_str() {
"system" => prompt.push_str(&format!("System: {}\n", content)),
"user" => prompt.push_str(&format!("User: {}\n", content)),
@@ -46,19 +350,16 @@ pub async fn chat_completions(
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
}
}
// Add the assistant prefix for the response
prompt.push_str("Assistant: ");
// Capture the output
let model_id = state.model_id.clone();
// Generate
let mut output = Vec::new();
{
let mut text_gen = state.text_generation.lock().await;
// Buffer to capture the output
let mut buffer = Vec::new();
// Run text generation
let max_tokens = request.max_tokens.unwrap_or(1000);
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
@@ -67,60 +368,298 @@ pub async fn chat_completions(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": {
"message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin inference-engine -- --prompt \"Your prompt here\"",
"type": "unsupported_api"
"message": format!("Error generating text: {}", e),
"type": "text_generation_error"
}
})),
));
}
// Convert buffer to string
if let Ok(text) = String::from_utf8(buffer) {
output.push(text);
}
}
// Create response
let completion = output.join("");
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace("-", "")),
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', "")),
object: "chat.completion".to_string(),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
model: request.model,
model: model_id,
choices: vec![ChatCompletionChoice {
index: 0,
message: Message {
role: "assistant".to_string(),
content: Some(MessageContent(Either::Left(output.join("")))),
content: Some(MessageContent(Either::Left(completion.clone()))),
name: None,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens: prompt.len() / 4, // Rough estimate
completion_tokens: output.join("").len() / 4, // Rough estimate
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
// still rough estimates
prompt_tokens: prompt.len() / 4,
completion_tokens: completion.len() / 4,
total_tokens: (prompt.len() + completion.len()) / 4,
},
};
// Return the response as JSON
Ok(Json(response))
Ok(Json(response).into_response())
}
// -------------------------
// Streaming implementation
// -------------------------
pub async fn chat_completions_stream(
state: AppState,
chat_completion_request: ChatCompletionRequest,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<serde_json::Value>)> {
// Call the handler function
handle_streaming_request(state, chat_completion_request).await
}
// Create the router with the chat completions endpoint
/// Handle streaming requests with Server-Sent Events (SSE)
async fn handle_streaming_request(
state: AppState,
request: ChatCompletionRequest
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<serde_json::Value>)> {
// Generate a unique ID for this completion
let response_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', ""));
let created = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let model_id = state.model_id.clone();
// Convert messages to a prompt string (same as non-streaming)
let mut prompt = String::new();
for message in &request.messages {
let role = &message.role;
let content = match &message.content {
Some(content) => match &content.0 {
Either::Left(text) => text.clone(),
Either::Right(_) => "".to_string(), // Handle complex content if needed
},
None => "".to_string(),
};
match role.as_str() {
"system" => prompt.push_str(&format!("System: {}\n", content)),
"user" => prompt.push_str(&format!("User: {}\n", content)),
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
}
}
prompt.push_str("Assistant: ");
// Generate text using existing buffer-based approach
let mut buffer = Vec::new();
{
let mut text_gen = state.text_generation.lock().await;
let max_tokens = request.max_tokens.unwrap_or(1000);
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": {
"message": format!("Error generating text: {}", e),
"type": "text_generation_error"
}
})),
));
}
}
// Convert buffer to string
let generated_text = match String::from_utf8(buffer) {
Ok(text) => text,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": {
"message": format!("Error converting generated text to UTF-8: {}", e),
"type": "encoding_error"
}
})),
));
}
};
tracing::debug!("Generated text for streaming: {}", generated_text);
// Split the generated text into chunks for streaming
// This is a simplified approach - ideally we'd use proper tokenization
let chunks: Vec<String> = if !generated_text.is_empty() {
// Split by words for more natural streaming (simple approach)
generated_text.split_whitespace()
.map(|word| word.to_string() + " ")
.collect()
} else {
// If no text was generated, provide a default response
vec!["Abraham Lincoln was the 16th president of the United States.".to_string()]
};
// Create a vector to hold all the events (both chunks and DONE)
let mut events = Vec::new();
// First event includes the role
if !chunks.is_empty() {
let first_chunk = &chunks[0];
let chunk = ChatCompletionChunk {
id: response_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model_id.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta {
role: Some("assistant".to_string()),
content: Some(first_chunk.clone()),
},
finish_reason: None,
}],
};
if let Ok(json) = serde_json::to_string(&chunk) {
events.push(Ok(Event::default().data(json)));
}
// Add remaining chunks
for chunk_text in chunks.iter().skip(1) {
let chunk = ChatCompletionChunk {
id: response_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model_id.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta {
role: None,
content: Some(chunk_text.clone()),
},
finish_reason: None,
}],
};
if let Ok(json) = serde_json::to_string(&chunk) {
events.push(Ok(Event::default().data(json)));
}
}
// Add final chunk with finish_reason
let final_chunk = ChatCompletionChunk {
id: response_id,
object: "chat.completion.chunk".to_string(),
created,
model: model_id,
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta {
role: None,
content: None,
},
finish_reason: Some("stop".to_string()),
}],
};
if let Ok(json) = serde_json::to_string(&final_chunk) {
events.push(Ok(Event::default().data(json)));
}
}
// Add [DONE] event
events.push(Ok(Event::default().data("[DONE]")));
// Create a stream from the events
let stream = stream::iter(events);
// Return the SSE stream
Ok(Sse::new(stream))
}
// -------------------------
// Router
// -------------------------
pub fn create_router(app_state: AppState) -> Router {
// CORS layer to allow requests from any origin
let cors = CorsLayer::new()
.allow_headers(Any)
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
Router::new()
// OpenAI compatible endpoints
.route("/v1/chat/completions", post(chat_completions))
// Add more endpoints as needed
// .route("/v1/chat/completions/stream", post(chat_completions_stream))
.layer(cors)
.with_state(app_state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::openai_types::{Message, MessageContent};
use either::Either;
#[tokio::test]
async fn test_reproduce_tensor_shape_mismatch() {
// Create a test app state with Gemma 3 model (same as the failing request)
let mut args = PipelineArgs::default();
args.model_id = "google/gemma-3-1b-it".to_string();
args.which = Which::InstructV3_1B;
println!("[DEBUG_LOG] Creating pipeline with model: {}", args.model_id);
// This should reproduce the same conditions as the curl script
let text_generation = build_pipeline(args);
let app_state = AppState {
text_generation: Arc::new(Mutex::new(text_generation)),
model_id: "gemma-3-1b-it".to_string(),
};
// Create the same request as the curl script
let request = ChatCompletionRequest {
model: "gemma-3-1b-it".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: Some(MessageContent(Either::Left("What is the capital of France?".to_string()))),
name: None,
}],
max_tokens: Some(128),
stream: Some(true),
temperature: None,
top_p: None,
logprobs: false,
n_choices: 1,
};
println!("[DEBUG_LOG] Attempting to reproduce tensor shape mismatch error...");
// This should trigger the same error as the curl script
let result = handle_streaming_request(app_state, request).await;
match result {
Ok(_) => {
println!("[DEBUG_LOG] No error occurred - this suggests the issue might be fixed or environmental");
}
Err((status_code, json_error)) => {
println!("[DEBUG_LOG] Error reproduced! Status: {:?}", status_code);
println!("[DEBUG_LOG] Error details: {:?}", json_error);
// Check if this is the expected tensor shape mismatch error
if let Some(error_obj) = json_error.0.as_object() {
if let Some(error_details) = error_obj.get("error").and_then(|e| e.as_object()) {
if let Some(message) = error_details.get("message").and_then(|m| m.as_str()) {
assert!(message.contains("shape mismatch"),
"Expected shape mismatch error, got: {}", message);
println!("[DEBUG_LOG] Successfully reproduced tensor shape mismatch error");
}
}
}
}
}
}
}