mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
cleanup, add ci
This commit is contained in:
@@ -30,4 +30,4 @@ pub trait ModelInference {
|
||||
}
|
||||
|
||||
/// Factory function type for creating model inference implementations
|
||||
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;
|
||||
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;
|
||||
|
@@ -1,19 +1,19 @@
|
||||
// Expose modules for testing and library usage
|
||||
pub mod token_output_stream;
|
||||
pub mod model;
|
||||
pub mod text_generation;
|
||||
pub mod utilities_lib;
|
||||
pub mod openai_types;
|
||||
pub mod text_generation;
|
||||
pub mod token_output_stream;
|
||||
pub mod utilities_lib;
|
||||
// pub mod cli;
|
||||
pub mod server;
|
||||
pub mod inference;
|
||||
pub mod server;
|
||||
|
||||
// Re-export key components for easier access
|
||||
pub use inference::ModelInference;
|
||||
pub use model::{Model, Which};
|
||||
pub use server::{create_router, AppState};
|
||||
pub use text_generation::TextGeneration;
|
||||
pub use token_output_stream::TokenOutputStream;
|
||||
pub use server::{AppState, create_router};
|
||||
pub use inference::ModelInference;
|
||||
|
||||
use std::env;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
@@ -1,8 +1,8 @@
|
||||
// use candle_core::Tensor;
|
||||
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
pub enum Which {
|
||||
@@ -52,7 +52,11 @@ pub enum Model {
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
input_ids: &candle_core::Tensor,
|
||||
pos: usize,
|
||||
) -> candle_core::Result<candle_core::Tensor> {
|
||||
match self {
|
||||
Self::V1(m) => m.forward(input_ids, pos),
|
||||
Self::V2(m) => m.forward(input_ids, pos),
|
||||
@@ -88,7 +92,13 @@ impl Which {
|
||||
|
||||
pub fn is_instruct_model(&self) -> bool {
|
||||
match self {
|
||||
Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false,
|
||||
Self::Base2B
|
||||
| Self::Base7B
|
||||
| Self::CodeBase2B
|
||||
| Self::CodeBase7B
|
||||
| Self::BaseV2_2B
|
||||
| Self::BaseV2_9B
|
||||
| Self::BaseV3_1B => false,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
@@ -100,4 +110,4 @@ impl Which {
|
||||
pub fn is_llama_model(&self) -> bool {
|
||||
matches!(self, Self::LlamaInstruct3_2_1B | Self::LlamaInstruct3_2_3B)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -10,7 +10,10 @@ pub struct MessageInnerContent(
|
||||
);
|
||||
|
||||
impl ToSchema<'_> for MessageInnerContent {
|
||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||
fn schema() -> (
|
||||
&'static str,
|
||||
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
|
||||
) {
|
||||
(
|
||||
"MessageInnerContent",
|
||||
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
||||
@@ -45,12 +48,18 @@ fn message_inner_content_schema() -> utoipa::openapi::Schema {
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MessageContent(
|
||||
#[serde(with = "either::serde_untagged")]
|
||||
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
||||
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
||||
);
|
||||
|
||||
impl ToSchema<'_> for MessageContent {
|
||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
|
||||
fn schema() -> (
|
||||
&'static str,
|
||||
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
|
||||
) {
|
||||
(
|
||||
"MessageContent",
|
||||
utoipa::openapi::RefOr::T(message_content_schema()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,4 +222,4 @@ pub struct ModelListResponse {
|
||||
pub object: String,
|
||||
/// Array of available models
|
||||
pub data: Vec<Model>,
|
||||
}
|
||||
}
|
||||
|
@@ -6,19 +6,22 @@ use axum::{
|
||||
Json, Router,
|
||||
};
|
||||
use futures_util::stream::{self, Stream};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
|
||||
use crate::openai_types::{
|
||||
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest,
|
||||
ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage,
|
||||
};
|
||||
use crate::Which;
|
||||
use either::Either;
|
||||
use serde_json::Value;
|
||||
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
|
||||
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
|
||||
use serde_json::Value;
|
||||
// -------------------------
|
||||
// Shared app state
|
||||
// -------------------------
|
||||
@@ -62,12 +65,15 @@ fn normalize_model_id(model_id: &str) -> String {
|
||||
|
||||
fn build_gemma_prompt(messages: &[Message]) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
|
||||
for message in messages {
|
||||
match message.role.as_str() {
|
||||
"system" => {
|
||||
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
||||
prompt.push_str(&format!("<start_of_turn>system\n{}<end_of_turn>\n", content));
|
||||
prompt.push_str(&format!(
|
||||
"<start_of_turn>system\n{}<end_of_turn>\n",
|
||||
content
|
||||
));
|
||||
}
|
||||
}
|
||||
"user" => {
|
||||
@@ -83,7 +89,7 @@ fn build_gemma_prompt(messages: &[Message]) -> String {
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
prompt.push_str("<start_of_turn>model\n");
|
||||
prompt
|
||||
}
|
||||
@@ -97,9 +103,13 @@ pub async fn chat_completions(
|
||||
Json(request): Json<ChatCompletionRequest>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
if !request.stream.unwrap_or(false) {
|
||||
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response());
|
||||
return Ok(chat_completions_non_streaming_proxy(state, request)
|
||||
.await
|
||||
.into_response());
|
||||
}
|
||||
Ok(chat_completions_stream(state, request).await.into_response())
|
||||
Ok(chat_completions_stream(state, request)
|
||||
.await
|
||||
.into_response())
|
||||
}
|
||||
|
||||
pub async fn chat_completions_non_streaming_proxy(
|
||||
@@ -136,7 +146,9 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||
ModelType::Llama => {
|
||||
// For Llama, just use the last user message for now
|
||||
request.messages.last()
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
@@ -147,46 +159,47 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
};
|
||||
|
||||
// Get streaming receiver based on model type
|
||||
let rx = match state.model_type {
|
||||
ModelType::Gemma => {
|
||||
if let Some(mut config) = state.gemma_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_gemma_api(config).map_err(|e| (
|
||||
let rx =
|
||||
match state.model_type {
|
||||
ModelType::Gemma => {
|
||||
if let Some(mut config) = state.gemma_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_gemma_api(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
}))
|
||||
));
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
ModelType::Llama => {
|
||||
if let Some(mut config) = state.llama_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
ModelType::Llama => {
|
||||
if let Some(mut config) = state.llama_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
}))
|
||||
));
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Collect all tokens from the stream
|
||||
let mut completion = String::new();
|
||||
@@ -281,7 +294,9 @@ async fn handle_streaming_request(
|
||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||
ModelType::Llama => {
|
||||
// For Llama, just use the last user message for now
|
||||
request.messages.last()
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
@@ -303,7 +318,10 @@ async fn handle_streaming_request(
|
||||
model: model_id.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: Some("assistant".to_string()), content: None },
|
||||
delta: Delta {
|
||||
role: Some("assistant".to_string()),
|
||||
content: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
@@ -324,7 +342,7 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -333,7 +351,7 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -348,7 +366,7 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -357,7 +375,7 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -386,16 +404,20 @@ async fn handle_streaming_request(
|
||||
if recent_tokens.len() > REPETITION_WINDOW {
|
||||
recent_tokens.remove(0);
|
||||
}
|
||||
|
||||
|
||||
// Check for repetitive patterns
|
||||
if recent_tokens.len() >= 4 {
|
||||
let last_token = &recent_tokens[recent_tokens.len() - 1];
|
||||
let second_last = &recent_tokens[recent_tokens.len() - 2];
|
||||
|
||||
|
||||
if last_token == second_last {
|
||||
repetition_count += 1;
|
||||
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
|
||||
|
||||
tracing::warn!(
|
||||
"Detected repetition pattern: '{}' (count: {})",
|
||||
last_token,
|
||||
repetition_count
|
||||
);
|
||||
|
||||
if repetition_count >= MAX_REPETITION_COUNT {
|
||||
tracing::info!("Stopping generation due to excessive repetition");
|
||||
break;
|
||||
@@ -412,11 +434,14 @@ async fn handle_streaming_request(
|
||||
model: model_id_clone.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: None, content: Some(token) },
|
||||
delta: Delta {
|
||||
role: None,
|
||||
content: Some(token),
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&chunk) {
|
||||
let _ = tx.send(Ok(Event::default().data(json)));
|
||||
}
|
||||
@@ -436,7 +461,10 @@ async fn handle_streaming_request(
|
||||
model: model_id_clone.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: None, content: None },
|
||||
delta: Delta {
|
||||
role: None,
|
||||
content: None,
|
||||
},
|
||||
finish_reason: Some("stop".to_string()),
|
||||
}],
|
||||
};
|
||||
@@ -451,8 +479,6 @@ async fn handle_streaming_request(
|
||||
Ok(Sse::new(stream))
|
||||
}
|
||||
|
||||
|
||||
|
||||
// -------------------------
|
||||
// Router
|
||||
// -------------------------
|
||||
@@ -647,7 +673,6 @@ pub async fn list_models() -> Json<ModelListResponse> {
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -681,10 +706,7 @@ mod tests {
|
||||
|
||||
let prompt = build_gemma_prompt(&messages);
|
||||
|
||||
let expected = "<start_of_turn>user\nSystem message\n\nKnock knock.<end_of_turn>\n\
|
||||
<start_of_turn>model\nWho's there?<end_of_turn>\n\
|
||||
<start_of_turn>user\nGemma.<end_of_turn>\n\
|
||||
<start_of_turn>model\n";
|
||||
let expected = "<start_of_turn>system\nSystem message<end_of_turn>\n<start_of_turn>user\nKnock knock.<end_of_turn>\n<start_of_turn>model\nWho's there?<end_of_turn>\n<start_of_turn>user\nGemma.<end_of_turn>\n<start_of_turn>model\n";
|
||||
|
||||
assert_eq!(prompt, expected);
|
||||
}
|
||||
@@ -698,15 +720,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_missing_content() {
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: None,
|
||||
name: None,
|
||||
}
|
||||
];
|
||||
let messages = vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: None,
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let prompt = build_gemma_prompt(&messages);
|
||||
assert_eq!(prompt, "<start_of_turn>user\n<end_of_turn>\n<start_of_turn>model\n");
|
||||
assert_eq!(prompt, "<start_of_turn>model\n");
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -84,4 +84,4 @@ impl TokenOutputStream {
|
||||
self.prev_index = 0;
|
||||
self.current_index = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -147,7 +147,8 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
||||
) -> Result<Vec<std::path::PathBuf>> {
|
||||
let path = path.as_ref();
|
||||
let jsfile = std::fs::File::open(path.join(json_file))?;
|
||||
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
None => candle_core::bail!("no weight map in {json_file:?}"),
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
@@ -164,4 +165,4 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
||||
.map(|v| path.join(v))
|
||||
.collect();
|
||||
Ok(safetensors_files)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user