fixes issue with model selection

This commit is contained in:
geoffsee
2025-09-04 13:42:30 -04:00
parent ff55d882c7
commit 1e02b12cda
6 changed files with 209 additions and 63 deletions

View File

@@ -12,7 +12,7 @@ AI inference Server with OpenAI-compatible API (Limited Features)
> This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks. > This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks.
> By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems. > By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems.
Stability is currently best effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name. Stability is currently best-effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces. A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces.

View File

@@ -4,7 +4,7 @@
"": { "": {
"name": "predict-otron-9000", "name": "predict-otron-9000",
}, },
"crates/cli/package": { "integration/cli/package": {
"name": "cli", "name": "cli",
"dependencies": { "dependencies": {
"install": "^0.13.0", "install": "^0.13.0",
@@ -13,7 +13,7 @@
}, },
}, },
"packages": { "packages": {
"cli": ["cli@workspace:crates/cli/package"], "cli": ["cli@workspace:integration/cli/package"],
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="], "install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],

View File

@@ -365,7 +365,7 @@ fn ChatPage() -> impl IntoView {
// State for available models and selected model // State for available models and selected model
let available_models = RwSignal::new(Vec::<ModelInfo>::new()); let available_models = RwSignal::new(Vec::<ModelInfo>::new());
let selected_model = RwSignal::new(String::from("gemma-3-1b-it")); // Default model let selected_model = RwSignal::new(String::from("")); // Default model
// State for streaming response // State for streaming response
let streaming_content = RwSignal::new(String::new()); let streaming_content = RwSignal::new(String::new());
@@ -382,6 +382,7 @@ fn ChatPage() -> impl IntoView {
match fetch_models().await { match fetch_models().await {
Ok(models) => { Ok(models) => {
available_models.set(models); available_models.set(models);
selected_model.set(String::from("gemma-3-1b-it"));
} }
Err(error) => { Err(error) => {
console::log_1(&format!("Failed to fetch models: {}", error).into()); console::log_1(&format!("Failed to fetch models: {}", error).into());

View File

@@ -7,6 +7,7 @@ use axum::{
}; };
use futures_util::stream::{self, Stream}; use futures_util::stream::{self, Stream};
use std::convert::Infallible; use std::convert::Infallible;
use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{mpsc, Mutex};
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
@@ -20,7 +21,7 @@ use crate::openai_types::{
use crate::Which; use crate::Which;
use either::Either; use either::Either;
use embeddings_engine::models_list; use embeddings_engine::models_list;
use gemma_runner::{run_gemma_api, GemmaInferenceConfig}; use gemma_runner::{run_gemma_api, GemmaInferenceConfig, WhichModel};
use llama_runner::{run_llama_inference, LlamaInferenceConfig}; use llama_runner::{run_llama_inference, LlamaInferenceConfig};
use serde_json::Value; use serde_json::Value;
// ------------------------- // -------------------------
@@ -35,12 +36,13 @@ pub enum ModelType {
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState {
pub model_type: ModelType, pub model_type: Option<ModelType>,
pub model_id: String, pub model_id: String,
pub gemma_config: Option<GemmaInferenceConfig>, pub gemma_config: Option<GemmaInferenceConfig>,
pub llama_config: Option<LlamaInferenceConfig>, pub llama_config: Option<LlamaInferenceConfig>,
} }
impl Default for AppState { impl Default for AppState {
fn default() -> Self { fn default() -> Self {
// Configure a default model to prevent 503 errors from the chat-ui // Configure a default model to prevent 503 errors from the chat-ui
@@ -48,12 +50,12 @@ impl Default for AppState {
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string()); let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
let gemma_config = GemmaInferenceConfig { let gemma_config = GemmaInferenceConfig {
model: gemma_runner::WhichModel::InstructV3_1B, model: None,
..Default::default() ..Default::default()
}; };
Self { Self {
model_type: ModelType::Gemma, model_type: None,
model_id: default_model_id, model_id: default_model_id,
gemma_config: Some(gemma_config), gemma_config: Some(gemma_config),
llama_config: None, llama_config: None,
@@ -84,7 +86,9 @@ fn model_id_to_which(model_id: &str) -> Option<Which> {
"gemma-2-9b-it" => Some(Which::InstructV2_9B), "gemma-2-9b-it" => Some(Which::InstructV2_9B),
"gemma-3-1b" => Some(Which::BaseV3_1B), "gemma-3-1b" => Some(Which::BaseV3_1B),
"gemma-3-1b-it" => Some(Which::InstructV3_1B), "gemma-3-1b-it" => Some(Which::InstructV3_1B),
"llama-3.2-1b" => Some(Which::Llama32_1B),
"llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct), "llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct),
"llama-3.2-3b" => Some(Which::Llama32_3B),
"llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct), "llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct),
_ => None, _ => None,
} }
@@ -190,7 +194,21 @@ pub async fn chat_completions_non_streaming_proxy(
// Get streaming receiver based on model type // Get streaming receiver based on model type
let rx = if which_model.is_llama_model() { let rx = if which_model.is_llama_model() {
// Create Llama configuration dynamically // Create Llama configuration dynamically
let mut config = LlamaInferenceConfig::default(); let llama_model = match which_model {
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
_ => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Llama model", model_id) }
}))
));
}
};
let mut config = LlamaInferenceConfig::new(llama_model);
config.prompt = prompt.clone(); config.prompt = prompt.clone();
config.max_tokens = max_tokens; config.max_tokens = max_tokens;
run_llama_inference(config).map_err(|e| ( run_llama_inference(config).map_err(|e| (
@@ -201,14 +219,35 @@ pub async fn chat_completions_non_streaming_proxy(
))? ))?
} else { } else {
// Create Gemma configuration dynamically // Create Gemma configuration dynamically
let gemma_model = if which_model.is_v3_model() { let gemma_model = match which_model {
gemma_runner::WhichModel::InstructV3_1B Which::Base2B => gemma_runner::WhichModel::Base2B,
} else { Which::Base7B => gemma_runner::WhichModel::Base7B,
gemma_runner::WhichModel::InstructV3_1B // Default fallback Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
_ => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
}))
));
}
}; };
let mut config = GemmaInferenceConfig { let mut config = GemmaInferenceConfig {
model: gemma_model, model: Some(gemma_model),
..Default::default() ..Default::default()
}; };
config.prompt = prompt.clone(); config.prompt = prompt.clone();
@@ -348,7 +387,21 @@ async fn handle_streaming_request(
// Get streaming receiver based on model type // Get streaming receiver based on model type
let model_rx = if which_model.is_llama_model() { let model_rx = if which_model.is_llama_model() {
// Create Llama configuration dynamically // Create Llama configuration dynamically
let mut config = LlamaInferenceConfig::default(); let llama_model = match which_model {
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
_ => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Llama model", model_id) }
}))
));
}
};
let mut config = LlamaInferenceConfig::new(llama_model);
config.prompt = prompt.clone(); config.prompt = prompt.clone();
config.max_tokens = max_tokens; config.max_tokens = max_tokens;
match run_llama_inference(config) { match run_llama_inference(config) {
@@ -364,14 +417,35 @@ async fn handle_streaming_request(
} }
} else { } else {
// Create Gemma configuration dynamically // Create Gemma configuration dynamically
let gemma_model = if which_model.is_v3_model() { let gemma_model = match which_model {
gemma_runner::WhichModel::InstructV3_1B Which::Base2B => gemma_runner::WhichModel::Base2B,
} else { Which::Base7B => gemma_runner::WhichModel::Base7B,
gemma_runner::WhichModel::InstructV3_1B // Default fallback Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
_ => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
}))
));
}
}; };
let mut config = GemmaInferenceConfig { let mut config = GemmaInferenceConfig {
model: gemma_model, model: Some(gemma_model),
..Default::default() ..Default::default()
}; };
config.prompt = prompt.clone(); config.prompt = prompt.clone();

View File

@@ -1,13 +1,8 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1}; use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2}; use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3}; use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use clap::ValueEnum;
// Removed gemma_cli import as it's not needed for the API // Removed gemma_cli import as it's not needed for the API
use candle_core::{DType, Device, Tensor}; use candle_core::{DType, Device, Tensor};
@@ -21,8 +16,10 @@ use std::thread;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use utils::hub_load_safetensors; use utils::hub_load_safetensors;
use utils::token_output_stream::TokenOutputStream; use utils::token_output_stream::TokenOutputStream;
use std::str::FromStr;
use std::fmt;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] #[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum WhichModel { pub enum WhichModel {
#[value(name = "gemma-2b")] #[value(name = "gemma-2b")]
Base2B, Base2B,
@@ -58,6 +55,56 @@ pub enum WhichModel {
InstructV3_1B, InstructV3_1B,
} }
impl FromStr for WhichModel {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gemma-2b" => Ok(Self::Base2B),
"gemma-7b" => Ok(Self::Base7B),
"gemma-2b-it" => Ok(Self::Instruct2B),
"gemma-7b-it" => Ok(Self::Instruct7B),
"gemma-1.1-2b-it" => Ok(Self::InstructV1_1_2B),
"gemma-1.1-7b-it" => Ok(Self::InstructV1_1_7B),
"codegemma-2b" => Ok(Self::CodeBase2B),
"codegemma-7b" => Ok(Self::CodeBase7B),
"codegemma-2b-it" => Ok(Self::CodeInstruct2B),
"codegemma-7b-it" => Ok(Self::CodeInstruct7B),
"gemma-2-2b" => Ok(Self::BaseV2_2B),
"gemma-2-2b-it" => Ok(Self::InstructV2_2B),
"gemma-2-9b" => Ok(Self::BaseV2_9B),
"gemma-2-9b-it" => Ok(Self::InstructV2_9B),
"gemma-3-1b" => Ok(Self::BaseV3_1B),
"gemma-3-1b-it" => Ok(Self::InstructV3_1B),
_ => Err(format!("Unknown model: {}", s)),
}
}
}
impl fmt::Display for WhichModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
Self::Base2B => "gemma-2b",
Self::Base7B => "gemma-7b",
Self::Instruct2B => "gemma-2b-it",
Self::Instruct7B => "gemma-7b-it",
Self::InstructV1_1_2B => "gemma-1.1-2b-it",
Self::InstructV1_1_7B => "gemma-1.1-7b-it",
Self::CodeBase2B => "codegemma-2b",
Self::CodeBase7B => "codegemma-7b",
Self::CodeInstruct2B => "codegemma-2b-it",
Self::CodeInstruct7B => "codegemma-7b-it",
Self::BaseV2_2B => "gemma-2-2b",
Self::InstructV2_2B => "gemma-2-2b-it",
Self::BaseV2_9B => "gemma-2-9b",
Self::InstructV2_9B => "gemma-2-9b-it",
Self::BaseV3_1B => "gemma-3-1b",
Self::InstructV3_1B => "gemma-3-1b-it",
};
write!(f, "{}", name)
}
}
enum Model { enum Model {
V1(Model1), V1(Model1),
V2(Model2), V2(Model2),
@@ -145,7 +192,7 @@ impl TextGeneration {
// Make sure stdout isn't holding anything (if caller also prints). // Make sure stdout isn't holding anything (if caller also prints).
std::io::stdout().flush()?; std::io::stdout().flush()?;
let mut generated_tokens = 0usize; let mut _generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") { let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token, Some(token) => token,
@@ -183,7 +230,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?; let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token); tokens.push(next_token);
generated_tokens += 1; _generated_tokens += 1;
if next_token == eos_token || next_token == eot_token { if next_token == eos_token || next_token == eot_token {
break; break;
@@ -210,7 +257,7 @@ impl TextGeneration {
pub struct GemmaInferenceConfig { pub struct GemmaInferenceConfig {
pub tracing: bool, pub tracing: bool,
pub prompt: String, pub prompt: String,
pub model: WhichModel, pub model: Option<WhichModel>,
pub cpu: bool, pub cpu: bool,
pub dtype: Option<String>, pub dtype: Option<String>,
pub model_id: Option<String>, pub model_id: Option<String>,
@@ -229,7 +276,7 @@ impl Default for GemmaInferenceConfig {
Self { Self {
tracing: false, tracing: false,
prompt: "Hello".to_string(), prompt: "Hello".to_string(),
model: WhichModel::InstructV2_2B, model: Some(WhichModel::InstructV2_2B),
cpu: false, cpu: false,
dtype: None, dtype: None,
model_id: None, model_id: None,
@@ -286,28 +333,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
} }
}; };
println!("Using dtype: {:?}", dtype); println!("Using dtype: {:?}", dtype);
println!("Raw model string: {:?}", cfg.model_id);
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let api = Api::new()?; let api = Api::new()?;
let model_id = cfg.model_id.unwrap_or_else(|| { let model_id = cfg.model_id.unwrap_or_else(|| {
match cfg.model { match cfg.model {
WhichModel::Base2B => "google/gemma-2b", Some(WhichModel::Base2B) => "google/gemma-2b",
WhichModel::Base7B => "google/gemma-7b", Some(WhichModel::Base7B) => "google/gemma-7b",
WhichModel::Instruct2B => "google/gemma-2b-it", Some(WhichModel::Instruct2B) => "google/gemma-2b-it",
WhichModel::Instruct7B => "google/gemma-7b-it", Some(WhichModel::Instruct7B) => "google/gemma-7b-it",
WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it", Some(WhichModel::InstructV1_1_2B) => "google/gemma-1.1-2b-it",
WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it", Some(WhichModel::InstructV1_1_7B) => "google/gemma-1.1-7b-it",
WhichModel::CodeBase2B => "google/codegemma-2b", Some(WhichModel::CodeBase2B) => "google/codegemma-2b",
WhichModel::CodeBase7B => "google/codegemma-7b", Some(WhichModel::CodeBase7B) => "google/codegemma-7b",
WhichModel::CodeInstruct2B => "google/codegemma-2b-it", Some(WhichModel::CodeInstruct2B) => "google/codegemma-2b-it",
WhichModel::CodeInstruct7B => "google/codegemma-7b-it", Some(WhichModel::CodeInstruct7B) => "google/codegemma-7b-it",
WhichModel::BaseV2_2B => "google/gemma-2-2b", Some(WhichModel::BaseV2_2B) => "google/gemma-2-2b",
WhichModel::InstructV2_2B => "google/gemma-2-2b-it", Some(WhichModel::InstructV2_2B) => "google/gemma-2-2b-it",
WhichModel::BaseV2_9B => "google/gemma-2-9b", Some(WhichModel::BaseV2_9B) => "google/gemma-2-9b",
WhichModel::InstructV2_9B => "google/gemma-2-9b-it", Some(WhichModel::InstructV2_9B) => "google/gemma-2-9b-it",
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt", Some(WhichModel::BaseV3_1B) => "google/gemma-3-1b-pt",
WhichModel::InstructV3_1B => "google/gemma-3-1b-it", Some(WhichModel::InstructV3_1B) => "google/gemma-3-1b-it",
None => "google/gemma-2-2b-it", // default fallback
} }
.to_string() .to_string()
}); });
@@ -318,7 +367,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let tokenizer_filename = repo.get("tokenizer.json")?; let tokenizer_filename = repo.get("tokenizer.json")?;
let config_filename = repo.get("config.json")?; let config_filename = repo.get("config.json")?;
let filenames = match cfg.model { let filenames = match cfg.model {
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?], Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => vec![repo.get("model.safetensors")?],
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?, _ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
}; };
println!("Retrieved files in {:?}", start.elapsed()); println!("Retrieved files in {:?}", start.elapsed());
@@ -329,29 +378,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model: Model = match cfg.model { let model: Model = match cfg.model {
WhichModel::Base2B Some(WhichModel::Base2B)
| WhichModel::Base7B | Some(WhichModel::Base7B)
| WhichModel::Instruct2B | Some(WhichModel::Instruct2B)
| WhichModel::Instruct7B | Some(WhichModel::Instruct7B)
| WhichModel::InstructV1_1_2B | Some(WhichModel::InstructV1_1_2B)
| WhichModel::InstructV1_1_7B | Some(WhichModel::InstructV1_1_7B)
| WhichModel::CodeBase2B | Some(WhichModel::CodeBase2B)
| WhichModel::CodeBase7B | Some(WhichModel::CodeBase7B)
| WhichModel::CodeInstruct2B | Some(WhichModel::CodeInstruct2B)
| WhichModel::CodeInstruct7B => { | Some(WhichModel::CodeInstruct7B) => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(cfg.use_flash_attn, &config, vb)?; let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
Model::V1(model) Model::V1(model)
} }
WhichModel::BaseV2_2B Some(WhichModel::BaseV2_2B)
| WhichModel::InstructV2_2B | Some(WhichModel::InstructV2_2B)
| WhichModel::BaseV2_9B | Some(WhichModel::BaseV2_9B)
| WhichModel::InstructV2_9B => { | Some(WhichModel::InstructV2_9B)
| None => { // default to V2 model
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(cfg.use_flash_attn, &config, vb)?; let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
Model::V2(model) Model::V2(model)
} }
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => { Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model3::new(cfg.use_flash_attn, &config, vb)?; let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
Model::V3(model) Model::V3(model)
@@ -371,7 +421,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
); );
let prompt = match cfg.model { let prompt = match cfg.model {
WhichModel::InstructV3_1B => { Some(WhichModel::InstructV3_1B) => {
format!( format!(
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n", "<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
cfg.prompt cfg.prompt

View File

@@ -57,6 +57,27 @@ pub struct LlamaInferenceConfig {
pub repeat_last_n: usize, pub repeat_last_n: usize,
} }
impl LlamaInferenceConfig {
pub fn new(model: WhichModel) -> Self {
Self {
prompt: String::new(),
model,
cpu: false,
temperature: 1.0,
top_p: None,
top_k: None,
seed: 42,
max_tokens: 512,
no_kv_cache: false,
dtype: None,
model_id: None,
revision: None,
use_flash_attn: true,
repeat_penalty: 1.1,
repeat_last_n: 64,
}
}
}
impl Default for LlamaInferenceConfig { impl Default for LlamaInferenceConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {