mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
supports small llama and gemma models
Refactor inference dedicated crates for llama and gemma inferencing, not integrated
This commit is contained in:
@@ -5,304 +5,85 @@ use axum::{
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use candle_core::DType;
|
||||
use candle_nn::VarBuilder;
|
||||
use futures_util::stream::{self, Stream};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use std::convert::Infallible;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
|
||||
use crate::text_generation::TextGeneration;
|
||||
use crate::{utilities_lib, Model as GemmaModel, Which};
|
||||
use crate::Which;
|
||||
use either::Either;
|
||||
use hf_hub::api::sync::{Api, ApiError};
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
use serde_json::Value;
|
||||
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
|
||||
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
|
||||
// -------------------------
|
||||
// Shared app state
|
||||
// -------------------------
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ModelType {
|
||||
Gemma,
|
||||
Llama,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub text_generation: Arc<Mutex<TextGeneration>>,
|
||||
pub model_type: ModelType,
|
||||
pub model_id: String,
|
||||
// Store build args to recreate TextGeneration when needed
|
||||
pub build_args: PipelineArgs,
|
||||
pub gemma_config: Option<GemmaInferenceConfig>,
|
||||
pub llama_config: Option<LlamaInferenceConfig>,
|
||||
}
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
let args = PipelineArgs::default();
|
||||
let text_generation = build_pipeline(args.clone());
|
||||
let gemma_config = GemmaInferenceConfig {
|
||||
model: gemma_runner::WhichModel::InstructV3_1B,
|
||||
..Default::default()
|
||||
};
|
||||
Self {
|
||||
text_generation: Arc::new(Mutex::new(text_generation)),
|
||||
model_id: args.model_id.clone(),
|
||||
build_args: args,
|
||||
model_type: ModelType::Gemma,
|
||||
model_id: "gemma-3-1b-it".to_string(),
|
||||
gemma_config: Some(gemma_config),
|
||||
llama_config: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------
|
||||
// Pipeline configuration
|
||||
// Helper functions
|
||||
// -------------------------
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PipelineArgs {
|
||||
pub model_id: String,
|
||||
pub which: Which,
|
||||
pub revision: Option<String>,
|
||||
pub tokenizer_path: Option<PathBuf>,
|
||||
pub config_path: Option<PathBuf>,
|
||||
pub weight_paths: Vec<PathBuf>,
|
||||
pub use_flash_attn: bool,
|
||||
pub force_cpu: bool,
|
||||
pub seed: u64,
|
||||
pub temperature: Option<f64>,
|
||||
pub top_p: Option<f64>,
|
||||
pub repeat_penalty: f32,
|
||||
pub repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl Default for PipelineArgs {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model_id: Which::InstructV3_1B.to_model_id().to_string(),
|
||||
which: Which::InstructV3_1B,
|
||||
revision: None,
|
||||
tokenizer_path: None,
|
||||
config_path: None,
|
||||
weight_paths: Vec::new(),
|
||||
use_flash_attn: false,
|
||||
force_cpu: false,
|
||||
seed: 299792458, // Speed of light in vacuum (m/s)
|
||||
temperature: Some(0.8), // Good balance between creativity and coherence
|
||||
top_p: Some(0.9), // Keep diverse but reasonable options
|
||||
repeat_penalty: 1.2, // Stronger penalty for repetition to prevent looping
|
||||
repeat_last_n: 64, // Consider last 64 tokens for repetition
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_model_id(model_id: &str) -> String {
|
||||
if model_id.contains('/') {
|
||||
model_id.to_string()
|
||||
} else {
|
||||
format!("google/{}", model_id)
|
||||
}
|
||||
}
|
||||
|
||||
fn ensure_repo_exists(api: &Api, model_id: &str, revision: &str) -> anyhow::Result<()> {
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id.to_string(),
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
));
|
||||
match repo.get("config.json") {
|
||||
Ok(_) => Ok(()),
|
||||
Err(e) => match e {
|
||||
ApiError::RequestError(resp) => {
|
||||
let error_str = resp.to_string();
|
||||
if error_str.contains("404") {
|
||||
anyhow::bail!(
|
||||
"Hugging Face model repo not found: '{model_id}' at revision '{revision}'."
|
||||
)
|
||||
}
|
||||
Err(anyhow::Error::new(ApiError::RequestError(resp)))
|
||||
}
|
||||
other => Err(anyhow::Error::new(other)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------
|
||||
// Pipeline builder
|
||||
// -------------------------
|
||||
|
||||
pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle_core::utils::with_avx(),
|
||||
candle_core::utils::with_neon(),
|
||||
candle_core::utils::with_simd128(),
|
||||
candle_core::utils::with_f16c()
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new().unwrap();
|
||||
let revision = args.revision.as_deref().unwrap_or("main");
|
||||
|
||||
if args.model_id.trim().is_empty() {
|
||||
panic!("No model ID specified.");
|
||||
}
|
||||
args.model_id = normalize_model_id(&args.model_id);
|
||||
|
||||
match ensure_repo_exists(&api, &args.model_id, revision) {
|
||||
Ok(_) => {}
|
||||
Err(e) => panic!("{}", e),
|
||||
};
|
||||
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id.clone(),
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
));
|
||||
|
||||
let tokenizer_path = args
|
||||
.tokenizer_path
|
||||
.unwrap_or_else(|| repo.get("tokenizer.json").unwrap());
|
||||
let config_path = args
|
||||
.config_path
|
||||
.unwrap_or_else(|| repo.get("config.json").unwrap());
|
||||
|
||||
if !matches!(
|
||||
args.which,
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
| Which::Instruct2B
|
||||
| Which::Instruct7B
|
||||
| Which::InstructV1_1_2B
|
||||
| Which::InstructV1_1_7B
|
||||
| Which::CodeBase2B
|
||||
| Which::CodeBase7B
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B
|
||||
| Which::BaseV2_2B
|
||||
| Which::InstructV2_2B
|
||||
| Which::BaseV2_9B
|
||||
| Which::InstructV2_9B
|
||||
| Which::BaseV3_1B
|
||||
| Which::InstructV3_1B
|
||||
) {
|
||||
if args.model_id.contains("gemma-2-2b-it") {
|
||||
args.which = Which::InstructV2_2B;
|
||||
} else if args.model_id.contains("gemma-3-1b-it") {
|
||||
args.which = Which::InstructV3_1B;
|
||||
} else if let Ok(file) = std::fs::File::open(config_path.clone()) {
|
||||
if let Ok(cfg_val) = serde_json::from_reader::<_, serde_json::Value>(file) {
|
||||
if let Some(model_type) = cfg_val.get("model_type").and_then(|v| v.as_str()) {
|
||||
if model_type.contains("gemma3") {
|
||||
args.which = Which::InstructV3_1B;
|
||||
} else if model_type.contains("gemma2") {
|
||||
args.which = Which::InstructV2_2B;
|
||||
} else {
|
||||
args.which = Which::Instruct2B;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let weight_paths = if !args.weight_paths.is_empty() {
|
||||
args.weight_paths
|
||||
} else {
|
||||
match repo.get("model.safetensors") {
|
||||
Ok(single) => vec![single],
|
||||
Err(_) => match utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json") {
|
||||
Ok(paths) => paths,
|
||||
Err(e) => {
|
||||
panic!("Unable to locate model weights: {}", e);
|
||||
}
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
|
||||
|
||||
let initial_device = utilities_lib::device(args.force_cpu).unwrap();
|
||||
let is_v3_model = args.which.is_v3_model();
|
||||
let is_metal = !initial_device.is_cpu()
|
||||
&& candle_core::utils::metal_is_available()
|
||||
&& !args.force_cpu;
|
||||
|
||||
let device = if is_v3_model && is_metal {
|
||||
candle_core::Device::Cpu
|
||||
} else {
|
||||
initial_device
|
||||
};
|
||||
|
||||
let dtype = if device.is_cuda() { DType::BF16 } else { DType::F32 };
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_paths, dtype, &device).unwrap() };
|
||||
|
||||
let model = match args.which {
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
| Which::Instruct2B
|
||||
| Which::Instruct7B
|
||||
| Which::InstructV1_1_2B
|
||||
| Which::InstructV1_1_7B
|
||||
| Which::CodeBase2B
|
||||
| Which::CodeBase7B
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B => {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
||||
GemmaModel::V1(Model1::new(args.use_flash_attn, &config, vb).unwrap())
|
||||
}
|
||||
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
||||
GemmaModel::V2(Model2::new(args.use_flash_attn, &config, vb).unwrap())
|
||||
}
|
||||
Which::BaseV3_1B | Which::InstructV3_1B => {
|
||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_path).unwrap()).unwrap();
|
||||
GemmaModel::V3(Model3::new(args.use_flash_attn, &config, vb).unwrap())
|
||||
}
|
||||
};
|
||||
|
||||
TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
)
|
||||
model_id.to_lowercase().replace("_", "-")
|
||||
}
|
||||
|
||||
fn build_gemma_prompt(messages: &[Message]) -> String {
|
||||
let mut prompt = String::new();
|
||||
let mut system_prompt: Option<String> = None;
|
||||
|
||||
|
||||
for message in messages {
|
||||
let content = match &message.content {
|
||||
Some(content) => match &content.0 {
|
||||
Either::Left(text) => text.clone(),
|
||||
Either::Right(_) => "".to_string(),
|
||||
},
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
match message.role.as_str() {
|
||||
"system" => system_prompt = Some(content),
|
||||
"user" => {
|
||||
prompt.push_str("<start_of_turn>user\n");
|
||||
if let Some(sys_prompt) = system_prompt.take() {
|
||||
prompt.push_str(&sys_prompt);
|
||||
prompt.push_str("\n\n");
|
||||
"system" => {
|
||||
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
||||
prompt.push_str(&format!("<start_of_turn>system\n{}<end_of_turn>\n", content));
|
||||
}
|
||||
}
|
||||
"user" => {
|
||||
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
||||
prompt.push_str(&format!("<start_of_turn>user\n{}<end_of_turn>\n", content));
|
||||
}
|
||||
prompt.push_str(&content);
|
||||
prompt.push_str("<end_of_turn>\n");
|
||||
}
|
||||
"assistant" => {
|
||||
prompt.push_str("<start_of_turn>model\n");
|
||||
prompt.push_str(&content);
|
||||
prompt.push_str("<end_of_turn>\n");
|
||||
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
||||
prompt.push_str(&format!("<start_of_turn>model\n{}<end_of_turn>\n", content));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
prompt.push_str("<start_of_turn>model\n");
|
||||
prompt
|
||||
}
|
||||
@@ -325,14 +106,13 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
state: AppState,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||
let prompt = build_gemma_prompt(&request.messages);
|
||||
|
||||
// Enforce model selection behavior: reject if a different model is requested
|
||||
let configured_model = state.build_args.model_id.clone();
|
||||
let configured_model = state.model_id.clone();
|
||||
let requested_model = request.model.clone();
|
||||
if requested_model.to_lowercase() != "default" {
|
||||
let normalized_requested = normalize_model_id(&requested_model);
|
||||
if normalized_requested != configured_model {
|
||||
let normalized_configured = normalize_model_id(&configured_model);
|
||||
if normalized_requested != normalized_configured {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
@@ -349,35 +129,81 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
}
|
||||
|
||||
let model_id = state.model_id.clone();
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
{
|
||||
let mut text_gen = state.text_generation.lock().await;
|
||||
// Reset per-request state without rebuilding the whole pipeline
|
||||
text_gen.reset_state();
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error generating text: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let completion = match String::from_utf8(buffer) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("UTF-8 conversion error: {}", e) }
|
||||
})),
|
||||
));
|
||||
// Build prompt based on model type
|
||||
let prompt = match state.model_type {
|
||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||
ModelType::Llama => {
|
||||
// For Llama, just use the last user message for now
|
||||
request.messages.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
};
|
||||
|
||||
// Get streaming receiver based on model type
|
||||
let rx = match state.model_type {
|
||||
ModelType::Gemma => {
|
||||
if let Some(mut config) = state.gemma_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_gemma_api(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
}))
|
||||
));
|
||||
}
|
||||
}
|
||||
ModelType::Llama => {
|
||||
if let Some(mut config) = state.llama_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
}))
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Collect all tokens from the stream
|
||||
let mut completion = String::new();
|
||||
while let Ok(token_result) = rx.recv() {
|
||||
match token_result {
|
||||
Ok(token) => completion.push_str(&token),
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error generating text: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response = ChatCompletionResponse {
|
||||
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', "")),
|
||||
object: "chat.completion".to_string(),
|
||||
@@ -420,11 +246,12 @@ async fn handle_streaming_request(
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
|
||||
// Validate requested model vs configured model
|
||||
let configured_model = state.build_args.model_id.clone();
|
||||
let configured_model = state.model_id.clone();
|
||||
let requested_model = request.model.clone();
|
||||
if requested_model.to_lowercase() != "default" {
|
||||
let normalized_requested = normalize_model_id(&requested_model);
|
||||
if normalized_requested != configured_model {
|
||||
let normalized_configured = normalize_model_id(&configured_model);
|
||||
if normalized_requested != normalized_configured {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
@@ -447,9 +274,22 @@ async fn handle_streaming_request(
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
let model_id = state.model_id.clone();
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
|
||||
// Build prompt
|
||||
let prompt = build_gemma_prompt(&request.messages);
|
||||
// Build prompt based on model type
|
||||
let prompt = match state.model_type {
|
||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||
ModelType::Llama => {
|
||||
// For Llama, just use the last user message for now
|
||||
request.messages.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
};
|
||||
tracing::debug!("Formatted prompt: {}", prompt);
|
||||
|
||||
// Channel for streaming SSE events
|
||||
@@ -471,80 +311,121 @@ async fn handle_streaming_request(
|
||||
let _ = tx.send(Ok(Event::default().data(json)));
|
||||
}
|
||||
|
||||
// Spawn generation task that streams tokens as they are generated
|
||||
let state_clone = state.clone();
|
||||
let response_id_clone = response_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
let mut text_gen = state_clone.text_generation.lock().await;
|
||||
text_gen.reset_state();
|
||||
// Get streaming receiver based on model type
|
||||
let model_rx = match state.model_type {
|
||||
ModelType::Gemma => {
|
||||
if let Some(mut config) = state.gemma_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
match run_gemma_api(config) {
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
}))
|
||||
));
|
||||
}
|
||||
}
|
||||
ModelType::Llama => {
|
||||
if let Some(mut config) = state.llama_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
match run_llama_inference(config) {
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
}))
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Stream tokens via callback with repetition detection
|
||||
// Spawn task to receive tokens from model and forward as SSE events
|
||||
let response_id_clone = response_id.clone();
|
||||
let model_id_clone = model_id.clone();
|
||||
tokio::spawn(async move {
|
||||
// Stream tokens with repetition detection
|
||||
let mut recent_tokens = Vec::new();
|
||||
let mut repetition_count = 0;
|
||||
const MAX_REPETITION_COUNT: usize = 5; // Stop after 5 consecutive repetitions
|
||||
const REPETITION_WINDOW: usize = 8; // Look at last 8 tokens for patterns
|
||||
|
||||
let result = text_gen.run_with_streaming(&prompt, max_tokens, |token| {
|
||||
// Debug log to verify token content
|
||||
tracing::debug!("Streaming token: '{}'", token);
|
||||
|
||||
// Skip sending empty tokens
|
||||
if token.is_empty() {
|
||||
tracing::debug!("Skipping empty token");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Add token to recent history for repetition detection
|
||||
recent_tokens.push(token.to_string());
|
||||
if recent_tokens.len() > REPETITION_WINDOW {
|
||||
recent_tokens.remove(0);
|
||||
}
|
||||
|
||||
// Check for repetitive patterns
|
||||
if recent_tokens.len() >= 4 {
|
||||
let last_token = &recent_tokens[recent_tokens.len() - 1];
|
||||
let second_last = &recent_tokens[recent_tokens.len() - 2];
|
||||
|
||||
// Check if we're repeating the same token or pattern
|
||||
if last_token == second_last ||
|
||||
(last_token.trim() == "plus" && second_last.trim() == "plus") ||
|
||||
(recent_tokens.len() >= 6 &&
|
||||
recent_tokens[recent_tokens.len()-3..].iter().all(|t| t.trim() == "plus" || t.trim().is_empty())) {
|
||||
repetition_count += 1;
|
||||
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
|
||||
|
||||
if repetition_count >= MAX_REPETITION_COUNT {
|
||||
tracing::info!("Stopping generation due to excessive repetition");
|
||||
return Err(anyhow::Error::msg("Repetition detected - stopping generation"));
|
||||
const MAX_REPETITION_COUNT: usize = 5;
|
||||
const REPETITION_WINDOW: usize = 8;
|
||||
|
||||
while let Ok(token_result) = model_rx.recv() {
|
||||
match token_result {
|
||||
Ok(token) => {
|
||||
// Skip sending empty tokens
|
||||
if token.is_empty() {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
repetition_count = 0; // Reset counter if pattern breaks
|
||||
|
||||
// Add token to recent history for repetition detection
|
||||
recent_tokens.push(token.clone());
|
||||
if recent_tokens.len() > REPETITION_WINDOW {
|
||||
recent_tokens.remove(0);
|
||||
}
|
||||
|
||||
// Check for repetitive patterns
|
||||
if recent_tokens.len() >= 4 {
|
||||
let last_token = &recent_tokens[recent_tokens.len() - 1];
|
||||
let second_last = &recent_tokens[recent_tokens.len() - 2];
|
||||
|
||||
if last_token == second_last {
|
||||
repetition_count += 1;
|
||||
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
|
||||
|
||||
if repetition_count >= MAX_REPETITION_COUNT {
|
||||
tracing::info!("Stopping generation due to excessive repetition");
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
repetition_count = 0;
|
||||
}
|
||||
}
|
||||
|
||||
let chunk = ChatCompletionChunk {
|
||||
id: response_id_clone.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model_id_clone.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: None, content: Some(token) },
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&chunk) {
|
||||
let _ = tx.send(Ok(Event::default().data(json)));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::info!("Text generation stopped: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let chunk = ChatCompletionChunk {
|
||||
id: response_id_clone.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model_id.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: None, content: Some(token.to_string()) },
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&chunk) {
|
||||
tracing::debug!("Sending chunk with content: '{}'", token);
|
||||
let _ = tx.send(Ok(Event::default().data(json)));
|
||||
}
|
||||
Ok(())
|
||||
}).await;
|
||||
|
||||
// Log result of generation
|
||||
match result {
|
||||
Ok(_) => tracing::debug!("Text generation completed successfully"),
|
||||
Err(e) => tracing::info!("Text generation stopped: {}", e),
|
||||
}
|
||||
|
||||
// Send final stop chunk and DONE marker
|
||||
@@ -552,7 +433,7 @@ async fn handle_streaming_request(
|
||||
id: response_id_clone.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model_id.clone(),
|
||||
model: model_id_clone.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: None, content: None },
|
||||
@@ -594,6 +475,7 @@ pub fn create_router(app_state: AppState) -> Router {
|
||||
pub async fn list_models() -> Json<ModelListResponse> {
|
||||
// Get all available model variants from the Which enum
|
||||
let models = vec![
|
||||
// Gemma models
|
||||
Model {
|
||||
id: "gemma-2b".to_string(),
|
||||
object: "model".to_string(),
|
||||
@@ -690,6 +572,73 @@ pub async fn list_models() -> Json<ModelListResponse> {
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
// Llama models
|
||||
Model {
|
||||
id: "llama-3.2-1b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "llama-3.2-1b-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "llama-3.2-3b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "llama-3.2-3b-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-135m".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-135m-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-360m".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-360m-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-1.7b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-1.7b-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "tinyllama-1.1b-chat".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "tinyllama".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
Json(ModelListResponse {
|
||||
|
Reference in New Issue
Block a user