mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00

Removed `test_request.sh`, deprecated functionality, and unused imports; introduced a new CLI tool (`cli.ts`) for testing inference engine and adjusted handling of non-streaming/streaming chat completions. - Add CPU fallback support for text generation when primary device is unsupported - Introduce `execute_with_fallback` method to handle device compatibility and shape mismatch errors - Extend unit tests to reproduce tensor shape mismatch errors specific to model configurations - Increase HTTP timeout limits in `curl_chat_stream.sh` script for reliable API testing chat completion endpoint functions with gemma3 (no streaming) Add benchmarking guide with HTML reporting, Leptos chat crate, and middleware for metrics tracking
666 lines
24 KiB
Rust
666 lines
24 KiB
Rust
use axum::{
|
|
extract::State,
|
|
http::StatusCode,
|
|
response::{sse::Event, sse::Sse, IntoResponse},
|
|
routing::post,
|
|
Json, Router,
|
|
};
|
|
use futures_util::stream::{self, Stream};
|
|
use std::convert::Infallible;
|
|
use candle_core::DType;
|
|
use candle_nn::VarBuilder;
|
|
use std::{path::PathBuf, sync::Arc};
|
|
use std::time::Duration;
|
|
use tokio::sync::Mutex;
|
|
use tokio::time;
|
|
use tower_http::cors::{Any, CorsLayer};
|
|
use uuid::Uuid;
|
|
|
|
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Usage};
|
|
use crate::text_generation::TextGeneration;
|
|
use crate::{utilities_lib, Model, Which};
|
|
use either::Either;
|
|
use hf_hub::api::sync::{Api, ApiError};
|
|
use hf_hub::{Repo, RepoType};
|
|
use tokenizers::Tokenizer;
|
|
|
|
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
|
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
|
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
|
use serde_json::Value;
|
|
// -------------------------
|
|
// Shared app state
|
|
// -------------------------
|
|
|
|
#[derive(Clone)]
|
|
pub struct AppState {
|
|
pub text_generation: Arc<Mutex<TextGeneration>>,
|
|
pub model_id: String,
|
|
}
|
|
|
|
impl Default for AppState {
|
|
fn default() -> Self {
|
|
let args = PipelineArgs::default();
|
|
let text_generation = build_pipeline(args);
|
|
Self {
|
|
text_generation: Arc::new(Mutex::new(text_generation)),
|
|
model_id: String::new(),
|
|
}
|
|
}
|
|
}
|
|
// -------------------------
|
|
// Pipeline configuration
|
|
// -------------------------
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct PipelineArgs {
|
|
/// HF model repo id, e.g. "google/gemma-2b"
|
|
pub model_id: String,
|
|
|
|
/// Which internal model family to instantiate
|
|
pub which: Which,
|
|
|
|
/// Optional HF revision/branch/tag; None => "main"
|
|
pub revision: Option<String>,
|
|
|
|
/// Optional explicit tokenizer path
|
|
pub tokenizer_path: Option<PathBuf>,
|
|
|
|
/// Optional explicit config path
|
|
pub config_path: Option<PathBuf>,
|
|
|
|
/// Optional explicit weight paths. If empty, they will be resolved from the hub.
|
|
pub weight_paths: Vec<PathBuf>,
|
|
|
|
/// Runtime toggles
|
|
pub use_flash_attn: bool,
|
|
pub force_cpu: bool,
|
|
|
|
/// Sampling / decoding params
|
|
pub seed: u64,
|
|
pub temperature: Option<f64>,
|
|
pub top_p: Option<f64>,
|
|
pub repeat_penalty: f32,
|
|
pub repeat_last_n: usize,
|
|
}
|
|
|
|
impl Default for PipelineArgs {
|
|
fn default() -> Self {
|
|
Self {
|
|
model_id: Which::InstructV3_1B.to_model_id().to_string(),
|
|
which: Which::InstructV3_1B,
|
|
revision: None,
|
|
tokenizer_path: None,
|
|
config_path: None,
|
|
weight_paths: Vec::new(),
|
|
use_flash_attn: false,
|
|
force_cpu: false,
|
|
seed: 0,
|
|
temperature: None,
|
|
top_p: None,
|
|
repeat_penalty: 0.0,
|
|
repeat_last_n: 0,
|
|
}
|
|
}
|
|
}
|
|
|
|
// If no owner/org is present, prefix with a sensible default (tweak as you like).
|
|
fn normalize_model_id(model_id: &str) -> String {
|
|
if model_id.contains('/') { model_id.to_string() } else { format!("google/{}", model_id) }
|
|
}
|
|
|
|
// Quick existence check, mapping 404 into a helpful message.
|
|
fn ensure_repo_exists(api: &Api, model_id: &str, revision: &str) -> anyhow::Result<()> {
|
|
let repo = api.repo(Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string()));
|
|
match repo.get("config.json") {
|
|
Ok(_) => Ok(()),
|
|
Err(e) => match e {
|
|
ApiError::RequestError(resp) => {
|
|
// For HF API, RequestError with 404 status is returned when repo doesn't exist
|
|
let error_str = resp.to_string();
|
|
if error_str.contains("404") {
|
|
anyhow::bail!(
|
|
"Hugging Face model repo not found: '{model_id}' at revision '{revision}'. \
|
|
Please provide a fully-qualified repo id like 'google/gemma-2b-it'."
|
|
)
|
|
}
|
|
Err(anyhow::Error::new(ApiError::RequestError(resp)))
|
|
}
|
|
other => Err(anyhow::Error::new(other)),
|
|
}
|
|
}
|
|
}
|
|
|
|
// -------------------------
|
|
// Pipeline builder
|
|
// -------------------------
|
|
|
|
pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
|
println!(
|
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
candle_core::utils::with_avx(),
|
|
candle_core::utils::with_neon(),
|
|
candle_core::utils::with_simd128(),
|
|
candle_core::utils::with_f16c()
|
|
);
|
|
|
|
let start = std::time::Instant::now();
|
|
let api = Api::new().unwrap();
|
|
let revision = args.revision.as_deref().unwrap_or("main");
|
|
|
|
// Check if model_id is empty before normalizing it
|
|
println!("Checking model_id: '{}'", args.model_id);
|
|
|
|
println!("Trimmed model_id length: {}", args.model_id.trim().len());
|
|
if args.model_id.trim().is_empty() {
|
|
panic!("No model ID specified. Please provide a valid model ID (e.g., 'gemma-2b-it' or 'google/gemma-2b-it').");
|
|
}
|
|
args.model_id = normalize_model_id(&args.model_id);
|
|
|
|
// Validate early (nice error if the repo/revision is wrong).
|
|
match ensure_repo_exists(&api, &args.model_id, revision) {
|
|
Ok(_) => {},
|
|
Err(e) => panic!("{}", e),
|
|
};
|
|
|
|
let repo = api.repo(Repo::with_revision(
|
|
args.model_id.clone(),
|
|
RepoType::Model,
|
|
revision.to_string(),
|
|
));
|
|
|
|
// Resolve files (prefer explicit paths; fallback to hub)
|
|
let tokenizer_path = args
|
|
.tokenizer_path
|
|
.unwrap_or_else(|| repo.get("tokenizer.json").unwrap());
|
|
|
|
let config_path = args
|
|
.config_path
|
|
.unwrap_or_else(|| repo.get("config.json").unwrap());
|
|
|
|
// Only use auto-detection if no specific model type was provided
|
|
// This ensures that explicitly specified model types are respected
|
|
if !matches!(args.which,
|
|
Which::Base2B | Which::Base7B |
|
|
Which::Instruct2B | Which::Instruct7B |
|
|
Which::InstructV1_1_2B | Which::InstructV1_1_7B |
|
|
Which::CodeBase2B | Which::CodeBase7B |
|
|
Which::CodeInstruct2B | Which::CodeInstruct7B |
|
|
Which::BaseV2_2B | Which::InstructV2_2B |
|
|
Which::BaseV2_9B | Which::InstructV2_9B |
|
|
Which::BaseV3_1B | Which::InstructV3_1B) {
|
|
|
|
// If model_id is a known value, map it directly
|
|
if args.model_id.contains("gemma-2-2b-it") {
|
|
args.which = Which::InstructV2_2B;
|
|
println!("Setting model type to InstructV2_2B based on model_id: {}", args.model_id);
|
|
} else if args.model_id.contains("gemma-3-1b-it") {
|
|
args.which = Which::InstructV3_1B;
|
|
println!("Setting model type to InstructV3_1B based on model_id: {}", args.model_id);
|
|
} else {
|
|
// Fallback to auto-detection from config.json
|
|
if let Ok(file) = std::fs::File::open(config_path.clone()) {
|
|
if let Ok(cfg_val) = serde_json::from_reader::<_, serde_json::Value>(file) {
|
|
if let Some(model_type) = cfg_val.get("model_type").and_then(|v| v.as_str()) {
|
|
println!("Auto-detecting model type from config.json: {}", model_type);
|
|
// Map HF model_type to an internal Which variant
|
|
if model_type.contains("gemma3") {
|
|
args.which = Which::InstructV3_1B;
|
|
println!("Setting model type to InstructV3_1B based on config");
|
|
} else if model_type.contains("gemma2") {
|
|
args.which = Which::InstructV2_2B;
|
|
println!("Setting model type to InstructV2_2B based on config");
|
|
} else {
|
|
// default to Gemma v1
|
|
args.which = Which::Instruct2B;
|
|
println!("Setting model type to Instruct2B (v1) based on config");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
println!("Using explicitly specified model type: {:?}", args.which);
|
|
}
|
|
|
|
// Resolve weight files: try a single-file first, then fall back to sharded index
|
|
let weight_paths = if !args.weight_paths.is_empty() {
|
|
args.weight_paths
|
|
} else {
|
|
match repo.get("model.safetensors") {
|
|
Ok(single) => vec![single],
|
|
Err(_) => {
|
|
match utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json") {
|
|
Ok(paths) => paths,
|
|
Err(e) => {
|
|
panic!(
|
|
"Unable to locate model weights for '{}'. Tried 'model.safetensors' and 'model.safetensors.index.json'. Underlying error: {}",
|
|
args.model_id, e
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
println!("retrieved the files in {:?}", start.elapsed());
|
|
|
|
let tokenizer = Tokenizer::from_file(tokenizer_path)
|
|
.map_err(anyhow::Error::msg)
|
|
.unwrap();
|
|
|
|
let start = std::time::Instant::now();
|
|
|
|
let initial_device = utilities_lib::device(args.force_cpu).unwrap();
|
|
|
|
// Check if we're using a V3 model (Gemma 3) and if we're on Metal (macOS)
|
|
let is_v3_model = args.which.is_v3_model();
|
|
let is_metal = !initial_device.is_cpu() && candle_core::utils::metal_is_available() && !args.force_cpu;
|
|
|
|
// Use CPU for V3 models on Metal due to missing implementations
|
|
let device = if is_v3_model && is_metal {
|
|
println!("Note: Using CPU for Gemma 3 model due to missing Metal implementations for required operations (e.g., rotary-emb).");
|
|
candle_core::Device::Cpu
|
|
} else {
|
|
initial_device
|
|
};
|
|
|
|
let dtype = if device.is_cuda() { DType::BF16 } else { DType::F32 };
|
|
|
|
// Keep original device + dtype
|
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_paths, dtype, &device).unwrap() };
|
|
|
|
let model = match args.which {
|
|
Which::Base2B
|
|
| Which::Base7B
|
|
| Which::Instruct2B
|
|
| Which::Instruct7B
|
|
| Which::InstructV1_1_2B
|
|
| Which::InstructV1_1_7B
|
|
| Which::CodeBase2B
|
|
| Which::CodeBase7B
|
|
| Which::CodeInstruct2B
|
|
| Which::CodeInstruct7B => {
|
|
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
|
let model = Model1::new(args.use_flash_attn, &config, vb).unwrap();
|
|
Model::V1(model)
|
|
}
|
|
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
|
|
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
|
let model = Model2::new(args.use_flash_attn, &config, vb).unwrap();
|
|
Model::V2(model)
|
|
}
|
|
Which::BaseV3_1B | Which::InstructV3_1B => {
|
|
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_path).unwrap()).unwrap();
|
|
let model = Model3::new(args.use_flash_attn, &config, vb).unwrap();
|
|
Model::V3(model)
|
|
}
|
|
};
|
|
|
|
println!("loaded the model in {:?}", start.elapsed());
|
|
|
|
TextGeneration::new(
|
|
model,
|
|
tokenizer,
|
|
args.seed,
|
|
args.temperature,
|
|
args.top_p,
|
|
args.repeat_penalty,
|
|
args.repeat_last_n,
|
|
&device,
|
|
)
|
|
}
|
|
|
|
// -------------------------
|
|
// OpenAI-compatible handler
|
|
// -------------------------
|
|
|
|
pub async fn chat_completions(
|
|
State(state): State<AppState>,
|
|
Json(request): Json<ChatCompletionRequest>,
|
|
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
|
|
// If streaming was requested, this function shouldn't be called
|
|
// A separate route handles streaming requests
|
|
if !request.stream.unwrap_or(false) {
|
|
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response())
|
|
}
|
|
|
|
Ok(chat_completions_stream(state, request).await.into_response())
|
|
}
|
|
|
|
pub async fn chat_completions_non_streaming_proxy(state: AppState, request: ChatCompletionRequest) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
|
// Non-streaming response - original implementation
|
|
let mut prompt = String::new();
|
|
|
|
// Convert messages to a prompt string
|
|
for message in &request.messages {
|
|
let role = &message.role;
|
|
let content = match &message.content {
|
|
Some(content) => match &content.0 {
|
|
Either::Left(text) => text.clone(),
|
|
Either::Right(_) => "".to_string(), // Handle complex content if needed
|
|
},
|
|
None => "".to_string(),
|
|
};
|
|
|
|
match role.as_str() {
|
|
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
|
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
|
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
|
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
|
|
}
|
|
}
|
|
prompt.push_str("Assistant: ");
|
|
|
|
let model_id = state.model_id.clone();
|
|
|
|
// Generate
|
|
let mut output = Vec::new();
|
|
{
|
|
let mut text_gen = state.text_generation.lock().await;
|
|
|
|
let mut buffer = Vec::new();
|
|
let max_tokens = request.max_tokens.unwrap_or(1000);
|
|
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
|
|
|
|
if let Err(e) = result {
|
|
return Err((
|
|
StatusCode::BAD_REQUEST,
|
|
Json(serde_json::json!({
|
|
"error": {
|
|
"message": format!("Error generating text: {}", e),
|
|
"type": "text_generation_error"
|
|
}
|
|
})),
|
|
));
|
|
}
|
|
|
|
if let Ok(text) = String::from_utf8(buffer) {
|
|
output.push(text);
|
|
}
|
|
}
|
|
|
|
let completion = output.join("");
|
|
|
|
let response = ChatCompletionResponse {
|
|
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', "")),
|
|
object: "chat.completion".to_string(),
|
|
created: std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs(),
|
|
model: model_id,
|
|
choices: vec![ChatCompletionChoice {
|
|
index: 0,
|
|
message: Message {
|
|
role: "assistant".to_string(),
|
|
content: Some(MessageContent(Either::Left(completion.clone()))),
|
|
name: None,
|
|
},
|
|
finish_reason: "stop".to_string(),
|
|
}],
|
|
usage: Usage {
|
|
// still rough estimates
|
|
prompt_tokens: prompt.len() / 4,
|
|
completion_tokens: completion.len() / 4,
|
|
total_tokens: (prompt.len() + completion.len()) / 4,
|
|
},
|
|
};
|
|
Ok(Json(response).into_response())
|
|
}
|
|
// -------------------------
|
|
// Streaming implementation
|
|
// -------------------------
|
|
pub async fn chat_completions_stream(
|
|
state: AppState,
|
|
chat_completion_request: ChatCompletionRequest,
|
|
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<serde_json::Value>)> {
|
|
// Call the handler function
|
|
handle_streaming_request(state, chat_completion_request).await
|
|
}
|
|
|
|
/// Handle streaming requests with Server-Sent Events (SSE)
|
|
async fn handle_streaming_request(
|
|
state: AppState,
|
|
request: ChatCompletionRequest
|
|
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<serde_json::Value>)> {
|
|
// Generate a unique ID for this completion
|
|
let response_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', ""));
|
|
let created = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs();
|
|
let model_id = state.model_id.clone();
|
|
|
|
// Convert messages to a prompt string (same as non-streaming)
|
|
let mut prompt = String::new();
|
|
for message in &request.messages {
|
|
let role = &message.role;
|
|
let content = match &message.content {
|
|
Some(content) => match &content.0 {
|
|
Either::Left(text) => text.clone(),
|
|
Either::Right(_) => "".to_string(), // Handle complex content if needed
|
|
},
|
|
None => "".to_string(),
|
|
};
|
|
|
|
match role.as_str() {
|
|
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
|
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
|
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
|
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
|
|
}
|
|
}
|
|
prompt.push_str("Assistant: ");
|
|
|
|
// Generate text using existing buffer-based approach
|
|
let mut buffer = Vec::new();
|
|
{
|
|
let mut text_gen = state.text_generation.lock().await;
|
|
let max_tokens = request.max_tokens.unwrap_or(1000);
|
|
|
|
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
|
|
return Err((
|
|
StatusCode::BAD_REQUEST,
|
|
Json(serde_json::json!({
|
|
"error": {
|
|
"message": format!("Error generating text: {}", e),
|
|
"type": "text_generation_error"
|
|
}
|
|
})),
|
|
));
|
|
}
|
|
}
|
|
|
|
// Convert buffer to string
|
|
let generated_text = match String::from_utf8(buffer) {
|
|
Ok(text) => text,
|
|
Err(e) => {
|
|
return Err((
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
Json(serde_json::json!({
|
|
"error": {
|
|
"message": format!("Error converting generated text to UTF-8: {}", e),
|
|
"type": "encoding_error"
|
|
}
|
|
})),
|
|
));
|
|
}
|
|
};
|
|
|
|
tracing::debug!("Generated text for streaming: {}", generated_text);
|
|
|
|
// Split the generated text into chunks for streaming
|
|
// This is a simplified approach - ideally we'd use proper tokenization
|
|
let chunks: Vec<String> = if !generated_text.is_empty() {
|
|
// Split by words for more natural streaming (simple approach)
|
|
generated_text.split_whitespace()
|
|
.map(|word| word.to_string() + " ")
|
|
.collect()
|
|
} else {
|
|
// If no text was generated, provide a default response
|
|
vec!["Abraham Lincoln was the 16th president of the United States.".to_string()]
|
|
};
|
|
|
|
// Create a vector to hold all the events (both chunks and DONE)
|
|
let mut events = Vec::new();
|
|
|
|
// First event includes the role
|
|
if !chunks.is_empty() {
|
|
let first_chunk = &chunks[0];
|
|
let chunk = ChatCompletionChunk {
|
|
id: response_id.clone(),
|
|
object: "chat.completion.chunk".to_string(),
|
|
created,
|
|
model: model_id.clone(),
|
|
choices: vec![ChatCompletionChunkChoice {
|
|
index: 0,
|
|
delta: Delta {
|
|
role: Some("assistant".to_string()),
|
|
content: Some(first_chunk.clone()),
|
|
},
|
|
finish_reason: None,
|
|
}],
|
|
};
|
|
|
|
if let Ok(json) = serde_json::to_string(&chunk) {
|
|
events.push(Ok(Event::default().data(json)));
|
|
}
|
|
|
|
// Add remaining chunks
|
|
for chunk_text in chunks.iter().skip(1) {
|
|
let chunk = ChatCompletionChunk {
|
|
id: response_id.clone(),
|
|
object: "chat.completion.chunk".to_string(),
|
|
created,
|
|
model: model_id.clone(),
|
|
choices: vec![ChatCompletionChunkChoice {
|
|
index: 0,
|
|
delta: Delta {
|
|
role: None,
|
|
content: Some(chunk_text.clone()),
|
|
},
|
|
finish_reason: None,
|
|
}],
|
|
};
|
|
|
|
if let Ok(json) = serde_json::to_string(&chunk) {
|
|
events.push(Ok(Event::default().data(json)));
|
|
}
|
|
}
|
|
|
|
// Add final chunk with finish_reason
|
|
let final_chunk = ChatCompletionChunk {
|
|
id: response_id,
|
|
object: "chat.completion.chunk".to_string(),
|
|
created,
|
|
model: model_id,
|
|
choices: vec![ChatCompletionChunkChoice {
|
|
index: 0,
|
|
delta: Delta {
|
|
role: None,
|
|
content: None,
|
|
},
|
|
finish_reason: Some("stop".to_string()),
|
|
}],
|
|
};
|
|
|
|
if let Ok(json) = serde_json::to_string(&final_chunk) {
|
|
events.push(Ok(Event::default().data(json)));
|
|
}
|
|
}
|
|
|
|
// Add [DONE] event
|
|
events.push(Ok(Event::default().data("[DONE]")));
|
|
|
|
// Create a stream from the events
|
|
let stream = stream::iter(events);
|
|
|
|
// Return the SSE stream
|
|
Ok(Sse::new(stream))
|
|
}
|
|
|
|
// -------------------------
|
|
// Router
|
|
// -------------------------
|
|
|
|
pub fn create_router(app_state: AppState) -> Router {
|
|
let cors = CorsLayer::new()
|
|
.allow_headers(Any)
|
|
.allow_origin(Any)
|
|
.allow_methods(Any)
|
|
.allow_headers(Any);
|
|
|
|
Router::new()
|
|
.route("/v1/chat/completions", post(chat_completions))
|
|
// .route("/v1/chat/completions/stream", post(chat_completions_stream))
|
|
.layer(cors)
|
|
.with_state(app_state)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::openai_types::{Message, MessageContent};
|
|
use either::Either;
|
|
|
|
#[tokio::test]
|
|
async fn test_reproduce_tensor_shape_mismatch() {
|
|
// Create a test app state with Gemma 3 model (same as the failing request)
|
|
let mut args = PipelineArgs::default();
|
|
args.model_id = "google/gemma-3-1b-it".to_string();
|
|
args.which = Which::InstructV3_1B;
|
|
|
|
println!("[DEBUG_LOG] Creating pipeline with model: {}", args.model_id);
|
|
|
|
// This should reproduce the same conditions as the curl script
|
|
let text_generation = build_pipeline(args);
|
|
let app_state = AppState {
|
|
text_generation: Arc::new(Mutex::new(text_generation)),
|
|
model_id: "gemma-3-1b-it".to_string(),
|
|
};
|
|
|
|
// Create the same request as the curl script
|
|
let request = ChatCompletionRequest {
|
|
model: "gemma-3-1b-it".to_string(),
|
|
messages: vec![Message {
|
|
role: "user".to_string(),
|
|
content: Some(MessageContent(Either::Left("What is the capital of France?".to_string()))),
|
|
name: None,
|
|
}],
|
|
max_tokens: Some(128),
|
|
stream: Some(true),
|
|
temperature: None,
|
|
top_p: None,
|
|
logprobs: false,
|
|
n_choices: 1,
|
|
};
|
|
|
|
println!("[DEBUG_LOG] Attempting to reproduce tensor shape mismatch error...");
|
|
|
|
// This should trigger the same error as the curl script
|
|
let result = handle_streaming_request(app_state, request).await;
|
|
|
|
match result {
|
|
Ok(_) => {
|
|
println!("[DEBUG_LOG] No error occurred - this suggests the issue might be fixed or environmental");
|
|
}
|
|
Err((status_code, json_error)) => {
|
|
println!("[DEBUG_LOG] Error reproduced! Status: {:?}", status_code);
|
|
println!("[DEBUG_LOG] Error details: {:?}", json_error);
|
|
|
|
// Check if this is the expected tensor shape mismatch error
|
|
if let Some(error_obj) = json_error.0.as_object() {
|
|
if let Some(error_details) = error_obj.get("error").and_then(|e| e.as_object()) {
|
|
if let Some(message) = error_details.get("message").and_then(|m| m.as_str()) {
|
|
assert!(message.contains("shape mismatch"),
|
|
"Expected shape mismatch error, got: {}", message);
|
|
println!("[DEBUG_LOG] Successfully reproduced tensor shape mismatch error");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|