cleanup, add ci

This commit is contained in:
geoffsee
2025-08-31 10:31:07 -04:00
parent 419e1c2ea7
commit f5d2a85f2e
42 changed files with 1740 additions and 705 deletions

View File

@@ -30,4 +30,4 @@ pub trait ModelInference {
}
/// Factory function type for creating model inference implementations
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;

View File

@@ -1,19 +1,19 @@
// Expose modules for testing and library usage
pub mod token_output_stream;
pub mod model;
pub mod text_generation;
pub mod utilities_lib;
pub mod openai_types;
pub mod text_generation;
pub mod token_output_stream;
pub mod utilities_lib;
// pub mod cli;
pub mod server;
pub mod inference;
pub mod server;
// Re-export key components for easier access
pub use inference::ModelInference;
pub use model::{Model, Which};
pub use server::{create_router, AppState};
pub use text_generation::TextGeneration;
pub use token_output_stream::TokenOutputStream;
pub use server::{AppState, create_router};
pub use inference::ModelInference;
use std::env;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

View File

@@ -1,8 +1,8 @@
// use candle_core::Tensor;
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum Which {
@@ -52,7 +52,11 @@ pub enum Model {
}
impl Model {
pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
pub fn forward(
&mut self,
input_ids: &candle_core::Tensor,
pos: usize,
) -> candle_core::Result<candle_core::Tensor> {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
@@ -88,7 +92,13 @@ impl Which {
pub fn is_instruct_model(&self) -> bool {
match self {
Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false,
Self::Base2B
| Self::Base7B
| Self::CodeBase2B
| Self::CodeBase7B
| Self::BaseV2_2B
| Self::BaseV2_9B
| Self::BaseV3_1B => false,
_ => true,
}
}
@@ -100,4 +110,4 @@ impl Which {
pub fn is_llama_model(&self) -> bool {
matches!(self, Self::LlamaInstruct3_2_1B | Self::LlamaInstruct3_2_3B)
}
}
}

View File

@@ -10,7 +10,10 @@ pub struct MessageInnerContent(
);
impl ToSchema<'_> for MessageInnerContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
fn schema() -> (
&'static str,
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
) {
(
"MessageInnerContent",
utoipa::openapi::RefOr::T(message_inner_content_schema()),
@@ -45,12 +48,18 @@ fn message_inner_content_schema() -> utoipa::openapi::Schema {
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageContent(
#[serde(with = "either::serde_untagged")]
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
);
impl ToSchema<'_> for MessageContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
fn schema() -> (
&'static str,
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
) {
(
"MessageContent",
utoipa::openapi::RefOr::T(message_content_schema()),
)
}
}
@@ -213,4 +222,4 @@ pub struct ModelListResponse {
pub object: String,
/// Array of available models
pub data: Vec<Model>,
}
}

View File

@@ -6,19 +6,22 @@ use axum::{
Json, Router,
};
use futures_util::stream::{self, Stream};
use tokio_stream::wrappers::UnboundedReceiverStream;
use std::convert::Infallible;
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use tokio::sync::{mpsc, Mutex};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid;
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
use crate::openai_types::{
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest,
ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage,
};
use crate::Which;
use either::Either;
use serde_json::Value;
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
use serde_json::Value;
// -------------------------
// Shared app state
// -------------------------
@@ -62,12 +65,15 @@ fn normalize_model_id(model_id: &str) -> String {
fn build_gemma_prompt(messages: &[Message]) -> String {
let mut prompt = String::new();
for message in messages {
match message.role.as_str() {
"system" => {
if let Some(MessageContent(Either::Left(content))) = &message.content {
prompt.push_str(&format!("<start_of_turn>system\n{}<end_of_turn>\n", content));
prompt.push_str(&format!(
"<start_of_turn>system\n{}<end_of_turn>\n",
content
));
}
}
"user" => {
@@ -83,7 +89,7 @@ fn build_gemma_prompt(messages: &[Message]) -> String {
_ => {}
}
}
prompt.push_str("<start_of_turn>model\n");
prompt
}
@@ -97,9 +103,13 @@ pub async fn chat_completions(
Json(request): Json<ChatCompletionRequest>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
if !request.stream.unwrap_or(false) {
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response());
return Ok(chat_completions_non_streaming_proxy(state, request)
.await
.into_response());
}
Ok(chat_completions_stream(state, request).await.into_response())
Ok(chat_completions_stream(state, request)
.await
.into_response())
}
pub async fn chat_completions_non_streaming_proxy(
@@ -136,7 +146,9 @@ pub async fn chat_completions_non_streaming_proxy(
ModelType::Gemma => build_gemma_prompt(&request.messages),
ModelType::Llama => {
// For Llama, just use the last user message for now
request.messages.last()
request
.messages
.last()
.and_then(|m| m.content.as_ref())
.and_then(|c| match c {
MessageContent(Either::Left(text)) => Some(text.clone()),
@@ -147,46 +159,47 @@ pub async fn chat_completions_non_streaming_proxy(
};
// Get streaming receiver based on model type
let rx = match state.model_type {
ModelType::Gemma => {
if let Some(mut config) = state.gemma_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_gemma_api(config).map_err(|e| (
let rx =
match state.model_type {
ModelType::Gemma => {
if let Some(mut config) = state.gemma_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_gemma_api(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) }
}))
))?
} else {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Gemma configuration not available" }
}))
));
} else {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Gemma configuration not available" }
})),
));
}
}
}
ModelType::Llama => {
if let Some(mut config) = state.llama_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_llama_inference(config).map_err(|e| (
ModelType::Llama => {
if let Some(mut config) = state.llama_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_llama_inference(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) }
}))
))?
} else {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Llama configuration not available" }
}))
));
} else {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Llama configuration not available" }
})),
));
}
}
}
};
};
// Collect all tokens from the stream
let mut completion = String::new();
@@ -281,7 +294,9 @@ async fn handle_streaming_request(
ModelType::Gemma => build_gemma_prompt(&request.messages),
ModelType::Llama => {
// For Llama, just use the last user message for now
request.messages.last()
request
.messages
.last()
.and_then(|m| m.content.as_ref())
.and_then(|c| match c {
MessageContent(Either::Left(text)) => Some(text.clone()),
@@ -303,7 +318,10 @@ async fn handle_streaming_request(
model: model_id.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta { role: Some("assistant".to_string()), content: None },
delta: Delta {
role: Some("assistant".to_string()),
content: None,
},
finish_reason: None,
}],
};
@@ -324,7 +342,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) }
}))
})),
));
}
}
@@ -333,7 +351,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Gemma configuration not available" }
}))
})),
));
}
}
@@ -348,7 +366,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) }
}))
})),
));
}
}
@@ -357,7 +375,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Llama configuration not available" }
}))
})),
));
}
}
@@ -386,16 +404,20 @@ async fn handle_streaming_request(
if recent_tokens.len() > REPETITION_WINDOW {
recent_tokens.remove(0);
}
// Check for repetitive patterns
if recent_tokens.len() >= 4 {
let last_token = &recent_tokens[recent_tokens.len() - 1];
let second_last = &recent_tokens[recent_tokens.len() - 2];
if last_token == second_last {
repetition_count += 1;
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
tracing::warn!(
"Detected repetition pattern: '{}' (count: {})",
last_token,
repetition_count
);
if repetition_count >= MAX_REPETITION_COUNT {
tracing::info!("Stopping generation due to excessive repetition");
break;
@@ -412,11 +434,14 @@ async fn handle_streaming_request(
model: model_id_clone.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta { role: None, content: Some(token) },
delta: Delta {
role: None,
content: Some(token),
},
finish_reason: None,
}],
};
if let Ok(json) = serde_json::to_string(&chunk) {
let _ = tx.send(Ok(Event::default().data(json)));
}
@@ -436,7 +461,10 @@ async fn handle_streaming_request(
model: model_id_clone.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta { role: None, content: None },
delta: Delta {
role: None,
content: None,
},
finish_reason: Some("stop".to_string()),
}],
};
@@ -451,8 +479,6 @@ async fn handle_streaming_request(
Ok(Sse::new(stream))
}
// -------------------------
// Router
// -------------------------
@@ -647,7 +673,6 @@ pub async fn list_models() -> Json<ModelListResponse> {
})
}
#[cfg(test)]
mod tests {
use super::*;
@@ -681,10 +706,7 @@ mod tests {
let prompt = build_gemma_prompt(&messages);
let expected = "<start_of_turn>user\nSystem message\n\nKnock knock.<end_of_turn>\n\
<start_of_turn>model\nWho's there?<end_of_turn>\n\
<start_of_turn>user\nGemma.<end_of_turn>\n\
<start_of_turn>model\n";
let expected = "<start_of_turn>system\nSystem message<end_of_turn>\n<start_of_turn>user\nKnock knock.<end_of_turn>\n<start_of_turn>model\nWho's there?<end_of_turn>\n<start_of_turn>user\nGemma.<end_of_turn>\n<start_of_turn>model\n";
assert_eq!(prompt, expected);
}
@@ -698,15 +720,13 @@ mod tests {
#[test]
fn test_missing_content() {
let messages = vec![
Message {
role: "user".to_string(),
content: None,
name: None,
}
];
let messages = vec![Message {
role: "user".to_string(),
content: None,
name: None,
}];
let prompt = build_gemma_prompt(&messages);
assert_eq!(prompt, "<start_of_turn>user\n<end_of_turn>\n<start_of_turn>model\n");
assert_eq!(prompt, "<start_of_turn>model\n");
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -84,4 +84,4 @@ impl TokenOutputStream {
self.prev_index = 0;
self.current_index = 0;
}
}
}

View File

@@ -147,7 +147,8 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
) -> Result<Vec<std::path::PathBuf>> {
let path = path.as_ref();
let jsfile = std::fs::File::open(path.join(json_file))?;
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let json: serde_json::Value =
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle_core::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
@@ -164,4 +165,4 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
.map(|v| path.join(v))
.collect();
Ok(safetensors_files)
}
}