- Refactored build_pipeline usage to ensure pipeline arguments are cloned.

- Introduced `reset_state` for clearing cached state between requests.
- Enhanced chat UI with model selector and dynamic model fetching.
- Improved error logging and detailed debug messages for chat request flows.
- Added fresh instantiation of `TextGeneration` to prevent tensor shape mismatches.
This commit is contained in:
geoffsee
2025-08-27 17:53:50 -04:00
parent f1b57866e1
commit 766d41af78
5 changed files with 185 additions and 209 deletions

View File

@@ -36,15 +36,18 @@ use serde_json::Value;
pub struct AppState {
pub text_generation: Arc<Mutex<TextGeneration>>,
pub model_id: String,
// Store build args to recreate TextGeneration when needed
pub build_args: PipelineArgs,
}
impl Default for AppState {
fn default() -> Self {
let args = PipelineArgs::default();
let text_generation = build_pipeline(args);
let text_generation = build_pipeline(args.clone());
Self {
text_generation: Arc::new(Mutex::new(text_generation)),
model_id: String::new(),
build_args: args,
}
}
}
@@ -318,7 +321,7 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
pub async fn chat_completions(
State(state): State<AppState>,
Json(request): Json<ChatCompletionRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
) -> Result<impl IntoResponse, (StatusCode, String)> {
// If streaming was requested, this function shouldn't be called
// A separate route handles streaming requests
if !request.stream.unwrap_or(false) {
@@ -357,7 +360,11 @@ pub async fn chat_completions_non_streaming_proxy(state: AppState, request: Chat
// Generate
let mut output = Vec::new();
{
// Recreate TextGeneration instance to ensure completely fresh state
// This prevents KV cache persistence that causes tensor shape mismatches
let fresh_text_gen = build_pipeline(state.build_args.clone());
let mut text_gen = state.text_generation.lock().await;
*text_gen = fresh_text_gen;
let mut buffer = Vec::new();
let max_tokens = request.max_tokens.unwrap_or(1000);
@@ -456,7 +463,12 @@ async fn handle_streaming_request(
// Generate text using existing buffer-based approach
let mut buffer = Vec::new();
{
// Recreate TextGeneration instance to ensure completely fresh state
// This prevents KV cache persistence that causes tensor shape mismatches
let fresh_text_gen = build_pipeline(state.build_args.clone());
let mut text_gen = state.text_generation.lock().await;
*text_gen = fresh_text_gen;
let max_tokens = request.max_tokens.unwrap_or(1000);
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
@@ -752,10 +764,11 @@ mod tests {
println!("[DEBUG_LOG] Creating pipeline with model: {}", args.model_id);
// This should reproduce the same conditions as the curl script
let text_generation = build_pipeline(args);
let text_generation = build_pipeline(args.clone());
let app_state = AppState {
text_generation: Arc::new(Mutex::new(text_generation)),
model_id: "gemma-3-1b-it".to_string(),
build_args: args,
};
// Create the same request as the curl script