mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
run cargo fmt
This commit is contained in:
@@ -194,57 +194,57 @@ pub fn send_chat_completion_stream(
|
||||
) {
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen::JsCast;
|
||||
|
||||
|
||||
let request = ChatRequest {
|
||||
model,
|
||||
messages,
|
||||
max_tokens: Some(1024),
|
||||
stream: Some(true),
|
||||
};
|
||||
|
||||
|
||||
// We need to send a POST request but EventSource only supports GET
|
||||
// So we'll use fetch with a readable stream instead
|
||||
let window = web_sys::window().unwrap();
|
||||
let request_json = serde_json::to_string(&request).unwrap();
|
||||
|
||||
|
||||
let opts = web_sys::RequestInit::new();
|
||||
opts.set_method("POST");
|
||||
opts.set_body(&JsValue::from_str(&request_json));
|
||||
|
||||
|
||||
let headers = web_sys::Headers::new().unwrap();
|
||||
headers.set("Content-Type", "application/json").unwrap();
|
||||
headers.set("Accept", "text/event-stream").unwrap();
|
||||
opts.set_headers(&headers);
|
||||
|
||||
|
||||
let request = web_sys::Request::new_with_str_and_init("/v1/chat/completions", &opts).unwrap();
|
||||
|
||||
|
||||
let promise = window.fetch_with_request(&request);
|
||||
|
||||
|
||||
wasm_bindgen_futures::spawn_local(async move {
|
||||
match wasm_bindgen_futures::JsFuture::from(promise).await {
|
||||
Ok(resp_value) => {
|
||||
let resp: web_sys::Response = resp_value.dyn_into().unwrap();
|
||||
|
||||
|
||||
if !resp.ok() {
|
||||
on_error(format!("Server error: {}", resp.status()));
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
let body = resp.body();
|
||||
if body.is_none() {
|
||||
on_error("No response body".to_string());
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
let reader = body
|
||||
.unwrap()
|
||||
.get_reader()
|
||||
.dyn_into::<web_sys::ReadableStreamDefaultReader>()
|
||||
.unwrap();
|
||||
|
||||
|
||||
let decoder = web_sys::TextDecoder::new().unwrap();
|
||||
let mut buffer = String::new();
|
||||
|
||||
|
||||
loop {
|
||||
match wasm_bindgen_futures::JsFuture::from(reader.read()).await {
|
||||
Ok(result) => {
|
||||
@@ -252,24 +252,25 @@ pub fn send_chat_completion_stream(
|
||||
.unwrap()
|
||||
.as_bool()
|
||||
.unwrap_or(false);
|
||||
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
|
||||
let value = js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap();
|
||||
|
||||
let value =
|
||||
js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap();
|
||||
let array = js_sys::Uint8Array::new(&value);
|
||||
let mut bytes = vec![0; array.length() as usize];
|
||||
array.copy_to(&mut bytes);
|
||||
let text = decoder.decode_with_u8_array(&bytes).unwrap();
|
||||
|
||||
|
||||
buffer.push_str(&text);
|
||||
|
||||
|
||||
// Process complete SSE events from buffer
|
||||
while let Some(event_end) = buffer.find("\n\n") {
|
||||
let event = buffer[..event_end].to_string();
|
||||
buffer = buffer[event_end + 2..].to_string();
|
||||
|
||||
|
||||
// Parse SSE event
|
||||
for line in event.lines() {
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
@@ -277,9 +278,11 @@ pub fn send_chat_completion_stream(
|
||||
on_complete();
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// Parse JSON chunk
|
||||
if let Ok(chunk) = serde_json::from_str::<StreamChatResponse>(data) {
|
||||
if let Ok(chunk) =
|
||||
serde_json::from_str::<StreamChatResponse>(data)
|
||||
{
|
||||
if let Some(choice) = chunk.choices.first() {
|
||||
if let Some(content) = &choice.delta.content {
|
||||
on_chunk(content.clone());
|
||||
@@ -296,7 +299,7 @@ pub fn send_chat_completion_stream(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
on_complete();
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -366,11 +369,11 @@ fn ChatPage() -> impl IntoView {
|
||||
// State for available models and selected model
|
||||
let available_models = RwSignal::new(Vec::<ModelInfo>::new());
|
||||
let selected_model = RwSignal::new(String::from("")); // Default model
|
||||
|
||||
|
||||
// State for streaming response
|
||||
let streaming_content = RwSignal::new(String::new());
|
||||
let is_streaming = RwSignal::new(false);
|
||||
|
||||
|
||||
// State for streaming mode toggle
|
||||
let use_streaming = RwSignal::new(true); // Default to streaming
|
||||
|
||||
@@ -424,7 +427,7 @@ fn ChatPage() -> impl IntoView {
|
||||
// Clear streaming content and set streaming flag
|
||||
streaming_content.set(String::new());
|
||||
is_streaming.set(true);
|
||||
|
||||
|
||||
// Use streaming API
|
||||
send_chat_completion_stream(
|
||||
current_messages,
|
||||
|
@@ -1,5 +1,10 @@
|
||||
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
||||
use axum::{Json, Router, response::Json as ResponseJson, routing::{get, post}, http::StatusCode};
|
||||
use axum::{
|
||||
Json, Router,
|
||||
http::StatusCode,
|
||||
response::Json as ResponseJson,
|
||||
routing::{get, post},
|
||||
};
|
||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::Serialize;
|
||||
@@ -9,9 +14,8 @@ use tower_http::trace::TraceLayer;
|
||||
use tracing;
|
||||
|
||||
// Cache for multiple embedding models
|
||||
static MODEL_CACHE: Lazy<RwLock<HashMap<EmbeddingModel, Arc<TextEmbedding>>>> = Lazy::new(|| {
|
||||
RwLock::new(HashMap::new())
|
||||
});
|
||||
static MODEL_CACHE: Lazy<RwLock<HashMap<EmbeddingModel, Arc<TextEmbedding>>>> =
|
||||
Lazy::new(|| RwLock::new(HashMap::new()));
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ModelInfo {
|
||||
@@ -32,11 +36,19 @@ pub struct ModelsResponse {
|
||||
fn parse_embedding_model(model_name: &str) -> Result<EmbeddingModel, String> {
|
||||
match model_name {
|
||||
// Sentence Transformers models
|
||||
"sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => Ok(EmbeddingModel::AllMiniLML6V2),
|
||||
"sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => Ok(EmbeddingModel::AllMiniLML6V2Q),
|
||||
"sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => Ok(EmbeddingModel::AllMiniLML12V2),
|
||||
"sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => Ok(EmbeddingModel::AllMiniLML12V2Q),
|
||||
|
||||
"sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => {
|
||||
Ok(EmbeddingModel::AllMiniLML6V2)
|
||||
}
|
||||
"sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => {
|
||||
Ok(EmbeddingModel::AllMiniLML6V2Q)
|
||||
}
|
||||
"sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => {
|
||||
Ok(EmbeddingModel::AllMiniLML12V2)
|
||||
}
|
||||
"sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => {
|
||||
Ok(EmbeddingModel::AllMiniLML12V2Q)
|
||||
}
|
||||
|
||||
// BGE models
|
||||
"BAAI/bge-base-en-v1.5" | "bge-base-en-v1.5" => Ok(EmbeddingModel::BGEBaseENV15),
|
||||
"BAAI/bge-base-en-v1.5-q" | "bge-base-en-v1.5-q" => Ok(EmbeddingModel::BGEBaseENV15Q),
|
||||
@@ -46,41 +58,68 @@ fn parse_embedding_model(model_name: &str) -> Result<EmbeddingModel, String> {
|
||||
"BAAI/bge-small-en-v1.5-q" | "bge-small-en-v1.5-q" => Ok(EmbeddingModel::BGESmallENV15Q),
|
||||
"BAAI/bge-small-zh-v1.5" | "bge-small-zh-v1.5" => Ok(EmbeddingModel::BGESmallZHV15),
|
||||
"BAAI/bge-large-zh-v1.5" | "bge-large-zh-v1.5" => Ok(EmbeddingModel::BGELargeZHV15),
|
||||
|
||||
|
||||
// Nomic models
|
||||
"nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => Ok(EmbeddingModel::NomicEmbedTextV1),
|
||||
"nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => Ok(EmbeddingModel::NomicEmbedTextV15),
|
||||
"nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => Ok(EmbeddingModel::NomicEmbedTextV15Q),
|
||||
|
||||
"nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => {
|
||||
Ok(EmbeddingModel::NomicEmbedTextV1)
|
||||
}
|
||||
"nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => {
|
||||
Ok(EmbeddingModel::NomicEmbedTextV15)
|
||||
}
|
||||
"nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => {
|
||||
Ok(EmbeddingModel::NomicEmbedTextV15Q)
|
||||
}
|
||||
|
||||
// Paraphrase models
|
||||
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2),
|
||||
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q" | "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q),
|
||||
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2" | "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2),
|
||||
|
||||
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
| "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2),
|
||||
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q"
|
||||
| "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q),
|
||||
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
| "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2),
|
||||
|
||||
// ModernBert
|
||||
"lightonai/modernbert-embed-large" | "modernbert-embed-large" => Ok(EmbeddingModel::ModernBertEmbedLarge),
|
||||
|
||||
"lightonai/modernbert-embed-large" | "modernbert-embed-large" => {
|
||||
Ok(EmbeddingModel::ModernBertEmbedLarge)
|
||||
}
|
||||
|
||||
// Multilingual E5 models
|
||||
"intfloat/multilingual-e5-small" | "multilingual-e5-small" => Ok(EmbeddingModel::MultilingualE5Small),
|
||||
"intfloat/multilingual-e5-base" | "multilingual-e5-base" => Ok(EmbeddingModel::MultilingualE5Base),
|
||||
"intfloat/multilingual-e5-large" | "multilingual-e5-large" => Ok(EmbeddingModel::MultilingualE5Large),
|
||||
|
||||
"intfloat/multilingual-e5-small" | "multilingual-e5-small" => {
|
||||
Ok(EmbeddingModel::MultilingualE5Small)
|
||||
}
|
||||
"intfloat/multilingual-e5-base" | "multilingual-e5-base" => {
|
||||
Ok(EmbeddingModel::MultilingualE5Base)
|
||||
}
|
||||
"intfloat/multilingual-e5-large" | "multilingual-e5-large" => {
|
||||
Ok(EmbeddingModel::MultilingualE5Large)
|
||||
}
|
||||
|
||||
// Mixedbread models
|
||||
"mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => Ok(EmbeddingModel::MxbaiEmbedLargeV1),
|
||||
"mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => Ok(EmbeddingModel::MxbaiEmbedLargeV1Q),
|
||||
|
||||
"mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => {
|
||||
Ok(EmbeddingModel::MxbaiEmbedLargeV1)
|
||||
}
|
||||
"mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => {
|
||||
Ok(EmbeddingModel::MxbaiEmbedLargeV1Q)
|
||||
}
|
||||
|
||||
// GTE models
|
||||
"Alibaba-NLP/gte-base-en-v1.5" | "gte-base-en-v1.5" => Ok(EmbeddingModel::GTEBaseENV15),
|
||||
"Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => Ok(EmbeddingModel::GTEBaseENV15Q),
|
||||
"Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => {
|
||||
Ok(EmbeddingModel::GTEBaseENV15Q)
|
||||
}
|
||||
"Alibaba-NLP/gte-large-en-v1.5" | "gte-large-en-v1.5" => Ok(EmbeddingModel::GTELargeENV15),
|
||||
"Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => Ok(EmbeddingModel::GTELargeENV15Q),
|
||||
|
||||
"Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => {
|
||||
Ok(EmbeddingModel::GTELargeENV15Q)
|
||||
}
|
||||
|
||||
// CLIP model
|
||||
"Qdrant/clip-ViT-B-32-text" | "clip-vit-b-32" => Ok(EmbeddingModel::ClipVitB32),
|
||||
|
||||
|
||||
// Jina model
|
||||
"jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode),
|
||||
|
||||
"jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => {
|
||||
Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode)
|
||||
}
|
||||
|
||||
_ => Err(format!("Unsupported embedding model: {}", model_name)),
|
||||
}
|
||||
}
|
||||
@@ -95,7 +134,9 @@ fn get_model_dimensions(model: &EmbeddingModel) -> usize {
|
||||
EmbeddingModel::BGESmallENV15 | EmbeddingModel::BGESmallENV15Q => 384,
|
||||
EmbeddingModel::BGESmallZHV15 => 512,
|
||||
EmbeddingModel::BGELargeZHV15 => 1024,
|
||||
EmbeddingModel::NomicEmbedTextV1 | EmbeddingModel::NomicEmbedTextV15 | EmbeddingModel::NomicEmbedTextV15Q => 768,
|
||||
EmbeddingModel::NomicEmbedTextV1
|
||||
| EmbeddingModel::NomicEmbedTextV15
|
||||
| EmbeddingModel::NomicEmbedTextV15Q => 768,
|
||||
EmbeddingModel::ParaphraseMLMiniLML12V2 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => 384,
|
||||
EmbeddingModel::ParaphraseMLMpnetBaseV2 => 768,
|
||||
EmbeddingModel::ModernBertEmbedLarge => 1024,
|
||||
@@ -114,37 +155,41 @@ fn get_model_dimensions(model: &EmbeddingModel) -> usize {
|
||||
fn get_or_create_model(embedding_model: EmbeddingModel) -> Result<Arc<TextEmbedding>, String> {
|
||||
// First try to get from cache (read lock)
|
||||
{
|
||||
let cache = MODEL_CACHE.read().map_err(|e| format!("Failed to acquire read lock: {}", e))?;
|
||||
let cache = MODEL_CACHE
|
||||
.read()
|
||||
.map_err(|e| format!("Failed to acquire read lock: {}", e))?;
|
||||
if let Some(model) = cache.get(&embedding_model) {
|
||||
tracing::debug!("Using cached model: {:?}", embedding_model);
|
||||
return Ok(Arc::clone(model));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Model not in cache, create it (write lock)
|
||||
let mut cache = MODEL_CACHE.write().map_err(|e| format!("Failed to acquire write lock: {}", e))?;
|
||||
|
||||
let mut cache = MODEL_CACHE
|
||||
.write()
|
||||
.map_err(|e| format!("Failed to acquire write lock: {}", e))?;
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if let Some(model) = cache.get(&embedding_model) {
|
||||
tracing::debug!("Using cached model (double-check): {:?}", embedding_model);
|
||||
return Ok(Arc::clone(model));
|
||||
}
|
||||
|
||||
|
||||
tracing::info!("Initializing new embedding model: {:?}", embedding_model);
|
||||
let model_start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
let model = TextEmbedding::try_new(
|
||||
InitOptions::new(embedding_model.clone()).with_show_download_progress(true),
|
||||
)
|
||||
.map_err(|e| format!("Failed to initialize model {:?}: {}", embedding_model, e))?;
|
||||
|
||||
|
||||
let model_init_time = model_start_time.elapsed();
|
||||
tracing::info!(
|
||||
"Embedding model {:?} initialized in {:.2?}",
|
||||
embedding_model,
|
||||
model_init_time
|
||||
);
|
||||
|
||||
|
||||
let model_arc = Arc::new(model);
|
||||
cache.insert(embedding_model.clone(), Arc::clone(&model_arc));
|
||||
Ok(model_arc)
|
||||
@@ -158,7 +203,7 @@ pub async fn embeddings_create(
|
||||
|
||||
// Phase 1: Parse and get the embedding model
|
||||
let model_start_time = std::time::Instant::now();
|
||||
|
||||
|
||||
let embedding_model = match parse_embedding_model(&payload.model) {
|
||||
Ok(model) => model,
|
||||
Err(e) => {
|
||||
@@ -166,15 +211,18 @@ pub async fn embeddings_create(
|
||||
return Err((StatusCode::BAD_REQUEST, format!("Invalid model: {}", e)));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
let model = match get_or_create_model(embedding_model.clone()) {
|
||||
Ok(model) => model,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to get/create model: {}", e);
|
||||
return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("Model initialization failed: {}", e)));
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Model initialization failed: {}", e),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
let model_access_time = model_start_time.elapsed();
|
||||
tracing::debug!(
|
||||
"Model access/creation completed in {:.2?}",
|
||||
@@ -205,12 +253,13 @@ pub async fn embeddings_create(
|
||||
// Phase 3: Generate embeddings
|
||||
let embedding_start_time = std::time::Instant::now();
|
||||
|
||||
let embeddings = model
|
||||
.embed(texts_from_embedding_input, None)
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to generate embeddings: {}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, format!("Embedding generation failed: {}", e))
|
||||
})?;
|
||||
let embeddings = model.embed(texts_from_embedding_input, None).map_err(|e| {
|
||||
tracing::error!("Failed to generate embeddings: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Embedding generation failed: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
let embedding_generation_time = embedding_start_time.elapsed();
|
||||
tracing::info!(
|
||||
@@ -287,7 +336,7 @@ pub async fn embeddings_create(
|
||||
// Use the actual model dimensions instead of hardcoded 768
|
||||
let actual_dimensions = padded_embedding.len();
|
||||
let expected_dimensions = get_model_dimensions(&embedding_model);
|
||||
|
||||
|
||||
if actual_dimensions != expected_dimensions {
|
||||
tracing::warn!(
|
||||
"Model {:?} produced {} dimensions but expected {}",
|
||||
@@ -455,7 +504,8 @@ pub async fn models_list() -> ResponseJson<ModelsResponse> {
|
||||
id: "nomic-ai/nomic-embed-text-v1.5-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "nomic-ai".to_string(),
|
||||
description: "Quantized v1.5 release of the 8192 context length english model".to_string(),
|
||||
description: "Quantized v1.5 release of the 8192 context length english model"
|
||||
.to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
@@ -476,7 +526,8 @@ pub async fn models_list() -> ResponseJson<ModelsResponse> {
|
||||
id: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "sentence-transformers".to_string(),
|
||||
description: "Sentence-transformers model for tasks like clustering or semantic search".to_string(),
|
||||
description: "Sentence-transformers model for tasks like clustering or semantic search"
|
||||
.to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
|
@@ -18,12 +18,10 @@ async fn embeddings_create(
|
||||
) -> Result<ResponseJson<serde_json::Value>, axum::response::Response> {
|
||||
match embeddings_engine::embeddings_create(Json(payload)).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err((status_code, message)) => {
|
||||
Err(axum::response::Response::builder()
|
||||
.status(status_code)
|
||||
.body(axum::body::Body::from(message))
|
||||
.unwrap())
|
||||
}
|
||||
Err((status_code, message)) => Err(axum::response::Response::builder()
|
||||
.status(status_code)
|
||||
.body(axum::body::Body::from(message))
|
||||
.unwrap()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -42,7 +42,11 @@ pub struct ModelMeta {
|
||||
}
|
||||
|
||||
const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta {
|
||||
ModelMeta { id, family, instruct }
|
||||
ModelMeta {
|
||||
id,
|
||||
family,
|
||||
instruct,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
|
@@ -42,13 +42,13 @@ pub struct AppState {
|
||||
pub llama_config: Option<LlamaInferenceConfig>,
|
||||
}
|
||||
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
// Configure a default model to prevent 503 errors from the chat-ui
|
||||
// This can be overridden by environment variables if needed
|
||||
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
|
||||
|
||||
let default_model_id =
|
||||
std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
|
||||
|
||||
let gemma_config = GemmaInferenceConfig {
|
||||
model: None,
|
||||
..Default::default()
|
||||
@@ -94,9 +94,6 @@ fn model_id_to_which(model_id: &str) -> Option<Which> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
fn normalize_model_id(model_id: &str) -> String {
|
||||
model_id.to_lowercase().replace("_", "-")
|
||||
}
|
||||
@@ -157,7 +154,7 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
// Use the model specified in the request
|
||||
let model_id = request.model.clone();
|
||||
let which_model = model_id_to_which(&model_id);
|
||||
|
||||
|
||||
// Validate that the requested model is supported
|
||||
let which_model = match which_model {
|
||||
Some(model) => model,
|
||||
@@ -204,19 +201,21 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Llama model", model_id) }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
let mut config = LlamaInferenceConfig::new(llama_model);
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
run_llama_inference(config).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
})),
|
||||
)
|
||||
})?
|
||||
} else {
|
||||
// Create Gemma configuration dynamically
|
||||
let gemma_model = match which_model {
|
||||
@@ -241,23 +240,25 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
let mut config = GemmaInferenceConfig {
|
||||
model: Some(gemma_model),
|
||||
..Default::default()
|
||||
};
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_gemma_api(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
run_gemma_api(config).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
})),
|
||||
)
|
||||
})?
|
||||
};
|
||||
|
||||
// Collect all tokens from the stream
|
||||
@@ -320,7 +321,7 @@ async fn handle_streaming_request(
|
||||
// Use the model specified in the request
|
||||
let model_id = request.model.clone();
|
||||
let which_model = model_id_to_which(&model_id);
|
||||
|
||||
|
||||
// Validate that the requested model is supported
|
||||
let which_model = match which_model {
|
||||
Some(model) => model,
|
||||
@@ -397,7 +398,7 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Llama model", model_id) }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
@@ -439,11 +440,11 @@ async fn handle_streaming_request(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
|
||||
}))
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
let mut config = GemmaInferenceConfig {
|
||||
model: Some(gemma_model),
|
||||
..Default::default()
|
||||
@@ -605,59 +606,66 @@ pub async fn list_models() -> Json<ModelListResponse> {
|
||||
Which::Llama32_3BInstruct,
|
||||
];
|
||||
|
||||
let mut models: Vec<Model> = which_variants
|
||||
.into_iter()
|
||||
.map(|which| {
|
||||
let meta = which.meta();
|
||||
let model_id = match which {
|
||||
Which::Base2B => "gemma-2b",
|
||||
Which::Base7B => "gemma-7b",
|
||||
Which::Instruct2B => "gemma-2b-it",
|
||||
Which::Instruct7B => "gemma-7b-it",
|
||||
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
|
||||
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
|
||||
Which::CodeBase2B => "codegemma-2b",
|
||||
Which::CodeBase7B => "codegemma-7b",
|
||||
Which::CodeInstruct2B => "codegemma-2b-it",
|
||||
Which::CodeInstruct7B => "codegemma-7b-it",
|
||||
Which::BaseV2_2B => "gemma-2-2b",
|
||||
Which::InstructV2_2B => "gemma-2-2b-it",
|
||||
Which::BaseV2_9B => "gemma-2-9b",
|
||||
Which::InstructV2_9B => "gemma-2-9b-it",
|
||||
Which::BaseV3_1B => "gemma-3-1b",
|
||||
Which::InstructV3_1B => "gemma-3-1b-it",
|
||||
Which::Llama32_1B => "llama-3.2-1b",
|
||||
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
|
||||
Which::Llama32_3B => "llama-3.2-3b",
|
||||
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
|
||||
};
|
||||
|
||||
let owned_by = if meta.id.starts_with("google/") {
|
||||
"google"
|
||||
} else if meta.id.starts_with("meta-llama/") {
|
||||
"meta"
|
||||
} else {
|
||||
"unknown"
|
||||
};
|
||||
|
||||
let mut models: Vec<Model> = which_variants.into_iter().map(|which| {
|
||||
let meta = which.meta();
|
||||
let model_id = match which {
|
||||
Which::Base2B => "gemma-2b",
|
||||
Which::Base7B => "gemma-7b",
|
||||
Which::Instruct2B => "gemma-2b-it",
|
||||
Which::Instruct7B => "gemma-7b-it",
|
||||
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
|
||||
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
|
||||
Which::CodeBase2B => "codegemma-2b",
|
||||
Which::CodeBase7B => "codegemma-7b",
|
||||
Which::CodeInstruct2B => "codegemma-2b-it",
|
||||
Which::CodeInstruct7B => "codegemma-7b-it",
|
||||
Which::BaseV2_2B => "gemma-2-2b",
|
||||
Which::InstructV2_2B => "gemma-2-2b-it",
|
||||
Which::BaseV2_9B => "gemma-2-9b",
|
||||
Which::InstructV2_9B => "gemma-2-9b-it",
|
||||
Which::BaseV3_1B => "gemma-3-1b",
|
||||
Which::InstructV3_1B => "gemma-3-1b-it",
|
||||
Which::Llama32_1B => "llama-3.2-1b",
|
||||
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
|
||||
Which::Llama32_3B => "llama-3.2-3b",
|
||||
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
|
||||
};
|
||||
|
||||
let owned_by = if meta.id.starts_with("google/") {
|
||||
"google"
|
||||
} else if meta.id.starts_with("meta-llama/") {
|
||||
"meta"
|
||||
} else {
|
||||
"unknown"
|
||||
};
|
||||
|
||||
Model {
|
||||
id: model_id.to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: owned_by.to_string(),
|
||||
}
|
||||
}).collect();
|
||||
Model {
|
||||
id: model_id.to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: owned_by.to_string(),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Get embeddings models and convert them to inference Model format
|
||||
let embeddings_response = models_list().await;
|
||||
let embeddings_models: Vec<Model> = embeddings_response.0.data.into_iter().map(|embedding_model| {
|
||||
Model {
|
||||
let embeddings_models: Vec<Model> = embeddings_response
|
||||
.0
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|embedding_model| Model {
|
||||
id: embedding_model.id,
|
||||
object: embedding_model.object,
|
||||
created: 1686935002,
|
||||
owned_by: format!("{} - {}", embedding_model.owned_by, embedding_model.description),
|
||||
}
|
||||
}).collect();
|
||||
owned_by: format!(
|
||||
"{} - {}",
|
||||
embedding_model.owned_by, embedding_model.description
|
||||
),
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Add embeddings models to the main models list
|
||||
models.extend(embeddings_models);
|
||||
|
Reference in New Issue
Block a user