mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
reorg + update docs with new paths
This commit is contained in:
7
integration/llama-runner/src/lib.rs
Normal file
7
integration/llama-runner/src/lib.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
pub mod llama_api;
|
||||
|
||||
use clap::ValueEnum;
|
||||
pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
||||
|
||||
// Re-export constants and types that might be needed
|
||||
pub const EOS_TOKEN: &str = "</s>";
|
333
integration/llama-runner/src/llama_api.rs
Normal file
333
integration/llama-runner/src/llama_api.rs
Normal file
@@ -0,0 +1,333 @@
|
||||
use crate::EOS_TOKEN;
|
||||
use anyhow::{bail, Error as E};
|
||||
use candle_core::{utils, DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use candle_transformers::models::llama as model;
|
||||
use candle_transformers::models::llama::{Llama, LlamaConfig};
|
||||
use clap::ValueEnum;
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use std::sync::mpsc::{self, Receiver};
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
|
||||
pub enum WhichModel {
|
||||
#[value(name = "llama-3.2-1b")]
|
||||
#[default]
|
||||
Llama32_1B,
|
||||
#[value(name = "llama-3.2-1b-instruct")]
|
||||
Llama32_1BInstruct,
|
||||
#[value(name = "llama-3.2-3b")]
|
||||
Llama32_3B,
|
||||
#[value(name = "llama-3.2-3b-instruct")]
|
||||
Llama32_3BInstruct,
|
||||
#[value(name = "smollm2-135m")]
|
||||
SmolLM2_135M,
|
||||
#[value(name = "smollm2-135m-instruct")]
|
||||
SmolLM2_135MInstruct,
|
||||
#[value(name = "smollm2-360m")]
|
||||
SmolLM2_360M,
|
||||
#[value(name = "smollm2-360m-instruct")]
|
||||
SmolLM2_360MInstruct,
|
||||
#[value(name = "smollm2-1.7b")]
|
||||
SmolLM2_1_7B,
|
||||
#[value(name = "smollm2-1.7b-instruct")]
|
||||
SmolLM2_1_7BInstruct,
|
||||
#[value(name = "tinyllama-1.1b-chat")]
|
||||
TinyLlama1_1BChat,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LlamaInferenceConfig {
|
||||
pub prompt: String,
|
||||
|
||||
pub model: WhichModel,
|
||||
pub cpu: bool,
|
||||
pub temperature: f64,
|
||||
pub top_p: Option<f64>,
|
||||
pub top_k: Option<usize>,
|
||||
pub seed: u64,
|
||||
pub max_tokens: usize,
|
||||
pub no_kv_cache: bool,
|
||||
pub dtype: Option<String>,
|
||||
pub model_id: Option<String>,
|
||||
pub revision: Option<String>,
|
||||
pub use_flash_attn: bool,
|
||||
pub repeat_penalty: f32,
|
||||
pub repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl Default for LlamaInferenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
// Leave prompt empty by default; let call sites set it.
|
||||
prompt: String::new(),
|
||||
|
||||
// Keep your existing model choice; swap at call-site if needed.
|
||||
model: WhichModel::Llama32_1BInstruct,
|
||||
|
||||
// Prefer GPU if available.
|
||||
cpu: false,
|
||||
|
||||
// Sampling: balanced + stable
|
||||
temperature: 0.7,
|
||||
top_p: Some(0.95),
|
||||
top_k: Some(50),
|
||||
|
||||
// Reproducible by default; override for variability.
|
||||
seed: 42,
|
||||
|
||||
// Don’t run unbounded generations.
|
||||
max_tokens: 512,
|
||||
|
||||
// Performance flags
|
||||
no_kv_cache: false, // keep cache ON for speed
|
||||
use_flash_attn: false, // great speed boost if supported
|
||||
|
||||
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
|
||||
dtype: Some("bf16".to_string()),
|
||||
|
||||
// Optional model source pinning (None = app defaults)
|
||||
model_id: None,
|
||||
revision: None,
|
||||
|
||||
// Anti-repeat heuristics
|
||||
repeat_penalty: 1.15,
|
||||
repeat_last_n: 128,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn device(cpu: bool) -> anyhow::Result<Device> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
} else if utils::cuda_is_available() {
|
||||
Ok(Device::new_cuda(0)?)
|
||||
} else if utils::metal_is_available() {
|
||||
Ok(Device::new_metal(0)?)
|
||||
} else {
|
||||
Ok(Device::Cpu)
|
||||
}
|
||||
}
|
||||
|
||||
fn hub_load_safetensors(
|
||||
api: &hf_hub::api::sync::ApiRepo,
|
||||
json_file: &str,
|
||||
) -> anyhow::Result<Vec<std::path::PathBuf>> {
|
||||
let json_file = api.get(json_file)?;
|
||||
let json_file = std::fs::File::open(json_file)?;
|
||||
let json: serde_json::Value = serde_json::from_reader(&json_file)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
None => bail!("no weight map in {json_file:?}"),
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
Some(_) => bail!("weight map in {json_file:?} is not a map"),
|
||||
};
|
||||
let mut safetensors_files = std::collections::HashSet::new();
|
||||
for value in weight_map.values() {
|
||||
if let Some(file) = value.as_str() {
|
||||
safetensors_files.insert(file.to_string());
|
||||
}
|
||||
}
|
||||
let safetensors_files = safetensors_files
|
||||
.iter()
|
||||
.map(|v| api.get(v))
|
||||
.collect::<anyhow::Result<Vec<_>, _>>()?;
|
||||
Ok(safetensors_files)
|
||||
}
|
||||
|
||||
pub fn run_llama_inference(
|
||||
cfg: LlamaInferenceConfig,
|
||||
) -> anyhow::Result<Receiver<anyhow::Result<String>>, anyhow::Error> {
|
||||
// ---- Device & dtype -----------------------------------------------------
|
||||
let device = device(cfg.cpu)?;
|
||||
println!("Device: {:?}", device);
|
||||
|
||||
let dtype = match cfg.dtype.as_deref() {
|
||||
Some("f16") => DType::F16,
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
println!("Using dtype: {:?}", dtype);
|
||||
|
||||
// ---- Load model & tokenizer --------------------------------------------
|
||||
let (llama, tokenizer, mut cache) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = cfg.model_id.clone().unwrap_or_else(|| {
|
||||
match cfg.model {
|
||||
WhichModel::Llama32_1B => "meta-llama/Llama-3.2-1B",
|
||||
WhichModel::Llama32_1BInstruct => "meta-llama/Llama-3.2-1B-Instruct",
|
||||
WhichModel::Llama32_3B => "meta-llama/Llama-3.2-3B",
|
||||
WhichModel::Llama32_3BInstruct => "meta-llama/Llama-3.2-3B-Instruct",
|
||||
WhichModel::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
|
||||
WhichModel::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||
WhichModel::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
|
||||
WhichModel::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
|
||||
WhichModel::SmolLM2_1_7B => "HuggingFaceTB/SmolLM2-1.7B",
|
||||
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
}
|
||||
.to_string()
|
||||
});
|
||||
println!("Loading model: {}", model_id);
|
||||
let revision = cfg.revision.clone().unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let config_filename = api.get("config.json")?;
|
||||
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let config = config.into_config(cfg.use_flash_attn);
|
||||
|
||||
let filenames = match cfg.model {
|
||||
WhichModel::Llama32_3B | WhichModel::Llama32_3BInstruct => {
|
||||
hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
}
|
||||
_ => vec![api.get("model.safetensors")?],
|
||||
};
|
||||
|
||||
let cache = model::Cache::new(!cfg.no_kv_cache, dtype, &config, &device)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let llama = Llama::load(vb, &config)?;
|
||||
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
(llama, tokenizer, cache)
|
||||
};
|
||||
|
||||
// ---- Prepare prompt & sampler ------------------------------------------
|
||||
let eos_token_id = tokenizer
|
||||
.token_to_id(EOS_TOKEN)
|
||||
.map(model::LlamaEosToks::Single);
|
||||
|
||||
let mut tokens = tokenizer
|
||||
.encode(cfg.prompt.as_str(), true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
println!("Starting inference...");
|
||||
|
||||
let mut logits_processor = {
|
||||
let temperature = cfg.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (cfg.top_k, cfg.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(cfg.seed, sampling)
|
||||
};
|
||||
|
||||
// Channel for streaming decoded fragments to the caller.
|
||||
let (tx, rx) = mpsc::channel::<anyhow::Result<String>>();
|
||||
|
||||
// ---- Spawn generation thread -------------------------------------------
|
||||
std::thread::spawn(move || {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0usize;
|
||||
let mut token_generated = 0usize;
|
||||
|
||||
for index in 0..cfg.max_tokens {
|
||||
// Use KV-cache for single-token step after the first pass.
|
||||
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
|
||||
(1, index_pos)
|
||||
} else {
|
||||
(tokens.len(), 0)
|
||||
};
|
||||
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = match Tensor::new(ctxt, &device).and_then(|t| t.unsqueeze(0)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let logits = match llama.forward(&input, context_index, &mut cache) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
let logits = match logits.squeeze(0) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let logits = if cfg.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(cfg.repeat_last_n);
|
||||
match candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
cfg.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = match logits_processor.sample(&logits) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
token_generated += 1;
|
||||
tokens.push(next_token);
|
||||
|
||||
// Early stop on EOS.
|
||||
let stop = match eos_token_id {
|
||||
Some(model::LlamaEosToks::Single(eos_tok_id)) => next_token == eos_tok_id,
|
||||
Some(model::LlamaEosToks::Multiple(ref eos_ids)) => eos_ids.contains(&next_token),
|
||||
None => false,
|
||||
};
|
||||
if stop {
|
||||
break;
|
||||
}
|
||||
|
||||
// Decode this token's text and stream it out.
|
||||
match tokenizer.decode(&[next_token], false) {
|
||||
Ok(text) => {
|
||||
if !text.is_empty() {
|
||||
// Best-effort send; if receiver is gone, just stop.
|
||||
if tx.send(Ok(text)).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(anyhow::anyhow!("{}", e)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Optional: final stats as a debug line (not sent through the stream).
|
||||
let dt = start_gen.elapsed();
|
||||
eprintln!(
|
||||
"[llama-runner] {} tokens generated ({:.2} tokens/s)",
|
||||
token_generated,
|
||||
token_generated as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
// Dropping tx closes the stream.
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
}
|
108
integration/llama-runner/src/llama_cli.rs
Normal file
108
integration/llama-runner/src/llama_cli.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use crate::llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
||||
use clap::Parser;
|
||||
use std::io::Write;
|
||||
|
||||
#[derive(Parser, Debug, Default)]
|
||||
#[command(author, version, about = "Fast Llama inference with Candle", long_about = None)]
|
||||
struct Args {
|
||||
/// The prompt to generate text from
|
||||
#[arg(short, long, default_value = "The capital of France is")]
|
||||
prompt: String,
|
||||
|
||||
/// The model to use
|
||||
#[arg(short, long, default_value = "llama-3.2-1b-instruct")]
|
||||
model: WhichModel,
|
||||
|
||||
/// Run on CPU rather than GPU
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The temperature used to generate samples
|
||||
#[arg(short, long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens)
|
||||
#[arg(short = 'n', long, default_value_t = 100)]
|
||||
max_tokens: usize,
|
||||
|
||||
/// Disable the key-value cache
|
||||
#[arg(long)]
|
||||
no_kv_cache: bool,
|
||||
|
||||
/// Use different dtype than f16
|
||||
#[arg(long)]
|
||||
dtype: Option<String>,
|
||||
|
||||
/// Custom model ID from HuggingFace Hub
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
/// Model revision
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// Use flash attention
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty
|
||||
#[arg(long, default_value_t = 128)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl Into<LlamaInferenceConfig> for Args {
|
||||
fn into(self) -> LlamaInferenceConfig {
|
||||
LlamaInferenceConfig {
|
||||
prompt: self.prompt,
|
||||
model: self.model,
|
||||
cpu: self.cpu,
|
||||
temperature: self.temperature,
|
||||
top_p: self.top_p,
|
||||
top_k: self.top_k,
|
||||
seed: self.seed,
|
||||
max_tokens: self.max_tokens,
|
||||
no_kv_cache: self.no_kv_cache,
|
||||
dtype: self.dtype,
|
||||
model_id: self.model_id,
|
||||
revision: self.revision,
|
||||
use_flash_attn: self.use_flash_attn,
|
||||
repeat_penalty: self.repeat_penalty,
|
||||
repeat_last_n: self.repeat_last_n,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_cli() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let cfg = args.into();
|
||||
let rx = run_llama_inference(cfg)?;
|
||||
for msg in rx {
|
||||
match msg {
|
||||
Ok(tok) => {
|
||||
print!("{tok}");
|
||||
let _ = std::io::stdout().flush(); // <- force it out now
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("generation error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
19
integration/llama-runner/src/main.rs
Normal file
19
integration/llama-runner/src/main.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
mod llama_api;
|
||||
mod llama_cli;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use std::io::Write;
|
||||
|
||||
use crate::llama_cli::run_cli;
|
||||
|
||||
const EOS_TOKEN: &str = "</s>";
|
||||
|
||||
fn main() -> Result<()> {
|
||||
run_cli()
|
||||
}
|
Reference in New Issue
Block a user