fixes issue with model selection

This commit is contained in:
geoffsee
2025-09-04 13:42:30 -04:00
parent ff55d882c7
commit 1e02b12cda
6 changed files with 209 additions and 63 deletions

View File

@@ -1,13 +1,8 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use clap::ValueEnum;
// Removed gemma_cli import as it's not needed for the API
use candle_core::{DType, Device, Tensor};
@@ -21,8 +16,10 @@ use std::thread;
use tokenizers::Tokenizer;
use utils::hub_load_safetensors;
use utils::token_output_stream::TokenOutputStream;
use std::str::FromStr;
use std::fmt;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum WhichModel {
#[value(name = "gemma-2b")]
Base2B,
@@ -58,6 +55,56 @@ pub enum WhichModel {
InstructV3_1B,
}
impl FromStr for WhichModel {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gemma-2b" => Ok(Self::Base2B),
"gemma-7b" => Ok(Self::Base7B),
"gemma-2b-it" => Ok(Self::Instruct2B),
"gemma-7b-it" => Ok(Self::Instruct7B),
"gemma-1.1-2b-it" => Ok(Self::InstructV1_1_2B),
"gemma-1.1-7b-it" => Ok(Self::InstructV1_1_7B),
"codegemma-2b" => Ok(Self::CodeBase2B),
"codegemma-7b" => Ok(Self::CodeBase7B),
"codegemma-2b-it" => Ok(Self::CodeInstruct2B),
"codegemma-7b-it" => Ok(Self::CodeInstruct7B),
"gemma-2-2b" => Ok(Self::BaseV2_2B),
"gemma-2-2b-it" => Ok(Self::InstructV2_2B),
"gemma-2-9b" => Ok(Self::BaseV2_9B),
"gemma-2-9b-it" => Ok(Self::InstructV2_9B),
"gemma-3-1b" => Ok(Self::BaseV3_1B),
"gemma-3-1b-it" => Ok(Self::InstructV3_1B),
_ => Err(format!("Unknown model: {}", s)),
}
}
}
impl fmt::Display for WhichModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
Self::Base2B => "gemma-2b",
Self::Base7B => "gemma-7b",
Self::Instruct2B => "gemma-2b-it",
Self::Instruct7B => "gemma-7b-it",
Self::InstructV1_1_2B => "gemma-1.1-2b-it",
Self::InstructV1_1_7B => "gemma-1.1-7b-it",
Self::CodeBase2B => "codegemma-2b",
Self::CodeBase7B => "codegemma-7b",
Self::CodeInstruct2B => "codegemma-2b-it",
Self::CodeInstruct7B => "codegemma-7b-it",
Self::BaseV2_2B => "gemma-2-2b",
Self::InstructV2_2B => "gemma-2-2b-it",
Self::BaseV2_9B => "gemma-2-9b",
Self::InstructV2_9B => "gemma-2-9b-it",
Self::BaseV3_1B => "gemma-3-1b",
Self::InstructV3_1B => "gemma-3-1b-it",
};
write!(f, "{}", name)
}
}
enum Model {
V1(Model1),
V2(Model2),
@@ -145,7 +192,7 @@ impl TextGeneration {
// Make sure stdout isn't holding anything (if caller also prints).
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let mut _generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
@@ -183,7 +230,7 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
_generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
@@ -210,7 +257,7 @@ impl TextGeneration {
pub struct GemmaInferenceConfig {
pub tracing: bool,
pub prompt: String,
pub model: WhichModel,
pub model: Option<WhichModel>,
pub cpu: bool,
pub dtype: Option<String>,
pub model_id: Option<String>,
@@ -229,7 +276,7 @@ impl Default for GemmaInferenceConfig {
Self {
tracing: false,
prompt: "Hello".to_string(),
model: WhichModel::InstructV2_2B,
model: Some(WhichModel::InstructV2_2B),
cpu: false,
dtype: None,
model_id: None,
@@ -286,28 +333,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
}
};
println!("Using dtype: {:?}", dtype);
println!("Raw model string: {:?}", cfg.model_id);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = cfg.model_id.unwrap_or_else(|| {
match cfg.model {
WhichModel::Base2B => "google/gemma-2b",
WhichModel::Base7B => "google/gemma-7b",
WhichModel::Instruct2B => "google/gemma-2b-it",
WhichModel::Instruct7B => "google/gemma-7b-it",
WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it",
WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it",
WhichModel::CodeBase2B => "google/codegemma-2b",
WhichModel::CodeBase7B => "google/codegemma-7b",
WhichModel::CodeInstruct2B => "google/codegemma-2b-it",
WhichModel::CodeInstruct7B => "google/codegemma-7b-it",
WhichModel::BaseV2_2B => "google/gemma-2-2b",
WhichModel::InstructV2_2B => "google/gemma-2-2b-it",
WhichModel::BaseV2_9B => "google/gemma-2-9b",
WhichModel::InstructV2_9B => "google/gemma-2-9b-it",
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
Some(WhichModel::Base2B) => "google/gemma-2b",
Some(WhichModel::Base7B) => "google/gemma-7b",
Some(WhichModel::Instruct2B) => "google/gemma-2b-it",
Some(WhichModel::Instruct7B) => "google/gemma-7b-it",
Some(WhichModel::InstructV1_1_2B) => "google/gemma-1.1-2b-it",
Some(WhichModel::InstructV1_1_7B) => "google/gemma-1.1-7b-it",
Some(WhichModel::CodeBase2B) => "google/codegemma-2b",
Some(WhichModel::CodeBase7B) => "google/codegemma-7b",
Some(WhichModel::CodeInstruct2B) => "google/codegemma-2b-it",
Some(WhichModel::CodeInstruct7B) => "google/codegemma-7b-it",
Some(WhichModel::BaseV2_2B) => "google/gemma-2-2b",
Some(WhichModel::InstructV2_2B) => "google/gemma-2-2b-it",
Some(WhichModel::BaseV2_9B) => "google/gemma-2-9b",
Some(WhichModel::InstructV2_9B) => "google/gemma-2-9b-it",
Some(WhichModel::BaseV3_1B) => "google/gemma-3-1b-pt",
Some(WhichModel::InstructV3_1B) => "google/gemma-3-1b-it",
None => "google/gemma-2-2b-it", // default fallback
}
.to_string()
});
@@ -318,7 +367,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let tokenizer_filename = repo.get("tokenizer.json")?;
let config_filename = repo.get("config.json")?;
let filenames = match cfg.model {
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => vec![repo.get("model.safetensors")?],
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("Retrieved files in {:?}", start.elapsed());
@@ -329,29 +378,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model: Model = match cfg.model {
WhichModel::Base2B
| WhichModel::Base7B
| WhichModel::Instruct2B
| WhichModel::Instruct7B
| WhichModel::InstructV1_1_2B
| WhichModel::InstructV1_1_7B
| WhichModel::CodeBase2B
| WhichModel::CodeBase7B
| WhichModel::CodeInstruct2B
| WhichModel::CodeInstruct7B => {
Some(WhichModel::Base2B)
| Some(WhichModel::Base7B)
| Some(WhichModel::Instruct2B)
| Some(WhichModel::Instruct7B)
| Some(WhichModel::InstructV1_1_2B)
| Some(WhichModel::InstructV1_1_7B)
| Some(WhichModel::CodeBase2B)
| Some(WhichModel::CodeBase7B)
| Some(WhichModel::CodeInstruct2B)
| Some(WhichModel::CodeInstruct7B) => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
Model::V1(model)
}
WhichModel::BaseV2_2B
| WhichModel::InstructV2_2B
| WhichModel::BaseV2_9B
| WhichModel::InstructV2_9B => {
Some(WhichModel::BaseV2_2B)
| Some(WhichModel::InstructV2_2B)
| Some(WhichModel::BaseV2_9B)
| Some(WhichModel::InstructV2_9B)
| None => { // default to V2 model
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
Model::V2(model)
}
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => {
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
Model::V3(model)
@@ -371,7 +421,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
);
let prompt = match cfg.model {
WhichModel::InstructV3_1B => {
Some(WhichModel::InstructV3_1B) => {
format!(
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
cfg.prompt

View File

@@ -57,6 +57,27 @@ pub struct LlamaInferenceConfig {
pub repeat_last_n: usize,
}
impl LlamaInferenceConfig {
pub fn new(model: WhichModel) -> Self {
Self {
prompt: String::new(),
model,
cpu: false,
temperature: 1.0,
top_p: None,
top_k: None,
seed: 42,
max_tokens: 512,
no_kv_cache: false,
dtype: None,
model_id: None,
revision: None,
use_flash_attn: true,
repeat_penalty: 1.1,
repeat_last_n: 64,
}
}
}
impl Default for LlamaInferenceConfig {
fn default() -> Self {
Self {