supports small llama and gemma models

Refactor inference

dedicated crates for llama and gemma inferencing, not integrated
This commit is contained in:
geoffsee
2025-08-29 18:15:29 -04:00
parent d06b16bb12
commit 315ef17605
26 changed files with 2136 additions and 1402 deletions

View File

@@ -0,0 +1,24 @@
[package]
name = "llama-runner"
version = "0.1.0"
edition = "2021"
[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git" }
candle-nn = { git = "https://github.com/huggingface/candle.git" }
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
hf-hub = "0.3"
tokenizers = "0.20"
anyhow = "1.0"
clap = { version = "4.0", features = ["derive", "string"] }
serde_json = "1.0"
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
[features]
default = []
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]

View File

@@ -0,0 +1,188 @@
# Llama Runner
A fast Rust implementation for running Llama and other language models using the Candle deep learning framework. Built on the official Candle examples with optimizations for speed and usability.
## Features
- 🚀 **High Performance**: Metal GPU acceleration on macOS, CUDA support on Linux/Windows
- 🤖 **Multiple Models**: Supports Llama 3.2, SmolLM2, TinyLlama, and more
-**Fast Inference**: Optimized with F16 precision and KV caching
- 🎯 **Advanced Sampling**: Top-k, top-p, temperature, and repeat penalty controls
- 📊 **Performance Metrics**: Real-time tokens/second reporting
- 🔧 **Easy CLI**: Simple command-line interface with sensible defaults
## Supported Models
| Model | Size | Command | Description |
|-------|------|---------|-------------|
| SmolLM2-135M | 135M | `smollm2-135m` | Tiny, fast model for testing |
| SmolLM2-360M | 360M | `smollm2-360m` | Small, efficient model |
| SmolLM2-1.7B | 1.7B | `smollm2-1.7b` | Balanced performance/speed |
| Llama-3.2-1B | 1B | `llama-3.2-1b` | Meta's compact model |
| Llama-3.2-3B | 3B | `llama-3.2-3b` | Larger Llama model |
| TinyLlama-1.1B | 1.1B | `tinyllama-1.1b-chat` | Chat-optimized small model |
Add `-instruct` suffix for instruction-tuned variants (e.g., `smollm2-135m-instruct`).
## Installation
```bash
# Clone the repository
git clone <repository-url>
cd llama-runner
# Build with GPU acceleration (recommended)
cargo build --release --features metal # macOS
cargo build --release --features cuda # Linux/Windows with NVIDIA GPU
# CPU-only build
cargo build --release
```
## Quick Start
```bash
# Fast inference with GPU acceleration
cargo run --features metal -- --prompt "What is quantum computing?"
# Specify a model and parameters
cargo run --features metal -- \
--prompt "Write a short story about space exploration" \
--model smollm2-360m \
--max-tokens 100 \
--temperature 0.8
# Use CPU (slower but works everywhere)
cargo run -- --prompt "Hello, world!" --model smollm2-135m --cpu
```
## Usage Examples
### Basic Text Generation
```bash
# Simple completion
cargo run --features metal -- --prompt "The capital of France is"
# Creative writing with higher temperature
cargo run --features metal -- \
--prompt "Once upon a time" \
--temperature 1.0 \
--max-tokens 200
```
### Advanced Sampling
```bash
# Top-k and top-p sampling
cargo run --features metal -- \
--prompt "Explain artificial intelligence" \
--top-k 40 \
--top-p 0.9 \
--temperature 0.7
# Reduce repetition
cargo run --features metal -- \
--prompt "List the benefits of renewable energy" \
--repeat-penalty 1.2 \
--repeat-last-n 64
```
### Different Models
```bash
# Ultra-fast with tiny model
cargo run --features metal -- \
--prompt "Quick test" \
--model smollm2-135m
# Better quality with larger model
cargo run --features metal -- \
--prompt "Explain quantum physics" \
--model llama-3.2-1b \
--max-tokens 150
```
## Command-Line Options
| Option | Short | Default | Description |
|--------|-------|---------|-------------|
| `--prompt` | `-p` | "The capital of France is" | Input prompt |
| `--model` | `-m` | `smollm2-135m` | Model to use |
| `--max-tokens` | `-n` | 100 | Maximum tokens to generate |
| `--temperature` | `-t` | 0.8 | Sampling temperature (0.0 = deterministic) |
| `--top-k` | | None | Top-k sampling |
| `--top-p` | | None | Top-p (nucleus) sampling |
| `--seed` | | 299792458 | Random seed for reproducibility |
| `--repeat-penalty` | | 1.1 | Repetition penalty (1.0 = no penalty) |
| `--repeat-last-n` | | 128 | Context window for repeat penalty |
| `--cpu` | | false | Force CPU usage |
| `--dtype` | | f16 | Data type: f16, bf16, f32 |
| `--no-kv-cache` | | false | Disable key-value caching |
## Performance
Typical performance on Apple M2 with Metal acceleration:
| Model | Size | Speed | Memory |
|-------|------|-------|--------|
| SmolLM2-135M | 135M | ~100 tok/s | ~500MB |
| SmolLM2-360M | 360M | ~80 tok/s | ~1GB |
| SmolLM2-1.7B | 1.7B | ~50 tok/s | ~3GB |
| Llama-3.2-1B | 1B | ~40 tok/s | ~2GB |
## Requirements
- **Rust**: 1.70+ (latest stable recommended)
- **Memory**: 2-8GB RAM depending on model size
- **Storage**: 1-10GB for model weights
- **Network**: Internet connection for first-time model download
- **GPU** (optional): Metal on macOS, CUDA on Linux/Windows
## GPU Support
### macOS (Metal)
```bash
cargo run --features metal -- [options]
```
### Linux/Windows (CUDA)
```bash
cargo run --features cuda -- [options]
```
### CPU Only
```bash
cargo run -- --cpu [options]
```
## Model Downloads
Models are automatically downloaded from HuggingFace Hub on first use and cached locally. Download times:
- SmolLM2-135M: ~1 minute
- SmolLM2-360M: ~2 minutes
- Llama-3.2-1B: ~5 minutes
- Larger models: 10+ minutes
## Troubleshooting
### Slow Performance
- Use `--features metal` on macOS or `--features cuda` on Linux/Windows
- Try smaller models like `smollm2-135m` for faster inference
- Ensure sufficient RAM for your chosen model
### Out of Memory
- Use `--cpu` to use system RAM instead of GPU memory
- Try smaller models or reduce `--max-tokens`
- Use `--dtype f32` if f16 causes issues
### Model Download Issues
- Check internet connection
- Some models may require HuggingFace Hub authentication
- Verify sufficient disk space in `~/.cache/huggingface/`
## Contributing
Contributions welcome! This project is based on the [Candle](https://github.com/huggingface/candle) framework by HuggingFace.
## License
MIT License - see LICENSE file for details.

View File

@@ -0,0 +1,8 @@
pub mod llama_api;
use clap::ValueEnum;
pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
// Re-export constants and types that might be needed
pub const EOS_TOKEN: &str = "</s>";

View File

@@ -0,0 +1,337 @@
use anyhow::{bail, Error as E};
use candle_core::{utils, DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama::{Llama, LlamaConfig};
use candle_transformers::models::llama as model;
use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType};
use std::sync::mpsc::{self, Receiver};
use clap::ValueEnum;
use crate::{EOS_TOKEN};
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
pub enum WhichModel {
#[value(name = "llama-3.2-1b")]
#[default]
Llama32_1B,
#[value(name = "llama-3.2-1b-instruct")]
Llama32_1BInstruct,
#[value(name = "llama-3.2-3b")]
Llama32_3B,
#[value(name = "llama-3.2-3b-instruct")]
Llama32_3BInstruct,
#[value(name = "smollm2-135m")]
SmolLM2_135M,
#[value(name = "smollm2-135m-instruct")]
SmolLM2_135MInstruct,
#[value(name = "smollm2-360m")]
SmolLM2_360M,
#[value(name = "smollm2-360m-instruct")]
SmolLM2_360MInstruct,
#[value(name = "smollm2-1.7b")]
SmolLM2_1_7B,
#[value(name = "smollm2-1.7b-instruct")]
SmolLM2_1_7BInstruct,
#[value(name = "tinyllama-1.1b-chat")]
TinyLlama1_1BChat,
}
#[derive(Debug, Clone)]
pub struct LlamaInferenceConfig {
pub prompt: String,
pub model: WhichModel,
pub cpu: bool,
pub temperature: f64,
pub top_p: Option<f64>,
pub top_k: Option<usize>,
pub seed: u64,
pub max_tokens: usize,
pub no_kv_cache: bool,
pub dtype: Option<String>,
pub model_id: Option<String>,
pub revision: Option<String>,
pub use_flash_attn: bool,
pub repeat_penalty: f32,
pub repeat_last_n: usize,
}
impl Default for LlamaInferenceConfig {
fn default() -> Self {
Self {
// Leave prompt empty by default; let call sites set it.
prompt: String::new(),
// Keep your existing model choice; swap at call-site if needed.
model: WhichModel::Llama32_1BInstruct,
// Prefer GPU if available.
cpu: false,
// Sampling: balanced + stable
temperature: 0.7,
top_p: Some(0.95),
top_k: Some(50),
// Reproducible by default; override for variability.
seed: 42,
// Dont run unbounded generations.
max_tokens: 512,
// Performance flags
no_kv_cache: false, // keep cache ON for speed
use_flash_attn: true, // great speed boost if supported
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
dtype: Some("bf16".to_string()),
// Optional model source pinning (None = app defaults)
model_id: None,
revision: None,
// Anti-repeat heuristics
repeat_penalty: 1.15,
repeat_last_n: 128,
}
}
}
fn device(cpu: bool) -> anyhow::Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if utils::cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if utils::metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
Ok(Device::Cpu)
}
}
fn hub_load_safetensors(
api: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> anyhow::Result<Vec<std::path::PathBuf>> {
let json_file = api.get(json_file)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value = serde_json::from_reader(&json_file)?;
let weight_map = match json.get("weight_map") {
None => bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| api.get(v))
.collect::<anyhow::Result<Vec<_>, _>>()?;
Ok(safetensors_files)
}
pub fn run_llama_inference(
cfg: LlamaInferenceConfig,
) -> anyhow::Result<Receiver<anyhow::Result<String>>, anyhow::Error> {
// ---- Device & dtype -----------------------------------------------------
let device = device(cfg.cpu)?;
println!("Device: {:?}", device);
let dtype = match cfg.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
println!("Using dtype: {:?}", dtype);
// ---- Load model & tokenizer --------------------------------------------
let (llama, tokenizer, mut cache) = {
let api = Api::new()?;
let model_id = cfg.model_id.clone().unwrap_or_else(|| {
match cfg.model {
WhichModel::Llama32_1B => "meta-llama/Llama-3.2-1B",
WhichModel::Llama32_1BInstruct => "meta-llama/Llama-3.2-1B-Instruct",
WhichModel::Llama32_3B => "meta-llama/Llama-3.2-3B",
WhichModel::Llama32_3BInstruct => "meta-llama/Llama-3.2-3B-Instruct",
WhichModel::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
WhichModel::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
WhichModel::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
WhichModel::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
WhichModel::SmolLM2_1_7B => "HuggingFaceTB/SmolLM2-1.7B",
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
.to_string()
});
println!("Loading model: {}", model_id);
let revision = cfg.revision.clone().unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = api.get("tokenizer.json")?;
let config_filename = api.get("config.json")?;
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(cfg.use_flash_attn);
let filenames = match cfg.model {
WhichModel::Llama32_3B | WhichModel::Llama32_3BInstruct => {
hub_load_safetensors(&api, "model.safetensors.index.json")?
}
_ => vec![api.get("model.safetensors")?],
};
let cache = model::Cache::new(!cfg.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let llama = Llama::load(vb, &config)?;
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
(llama, tokenizer, cache)
};
// ---- Prepare prompt & sampler ------------------------------------------
let eos_token_id = tokenizer
.token_to_id(EOS_TOKEN)
.map(model::LlamaEosToks::Single);
let mut tokens = tokenizer
.encode(cfg.prompt.as_str(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
println!("Starting inference...");
let mut logits_processor = {
let temperature = cfg.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (cfg.top_k, cfg.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(cfg.seed, sampling)
};
// Channel for streaming decoded fragments to the caller.
let (tx, rx) = mpsc::channel::<anyhow::Result<String>>();
// ---- Spawn generation thread -------------------------------------------
std::thread::spawn(move || {
let start_gen = std::time::Instant::now();
let mut index_pos = 0usize;
let mut token_generated = 0usize;
for index in 0..cfg.max_tokens {
// Use KV-cache for single-token step after the first pass.
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = match Tensor::new(ctxt, &device).and_then(|t| t.unsqueeze(0)) {
Ok(t) => t,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
let logits = match llama.forward(&input, context_index, &mut cache) {
Ok(l) => l,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
let logits = match logits.squeeze(0) {
Ok(l) => l,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
let logits = if cfg.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(cfg.repeat_last_n);
match candle_transformers::utils::apply_repeat_penalty(
&logits,
cfg.repeat_penalty,
&tokens[start_at..],
) {
Ok(l) => l,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
}
};
index_pos += ctxt.len();
let next_token = match logits_processor.sample(&logits) {
Ok(t) => t,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
token_generated += 1;
tokens.push(next_token);
// Early stop on EOS.
let stop = match eos_token_id {
Some(model::LlamaEosToks::Single(eos_tok_id)) => next_token == eos_tok_id,
Some(model::LlamaEosToks::Multiple(ref eos_ids)) => eos_ids.contains(&next_token),
None => false,
};
if stop {
break;
}
// Decode this token's text and stream it out.
match tokenizer.decode(&[next_token], false) {
Ok(text) => {
if !text.is_empty() {
// Best-effort send; if receiver is gone, just stop.
if tx.send(Ok(text)).is_err() {
break;
}
}
}
Err(e) => {
let _ = tx.send(Err(anyhow::anyhow!("{}", e)));
break;
}
}
}
// Optional: final stats as a debug line (not sent through the stream).
let dt = start_gen.elapsed();
eprintln!(
"[llama-runner] {} tokens generated ({:.2} tokens/s)",
token_generated,
token_generated as f64 / dt.as_secs_f64(),
);
// Dropping tx closes the stream.
});
Ok(rx)
}

View File

@@ -0,0 +1,109 @@
use crate::llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
use clap::Parser;
use std::io::Write;
#[derive(Parser, Debug, Default)]
#[command(author, version, about = "Fast Llama inference with Candle", long_about = None)]
struct Args {
/// The prompt to generate text from
#[arg(short, long, default_value = "The capital of France is")]
prompt: String,
/// The model to use
#[arg(short, long, default_value = "llama-3.2-1b-instruct")]
model: WhichModel,
/// Run on CPU rather than GPU
#[arg(long)]
cpu: bool,
/// The temperature used to generate samples
#[arg(short, long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens)
#[arg(short = 'n', long, default_value_t = 100)]
max_tokens: usize,
/// Disable the key-value cache
#[arg(long)]
no_kv_cache: bool,
/// Use different dtype than f16
#[arg(long)]
dtype: Option<String>,
/// Custom model ID from HuggingFace Hub
#[arg(long)]
model_id: Option<String>,
/// Model revision
#[arg(long)]
revision: Option<String>,
/// Use flash attention
#[arg(long)]
use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty
#[arg(long, default_value_t = 128)]
repeat_last_n: usize,
}
impl Into<LlamaInferenceConfig> for Args {
fn into(self) -> LlamaInferenceConfig {
LlamaInferenceConfig {
prompt: self.prompt,
model: self.model,
cpu: self.cpu,
temperature: self.temperature,
top_p: self.top_p,
top_k: self.top_k,
seed: self.seed,
max_tokens: self.max_tokens,
no_kv_cache: self.no_kv_cache,
dtype: self.dtype,
model_id: self.model_id,
revision: self.revision,
use_flash_attn: self.use_flash_attn,
repeat_penalty: self.repeat_penalty,
repeat_last_n: self.repeat_last_n,
}
}
}
pub fn run_cli() -> anyhow::Result<()> {
let args = Args::parse();
let cfg = args.into();
let rx = run_llama_inference(cfg)?;
for msg in rx {
match msg {
Ok(tok) => {
print!("{tok}");
let _ = std::io::stdout().flush(); // <- force it out now
}
Err(e) => {
eprintln!("generation error: {e}");
break;
}
}
}
Ok(())
}

View File

@@ -0,0 +1,20 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod llama_cli;
mod llama_api;
use anyhow::Result;
use clap::{Parser, ValueEnum};
use std::io::Write;
use crate::llama_cli::run_cli;
const EOS_TOKEN: &str = "</s>";
fn main() -> Result<()> {
run_cli()
}