mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
- Refactored build_pipeline
usage to ensure pipeline arguments are cloned.
- Introduced `reset_state` for clearing cached state between requests. - Enhanced chat UI with model selector and dynamic model fetching. - Improved error logging and detailed debug messages for chat request flows. - Added fresh instantiation of `TextGeneration` to prevent tensor shape mismatches.
This commit is contained in:
@@ -36,15 +36,18 @@ use serde_json::Value;
|
||||
pub struct AppState {
|
||||
pub text_generation: Arc<Mutex<TextGeneration>>,
|
||||
pub model_id: String,
|
||||
// Store build args to recreate TextGeneration when needed
|
||||
pub build_args: PipelineArgs,
|
||||
}
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
let args = PipelineArgs::default();
|
||||
let text_generation = build_pipeline(args);
|
||||
let text_generation = build_pipeline(args.clone());
|
||||
Self {
|
||||
text_generation: Arc::new(Mutex::new(text_generation)),
|
||||
model_id: String::new(),
|
||||
build_args: args,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -318,7 +321,7 @@ pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
||||
pub async fn chat_completions(
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<ChatCompletionRequest>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
// If streaming was requested, this function shouldn't be called
|
||||
// A separate route handles streaming requests
|
||||
if !request.stream.unwrap_or(false) {
|
||||
@@ -357,7 +360,11 @@ pub async fn chat_completions_non_streaming_proxy(state: AppState, request: Chat
|
||||
// Generate
|
||||
let mut output = Vec::new();
|
||||
{
|
||||
// Recreate TextGeneration instance to ensure completely fresh state
|
||||
// This prevents KV cache persistence that causes tensor shape mismatches
|
||||
let fresh_text_gen = build_pipeline(state.build_args.clone());
|
||||
let mut text_gen = state.text_generation.lock().await;
|
||||
*text_gen = fresh_text_gen;
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
@@ -456,7 +463,12 @@ async fn handle_streaming_request(
|
||||
// Generate text using existing buffer-based approach
|
||||
let mut buffer = Vec::new();
|
||||
{
|
||||
// Recreate TextGeneration instance to ensure completely fresh state
|
||||
// This prevents KV cache persistence that causes tensor shape mismatches
|
||||
let fresh_text_gen = build_pipeline(state.build_args.clone());
|
||||
let mut text_gen = state.text_generation.lock().await;
|
||||
*text_gen = fresh_text_gen;
|
||||
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
|
||||
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
|
||||
@@ -752,10 +764,11 @@ mod tests {
|
||||
println!("[DEBUG_LOG] Creating pipeline with model: {}", args.model_id);
|
||||
|
||||
// This should reproduce the same conditions as the curl script
|
||||
let text_generation = build_pipeline(args);
|
||||
let text_generation = build_pipeline(args.clone());
|
||||
let app_state = AppState {
|
||||
text_generation: Arc::new(Mutex::new(text_generation)),
|
||||
model_id: "gemma-3-1b-it".to_string(),
|
||||
build_args: args,
|
||||
};
|
||||
|
||||
// Create the same request as the curl script
|
||||
|
Reference in New Issue
Block a user