mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
fixes issue with model selection
This commit is contained in:
@@ -1,13 +1,8 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
use clap::ValueEnum;
|
||||
|
||||
// Removed gemma_cli import as it's not needed for the API
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
@@ -21,8 +16,10 @@ use std::thread;
|
||||
use tokenizers::Tokenizer;
|
||||
use utils::hub_load_safetensors;
|
||||
use utils::token_output_stream::TokenOutputStream;
|
||||
use std::str::FromStr;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
pub enum WhichModel {
|
||||
#[value(name = "gemma-2b")]
|
||||
Base2B,
|
||||
@@ -58,6 +55,56 @@ pub enum WhichModel {
|
||||
InstructV3_1B,
|
||||
}
|
||||
|
||||
impl FromStr for WhichModel {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"gemma-2b" => Ok(Self::Base2B),
|
||||
"gemma-7b" => Ok(Self::Base7B),
|
||||
"gemma-2b-it" => Ok(Self::Instruct2B),
|
||||
"gemma-7b-it" => Ok(Self::Instruct7B),
|
||||
"gemma-1.1-2b-it" => Ok(Self::InstructV1_1_2B),
|
||||
"gemma-1.1-7b-it" => Ok(Self::InstructV1_1_7B),
|
||||
"codegemma-2b" => Ok(Self::CodeBase2B),
|
||||
"codegemma-7b" => Ok(Self::CodeBase7B),
|
||||
"codegemma-2b-it" => Ok(Self::CodeInstruct2B),
|
||||
"codegemma-7b-it" => Ok(Self::CodeInstruct7B),
|
||||
"gemma-2-2b" => Ok(Self::BaseV2_2B),
|
||||
"gemma-2-2b-it" => Ok(Self::InstructV2_2B),
|
||||
"gemma-2-9b" => Ok(Self::BaseV2_9B),
|
||||
"gemma-2-9b-it" => Ok(Self::InstructV2_9B),
|
||||
"gemma-3-1b" => Ok(Self::BaseV3_1B),
|
||||
"gemma-3-1b-it" => Ok(Self::InstructV3_1B),
|
||||
_ => Err(format!("Unknown model: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for WhichModel {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let name = match self {
|
||||
Self::Base2B => "gemma-2b",
|
||||
Self::Base7B => "gemma-7b",
|
||||
Self::Instruct2B => "gemma-2b-it",
|
||||
Self::Instruct7B => "gemma-7b-it",
|
||||
Self::InstructV1_1_2B => "gemma-1.1-2b-it",
|
||||
Self::InstructV1_1_7B => "gemma-1.1-7b-it",
|
||||
Self::CodeBase2B => "codegemma-2b",
|
||||
Self::CodeBase7B => "codegemma-7b",
|
||||
Self::CodeInstruct2B => "codegemma-2b-it",
|
||||
Self::CodeInstruct7B => "codegemma-7b-it",
|
||||
Self::BaseV2_2B => "gemma-2-2b",
|
||||
Self::InstructV2_2B => "gemma-2-2b-it",
|
||||
Self::BaseV2_9B => "gemma-2-9b",
|
||||
Self::InstructV2_9B => "gemma-2-9b-it",
|
||||
Self::BaseV3_1B => "gemma-3-1b",
|
||||
Self::InstructV3_1B => "gemma-3-1b-it",
|
||||
};
|
||||
write!(f, "{}", name)
|
||||
}
|
||||
}
|
||||
|
||||
enum Model {
|
||||
V1(Model1),
|
||||
V2(Model2),
|
||||
@@ -145,7 +192,7 @@ impl TextGeneration {
|
||||
// Make sure stdout isn't holding anything (if caller also prints).
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let mut _generated_tokens = 0usize;
|
||||
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
@@ -183,7 +230,7 @@ impl TextGeneration {
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
_generated_tokens += 1;
|
||||
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
@@ -210,7 +257,7 @@ impl TextGeneration {
|
||||
pub struct GemmaInferenceConfig {
|
||||
pub tracing: bool,
|
||||
pub prompt: String,
|
||||
pub model: WhichModel,
|
||||
pub model: Option<WhichModel>,
|
||||
pub cpu: bool,
|
||||
pub dtype: Option<String>,
|
||||
pub model_id: Option<String>,
|
||||
@@ -229,7 +276,7 @@ impl Default for GemmaInferenceConfig {
|
||||
Self {
|
||||
tracing: false,
|
||||
prompt: "Hello".to_string(),
|
||||
model: WhichModel::InstructV2_2B,
|
||||
model: Some(WhichModel::InstructV2_2B),
|
||||
cpu: false,
|
||||
dtype: None,
|
||||
model_id: None,
|
||||
@@ -286,28 +333,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
}
|
||||
};
|
||||
println!("Using dtype: {:?}", dtype);
|
||||
println!("Raw model string: {:?}", cfg.model_id);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
|
||||
let model_id = cfg.model_id.unwrap_or_else(|| {
|
||||
match cfg.model {
|
||||
WhichModel::Base2B => "google/gemma-2b",
|
||||
WhichModel::Base7B => "google/gemma-7b",
|
||||
WhichModel::Instruct2B => "google/gemma-2b-it",
|
||||
WhichModel::Instruct7B => "google/gemma-7b-it",
|
||||
WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it",
|
||||
WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it",
|
||||
WhichModel::CodeBase2B => "google/codegemma-2b",
|
||||
WhichModel::CodeBase7B => "google/codegemma-7b",
|
||||
WhichModel::CodeInstruct2B => "google/codegemma-2b-it",
|
||||
WhichModel::CodeInstruct7B => "google/codegemma-7b-it",
|
||||
WhichModel::BaseV2_2B => "google/gemma-2-2b",
|
||||
WhichModel::InstructV2_2B => "google/gemma-2-2b-it",
|
||||
WhichModel::BaseV2_9B => "google/gemma-2-9b",
|
||||
WhichModel::InstructV2_9B => "google/gemma-2-9b-it",
|
||||
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
|
||||
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
|
||||
Some(WhichModel::Base2B) => "google/gemma-2b",
|
||||
Some(WhichModel::Base7B) => "google/gemma-7b",
|
||||
Some(WhichModel::Instruct2B) => "google/gemma-2b-it",
|
||||
Some(WhichModel::Instruct7B) => "google/gemma-7b-it",
|
||||
Some(WhichModel::InstructV1_1_2B) => "google/gemma-1.1-2b-it",
|
||||
Some(WhichModel::InstructV1_1_7B) => "google/gemma-1.1-7b-it",
|
||||
Some(WhichModel::CodeBase2B) => "google/codegemma-2b",
|
||||
Some(WhichModel::CodeBase7B) => "google/codegemma-7b",
|
||||
Some(WhichModel::CodeInstruct2B) => "google/codegemma-2b-it",
|
||||
Some(WhichModel::CodeInstruct7B) => "google/codegemma-7b-it",
|
||||
Some(WhichModel::BaseV2_2B) => "google/gemma-2-2b",
|
||||
Some(WhichModel::InstructV2_2B) => "google/gemma-2-2b-it",
|
||||
Some(WhichModel::BaseV2_9B) => "google/gemma-2-9b",
|
||||
Some(WhichModel::InstructV2_9B) => "google/gemma-2-9b-it",
|
||||
Some(WhichModel::BaseV3_1B) => "google/gemma-3-1b-pt",
|
||||
Some(WhichModel::InstructV3_1B) => "google/gemma-3-1b-it",
|
||||
None => "google/gemma-2-2b-it", // default fallback
|
||||
}
|
||||
.to_string()
|
||||
});
|
||||
@@ -318,7 +367,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let filenames = match cfg.model {
|
||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
||||
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => vec![repo.get("model.safetensors")?],
|
||||
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("Retrieved files in {:?}", start.elapsed());
|
||||
@@ -329,29 +378,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
|
||||
let model: Model = match cfg.model {
|
||||
WhichModel::Base2B
|
||||
| WhichModel::Base7B
|
||||
| WhichModel::Instruct2B
|
||||
| WhichModel::Instruct7B
|
||||
| WhichModel::InstructV1_1_2B
|
||||
| WhichModel::InstructV1_1_7B
|
||||
| WhichModel::CodeBase2B
|
||||
| WhichModel::CodeBase7B
|
||||
| WhichModel::CodeInstruct2B
|
||||
| WhichModel::CodeInstruct7B => {
|
||||
Some(WhichModel::Base2B)
|
||||
| Some(WhichModel::Base7B)
|
||||
| Some(WhichModel::Instruct2B)
|
||||
| Some(WhichModel::Instruct7B)
|
||||
| Some(WhichModel::InstructV1_1_2B)
|
||||
| Some(WhichModel::InstructV1_1_7B)
|
||||
| Some(WhichModel::CodeBase2B)
|
||||
| Some(WhichModel::CodeBase7B)
|
||||
| Some(WhichModel::CodeInstruct2B)
|
||||
| Some(WhichModel::CodeInstruct7B) => {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V1(model)
|
||||
}
|
||||
WhichModel::BaseV2_2B
|
||||
| WhichModel::InstructV2_2B
|
||||
| WhichModel::BaseV2_9B
|
||||
| WhichModel::InstructV2_9B => {
|
||||
Some(WhichModel::BaseV2_2B)
|
||||
| Some(WhichModel::InstructV2_2B)
|
||||
| Some(WhichModel::BaseV2_9B)
|
||||
| Some(WhichModel::InstructV2_9B)
|
||||
| None => { // default to V2 model
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V2(model)
|
||||
}
|
||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => {
|
||||
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
|
||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V3(model)
|
||||
@@ -371,7 +421,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
);
|
||||
|
||||
let prompt = match cfg.model {
|
||||
WhichModel::InstructV3_1B => {
|
||||
Some(WhichModel::InstructV3_1B) => {
|
||||
format!(
|
||||
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
|
||||
cfg.prompt
|
||||
|
@@ -57,6 +57,27 @@ pub struct LlamaInferenceConfig {
|
||||
pub repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl LlamaInferenceConfig {
|
||||
pub fn new(model: WhichModel) -> Self {
|
||||
Self {
|
||||
prompt: String::new(),
|
||||
model,
|
||||
cpu: false,
|
||||
temperature: 1.0,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
seed: 42,
|
||||
max_tokens: 512,
|
||||
no_kv_cache: false,
|
||||
dtype: None,
|
||||
model_id: None,
|
||||
revision: None,
|
||||
use_flash_attn: true,
|
||||
repeat_penalty: 1.1,
|
||||
repeat_last_n: 64,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Default for LlamaInferenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
|
Reference in New Issue
Block a user