mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
supports small llama and gemma models
Refactor inference dedicated crates for llama and gemma inferencing, not integrated
This commit is contained in:
24
crates/llama-runner/Cargo.toml
Normal file
24
crates/llama-runner/Cargo.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[package]
|
||||
name = "llama-runner"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
||||
hf-hub = "0.3"
|
||||
tokenizers = "0.20"
|
||||
anyhow = "1.0"
|
||||
clap = { version = "4.0", features = ["derive", "string"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
|
188
crates/llama-runner/README.md
Normal file
188
crates/llama-runner/README.md
Normal file
@@ -0,0 +1,188 @@
|
||||
# Llama Runner
|
||||
|
||||
A fast Rust implementation for running Llama and other language models using the Candle deep learning framework. Built on the official Candle examples with optimizations for speed and usability.
|
||||
|
||||
## Features
|
||||
|
||||
- 🚀 **High Performance**: Metal GPU acceleration on macOS, CUDA support on Linux/Windows
|
||||
- 🤖 **Multiple Models**: Supports Llama 3.2, SmolLM2, TinyLlama, and more
|
||||
- ⚡ **Fast Inference**: Optimized with F16 precision and KV caching
|
||||
- 🎯 **Advanced Sampling**: Top-k, top-p, temperature, and repeat penalty controls
|
||||
- 📊 **Performance Metrics**: Real-time tokens/second reporting
|
||||
- 🔧 **Easy CLI**: Simple command-line interface with sensible defaults
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Size | Command | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| SmolLM2-135M | 135M | `smollm2-135m` | Tiny, fast model for testing |
|
||||
| SmolLM2-360M | 360M | `smollm2-360m` | Small, efficient model |
|
||||
| SmolLM2-1.7B | 1.7B | `smollm2-1.7b` | Balanced performance/speed |
|
||||
| Llama-3.2-1B | 1B | `llama-3.2-1b` | Meta's compact model |
|
||||
| Llama-3.2-3B | 3B | `llama-3.2-3b` | Larger Llama model |
|
||||
| TinyLlama-1.1B | 1.1B | `tinyllama-1.1b-chat` | Chat-optimized small model |
|
||||
|
||||
Add `-instruct` suffix for instruction-tuned variants (e.g., `smollm2-135m-instruct`).
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone <repository-url>
|
||||
cd llama-runner
|
||||
|
||||
# Build with GPU acceleration (recommended)
|
||||
cargo build --release --features metal # macOS
|
||||
cargo build --release --features cuda # Linux/Windows with NVIDIA GPU
|
||||
|
||||
# CPU-only build
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Fast inference with GPU acceleration
|
||||
cargo run --features metal -- --prompt "What is quantum computing?"
|
||||
|
||||
# Specify a model and parameters
|
||||
cargo run --features metal -- \
|
||||
--prompt "Write a short story about space exploration" \
|
||||
--model smollm2-360m \
|
||||
--max-tokens 100 \
|
||||
--temperature 0.8
|
||||
|
||||
# Use CPU (slower but works everywhere)
|
||||
cargo run -- --prompt "Hello, world!" --model smollm2-135m --cpu
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Text Generation
|
||||
```bash
|
||||
# Simple completion
|
||||
cargo run --features metal -- --prompt "The capital of France is"
|
||||
|
||||
# Creative writing with higher temperature
|
||||
cargo run --features metal -- \
|
||||
--prompt "Once upon a time" \
|
||||
--temperature 1.0 \
|
||||
--max-tokens 200
|
||||
```
|
||||
|
||||
### Advanced Sampling
|
||||
```bash
|
||||
# Top-k and top-p sampling
|
||||
cargo run --features metal -- \
|
||||
--prompt "Explain artificial intelligence" \
|
||||
--top-k 40 \
|
||||
--top-p 0.9 \
|
||||
--temperature 0.7
|
||||
|
||||
# Reduce repetition
|
||||
cargo run --features metal -- \
|
||||
--prompt "List the benefits of renewable energy" \
|
||||
--repeat-penalty 1.2 \
|
||||
--repeat-last-n 64
|
||||
```
|
||||
|
||||
### Different Models
|
||||
```bash
|
||||
# Ultra-fast with tiny model
|
||||
cargo run --features metal -- \
|
||||
--prompt "Quick test" \
|
||||
--model smollm2-135m
|
||||
|
||||
# Better quality with larger model
|
||||
cargo run --features metal -- \
|
||||
--prompt "Explain quantum physics" \
|
||||
--model llama-3.2-1b \
|
||||
--max-tokens 150
|
||||
```
|
||||
|
||||
## Command-Line Options
|
||||
|
||||
| Option | Short | Default | Description |
|
||||
|--------|-------|---------|-------------|
|
||||
| `--prompt` | `-p` | "The capital of France is" | Input prompt |
|
||||
| `--model` | `-m` | `smollm2-135m` | Model to use |
|
||||
| `--max-tokens` | `-n` | 100 | Maximum tokens to generate |
|
||||
| `--temperature` | `-t` | 0.8 | Sampling temperature (0.0 = deterministic) |
|
||||
| `--top-k` | | None | Top-k sampling |
|
||||
| `--top-p` | | None | Top-p (nucleus) sampling |
|
||||
| `--seed` | | 299792458 | Random seed for reproducibility |
|
||||
| `--repeat-penalty` | | 1.1 | Repetition penalty (1.0 = no penalty) |
|
||||
| `--repeat-last-n` | | 128 | Context window for repeat penalty |
|
||||
| `--cpu` | | false | Force CPU usage |
|
||||
| `--dtype` | | f16 | Data type: f16, bf16, f32 |
|
||||
| `--no-kv-cache` | | false | Disable key-value caching |
|
||||
|
||||
## Performance
|
||||
|
||||
Typical performance on Apple M2 with Metal acceleration:
|
||||
|
||||
| Model | Size | Speed | Memory |
|
||||
|-------|------|-------|--------|
|
||||
| SmolLM2-135M | 135M | ~100 tok/s | ~500MB |
|
||||
| SmolLM2-360M | 360M | ~80 tok/s | ~1GB |
|
||||
| SmolLM2-1.7B | 1.7B | ~50 tok/s | ~3GB |
|
||||
| Llama-3.2-1B | 1B | ~40 tok/s | ~2GB |
|
||||
|
||||
## Requirements
|
||||
|
||||
- **Rust**: 1.70+ (latest stable recommended)
|
||||
- **Memory**: 2-8GB RAM depending on model size
|
||||
- **Storage**: 1-10GB for model weights
|
||||
- **Network**: Internet connection for first-time model download
|
||||
- **GPU** (optional): Metal on macOS, CUDA on Linux/Windows
|
||||
|
||||
## GPU Support
|
||||
|
||||
### macOS (Metal)
|
||||
```bash
|
||||
cargo run --features metal -- [options]
|
||||
```
|
||||
|
||||
### Linux/Windows (CUDA)
|
||||
```bash
|
||||
cargo run --features cuda -- [options]
|
||||
```
|
||||
|
||||
### CPU Only
|
||||
```bash
|
||||
cargo run -- --cpu [options]
|
||||
```
|
||||
|
||||
## Model Downloads
|
||||
|
||||
Models are automatically downloaded from HuggingFace Hub on first use and cached locally. Download times:
|
||||
|
||||
- SmolLM2-135M: ~1 minute
|
||||
- SmolLM2-360M: ~2 minutes
|
||||
- Llama-3.2-1B: ~5 minutes
|
||||
- Larger models: 10+ minutes
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Slow Performance
|
||||
- Use `--features metal` on macOS or `--features cuda` on Linux/Windows
|
||||
- Try smaller models like `smollm2-135m` for faster inference
|
||||
- Ensure sufficient RAM for your chosen model
|
||||
|
||||
### Out of Memory
|
||||
- Use `--cpu` to use system RAM instead of GPU memory
|
||||
- Try smaller models or reduce `--max-tokens`
|
||||
- Use `--dtype f32` if f16 causes issues
|
||||
|
||||
### Model Download Issues
|
||||
- Check internet connection
|
||||
- Some models may require HuggingFace Hub authentication
|
||||
- Verify sufficient disk space in `~/.cache/huggingface/`
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions welcome! This project is based on the [Candle](https://github.com/huggingface/candle) framework by HuggingFace.
|
||||
|
||||
## License
|
||||
|
||||
MIT License - see LICENSE file for details.
|
8
crates/llama-runner/src/lib.rs
Normal file
8
crates/llama-runner/src/lib.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
pub mod llama_api;
|
||||
|
||||
use clap::ValueEnum;
|
||||
pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
||||
|
||||
// Re-export constants and types that might be needed
|
||||
pub const EOS_TOKEN: &str = "</s>";
|
||||
|
337
crates/llama-runner/src/llama_api.rs
Normal file
337
crates/llama-runner/src/llama_api.rs
Normal file
@@ -0,0 +1,337 @@
|
||||
use anyhow::{bail, Error as E};
|
||||
use candle_core::{utils, DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use candle_transformers::models::llama::{Llama, LlamaConfig};
|
||||
use candle_transformers::models::llama as model;
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use std::sync::mpsc::{self, Receiver};
|
||||
use clap::ValueEnum;
|
||||
use crate::{EOS_TOKEN};
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
|
||||
pub enum WhichModel {
|
||||
#[value(name = "llama-3.2-1b")]
|
||||
#[default]
|
||||
Llama32_1B,
|
||||
#[value(name = "llama-3.2-1b-instruct")]
|
||||
Llama32_1BInstruct,
|
||||
#[value(name = "llama-3.2-3b")]
|
||||
Llama32_3B,
|
||||
#[value(name = "llama-3.2-3b-instruct")]
|
||||
Llama32_3BInstruct,
|
||||
#[value(name = "smollm2-135m")]
|
||||
SmolLM2_135M,
|
||||
#[value(name = "smollm2-135m-instruct")]
|
||||
SmolLM2_135MInstruct,
|
||||
#[value(name = "smollm2-360m")]
|
||||
SmolLM2_360M,
|
||||
#[value(name = "smollm2-360m-instruct")]
|
||||
SmolLM2_360MInstruct,
|
||||
#[value(name = "smollm2-1.7b")]
|
||||
SmolLM2_1_7B,
|
||||
#[value(name = "smollm2-1.7b-instruct")]
|
||||
SmolLM2_1_7BInstruct,
|
||||
#[value(name = "tinyllama-1.1b-chat")]
|
||||
TinyLlama1_1BChat,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LlamaInferenceConfig {
|
||||
pub prompt: String,
|
||||
|
||||
pub model: WhichModel,
|
||||
pub cpu: bool,
|
||||
pub temperature: f64,
|
||||
pub top_p: Option<f64>,
|
||||
pub top_k: Option<usize>,
|
||||
pub seed: u64,
|
||||
pub max_tokens: usize,
|
||||
pub no_kv_cache: bool,
|
||||
pub dtype: Option<String>,
|
||||
pub model_id: Option<String>,
|
||||
pub revision: Option<String>,
|
||||
pub use_flash_attn: bool,
|
||||
pub repeat_penalty: f32,
|
||||
pub repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl Default for LlamaInferenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
// Leave prompt empty by default; let call sites set it.
|
||||
prompt: String::new(),
|
||||
|
||||
// Keep your existing model choice; swap at call-site if needed.
|
||||
model: WhichModel::Llama32_1BInstruct,
|
||||
|
||||
// Prefer GPU if available.
|
||||
cpu: false,
|
||||
|
||||
// Sampling: balanced + stable
|
||||
temperature: 0.7,
|
||||
top_p: Some(0.95),
|
||||
top_k: Some(50),
|
||||
|
||||
// Reproducible by default; override for variability.
|
||||
seed: 42,
|
||||
|
||||
// Don’t run unbounded generations.
|
||||
max_tokens: 512,
|
||||
|
||||
// Performance flags
|
||||
no_kv_cache: false, // keep cache ON for speed
|
||||
use_flash_attn: true, // great speed boost if supported
|
||||
|
||||
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
|
||||
dtype: Some("bf16".to_string()),
|
||||
|
||||
// Optional model source pinning (None = app defaults)
|
||||
model_id: None,
|
||||
revision: None,
|
||||
|
||||
// Anti-repeat heuristics
|
||||
repeat_penalty: 1.15,
|
||||
repeat_last_n: 128,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
fn device(cpu: bool) -> anyhow::Result<Device> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
} else if utils::cuda_is_available() {
|
||||
Ok(Device::new_cuda(0)?)
|
||||
} else if utils::metal_is_available() {
|
||||
Ok(Device::new_metal(0)?)
|
||||
} else {
|
||||
Ok(Device::Cpu)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn hub_load_safetensors(
|
||||
api: &hf_hub::api::sync::ApiRepo,
|
||||
json_file: &str,
|
||||
) -> anyhow::Result<Vec<std::path::PathBuf>> {
|
||||
let json_file = api.get(json_file)?;
|
||||
let json_file = std::fs::File::open(json_file)?;
|
||||
let json: serde_json::Value = serde_json::from_reader(&json_file)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
None => bail!("no weight map in {json_file:?}"),
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
Some(_) => bail!("weight map in {json_file:?} is not a map"),
|
||||
};
|
||||
let mut safetensors_files = std::collections::HashSet::new();
|
||||
for value in weight_map.values() {
|
||||
if let Some(file) = value.as_str() {
|
||||
safetensors_files.insert(file.to_string());
|
||||
}
|
||||
}
|
||||
let safetensors_files = safetensors_files
|
||||
.iter()
|
||||
.map(|v| api.get(v))
|
||||
.collect::<anyhow::Result<Vec<_>, _>>()?;
|
||||
Ok(safetensors_files)
|
||||
}
|
||||
|
||||
pub fn run_llama_inference(
|
||||
cfg: LlamaInferenceConfig,
|
||||
) -> anyhow::Result<Receiver<anyhow::Result<String>>, anyhow::Error> {
|
||||
// ---- Device & dtype -----------------------------------------------------
|
||||
let device = device(cfg.cpu)?;
|
||||
println!("Device: {:?}", device);
|
||||
|
||||
let dtype = match cfg.dtype.as_deref() {
|
||||
Some("f16") => DType::F16,
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
println!("Using dtype: {:?}", dtype);
|
||||
|
||||
// ---- Load model & tokenizer --------------------------------------------
|
||||
let (llama, tokenizer, mut cache) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = cfg.model_id.clone().unwrap_or_else(|| {
|
||||
match cfg.model {
|
||||
WhichModel::Llama32_1B => "meta-llama/Llama-3.2-1B",
|
||||
WhichModel::Llama32_1BInstruct => "meta-llama/Llama-3.2-1B-Instruct",
|
||||
WhichModel::Llama32_3B => "meta-llama/Llama-3.2-3B",
|
||||
WhichModel::Llama32_3BInstruct => "meta-llama/Llama-3.2-3B-Instruct",
|
||||
WhichModel::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
|
||||
WhichModel::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||
WhichModel::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
|
||||
WhichModel::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
|
||||
WhichModel::SmolLM2_1_7B => "HuggingFaceTB/SmolLM2-1.7B",
|
||||
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
}
|
||||
.to_string()
|
||||
});
|
||||
println!("Loading model: {}", model_id);
|
||||
let revision = cfg.revision.clone().unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let config_filename = api.get("config.json")?;
|
||||
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let config = config.into_config(cfg.use_flash_attn);
|
||||
|
||||
let filenames = match cfg.model {
|
||||
WhichModel::Llama32_3B | WhichModel::Llama32_3BInstruct => {
|
||||
hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
}
|
||||
_ => vec![api.get("model.safetensors")?],
|
||||
};
|
||||
|
||||
let cache = model::Cache::new(!cfg.no_kv_cache, dtype, &config, &device)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let llama = Llama::load(vb, &config)?;
|
||||
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
(llama, tokenizer, cache)
|
||||
};
|
||||
|
||||
// ---- Prepare prompt & sampler ------------------------------------------
|
||||
let eos_token_id = tokenizer
|
||||
.token_to_id(EOS_TOKEN)
|
||||
.map(model::LlamaEosToks::Single);
|
||||
|
||||
let mut tokens = tokenizer
|
||||
.encode(cfg.prompt.as_str(), true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
println!("Starting inference...");
|
||||
|
||||
let mut logits_processor = {
|
||||
let temperature = cfg.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (cfg.top_k, cfg.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(cfg.seed, sampling)
|
||||
};
|
||||
|
||||
// Channel for streaming decoded fragments to the caller.
|
||||
let (tx, rx) = mpsc::channel::<anyhow::Result<String>>();
|
||||
|
||||
// ---- Spawn generation thread -------------------------------------------
|
||||
std::thread::spawn(move || {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0usize;
|
||||
let mut token_generated = 0usize;
|
||||
|
||||
for index in 0..cfg.max_tokens {
|
||||
// Use KV-cache for single-token step after the first pass.
|
||||
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
|
||||
(1, index_pos)
|
||||
} else {
|
||||
(tokens.len(), 0)
|
||||
};
|
||||
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = match Tensor::new(ctxt, &device).and_then(|t| t.unsqueeze(0)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let logits = match llama.forward(&input, context_index, &mut cache) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
let logits = match logits.squeeze(0) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let logits = if cfg.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(cfg.repeat_last_n);
|
||||
match candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
cfg.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = match logits_processor.sample(&logits) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
token_generated += 1;
|
||||
tokens.push(next_token);
|
||||
|
||||
// Early stop on EOS.
|
||||
let stop = match eos_token_id {
|
||||
Some(model::LlamaEosToks::Single(eos_tok_id)) => next_token == eos_tok_id,
|
||||
Some(model::LlamaEosToks::Multiple(ref eos_ids)) => eos_ids.contains(&next_token),
|
||||
None => false,
|
||||
};
|
||||
if stop {
|
||||
break;
|
||||
}
|
||||
|
||||
// Decode this token's text and stream it out.
|
||||
match tokenizer.decode(&[next_token], false) {
|
||||
Ok(text) => {
|
||||
if !text.is_empty() {
|
||||
// Best-effort send; if receiver is gone, just stop.
|
||||
if tx.send(Ok(text)).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(anyhow::anyhow!("{}", e)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Optional: final stats as a debug line (not sent through the stream).
|
||||
let dt = start_gen.elapsed();
|
||||
eprintln!(
|
||||
"[llama-runner] {} tokens generated ({:.2} tokens/s)",
|
||||
token_generated,
|
||||
token_generated as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
// Dropping tx closes the stream.
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
}
|
||||
|
109
crates/llama-runner/src/llama_cli.rs
Normal file
109
crates/llama-runner/src/llama_cli.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
use crate::llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
||||
use clap::Parser;
|
||||
use std::io::Write;
|
||||
|
||||
#[derive(Parser, Debug, Default)]
|
||||
#[command(author, version, about = "Fast Llama inference with Candle", long_about = None)]
|
||||
struct Args {
|
||||
/// The prompt to generate text from
|
||||
#[arg(short, long, default_value = "The capital of France is")]
|
||||
prompt: String,
|
||||
|
||||
/// The model to use
|
||||
#[arg(short, long, default_value = "llama-3.2-1b-instruct")]
|
||||
model: WhichModel,
|
||||
|
||||
/// Run on CPU rather than GPU
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The temperature used to generate samples
|
||||
#[arg(short, long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens)
|
||||
#[arg(short = 'n', long, default_value_t = 100)]
|
||||
max_tokens: usize,
|
||||
|
||||
/// Disable the key-value cache
|
||||
#[arg(long)]
|
||||
no_kv_cache: bool,
|
||||
|
||||
/// Use different dtype than f16
|
||||
#[arg(long)]
|
||||
dtype: Option<String>,
|
||||
|
||||
/// Custom model ID from HuggingFace Hub
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
/// Model revision
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// Use flash attention
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty
|
||||
#[arg(long, default_value_t = 128)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl Into<LlamaInferenceConfig> for Args {
|
||||
fn into(self) -> LlamaInferenceConfig {
|
||||
LlamaInferenceConfig {
|
||||
prompt: self.prompt,
|
||||
model: self.model,
|
||||
cpu: self.cpu,
|
||||
temperature: self.temperature,
|
||||
top_p: self.top_p,
|
||||
top_k: self.top_k,
|
||||
seed: self.seed,
|
||||
max_tokens: self.max_tokens,
|
||||
no_kv_cache: self.no_kv_cache,
|
||||
dtype: self.dtype,
|
||||
model_id: self.model_id,
|
||||
revision: self.revision,
|
||||
use_flash_attn: self.use_flash_attn,
|
||||
repeat_penalty: self.repeat_penalty,
|
||||
repeat_last_n: self.repeat_last_n,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn run_cli() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let cfg = args.into();
|
||||
let rx = run_llama_inference(cfg)?;
|
||||
for msg in rx {
|
||||
match msg {
|
||||
Ok(tok) => {
|
||||
print!("{tok}");
|
||||
let _ = std::io::stdout().flush(); // <- force it out now
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("generation error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
20
crates/llama-runner/src/main.rs
Normal file
20
crates/llama-runner/src/main.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
mod llama_cli;
|
||||
mod llama_api;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use std::io::Write;
|
||||
|
||||
use crate::llama_cli::run_cli;
|
||||
|
||||
const EOS_TOKEN: &str = "</s>";
|
||||
|
||||
|
||||
fn main() -> Result<()> {
|
||||
run_cli()
|
||||
}
|
Reference in New Issue
Block a user