supports small llama and gemma models

Refactor inference

dedicated crates for llama and gemma inferencing, not integrated
This commit is contained in:
geoffsee
2025-08-29 18:15:29 -04:00
parent d06b16bb12
commit 315ef17605
26 changed files with 2136 additions and 1402 deletions

369
Cargo.lock generated
View File

@@ -686,6 +686,15 @@ dependencies = [
"generic-array",
]
[[package]]
name = "block2"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "340d2f0bdb2a43c1d3cd40513185b2bd7def0aa1052f956455114bc98f82dcf2"
dependencies = [
"objc2",
]
[[package]]
name = "brotli"
version = "3.5.0"
@@ -786,8 +795,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9f51e2ecf6efe9737af8f993433c839f956d2b6ed4fd2dd4a7c6d8b0fa667ff"
dependencies = [
"byteorder",
"candle-kernels",
"candle-metal-kernels",
"candle-kernels 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"candle-metal-kernels 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"cudarc",
"gemm 0.17.1",
"half",
@@ -807,6 +816,35 @@ dependencies = [
"zip",
]
[[package]]
name = "candle-core"
version = "0.9.1"
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
dependencies = [
"byteorder",
"candle-kernels 0.9.1 (git+https://github.com/huggingface/candle.git)",
"candle-metal-kernels 0.9.1 (git+https://github.com/huggingface/candle.git)",
"cudarc",
"float8",
"gemm 0.17.1",
"half",
"memmap2",
"num-traits",
"num_cpus",
"objc2-foundation",
"objc2-metal",
"rand 0.9.2",
"rand_distr 0.5.1",
"rayon",
"safetensors",
"thiserror 1.0.69",
"ug",
"ug-cuda",
"ug-metal",
"yoke 0.7.5",
"zip",
]
[[package]]
name = "candle-datasets"
version = "0.9.1"
@@ -814,15 +852,35 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0a7c351dd50cda83f00f17c4412e35c69d840e453edf06064974de1cc59343d"
dependencies = [
"byteorder",
"candle-core",
"candle-nn",
"hf-hub",
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"hf-hub 0.4.3",
"image",
"memmap2",
"parquet",
"rand 0.9.2",
"thiserror 1.0.69",
"tokenizers",
"tokenizers 0.21.4",
]
[[package]]
name = "candle-examples"
version = "0.9.1"
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
dependencies = [
"anyhow",
"candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)",
"candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)",
"candle-transformers 0.9.1 (git+https://github.com/huggingface/candle.git)",
"csv",
"hf-hub 0.4.3",
"image",
"num-traits",
"rayon",
"safetensors",
"serde",
"serde_json",
"tokenizers 0.21.4",
]
[[package]]
@@ -833,7 +891,7 @@ checksum = "fb38a5bfae09c4ae73fd00039e5eaf97a7d6d9400cc35ee8e603fc4a5f9cb0a3"
dependencies = [
"anyhow",
"bindgen_cuda",
"candle-core",
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"half",
]
@@ -846,6 +904,14 @@ dependencies = [
"bindgen_cuda",
]
[[package]]
name = "candle-kernels"
version = "0.9.1"
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
dependencies = [
"bindgen_cuda",
]
[[package]]
name = "candle-metal-kernels"
version = "0.9.1"
@@ -859,13 +925,27 @@ dependencies = [
"tracing",
]
[[package]]
name = "candle-metal-kernels"
version = "0.9.1"
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
dependencies = [
"half",
"objc2",
"objc2-foundation",
"objc2-metal",
"once_cell",
"thiserror 1.0.69",
"tracing",
]
[[package]]
name = "candle-nn"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1980d53280c8f9e2c6cbe1785855d7ff8010208b46e21252b978badf13ad69d"
dependencies = [
"candle-core",
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"half",
"num-traits",
"rayon",
@@ -874,14 +954,30 @@ dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "candle-nn"
version = "0.9.1"
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
dependencies = [
"candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)",
"candle-metal-kernels 0.9.1 (git+https://github.com/huggingface/candle.git)",
"half",
"num-traits",
"objc2-metal",
"rayon",
"safetensors",
"serde",
"thiserror 1.0.69",
]
[[package]]
name = "candle-onnx"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a8fa227a8176fd9b8fb58d63c908c08ad3af1503ee6fcd058be072a598044d2"
dependencies = [
"candle-core",
"candle-nn",
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"prost",
"prost-build",
]
@@ -893,8 +989,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "186cb80045dbe47e0b387ea6d3e906f02fb3056297080d9922984c90e90a72b0"
dependencies = [
"byteorder",
"candle-core",
"candle-nn",
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"fancy-regex",
"num-traits",
"rand 0.9.2",
"rayon",
"serde",
"serde_json",
"serde_plain",
"tracing",
]
[[package]]
name = "candle-transformers"
version = "0.9.1"
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
dependencies = [
"byteorder",
"candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)",
"candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)",
"fancy-regex",
"num-traits",
"rand 0.9.2",
@@ -1523,6 +1637,15 @@ dependencies = [
"dirs-sys 0.4.1",
]
[[package]]
name = "dirs"
version = "5.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225"
dependencies = [
"dirs-sys 0.4.1",
]
[[package]]
name = "dirs"
version = "6.0.0"
@@ -1556,6 +1679,16 @@ dependencies = [
"windows-sys 0.60.2",
]
[[package]]
name = "dispatch2"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89a09f22a6c6069a18470eb92d2298acf25463f14256d24778e1230d789a2aec"
dependencies = [
"bitflags 2.9.2",
"objc2",
]
[[package]]
name = "displaydoc"
version = "0.2.5"
@@ -1715,6 +1848,9 @@ name = "esaxx-rs"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6"
dependencies = [
"cc",
]
[[package]]
name = "event-listener"
@@ -1793,14 +1929,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04c269a76bfc6cea69553b7d040acb16c793119cebd97c756d21e08d0f075ff8"
dependencies = [
"anyhow",
"hf-hub",
"hf-hub 0.4.3",
"image",
"ndarray",
"ort",
"ort-sys",
"rayon",
"serde_json",
"tokenizers",
"tokenizers 0.21.4",
]
[[package]]
@@ -1856,6 +1992,18 @@ dependencies = [
"miniz_oxide",
]
[[package]]
name = "float8"
version = "0.2.1"
source = "git+https://github.com/zackangelo/float8?branch=cudarc_0_16#03c1f5fe7cdb2f9cb690823fdd40593be57c408f"
dependencies = [
"cudarc",
"half",
"num-traits",
"rand 0.9.2",
"rand_distr 0.5.1",
]
[[package]]
name = "fnv"
version = "1.0.7"
@@ -2246,6 +2394,24 @@ dependencies = [
"seq-macro",
]
[[package]]
name = "gemma-runner"
version = "0.1.0"
dependencies = [
"anyhow",
"candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)",
"candle-examples",
"candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)",
"candle-transformers 0.9.1 (git+https://github.com/huggingface/candle.git)",
"clap",
"hf-hub 0.4.3",
"serde_json",
"tokenizers 0.21.4",
"tracing",
"tracing-chrome",
"tracing-subscriber",
]
[[package]]
name = "generic-array"
version = "0.14.7"
@@ -2421,19 +2587,48 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "helm-chart-tool"
version = "0.1.0"
dependencies = [
"anyhow",
"clap",
"serde",
"serde_json",
"toml 0.8.23",
"walkdir",
]
[[package]]
name = "hermit-abi"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
[[package]]
name = "hf-hub"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
dependencies = [
"dirs 5.0.1",
"indicatif",
"log",
"native-tls",
"rand 0.8.5",
"serde",
"serde_json",
"thiserror 1.0.69",
"ureq",
]
[[package]]
name = "hf-hub"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97"
dependencies = [
"dirs",
"dirs 6.0.0",
"futures",
"http",
"indicatif",
@@ -2842,12 +3037,12 @@ dependencies = [
"axum",
"bindgen_cuda",
"byteorder",
"candle-core",
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"candle-datasets",
"candle-flash-attn",
"candle-nn",
"candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"candle-onnx",
"candle-transformers",
"candle-transformers 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"clap",
"cpal",
"csv",
@@ -2855,11 +3050,13 @@ dependencies = [
"either",
"enterpolation",
"futures-util",
"gemma-runner",
"half",
"hf-hub",
"hf-hub 0.4.3",
"image",
"imageproc",
"intel-mkl-src",
"llama-runner",
"memmap2",
"num-traits",
"palette",
@@ -2873,7 +3070,7 @@ dependencies = [
"serde",
"serde_json",
"symphonia",
"tokenizers",
"tokenizers 0.21.4",
"tokio",
"tokio-stream",
"tower",
@@ -2981,6 +3178,15 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "itertools"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.12.1"
@@ -3405,7 +3611,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
dependencies = [
"cfg-if",
"windows-targets 0.53.3",
"windows-targets 0.48.5",
]
[[package]]
@@ -3443,6 +3649,20 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956"
[[package]]
name = "llama-runner"
version = "0.1.0"
dependencies = [
"anyhow",
"candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)",
"candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)",
"candle-transformers 0.9.1 (git+https://github.com/huggingface/candle.git)",
"clap",
"hf-hub 0.3.2",
"serde_json",
"tokenizers 0.20.4",
]
[[package]]
name = "lock_api"
version = "0.4.13"
@@ -3965,6 +4185,59 @@ dependencies = [
"objc_exception",
]
[[package]]
name = "objc2"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "561f357ba7f3a2a61563a186a163d0a3a5247e1089524a3981d49adb775078bc"
dependencies = [
"objc2-encode",
]
[[package]]
name = "objc2-core-foundation"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166"
dependencies = [
"bitflags 2.9.2",
"dispatch2",
"objc2",
]
[[package]]
name = "objc2-encode"
version = "4.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33"
[[package]]
name = "objc2-foundation"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "900831247d2fe1a09a683278e5384cfb8c80c79fe6b166f9d14bfdde0ea1b03c"
dependencies = [
"bitflags 2.9.2",
"block2",
"libc",
"objc2",
"objc2-core-foundation",
]
[[package]]
name = "objc2-metal"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f246c183239540aab1782457b35ab2040d4259175bd1d0c58e46ada7b47a874"
dependencies = [
"bitflags 2.9.2",
"block2",
"dispatch2",
"objc2",
"objc2-core-foundation",
"objc2-foundation",
]
[[package]]
name = "objc_exception"
version = "0.1.2"
@@ -4803,7 +5076,7 @@ dependencies = [
"once_cell",
"socket2 0.5.10",
"tracing",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -5006,6 +5279,17 @@ dependencies = [
"rayon-core",
]
[[package]]
name = "rayon-cond"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9"
dependencies = [
"either",
"itertools 0.11.0",
"rayon",
]
[[package]]
name = "rayon-cond"
version = "0.4.0"
@@ -6267,7 +6551,7 @@ dependencies = [
"getrandom 0.3.3",
"once_cell",
"rustix",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -6384,6 +6668,38 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.20.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b08cc37428a476fc9e20ac850132a513a2e1ce32b6a31addf2b74fa7033b905"
dependencies = [
"aho-corasick",
"derive_builder",
"esaxx-rs",
"getrandom 0.2.16",
"indicatif",
"itertools 0.12.1",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand 0.8.5",
"rayon",
"rayon-cond 0.3.0",
"regex",
"regex-syntax 0.8.5",
"serde",
"serde_json",
"spm_precompiled",
"thiserror 1.0.69",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]]
name = "tokenizers"
version = "0.21.4"
@@ -6397,7 +6713,8 @@ dependencies = [
"derive_builder",
"esaxx-rs",
"getrandom 0.3.3",
"hf-hub",
"hf-hub 0.4.3",
"indicatif",
"itertools 0.14.0",
"log",
"macro_rules_attribute",
@@ -6406,7 +6723,7 @@ dependencies = [
"paste",
"rand 0.9.2",
"rayon",
"rayon-cond",
"rayon-cond 0.4.0",
"regex",
"regex-syntax 0.8.5",
"serde",
@@ -7260,7 +7577,7 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.48.0",
]
[[package]]

View File

@@ -4,7 +4,9 @@ members = [
"crates/inference-engine",
"crates/embeddings-engine",
"crates/leptos-app",
"crates/helm-chart-tool"
"crates/helm-chart-tool",
"crates/llama-runner",
"crates/gemma-runner"
]
default-members = ["crates/predict-otron-9000"]
resolver = "2"

View File

@@ -10,7 +10,7 @@ Powerful local AI inference with OpenAI-compatible APIs
The predict-otron-9000 is a flexible AI platform that provides:
- **Local LLM Inference**: Run Gemma models locally with CPU or GPU acceleration
- **Local LLM Inference**: Run Gemma and Llama models locally with CPU or GPU acceleration
- **Embeddings Generation**: Create text embeddings with FastEmbed
- **Web Interface**: Interact with models through a Leptos WASM chat interface
- **TypeScript CLI**: Command-line client for testing and automation
@@ -22,7 +22,7 @@ The system supports both CPU and GPU acceleration (CUDA/Metal), with intelligent
- **OpenAI Compatible**: API endpoints match OpenAI's format for easy integration
- **Text Embeddings**: Generate high-quality text embeddings using FastEmbed
- **Text Generation**: Chat completions with OpenAI-compatible API using Gemma models (1B, 2B, 7B variants including instruction-tuned models)
- **Text Generation**: Chat completions with OpenAI-compatible API using Gemma and Llama models (various sizes including instruction-tuned variants)
- **Performance Optimized**: Efficient caching and platform-specific optimizations for improved throughput
- **Web Chat Interface**: Leptos-based WebAssembly (WASM) chat interface for browser-based interaction
- **Flexible Deployment**: Run as monolithic service or microservices architecture
@@ -31,15 +31,19 @@ The system supports both CPU and GPU acceleration (CUDA/Metal), with intelligent
### Workspace Structure
The project uses a 4-crate Rust workspace plus TypeScript components:
The project uses a 7-crate Rust workspace plus TypeScript components:
```
crates/
├── predict-otron-9000/ # Main orchestration server (Rust 2024)
├── inference-engine/ # Gemma inference via Candle (Rust 2021)
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
── leptos-app/ # WASM web frontend (Rust 2021)
cli.ts # TypeScript/Bun CLI client
├── inference-engine/ # Multi-model inference orchestrator (Rust 2021)
├── gemma-runner/ # Gemma model inference via Candle (Rust 2021)
── llama-runner/ # Llama model inference via Candle (Rust 2021)
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
├── leptos-app/ # WASM web frontend (Rust 2021)
├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024)
└── scripts/
└── cli.ts # TypeScript/Bun CLI client
```
### Service Architecture
@@ -149,16 +153,16 @@ cd crates/leptos-app
#### TypeScript CLI Client
```bash
# List available models
bun run cli.ts --list-models
bun run scripts/cli.ts --list-models
# Chat completion
bun run cli.ts "What is the capital of France?"
bun run scripts/cli.ts "What is the capital of France?"
# With specific model
bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
bun run scripts/cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
# Show help
bun run cli.ts --help
bun run scripts/cli.ts --help
```
## API Usage
@@ -454,7 +458,7 @@ curl -s http://localhost:8080/v1/models | jq
**CLI client test:**
```bash
bun run cli.ts "What is 2+2?"
bun run scripts/cli.ts "What is 2+2?"
```
**Web frontend:**

View File

@@ -0,0 +1,28 @@
[package]
name = "gemma-runner"
version = "0.1.0"
edition = "2021"
[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git" }
candle-nn = { git = "https://github.com/huggingface/candle.git" }
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
candle-examples = { git = "https://github.com/huggingface/candle.git" }
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
hf-hub = "0.4"
tokenizers = "0.21"
anyhow = "1.0"
clap = { version = "4.0", features = ["derive", "string"] }
serde_json = "1.0"
tracing = "0.1"
tracing-chrome = "0.7"
tracing-subscriber = "0.3"
[features]
default = []
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]

View File

@@ -0,0 +1,137 @@
# Gemma Runner
Fast Gemma inference with Candle framework in Rust.
## Features
- Support for multiple Gemma model versions (v1, v2, v3)
- GPU acceleration with CUDA and Metal
- Configurable sampling parameters
- Multiple model variants including instruct and code models
## Supported Models
### Gemma v1
- `gemma-2b` - Base 2B model
- `gemma-7b` - Base 7B model
- `gemma-2b-it` - Instruct 2B model
- `gemma-7b-it` - Instruct 7B model
- `gemma-1.1-2b-it` - Instruct 2B v1.1 model
- `gemma-1.1-7b-it` - Instruct 7B v1.1 model
### CodeGemma
- `codegemma-2b` - Code base 2B model
- `codegemma-7b` - Code base 7B model
- `codegemma-2b-it` - Code instruct 2B model
- `codegemma-7b-it` - Code instruct 7B model
### Gemma v2
- `gemma-2-2b` - Base 2B v2 model (default)
- `gemma-2-2b-it` - Instruct 2B v2 model
- `gemma-2-9b` - Base 9B v2 model
- `gemma-2-9b-it` - Instruct 9B v2 model
### Gemma v3
- `gemma-3-1b` - Base 1B v3 model
- `gemma-3-1b-it` - Instruct 1B v3 model
## Installation
```bash
cd gemma-runner
cargo build --release
```
For GPU support:
```bash
# CUDA
cargo build --release --features cuda
# Metal (macOS)
cargo build --release --features metal
```
## Usage
### Basic Usage
```bash
# Run with default model (gemma-2-2b)
cargo run -- --prompt "The capital of France is"
# Specify a different model
cargo run -- --model gemma-2b-it --prompt "Explain quantum computing"
# Generate more tokens
cargo run -- --model codegemma-2b-it --prompt "Write a Python function to sort a list" --max-tokens 200
```
### Advanced Options
```bash
# Use CPU instead of GPU
cargo run -- --cpu --prompt "Hello world"
# Adjust sampling parameters
cargo run -- --temperature 0.8 --top-p 0.9 --prompt "Write a story about"
# Use custom model from HuggingFace Hub
cargo run -- --model-id "google/gemma-2-2b-it" --prompt "What is AI?"
# Enable tracing for performance analysis
cargo run -- --tracing --prompt "Explain machine learning"
```
### Command Line Arguments
- `--prompt, -p` - The prompt to generate text from (default: "The capital of France is")
- `--model, -m` - The model to use (default: "gemma-2-2b")
- `--cpu` - Run on CPU rather than GPU
- `--temperature, -t` - Sampling temperature (optional)
- `--top-p` - Nucleus sampling probability cutoff (optional)
- `--seed` - Random seed (default: 299792458)
- `--max-tokens, -n` - Maximum tokens to generate (default: 100)
- `--model-id` - Custom model ID from HuggingFace Hub
- `--revision` - Model revision (default: "main")
- `--use-flash-attn` - Use flash attention
- `--repeat-penalty` - Repetition penalty (default: 1.1)
- `--repeat-last-n` - Context size for repeat penalty (default: 64)
- `--dtype` - Data type (f16, bf16, f32)
- `--tracing` - Enable performance tracing
## Examples
### Text Generation
```bash
cargo run -- --model gemma-2b-it --prompt "Explain the theory of relativity" --max-tokens 150
```
### Code Generation
```bash
cargo run -- --model codegemma-7b-it --prompt "Write a Rust function to calculate factorial" --max-tokens 100
```
### Creative Writing
```bash
cargo run -- --model gemma-7b-it --temperature 0.9 --prompt "Once upon a time in a magical forest" --max-tokens 200
```
### Chat with Gemma 3 (Instruct format)
```bash
cargo run -- --model gemma-3-1b-it --prompt "How do I learn Rust programming?"
```
## Performance Notes
- GPU acceleration is automatically detected and used when available
- BF16 precision is used on CUDA for better performance
- F32 precision is used on CPU
- Flash attention can be enabled with `--use-flash-attn` for supported models
- Model files are cached locally after first download
## Requirements
- Rust 1.70+
- CUDA toolkit (for CUDA support)
- Metal (automatically available on macOS)
- Internet connection for first-time model download

View File

@@ -0,0 +1,389 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use clap::ValueEnum;
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
// Removed gemma_cli import as it's not needed for the API
use candle_core::{utils, DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
use tokenizers::Tokenizer;
use std::sync::mpsc::{self, Receiver, Sender};
use std::thread;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
pub enum WhichModel {
#[value(name = "gemma-2b")]
Base2B,
#[value(name = "gemma-7b")]
Base7B,
#[value(name = "gemma-2b-it")]
Instruct2B,
#[value(name = "gemma-7b-it")]
Instruct7B,
#[value(name = "gemma-1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "gemma-1.1-7b-it")]
InstructV1_1_7B,
#[value(name = "codegemma-2b")]
CodeBase2B,
#[value(name = "codegemma-7b")]
CodeBase7B,
#[value(name = "codegemma-2b-it")]
CodeInstruct2B,
#[value(name = "codegemma-7b-it")]
CodeInstruct7B,
#[value(name = "gemma-2-2b")]
BaseV2_2B,
#[value(name = "gemma-2-2b-it")]
InstructV2_2B,
#[value(name = "gemma-2-9b")]
BaseV2_9B,
#[value(name = "gemma-2-9b-it")]
InstructV2_9B,
#[value(name = "gemma-3-1b")]
BaseV3_1B,
#[value(name = "gemma-3-1b-it")]
InstructV3_1B,
}
enum Model {
V1(Model1),
V2(Model2),
V3(Model3),
}
impl Model {
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> candle_core::Result<Tensor> {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
Self::V3(m) => m.forward(input_ids, pos),
}
}
}
pub struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if utils::cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if utils::metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
Ok(Device::Cpu)
}
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
/// Stream-only generation: sends freshly generated token strings over `tx`.
/// (Does not send the prompt tokens; only newly generated model tokens.)
fn run_stream(&mut self, prompt: &str, sample_len: usize, tx: Sender<Result<String>>) -> Result<()> {
self.tokenizer.clear();
// Encode prompt (context only; do not emit prompt tokens to the stream).
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
// Warm the tokenizer's internal state with prompt tokens (so merges are correct),
// but do not send them to the receiver.
for &t in tokens.iter() {
let _ = self.tokenizer.next_token(t)?;
}
// Make sure stdout isn't holding anything (if caller also prints).
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
eprintln!("Warning: <end_of_turn> token not found, using <eos> as backup");
eos_token
}
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
// Best-effort send; ignore if receiver dropped.
let _ = tx.send(Ok(t));
}
}
let _dt = start_gen.elapsed();
// Flush any remaining buffered bytes as one final chunk.
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
let _ = tx.send(Ok(rest));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct GemmaInferenceConfig {
pub tracing: bool,
pub prompt: String,
pub model: WhichModel,
pub cpu: bool,
pub dtype: Option<String>,
pub model_id: Option<String>,
pub revision: String,
pub use_flash_attn: bool,
pub seed: u64,
pub temperature: f64,
pub top_p: Option<f64>,
pub repeat_penalty: f32,
pub repeat_last_n: usize,
pub max_tokens: usize,
}
impl Default for GemmaInferenceConfig {
fn default() -> Self {
Self {
tracing: false,
prompt: "Hello".to_string(),
model: WhichModel::InstructV2_2B,
cpu: false,
dtype: None,
model_id: None,
revision: "main".to_string(),
use_flash_attn: false,
seed: 299792458,
temperature: 0.8,
top_p: None,
repeat_penalty: 1.1,
repeat_last_n: 128,
max_tokens: 100,
}
}
}
// Removed From<Args> implementation as Args is not available and not needed for API usage
/// Builds the model and returns a channel that streams generated token strings.
/// If model setup fails, the `Result` is returned immediately.
pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String>>> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let _guard = if cfg.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
utils::with_avx(),
utils::with_neon(),
utils::with_simd128(),
utils::with_f16c()
);
let device = device(cfg.cpu)?;
println!("Device: {:?}", device);
let dtype = match cfg.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => anyhow::bail!("Unsupported dtype {dtype}"),
None => {
if device.is_cuda() {
DType::BF16
} else {
DType::F16
}
}
};
println!("Using dtype: {:?}", dtype);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = cfg.model_id.unwrap_or_else(|| {
match cfg.model {
WhichModel::Base2B => "google/gemma-2b",
WhichModel::Base7B => "google/gemma-7b",
WhichModel::Instruct2B => "google/gemma-2b-it",
WhichModel::Instruct7B => "google/gemma-7b-it",
WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it",
WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it",
WhichModel::CodeBase2B => "google/codegemma-2b",
WhichModel::CodeBase7B => "google/codegemma-7b",
WhichModel::CodeInstruct2B => "google/codegemma-2b-it",
WhichModel::CodeInstruct7B => "google/codegemma-7b-it",
WhichModel::BaseV2_2B => "google/gemma-2-2b",
WhichModel::InstructV2_2B => "google/gemma-2-2b-it",
WhichModel::BaseV2_9B => "google/gemma-2-9b",
WhichModel::InstructV2_9B => "google/gemma-2-9b-it",
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
}
.to_string()
});
println!("Loading model: {}", &model_id);
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, cfg.revision));
let tokenizer_filename = repo.get("tokenizer.json")?;
let config_filename = repo.get("config.json")?;
let filenames = match cfg.model {
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("Retrieved files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model: Model = match cfg.model {
WhichModel::Base2B
| WhichModel::Base7B
| WhichModel::Instruct2B
| WhichModel::Instruct7B
| WhichModel::InstructV1_1_2B
| WhichModel::InstructV1_1_7B
| WhichModel::CodeBase2B
| WhichModel::CodeBase7B
| WhichModel::CodeInstruct2B
| WhichModel::CodeInstruct7B => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
Model::V1(model)
}
WhichModel::BaseV2_2B | WhichModel::InstructV2_2B | WhichModel::BaseV2_9B | WhichModel::InstructV2_9B => {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
Model::V2(model)
}
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
Model::V3(model)
}
};
println!("Loaded model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
cfg.seed,
cfg.temperature.into(),
cfg.top_p,
cfg.repeat_penalty,
cfg.repeat_last_n,
&device,
);
let prompt = match cfg.model {
WhichModel::InstructV3_1B => {
format!(
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
cfg.prompt
)
}
_ => cfg.prompt,
};
println!("Starting inference...");
// Create the channel after successful setup.
let (tx, rx) = mpsc::channel::<Result<String>>();
// Spawn generation thread; send tokens to the channel.
thread::spawn(move || {
// If generation fails, forward the error once.
if let Err(e) = pipeline.run_stream(&prompt, cfg.max_tokens, tx.clone()) {
let _ = tx.send(Err(e));
}
// Channel closes when tx is dropped.
});
Ok(rx)
}

View File

@@ -0,0 +1,97 @@
use std::io::Write;
use clap::Parser;
use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
#[derive(Parser, Debug)]
#[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)]
pub struct Args {
/// The prompt to generate text from
#[arg(short, long, default_value = "The capital of France is")]
pub(crate) prompt: String,
/// The model to use
#[arg(short, long, default_value = "gemma-2-2b")]
pub(crate) model: WhichModel,
/// Run on CPU rather than GPU
#[arg(long)]
pub(crate) cpu: bool,
/// The temperature used to generate samples
#[arg(short, long)]
pub(crate) temperature: Option<f64>,
/// Nucleus sampling probability cutoff
#[arg(long)]
pub(crate) top_p: Option<f64>,
/// The seed to use when generating random samples
#[arg(long, default_value_t = 299792458)]
pub(crate) seed: u64,
/// The length of the sample to generate (in tokens)
#[arg(short = 'n', long, default_value_t = 100)]
pub(crate) max_tokens: usize,
/// Use different dtype than default
#[arg(long)]
pub(crate) dtype: Option<String>,
/// Custom model ID from HuggingFace Hub
#[arg(long)]
pub(crate) model_id: Option<String>,
/// Model revision
#[arg(long, default_value = "main")]
pub(crate) revision: String,
/// Use flash attention
#[arg(long)]
pub(crate) use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty
#[arg(long, default_value_t = 1.1)]
pub(crate) repeat_penalty: f32,
/// The context size to consider for the repeat penalty
#[arg(long, default_value_t = 64)]
pub(crate) repeat_last_n: usize,
/// Enable tracing
#[arg(long)]
pub(crate) tracing: bool,
}
pub fn run_cli() -> anyhow::Result<()> {
let args = Args::parse();
let cfg = GemmaInferenceConfig {
tracing: args.tracing,
prompt: args.prompt,
model: args.model,
cpu: args.cpu,
dtype: args.dtype,
model_id: args.model_id,
revision: args.revision,
use_flash_attn: args.use_flash_attn,
seed: args.seed,
temperature: args.temperature.unwrap_or(0.8),
top_p: args.top_p,
repeat_penalty: args.repeat_penalty,
repeat_last_n: args.repeat_last_n,
max_tokens: args.max_tokens,
};
let rx = run_gemma_api(cfg)?;
for msg in rx {
match msg {
Ok(tok) => {
print!("{tok}");
let _ = std::io::stdout().flush(); // <- force it out now
}
Err(e) => {
eprintln!("generation error: {e}");
break;
}
}
}
Ok(())
}

View File

@@ -0,0 +1,3 @@
pub mod gemma_api;
pub use gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};

View File

@@ -0,0 +1,17 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod gemma_cli;
mod gemma_api;
use anyhow::Error;
use clap::{Parser, ValueEnum};
use crate::gemma_cli::run_cli;
use std::io::Write;
/// just a placeholder, not used for anything
fn main() -> std::result::Result<(), Error> {
run_cli()
}

View File

@@ -3,8 +3,6 @@ name = "helm-chart-tool"
version = "0.1.0"
edition = "2021"
[workspace]
[[bin]]
name = "helm-chart-tool"
path = "src/main.rs"

View File

@@ -3,9 +3,16 @@ name = "inference-engine"
version = "0.1.0"
edition = "2021"
[[bin]]
name="cli"
path = "src/cli_main.rs"
name="gemma_inference"
path = "src/gemma_inference.rs"
required-features = ["bin"]
[[bin]]
name="llama_inference"
path = "src/llama_inference.rs"
required-features = ["bin"]
[dependencies]
@@ -50,6 +57,8 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
uuid = { version = "1.7.0", features = ["v4"] }
reborrow = "0.5.5"
futures-util = "0.3.31"
gemma-runner = { path = "../gemma-runner" }
llama-runner = { path = "../llama-runner" }
# --- Add this section for conditional compilation ---
[target.'cfg(target_os = "macos")'.dependencies]
@@ -83,6 +92,9 @@ tokio = "1.43.0"
anyhow = { version = "1", features = ["backtrace"] }
bindgen_cuda = { version = "0.1.1", optional = true }
[features]
bin = []
[package.metadata.compose]

View File

@@ -1,72 +0,0 @@
use clap::Parser;
use crate::model::Which;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
pub cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
pub tracing: bool,
/// Run in server mode with OpenAI compatible API
#[arg(long)]
pub server: bool,
/// Port to use for the server
#[arg(long, default_value_t = 3777)]
pub port: u16,
/// Prompt for text generation (not used in server mode)
#[arg(long)]
pub prompt: Option<String>,
/// The temperature used to generate samples.
#[arg(long)]
pub temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
pub top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
pub seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
pub sample_len: usize,
#[arg(long)]
pub model_id: Option<String>,
#[arg(long, default_value = "main")]
pub revision: String,
#[arg(long)]
pub tokenizer_file: Option<String>,
#[arg(long)]
pub config_file: Option<String>,
#[arg(long)]
pub weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
pub repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
pub repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "3-1b-it")]
pub which: Which,
#[arg(long)]
pub use_flash_attn: bool,
}

View File

@@ -1,912 +0,0 @@
mod token_output_stream;
mod utilities_lib;
#[cfg(feature = "intel-mkl-src")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate-src")]
extern crate accelerate_src;
#[cfg(feature = "metal")]
extern crate metal_src;
use anyhow::{Error as E, Result};
use axum::{
extract::State,
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use clap::Parser;
use either::Either;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use tokio::sync::Mutex;
use tower_http::cors::{Any, CorsLayer};
use utoipa::ToSchema;
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
// OpenAI API compatible structs
/// Inner content structure for messages that can be either a string or key-value pairs
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageInnerContent(
#[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
);
impl ToSchema<'_> for MessageInnerContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
(
"MessageInnerContent",
utoipa::openapi::RefOr::T(message_inner_content_schema()),
)
}
}
/// Function for MessageInnerContent Schema generation to handle `Either`
fn message_inner_content_schema() -> utoipa::openapi::Schema {
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
Schema::OneOf(
OneOfBuilder::new()
// Either::Left - simple string
.item(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))
// Either::Right - object with string values
.item(Schema::Object(
ObjectBuilder::new()
.schema_type(SchemaType::Object)
.additional_properties(Some(RefOr::T(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))))
.build(),
))
.build(),
)
}
/// Message content that can be either simple text or complex structured content
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageContent(
#[serde(with = "either::serde_untagged")]
Either<String, Vec<HashMap<String, MessageInnerContent>>>,
);
impl ToSchema<'_> for MessageContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
}
}
/// Function for MessageContent Schema generation to handle `Either`
fn message_content_schema() -> utoipa::openapi::Schema {
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
Schema::OneOf(
OneOfBuilder::new()
.item(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))
.item(Schema::Array(
ArrayBuilder::new()
.items(RefOr::T(Schema::Object(
ObjectBuilder::new()
.schema_type(SchemaType::Object)
.additional_properties(Some(RefOr::Ref(
utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
)))
.build(),
)))
.build(),
))
.build(),
)
}
/// Represents a single message in a conversation
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct Message {
/// The message content
pub content: Option<MessageContent>,
/// The role of the message sender ("user", "assistant", "system", "tool", etc.)
pub role: String,
pub name: Option<String>,
}
/// Stop token configuration for generation
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
#[serde(untagged)]
pub enum StopTokens {
/// Multiple possible stop sequences
Multi(Vec<String>),
/// Single stop sequence
Single(String),
}
/// Default value helper
fn default_false() -> bool {
false
}
/// Default value helper
fn default_1usize() -> usize {
1
}
/// Default value helper
fn default_model() -> String {
"default".to_string()
}
/// Chat completion request following OpenAI's specification
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct ChatCompletionRequest {
#[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))]
pub messages: Vec<Message>,
#[schema(example = "gemma-3-1b-it")]
#[serde(default = "default_model")]
pub model: String,
#[serde(default = "default_false")]
#[schema(example = false)]
pub logprobs: bool,
#[schema(example = 256)]
pub max_tokens: Option<usize>,
#[serde(rename = "n")]
#[serde(default = "default_1usize")]
#[schema(example = 1)]
pub n_choices: usize,
#[schema(example = 0.7)]
pub temperature: Option<f64>,
#[schema(example = 0.9)]
pub top_p: Option<f64>,
#[schema(example = false)]
pub stream: Option<bool>,
}
/// Chat completion choice
#[derive(Debug, Serialize, ToSchema)]
pub struct ChatCompletionChoice {
pub index: usize,
pub message: Message,
pub finish_reason: String,
}
/// Chat completion response
#[derive(Debug, Serialize, ToSchema)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatCompletionChoice>,
pub usage: Usage,
}
/// Token usage information
#[derive(Debug, Serialize, ToSchema)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
// Application state shared between handlers
#[derive(Clone)]
struct AppState {
text_generation: Arc<Mutex<TextGeneration>>,
model_id: String,
}
// Chat completions endpoint handler
async fn chat_completions(
State(state): State<AppState>,
Json(request): Json<ChatCompletionRequest>,
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
let mut prompt = String::new();
// Convert messages to a prompt string
for message in &request.messages {
let role = &message.role;
let content = match &message.content {
Some(content) => match &content.0 {
Either::Left(text) => text.clone(),
Either::Right(_) => "".to_string(), // Handle complex content if needed
},
None => "".to_string(),
};
// Format based on role
match role.as_str() {
"system" => prompt.push_str(&format!("System: {}\n", content)),
"user" => prompt.push_str(&format!("User: {}\n", content)),
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
}
}
// Add the assistant prefix for the response
prompt.push_str("Assistant: ");
// Capture the output
let mut output = Vec::new();
{
let mut text_gen = state.text_generation.lock().await;
// Buffer to capture the output
let mut buffer = Vec::new();
// Run text generation
let max_tokens = request.max_tokens.unwrap_or(1000);
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
if let Err(e) = result {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": {
"message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin inference-engine -- --prompt \"Your prompt here\"",
"type": "unsupported_api"
}
})),
));
}
// Convert buffer to string
if let Ok(text) = String::from_utf8(buffer) {
output.push(text);
}
}
// Create response
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4().to_string().replace("-", "")),
object: "chat.completion".to_string(),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
model: request.model,
choices: vec![ChatCompletionChoice {
index: 0,
message: Message {
role: "assistant".to_string(),
content: Some(MessageContent(Either::Left(output.join("")))),
name: None,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens: prompt.len() / 4, // Rough estimate
completion_tokens: output.join("").len() / 4, // Rough estimate
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
},
};
// Return the response as JSON
Ok(Json(response))
}
use candle_core::{DType, Device, MetalDevice, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{Repo, RepoType, api::sync::Api};
use serde_json::json;
use tokenizers::Tokenizer;
use crate::token_output_stream::TokenOutputStream;
use crate::utilities_lib::device;
// Create the router with the chat completions endpoint
fn create_router(app_state: AppState) -> Router {
// CORS layer to allow requests from any origin
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
Router::new()
// OpenAI compatible endpoints
.route("/v1/chat/completions", post(chat_completions))
// Add more endpoints as needed
.layer(cors)
.with_state(app_state)
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "2b")]
Base2B,
#[value(name = "7b")]
Base7B,
#[value(name = "2b-it")]
Instruct2B,
#[value(name = "7b-it")]
Instruct7B,
#[value(name = "1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
#[value(name = "code-2b")]
CodeBase2B,
#[value(name = "code-7b")]
CodeBase7B,
#[value(name = "code-2b-it")]
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
#[value(name = "2-2b")]
BaseV2_2B,
#[value(name = "2-2b-it")]
InstructV2_2B,
#[value(name = "2-9b")]
BaseV2_9B,
#[value(name = "2-9b-it")]
InstructV2_9B,
#[value(name = "3-1b")]
BaseV3_1B,
#[value(name = "3-1b-it")]
InstructV3_1B,
}
enum Model {
V1(Model1),
V2(Model2),
V3(Model3),
}
impl Model {
fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
Self::V3(m) => m.forward(input_ids, pos),
}
}
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
// Run text generation and print to stdout
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
println!(
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
);
eos_token
}
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
// Run text generation and write to a buffer
fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
// Write prompt tokens to output
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
write!(output, "{}", t)?;
}
}
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
eos_token
}
};
// Determine if we're using a Model3 (gemma-3) variant
let is_model3 = match &self.model {
Model::V3(_) => true,
_ => false,
};
// For Model3, we need to use a different approach
if is_model3 {
// For gemma-3 models, we'll generate one token at a time with the full context
let start_gen = std::time::Instant::now();
// Initial generation with the full prompt
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
let mut logits = self.model.forward(&input, 0)?;
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
for _ in 0..sample_len {
// Apply repeat penalty if needed
let current_logits = if self.repeat_penalty == 1. {
logits.clone()
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&current_logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
write!(output, "{}", t)?;
}
// For the next iteration, just use the new token
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
logits = self.model.forward(&new_input, tokens.len() - 1)?;
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
}
return Ok(());
}
// Standard approach for other models
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
write!(output, "{}", t)?;
}
}
// Write any remaining tokens
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
write!(output, "{}", rest)?;
}
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Run in server mode with OpenAI compatible API
#[arg(long)]
server: bool,
/// Port to use for the server
#[arg(long, default_value_t = 3777)]
port: u16,
/// Prompt for text generation (not used in server mode)
#[arg(long)]
prompt: Option<String>,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "3-1b-it")]
which: Which,
#[arg(long)]
use_flash_attn: bool,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle_core::utils::with_avx(),
candle_core::utils::with_neon(),
candle_core::utils::with_simd128(),
candle_core::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
Which::Base2B => "google/gemma-2b".to_string(),
Which::Base7B => "google/gemma-7b".to_string(),
Which::Instruct2B => "google/gemma-2b-it".to_string(),
Which::Instruct7B => "google/gemma-7b-it".to_string(),
Which::CodeBase2B => "google/codegemma-2b".to_string(),
Which::CodeBase7B => "google/codegemma-7b".to_string(),
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
Which::BaseV2_2B => "google/gemma-2-2b".to_string(),
Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id.clone(),
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.which {
Which::BaseV3_1B | Which::InstructV3_1B => vec![repo.get("model.safetensors")?],
_ => utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let initial_device = utilities_lib::device(args.cpu)?;
// Check if we're using a V3 model (Gemma 3) and if we're on Metal (macOS)
let is_v3_model = matches!(args.which, Which::BaseV3_1B | Which::InstructV3_1B);
let is_metal = !initial_device.is_cpu() && candle_core::utils::metal_is_available() && !args.cpu;
// Use CPU for V3 models on Metal due to missing implementations
let device = if is_v3_model && is_metal {
println!("Note: Using CPU for Gemma 3 model due to missing Metal implementations for required operations (e.g., rotary-emb).");
Device::Cpu
} else {
initial_device
};
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
// Use the selected device and dtype
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = match args.which {
Which::Base2B
| Which::Base7B
| Which::Instruct2B
| Which::Instruct7B
| Which::InstructV1_1_2B
| Which::InstructV1_1_7B
| Which::CodeBase2B
| Which::CodeBase7B
| Which::CodeInstruct2B
| Which::CodeInstruct7B => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(args.use_flash_attn, &config, vb)?;
Model::V1(model)
}
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(args.use_flash_attn, &config, vb)?;
Model::V2(model)
}
Which::BaseV3_1B | Which::InstructV3_1B => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model3::new(args.use_flash_attn, &config, vb)?;
Model::V3(model)
}
};
println!("loaded the model in {:?}", start.elapsed());
let pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
if args.server {
// Start the server
println!("Starting server on port {}", args.port);
// Create app state
let app_state = AppState {
text_generation: Arc::new(Mutex::new(pipeline)),
model_id,
};
// Create router
let app = create_router(app_state);
// Run the server
let addr = SocketAddr::from(([0, 0, 0, 0], args.port));
// Use tokio to run the server
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?
.block_on(async {
axum::serve(tokio::net::TcpListener::bind(&addr).await?, app)
.await
.map_err(|e| anyhow::anyhow!("Server error: {}", e))
})?;
Ok(())
} else {
// Run in CLI mode
if let Some(prompt_text) = &args.prompt {
let prompt = match args.which {
Which::Base2B
| Which::Base7B
| Which::Instruct2B
| Which::Instruct7B
| Which::InstructV1_1_2B
| Which::InstructV1_1_7B
| Which::CodeBase2B
| Which::CodeBase7B
| Which::CodeInstruct2B
| Which::CodeInstruct7B
| Which::BaseV2_2B
| Which::InstructV2_2B
| Which::BaseV2_9B
| Which::InstructV2_9B
| Which::BaseV3_1B => prompt_text.clone(),
Which::InstructV3_1B => {
format!(
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
prompt_text
)
}
};
let mut pipeline = pipeline;
pipeline.run(&prompt, args.sample_len)?;
Ok(())
} else {
anyhow::bail!("Prompt is required in CLI mode. Use --prompt to specify a prompt or --server to run in server mode.")
}
}
}

View File

@@ -0,0 +1,33 @@
use anyhow::Result;
use candle_core::Tensor;
/// ModelInference trait defines the common interface for model inference operations
///
/// This trait serves as an abstraction for different model implementations (Gemma and Llama)
/// to provide a unified interface for the inference engine.
pub trait ModelInference {
/// Perform model inference for the given input tensor starting at the specified position
///
/// # Arguments
///
/// * `input_ids` - The input tensor containing token IDs
/// * `pos` - The position to start generation from
///
/// # Returns
///
/// A tensor containing the logits for the next token prediction
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> Result<Tensor>;
/// Reset the model's internal state, if applicable
///
/// This method can be used to clear any cached state between inference requests
fn reset_state(&mut self) -> Result<()>;
/// Get the model type name
///
/// Returns a string identifier for the model type (e.g., "Gemma", "Llama")
fn model_type(&self) -> &'static str;
}
/// Factory function type for creating model inference implementations
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;

View File

@@ -4,14 +4,16 @@ pub mod model;
pub mod text_generation;
pub mod utilities_lib;
pub mod openai_types;
pub mod cli;
// pub mod cli;
pub mod server;
pub mod inference;
// Re-export key components for easier access
pub use model::{Model, Which};
pub use text_generation::TextGeneration;
pub use token_output_stream::TokenOutputStream;
pub use server::{AppState, create_router};
pub use inference::ModelInference;
use std::env;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

View File

@@ -2,6 +2,7 @@
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum Which {
@@ -37,12 +38,17 @@ pub enum Which {
BaseV3_1B,
#[value(name = "3-1b-it")]
InstructV3_1B,
#[value(name = "llama-3.2-1b-it")]
LlamaInstruct3_2_1B,
#[value(name = "llama-3.2-3b-it")]
LlamaInstruct3_2_3B,
}
pub enum Model {
V1(Model1),
V2(Model2),
V3(Model3),
Llama(LlamaModel),
}
impl Model {
@@ -51,6 +57,7 @@ impl Model {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
Self::V3(m) => m.forward(input_ids, pos),
Self::Llama(m) => m.forward(input_ids, pos),
}
}
}
@@ -74,6 +81,8 @@ impl Which {
Self::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
Self::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
Self::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
Self::LlamaInstruct3_2_1B => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
Self::LlamaInstruct3_2_3B => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
}
}
@@ -87,4 +96,8 @@ impl Which {
pub fn is_v3_model(&self) -> bool {
matches!(self, Self::BaseV3_1B | Self::InstructV3_1B)
}
pub fn is_llama_model(&self) -> bool {
matches!(self, Self::LlamaInstruct3_2_1B | Self::LlamaInstruct3_2_3B)
}
}

View File

@@ -5,304 +5,85 @@ use axum::{
routing::{get, post},
Json, Router,
};
use candle_core::DType;
use candle_nn::VarBuilder;
use futures_util::stream::{self, Stream};
use tokio_stream::wrappers::UnboundedReceiverStream;
use std::convert::Infallible;
use std::{path::PathBuf, sync::Arc};
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid;
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
use crate::text_generation::TextGeneration;
use crate::{utilities_lib, Model as GemmaModel, Which};
use crate::Which;
use either::Either;
use hf_hub::api::sync::{Api, ApiError};
use hf_hub::{Repo, RepoType};
use tokenizers::Tokenizer;
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use serde_json::Value;
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
// -------------------------
// Shared app state
// -------------------------
#[derive(Clone, Debug)]
pub enum ModelType {
Gemma,
Llama,
}
#[derive(Clone)]
pub struct AppState {
pub text_generation: Arc<Mutex<TextGeneration>>,
pub model_type: ModelType,
pub model_id: String,
// Store build args to recreate TextGeneration when needed
pub build_args: PipelineArgs,
pub gemma_config: Option<GemmaInferenceConfig>,
pub llama_config: Option<LlamaInferenceConfig>,
}
impl Default for AppState {
fn default() -> Self {
let args = PipelineArgs::default();
let text_generation = build_pipeline(args.clone());
let gemma_config = GemmaInferenceConfig {
model: gemma_runner::WhichModel::InstructV3_1B,
..Default::default()
};
Self {
text_generation: Arc::new(Mutex::new(text_generation)),
model_id: args.model_id.clone(),
build_args: args,
model_type: ModelType::Gemma,
model_id: "gemma-3-1b-it".to_string(),
gemma_config: Some(gemma_config),
llama_config: None,
}
}
}
// -------------------------
// Pipeline configuration
// Helper functions
// -------------------------
#[derive(Debug, Clone)]
pub struct PipelineArgs {
pub model_id: String,
pub which: Which,
pub revision: Option<String>,
pub tokenizer_path: Option<PathBuf>,
pub config_path: Option<PathBuf>,
pub weight_paths: Vec<PathBuf>,
pub use_flash_attn: bool,
pub force_cpu: bool,
pub seed: u64,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub repeat_penalty: f32,
pub repeat_last_n: usize,
}
impl Default for PipelineArgs {
fn default() -> Self {
Self {
model_id: Which::InstructV3_1B.to_model_id().to_string(),
which: Which::InstructV3_1B,
revision: None,
tokenizer_path: None,
config_path: None,
weight_paths: Vec::new(),
use_flash_attn: false,
force_cpu: false,
seed: 299792458, // Speed of light in vacuum (m/s)
temperature: Some(0.8), // Good balance between creativity and coherence
top_p: Some(0.9), // Keep diverse but reasonable options
repeat_penalty: 1.2, // Stronger penalty for repetition to prevent looping
repeat_last_n: 64, // Consider last 64 tokens for repetition
}
}
}
fn normalize_model_id(model_id: &str) -> String {
if model_id.contains('/') {
model_id.to_string()
} else {
format!("google/{}", model_id)
}
}
fn ensure_repo_exists(api: &Api, model_id: &str, revision: &str) -> anyhow::Result<()> {
let repo = api.repo(Repo::with_revision(
model_id.to_string(),
RepoType::Model,
revision.to_string(),
));
match repo.get("config.json") {
Ok(_) => Ok(()),
Err(e) => match e {
ApiError::RequestError(resp) => {
let error_str = resp.to_string();
if error_str.contains("404") {
anyhow::bail!(
"Hugging Face model repo not found: '{model_id}' at revision '{revision}'."
)
}
Err(anyhow::Error::new(ApiError::RequestError(resp)))
}
other => Err(anyhow::Error::new(other)),
},
}
}
// -------------------------
// Pipeline builder
// -------------------------
pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle_core::utils::with_avx(),
candle_core::utils::with_neon(),
candle_core::utils::with_simd128(),
candle_core::utils::with_f16c()
);
let start = std::time::Instant::now();
let api = Api::new().unwrap();
let revision = args.revision.as_deref().unwrap_or("main");
if args.model_id.trim().is_empty() {
panic!("No model ID specified.");
}
args.model_id = normalize_model_id(&args.model_id);
match ensure_repo_exists(&api, &args.model_id, revision) {
Ok(_) => {}
Err(e) => panic!("{}", e),
};
let repo = api.repo(Repo::with_revision(
args.model_id.clone(),
RepoType::Model,
revision.to_string(),
));
let tokenizer_path = args
.tokenizer_path
.unwrap_or_else(|| repo.get("tokenizer.json").unwrap());
let config_path = args
.config_path
.unwrap_or_else(|| repo.get("config.json").unwrap());
if !matches!(
args.which,
Which::Base2B
| Which::Base7B
| Which::Instruct2B
| Which::Instruct7B
| Which::InstructV1_1_2B
| Which::InstructV1_1_7B
| Which::CodeBase2B
| Which::CodeBase7B
| Which::CodeInstruct2B
| Which::CodeInstruct7B
| Which::BaseV2_2B
| Which::InstructV2_2B
| Which::BaseV2_9B
| Which::InstructV2_9B
| Which::BaseV3_1B
| Which::InstructV3_1B
) {
if args.model_id.contains("gemma-2-2b-it") {
args.which = Which::InstructV2_2B;
} else if args.model_id.contains("gemma-3-1b-it") {
args.which = Which::InstructV3_1B;
} else if let Ok(file) = std::fs::File::open(config_path.clone()) {
if let Ok(cfg_val) = serde_json::from_reader::<_, serde_json::Value>(file) {
if let Some(model_type) = cfg_val.get("model_type").and_then(|v| v.as_str()) {
if model_type.contains("gemma3") {
args.which = Which::InstructV3_1B;
} else if model_type.contains("gemma2") {
args.which = Which::InstructV2_2B;
} else {
args.which = Which::Instruct2B;
}
}
}
}
}
let weight_paths = if !args.weight_paths.is_empty() {
args.weight_paths
} else {
match repo.get("model.safetensors") {
Ok(single) => vec![single],
Err(_) => match utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json") {
Ok(paths) => paths,
Err(e) => {
panic!("Unable to locate model weights: {}", e);
}
},
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
let initial_device = utilities_lib::device(args.force_cpu).unwrap();
let is_v3_model = args.which.is_v3_model();
let is_metal = !initial_device.is_cpu()
&& candle_core::utils::metal_is_available()
&& !args.force_cpu;
let device = if is_v3_model && is_metal {
candle_core::Device::Cpu
} else {
initial_device
};
let dtype = if device.is_cuda() { DType::BF16 } else { DType::F32 };
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_paths, dtype, &device).unwrap() };
let model = match args.which {
Which::Base2B
| Which::Base7B
| Which::Instruct2B
| Which::Instruct7B
| Which::InstructV1_1_2B
| Which::InstructV1_1_7B
| Which::CodeBase2B
| Which::CodeBase7B
| Which::CodeInstruct2B
| Which::CodeInstruct7B => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
GemmaModel::V1(Model1::new(args.use_flash_attn, &config, vb).unwrap())
}
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
GemmaModel::V2(Model2::new(args.use_flash_attn, &config, vb).unwrap())
}
Which::BaseV3_1B | Which::InstructV3_1B => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_path).unwrap()).unwrap();
GemmaModel::V3(Model3::new(args.use_flash_attn, &config, vb).unwrap())
}
};
TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
)
model_id.to_lowercase().replace("_", "-")
}
fn build_gemma_prompt(messages: &[Message]) -> String {
let mut prompt = String::new();
let mut system_prompt: Option<String> = None;
for message in messages {
let content = match &message.content {
Some(content) => match &content.0 {
Either::Left(text) => text.clone(),
Either::Right(_) => "".to_string(),
},
None => "".to_string(),
};
match message.role.as_str() {
"system" => system_prompt = Some(content),
"user" => {
prompt.push_str("<start_of_turn>user\n");
if let Some(sys_prompt) = system_prompt.take() {
prompt.push_str(&sys_prompt);
prompt.push_str("\n\n");
"system" => {
if let Some(MessageContent(Either::Left(content))) = &message.content {
prompt.push_str(&format!("<start_of_turn>system\n{}<end_of_turn>\n", content));
}
}
"user" => {
if let Some(MessageContent(Either::Left(content))) = &message.content {
prompt.push_str(&format!("<start_of_turn>user\n{}<end_of_turn>\n", content));
}
prompt.push_str(&content);
prompt.push_str("<end_of_turn>\n");
}
"assistant" => {
prompt.push_str("<start_of_turn>model\n");
prompt.push_str(&content);
prompt.push_str("<end_of_turn>\n");
if let Some(MessageContent(Either::Left(content))) = &message.content {
prompt.push_str(&format!("<start_of_turn>model\n{}<end_of_turn>\n", content));
}
}
_ => {}
}
}
prompt.push_str("<start_of_turn>model\n");
prompt
}
@@ -325,14 +106,13 @@ pub async fn chat_completions_non_streaming_proxy(
state: AppState,
request: ChatCompletionRequest,
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
let prompt = build_gemma_prompt(&request.messages);
// Enforce model selection behavior: reject if a different model is requested
let configured_model = state.build_args.model_id.clone();
let configured_model = state.model_id.clone();
let requested_model = request.model.clone();
if requested_model.to_lowercase() != "default" {
let normalized_requested = normalize_model_id(&requested_model);
if normalized_requested != configured_model {
let normalized_configured = normalize_model_id(&configured_model);
if normalized_requested != normalized_configured {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
@@ -349,35 +129,81 @@ pub async fn chat_completions_non_streaming_proxy(
}
let model_id = state.model_id.clone();
let max_tokens = request.max_tokens.unwrap_or(1000);
let mut buffer = Vec::new();
{
let mut text_gen = state.text_generation.lock().await;
// Reset per-request state without rebuilding the whole pipeline
text_gen.reset_state();
let max_tokens = request.max_tokens.unwrap_or(1000);
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": { "message": format!("Error generating text: {}", e) }
})),
));
}
}
let completion = match String::from_utf8(buffer) {
Ok(s) => s,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("UTF-8 conversion error: {}", e) }
})),
));
// Build prompt based on model type
let prompt = match state.model_type {
ModelType::Gemma => build_gemma_prompt(&request.messages),
ModelType::Llama => {
// For Llama, just use the last user message for now
request.messages.last()
.and_then(|m| m.content.as_ref())
.and_then(|c| match c {
MessageContent(Either::Left(text)) => Some(text.clone()),
_ => None,
})
.unwrap_or_default()
}
};
// Get streaming receiver based on model type
let rx = match state.model_type {
ModelType::Gemma => {
if let Some(mut config) = state.gemma_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_gemma_api(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) }
}))
))?
} else {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Gemma configuration not available" }
}))
));
}
}
ModelType::Llama => {
if let Some(mut config) = state.llama_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_llama_inference(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) }
}))
))?
} else {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Llama configuration not available" }
}))
));
}
}
};
// Collect all tokens from the stream
let mut completion = String::new();
while let Ok(token_result) = rx.recv() {
match token_result {
Ok(token) => completion.push_str(&token),
Err(e) => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": { "message": format!("Error generating text: {}", e) }
})),
));
}
}
}
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', "")),
object: "chat.completion".to_string(),
@@ -420,11 +246,12 @@ async fn handle_streaming_request(
request: ChatCompletionRequest,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
// Validate requested model vs configured model
let configured_model = state.build_args.model_id.clone();
let configured_model = state.model_id.clone();
let requested_model = request.model.clone();
if requested_model.to_lowercase() != "default" {
let normalized_requested = normalize_model_id(&requested_model);
if normalized_requested != configured_model {
let normalized_configured = normalize_model_id(&configured_model);
if normalized_requested != normalized_configured {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
@@ -447,9 +274,22 @@ async fn handle_streaming_request(
.unwrap_or_default()
.as_secs();
let model_id = state.model_id.clone();
let max_tokens = request.max_tokens.unwrap_or(1000);
// Build prompt
let prompt = build_gemma_prompt(&request.messages);
// Build prompt based on model type
let prompt = match state.model_type {
ModelType::Gemma => build_gemma_prompt(&request.messages),
ModelType::Llama => {
// For Llama, just use the last user message for now
request.messages.last()
.and_then(|m| m.content.as_ref())
.and_then(|c| match c {
MessageContent(Either::Left(text)) => Some(text.clone()),
_ => None,
})
.unwrap_or_default()
}
};
tracing::debug!("Formatted prompt: {}", prompt);
// Channel for streaming SSE events
@@ -471,80 +311,121 @@ async fn handle_streaming_request(
let _ = tx.send(Ok(Event::default().data(json)));
}
// Spawn generation task that streams tokens as they are generated
let state_clone = state.clone();
let response_id_clone = response_id.clone();
tokio::spawn(async move {
let max_tokens = request.max_tokens.unwrap_or(1000);
let mut text_gen = state_clone.text_generation.lock().await;
text_gen.reset_state();
// Get streaming receiver based on model type
let model_rx = match state.model_type {
ModelType::Gemma => {
if let Some(mut config) = state.gemma_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
match run_gemma_api(config) {
Ok(rx) => rx,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) }
}))
));
}
}
} else {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Gemma configuration not available" }
}))
));
}
}
ModelType::Llama => {
if let Some(mut config) = state.llama_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
match run_llama_inference(config) {
Ok(rx) => rx,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) }
}))
));
}
}
} else {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Llama configuration not available" }
}))
));
}
}
};
// Stream tokens via callback with repetition detection
// Spawn task to receive tokens from model and forward as SSE events
let response_id_clone = response_id.clone();
let model_id_clone = model_id.clone();
tokio::spawn(async move {
// Stream tokens with repetition detection
let mut recent_tokens = Vec::new();
let mut repetition_count = 0;
const MAX_REPETITION_COUNT: usize = 5; // Stop after 5 consecutive repetitions
const REPETITION_WINDOW: usize = 8; // Look at last 8 tokens for patterns
let result = text_gen.run_with_streaming(&prompt, max_tokens, |token| {
// Debug log to verify token content
tracing::debug!("Streaming token: '{}'", token);
// Skip sending empty tokens
if token.is_empty() {
tracing::debug!("Skipping empty token");
return Ok(());
}
// Add token to recent history for repetition detection
recent_tokens.push(token.to_string());
if recent_tokens.len() > REPETITION_WINDOW {
recent_tokens.remove(0);
}
// Check for repetitive patterns
if recent_tokens.len() >= 4 {
let last_token = &recent_tokens[recent_tokens.len() - 1];
let second_last = &recent_tokens[recent_tokens.len() - 2];
// Check if we're repeating the same token or pattern
if last_token == second_last ||
(last_token.trim() == "plus" && second_last.trim() == "plus") ||
(recent_tokens.len() >= 6 &&
recent_tokens[recent_tokens.len()-3..].iter().all(|t| t.trim() == "plus" || t.trim().is_empty())) {
repetition_count += 1;
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
if repetition_count >= MAX_REPETITION_COUNT {
tracing::info!("Stopping generation due to excessive repetition");
return Err(anyhow::Error::msg("Repetition detected - stopping generation"));
const MAX_REPETITION_COUNT: usize = 5;
const REPETITION_WINDOW: usize = 8;
while let Ok(token_result) = model_rx.recv() {
match token_result {
Ok(token) => {
// Skip sending empty tokens
if token.is_empty() {
continue;
}
} else {
repetition_count = 0; // Reset counter if pattern breaks
// Add token to recent history for repetition detection
recent_tokens.push(token.clone());
if recent_tokens.len() > REPETITION_WINDOW {
recent_tokens.remove(0);
}
// Check for repetitive patterns
if recent_tokens.len() >= 4 {
let last_token = &recent_tokens[recent_tokens.len() - 1];
let second_last = &recent_tokens[recent_tokens.len() - 2];
if last_token == second_last {
repetition_count += 1;
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
if repetition_count >= MAX_REPETITION_COUNT {
tracing::info!("Stopping generation due to excessive repetition");
break;
}
} else {
repetition_count = 0;
}
}
let chunk = ChatCompletionChunk {
id: response_id_clone.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model_id_clone.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta { role: None, content: Some(token) },
finish_reason: None,
}],
};
if let Ok(json) = serde_json::to_string(&chunk) {
let _ = tx.send(Ok(Event::default().data(json)));
}
}
Err(e) => {
tracing::info!("Text generation stopped: {}", e);
break;
}
}
let chunk = ChatCompletionChunk {
id: response_id_clone.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model_id.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta { role: None, content: Some(token.to_string()) },
finish_reason: None,
}],
};
if let Ok(json) = serde_json::to_string(&chunk) {
tracing::debug!("Sending chunk with content: '{}'", token);
let _ = tx.send(Ok(Event::default().data(json)));
}
Ok(())
}).await;
// Log result of generation
match result {
Ok(_) => tracing::debug!("Text generation completed successfully"),
Err(e) => tracing::info!("Text generation stopped: {}", e),
}
// Send final stop chunk and DONE marker
@@ -552,7 +433,7 @@ async fn handle_streaming_request(
id: response_id_clone.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model_id.clone(),
model: model_id_clone.clone(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: Delta { role: None, content: None },
@@ -594,6 +475,7 @@ pub fn create_router(app_state: AppState) -> Router {
pub async fn list_models() -> Json<ModelListResponse> {
// Get all available model variants from the Which enum
let models = vec![
// Gemma models
Model {
id: "gemma-2b".to_string(),
object: "model".to_string(),
@@ -690,6 +572,73 @@ pub async fn list_models() -> Json<ModelListResponse> {
created: 1686935002,
owned_by: "google".to_string(),
},
// Llama models
Model {
id: "llama-3.2-1b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "meta".to_string(),
},
Model {
id: "llama-3.2-1b-instruct".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "meta".to_string(),
},
Model {
id: "llama-3.2-3b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "meta".to_string(),
},
Model {
id: "llama-3.2-3b-instruct".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "meta".to_string(),
},
Model {
id: "smollm2-135m".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "smollm2-135m-instruct".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "smollm2-360m".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "smollm2-360m-instruct".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "smollm2-1.7b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "smollm2-1.7b-instruct".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "tinyllama-1.1b-chat".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "tinyllama".to_string(),
},
];
Json(ModelListResponse {

View File

@@ -0,0 +1,24 @@
[package]
name = "llama-runner"
version = "0.1.0"
edition = "2021"
[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git" }
candle-nn = { git = "https://github.com/huggingface/candle.git" }
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
hf-hub = "0.3"
tokenizers = "0.20"
anyhow = "1.0"
clap = { version = "4.0", features = ["derive", "string"] }
serde_json = "1.0"
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
[features]
default = []
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]

View File

@@ -0,0 +1,188 @@
# Llama Runner
A fast Rust implementation for running Llama and other language models using the Candle deep learning framework. Built on the official Candle examples with optimizations for speed and usability.
## Features
- 🚀 **High Performance**: Metal GPU acceleration on macOS, CUDA support on Linux/Windows
- 🤖 **Multiple Models**: Supports Llama 3.2, SmolLM2, TinyLlama, and more
-**Fast Inference**: Optimized with F16 precision and KV caching
- 🎯 **Advanced Sampling**: Top-k, top-p, temperature, and repeat penalty controls
- 📊 **Performance Metrics**: Real-time tokens/second reporting
- 🔧 **Easy CLI**: Simple command-line interface with sensible defaults
## Supported Models
| Model | Size | Command | Description |
|-------|------|---------|-------------|
| SmolLM2-135M | 135M | `smollm2-135m` | Tiny, fast model for testing |
| SmolLM2-360M | 360M | `smollm2-360m` | Small, efficient model |
| SmolLM2-1.7B | 1.7B | `smollm2-1.7b` | Balanced performance/speed |
| Llama-3.2-1B | 1B | `llama-3.2-1b` | Meta's compact model |
| Llama-3.2-3B | 3B | `llama-3.2-3b` | Larger Llama model |
| TinyLlama-1.1B | 1.1B | `tinyllama-1.1b-chat` | Chat-optimized small model |
Add `-instruct` suffix for instruction-tuned variants (e.g., `smollm2-135m-instruct`).
## Installation
```bash
# Clone the repository
git clone <repository-url>
cd llama-runner
# Build with GPU acceleration (recommended)
cargo build --release --features metal # macOS
cargo build --release --features cuda # Linux/Windows with NVIDIA GPU
# CPU-only build
cargo build --release
```
## Quick Start
```bash
# Fast inference with GPU acceleration
cargo run --features metal -- --prompt "What is quantum computing?"
# Specify a model and parameters
cargo run --features metal -- \
--prompt "Write a short story about space exploration" \
--model smollm2-360m \
--max-tokens 100 \
--temperature 0.8
# Use CPU (slower but works everywhere)
cargo run -- --prompt "Hello, world!" --model smollm2-135m --cpu
```
## Usage Examples
### Basic Text Generation
```bash
# Simple completion
cargo run --features metal -- --prompt "The capital of France is"
# Creative writing with higher temperature
cargo run --features metal -- \
--prompt "Once upon a time" \
--temperature 1.0 \
--max-tokens 200
```
### Advanced Sampling
```bash
# Top-k and top-p sampling
cargo run --features metal -- \
--prompt "Explain artificial intelligence" \
--top-k 40 \
--top-p 0.9 \
--temperature 0.7
# Reduce repetition
cargo run --features metal -- \
--prompt "List the benefits of renewable energy" \
--repeat-penalty 1.2 \
--repeat-last-n 64
```
### Different Models
```bash
# Ultra-fast with tiny model
cargo run --features metal -- \
--prompt "Quick test" \
--model smollm2-135m
# Better quality with larger model
cargo run --features metal -- \
--prompt "Explain quantum physics" \
--model llama-3.2-1b \
--max-tokens 150
```
## Command-Line Options
| Option | Short | Default | Description |
|--------|-------|---------|-------------|
| `--prompt` | `-p` | "The capital of France is" | Input prompt |
| `--model` | `-m` | `smollm2-135m` | Model to use |
| `--max-tokens` | `-n` | 100 | Maximum tokens to generate |
| `--temperature` | `-t` | 0.8 | Sampling temperature (0.0 = deterministic) |
| `--top-k` | | None | Top-k sampling |
| `--top-p` | | None | Top-p (nucleus) sampling |
| `--seed` | | 299792458 | Random seed for reproducibility |
| `--repeat-penalty` | | 1.1 | Repetition penalty (1.0 = no penalty) |
| `--repeat-last-n` | | 128 | Context window for repeat penalty |
| `--cpu` | | false | Force CPU usage |
| `--dtype` | | f16 | Data type: f16, bf16, f32 |
| `--no-kv-cache` | | false | Disable key-value caching |
## Performance
Typical performance on Apple M2 with Metal acceleration:
| Model | Size | Speed | Memory |
|-------|------|-------|--------|
| SmolLM2-135M | 135M | ~100 tok/s | ~500MB |
| SmolLM2-360M | 360M | ~80 tok/s | ~1GB |
| SmolLM2-1.7B | 1.7B | ~50 tok/s | ~3GB |
| Llama-3.2-1B | 1B | ~40 tok/s | ~2GB |
## Requirements
- **Rust**: 1.70+ (latest stable recommended)
- **Memory**: 2-8GB RAM depending on model size
- **Storage**: 1-10GB for model weights
- **Network**: Internet connection for first-time model download
- **GPU** (optional): Metal on macOS, CUDA on Linux/Windows
## GPU Support
### macOS (Metal)
```bash
cargo run --features metal -- [options]
```
### Linux/Windows (CUDA)
```bash
cargo run --features cuda -- [options]
```
### CPU Only
```bash
cargo run -- --cpu [options]
```
## Model Downloads
Models are automatically downloaded from HuggingFace Hub on first use and cached locally. Download times:
- SmolLM2-135M: ~1 minute
- SmolLM2-360M: ~2 minutes
- Llama-3.2-1B: ~5 minutes
- Larger models: 10+ minutes
## Troubleshooting
### Slow Performance
- Use `--features metal` on macOS or `--features cuda` on Linux/Windows
- Try smaller models like `smollm2-135m` for faster inference
- Ensure sufficient RAM for your chosen model
### Out of Memory
- Use `--cpu` to use system RAM instead of GPU memory
- Try smaller models or reduce `--max-tokens`
- Use `--dtype f32` if f16 causes issues
### Model Download Issues
- Check internet connection
- Some models may require HuggingFace Hub authentication
- Verify sufficient disk space in `~/.cache/huggingface/`
## Contributing
Contributions welcome! This project is based on the [Candle](https://github.com/huggingface/candle) framework by HuggingFace.
## License
MIT License - see LICENSE file for details.

View File

@@ -0,0 +1,8 @@
pub mod llama_api;
use clap::ValueEnum;
pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
// Re-export constants and types that might be needed
pub const EOS_TOKEN: &str = "</s>";

View File

@@ -0,0 +1,337 @@
use anyhow::{bail, Error as E};
use candle_core::{utils, DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama::{Llama, LlamaConfig};
use candle_transformers::models::llama as model;
use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType};
use std::sync::mpsc::{self, Receiver};
use clap::ValueEnum;
use crate::{EOS_TOKEN};
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
pub enum WhichModel {
#[value(name = "llama-3.2-1b")]
#[default]
Llama32_1B,
#[value(name = "llama-3.2-1b-instruct")]
Llama32_1BInstruct,
#[value(name = "llama-3.2-3b")]
Llama32_3B,
#[value(name = "llama-3.2-3b-instruct")]
Llama32_3BInstruct,
#[value(name = "smollm2-135m")]
SmolLM2_135M,
#[value(name = "smollm2-135m-instruct")]
SmolLM2_135MInstruct,
#[value(name = "smollm2-360m")]
SmolLM2_360M,
#[value(name = "smollm2-360m-instruct")]
SmolLM2_360MInstruct,
#[value(name = "smollm2-1.7b")]
SmolLM2_1_7B,
#[value(name = "smollm2-1.7b-instruct")]
SmolLM2_1_7BInstruct,
#[value(name = "tinyllama-1.1b-chat")]
TinyLlama1_1BChat,
}
#[derive(Debug, Clone)]
pub struct LlamaInferenceConfig {
pub prompt: String,
pub model: WhichModel,
pub cpu: bool,
pub temperature: f64,
pub top_p: Option<f64>,
pub top_k: Option<usize>,
pub seed: u64,
pub max_tokens: usize,
pub no_kv_cache: bool,
pub dtype: Option<String>,
pub model_id: Option<String>,
pub revision: Option<String>,
pub use_flash_attn: bool,
pub repeat_penalty: f32,
pub repeat_last_n: usize,
}
impl Default for LlamaInferenceConfig {
fn default() -> Self {
Self {
// Leave prompt empty by default; let call sites set it.
prompt: String::new(),
// Keep your existing model choice; swap at call-site if needed.
model: WhichModel::Llama32_1BInstruct,
// Prefer GPU if available.
cpu: false,
// Sampling: balanced + stable
temperature: 0.7,
top_p: Some(0.95),
top_k: Some(50),
// Reproducible by default; override for variability.
seed: 42,
// Dont run unbounded generations.
max_tokens: 512,
// Performance flags
no_kv_cache: false, // keep cache ON for speed
use_flash_attn: true, // great speed boost if supported
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
dtype: Some("bf16".to_string()),
// Optional model source pinning (None = app defaults)
model_id: None,
revision: None,
// Anti-repeat heuristics
repeat_penalty: 1.15,
repeat_last_n: 128,
}
}
}
fn device(cpu: bool) -> anyhow::Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if utils::cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if utils::metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
Ok(Device::Cpu)
}
}
fn hub_load_safetensors(
api: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> anyhow::Result<Vec<std::path::PathBuf>> {
let json_file = api.get(json_file)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value = serde_json::from_reader(&json_file)?;
let weight_map = match json.get("weight_map") {
None => bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| api.get(v))
.collect::<anyhow::Result<Vec<_>, _>>()?;
Ok(safetensors_files)
}
pub fn run_llama_inference(
cfg: LlamaInferenceConfig,
) -> anyhow::Result<Receiver<anyhow::Result<String>>, anyhow::Error> {
// ---- Device & dtype -----------------------------------------------------
let device = device(cfg.cpu)?;
println!("Device: {:?}", device);
let dtype = match cfg.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
println!("Using dtype: {:?}", dtype);
// ---- Load model & tokenizer --------------------------------------------
let (llama, tokenizer, mut cache) = {
let api = Api::new()?;
let model_id = cfg.model_id.clone().unwrap_or_else(|| {
match cfg.model {
WhichModel::Llama32_1B => "meta-llama/Llama-3.2-1B",
WhichModel::Llama32_1BInstruct => "meta-llama/Llama-3.2-1B-Instruct",
WhichModel::Llama32_3B => "meta-llama/Llama-3.2-3B",
WhichModel::Llama32_3BInstruct => "meta-llama/Llama-3.2-3B-Instruct",
WhichModel::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
WhichModel::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
WhichModel::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
WhichModel::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
WhichModel::SmolLM2_1_7B => "HuggingFaceTB/SmolLM2-1.7B",
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
.to_string()
});
println!("Loading model: {}", model_id);
let revision = cfg.revision.clone().unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = api.get("tokenizer.json")?;
let config_filename = api.get("config.json")?;
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(cfg.use_flash_attn);
let filenames = match cfg.model {
WhichModel::Llama32_3B | WhichModel::Llama32_3BInstruct => {
hub_load_safetensors(&api, "model.safetensors.index.json")?
}
_ => vec![api.get("model.safetensors")?],
};
let cache = model::Cache::new(!cfg.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let llama = Llama::load(vb, &config)?;
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
(llama, tokenizer, cache)
};
// ---- Prepare prompt & sampler ------------------------------------------
let eos_token_id = tokenizer
.token_to_id(EOS_TOKEN)
.map(model::LlamaEosToks::Single);
let mut tokens = tokenizer
.encode(cfg.prompt.as_str(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
println!("Starting inference...");
let mut logits_processor = {
let temperature = cfg.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (cfg.top_k, cfg.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(cfg.seed, sampling)
};
// Channel for streaming decoded fragments to the caller.
let (tx, rx) = mpsc::channel::<anyhow::Result<String>>();
// ---- Spawn generation thread -------------------------------------------
std::thread::spawn(move || {
let start_gen = std::time::Instant::now();
let mut index_pos = 0usize;
let mut token_generated = 0usize;
for index in 0..cfg.max_tokens {
// Use KV-cache for single-token step after the first pass.
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = match Tensor::new(ctxt, &device).and_then(|t| t.unsqueeze(0)) {
Ok(t) => t,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
let logits = match llama.forward(&input, context_index, &mut cache) {
Ok(l) => l,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
let logits = match logits.squeeze(0) {
Ok(l) => l,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
let logits = if cfg.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(cfg.repeat_last_n);
match candle_transformers::utils::apply_repeat_penalty(
&logits,
cfg.repeat_penalty,
&tokens[start_at..],
) {
Ok(l) => l,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
}
};
index_pos += ctxt.len();
let next_token = match logits_processor.sample(&logits) {
Ok(t) => t,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
token_generated += 1;
tokens.push(next_token);
// Early stop on EOS.
let stop = match eos_token_id {
Some(model::LlamaEosToks::Single(eos_tok_id)) => next_token == eos_tok_id,
Some(model::LlamaEosToks::Multiple(ref eos_ids)) => eos_ids.contains(&next_token),
None => false,
};
if stop {
break;
}
// Decode this token's text and stream it out.
match tokenizer.decode(&[next_token], false) {
Ok(text) => {
if !text.is_empty() {
// Best-effort send; if receiver is gone, just stop.
if tx.send(Ok(text)).is_err() {
break;
}
}
}
Err(e) => {
let _ = tx.send(Err(anyhow::anyhow!("{}", e)));
break;
}
}
}
// Optional: final stats as a debug line (not sent through the stream).
let dt = start_gen.elapsed();
eprintln!(
"[llama-runner] {} tokens generated ({:.2} tokens/s)",
token_generated,
token_generated as f64 / dt.as_secs_f64(),
);
// Dropping tx closes the stream.
});
Ok(rx)
}

View File

@@ -0,0 +1,109 @@
use crate::llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
use clap::Parser;
use std::io::Write;
#[derive(Parser, Debug, Default)]
#[command(author, version, about = "Fast Llama inference with Candle", long_about = None)]
struct Args {
/// The prompt to generate text from
#[arg(short, long, default_value = "The capital of France is")]
prompt: String,
/// The model to use
#[arg(short, long, default_value = "llama-3.2-1b-instruct")]
model: WhichModel,
/// Run on CPU rather than GPU
#[arg(long)]
cpu: bool,
/// The temperature used to generate samples
#[arg(short, long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens)
#[arg(short = 'n', long, default_value_t = 100)]
max_tokens: usize,
/// Disable the key-value cache
#[arg(long)]
no_kv_cache: bool,
/// Use different dtype than f16
#[arg(long)]
dtype: Option<String>,
/// Custom model ID from HuggingFace Hub
#[arg(long)]
model_id: Option<String>,
/// Model revision
#[arg(long)]
revision: Option<String>,
/// Use flash attention
#[arg(long)]
use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty
#[arg(long, default_value_t = 128)]
repeat_last_n: usize,
}
impl Into<LlamaInferenceConfig> for Args {
fn into(self) -> LlamaInferenceConfig {
LlamaInferenceConfig {
prompt: self.prompt,
model: self.model,
cpu: self.cpu,
temperature: self.temperature,
top_p: self.top_p,
top_k: self.top_k,
seed: self.seed,
max_tokens: self.max_tokens,
no_kv_cache: self.no_kv_cache,
dtype: self.dtype,
model_id: self.model_id,
revision: self.revision,
use_flash_attn: self.use_flash_attn,
repeat_penalty: self.repeat_penalty,
repeat_last_n: self.repeat_last_n,
}
}
}
pub fn run_cli() -> anyhow::Result<()> {
let args = Args::parse();
let cfg = args.into();
let rx = run_llama_inference(cfg)?;
for msg in rx {
match msg {
Ok(tok) => {
print!("{tok}");
let _ = std::io::stdout().flush(); // <- force it out now
}
Err(e) => {
eprintln!("generation error: {e}");
break;
}
}
}
Ok(())
}

View File

@@ -0,0 +1,20 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod llama_cli;
mod llama_api;
use anyhow::Result;
use clap::{Parser, ValueEnum};
use std::io::Write;
use crate::llama_cli::run_cli;
const EOS_TOKEN: &str = "</s>";
fn main() -> Result<()> {
run_cli()
}

View File

@@ -67,18 +67,7 @@ async fn main() {
let embeddings_router = embeddings_engine::create_embeddings_router();
// Create AppState with correct model configuration
use inference_engine::Which;
use inference_engine::server::{PipelineArgs, build_pipeline};
let mut pipeline_args = PipelineArgs::default();
pipeline_args.model_id = "google/gemma-3-1b-it".to_string();
pipeline_args.which = Which::InstructV3_1B;
let text_generation = build_pipeline(pipeline_args.clone());
let app_state = AppState {
text_generation: std::sync::Arc::new(tokio::sync::Mutex::new(text_generation)),
model_id: "google/gemma-3-1b-it".to_string(),
build_args: pipeline_args,
};
let app_state = AppState::default();
// Get the inference router directly from the inference engine
let inference_router = inference_engine::create_router(app_state);

View File

@@ -22,7 +22,7 @@ The Predict-Otron-9000 is a comprehensive multi-service AI platform built around
graph TB
subgraph "Core Components"
A[Main Server<br/>predict-otron-9000]
B[Inference Engine<br/>Gemma via Candle]
B[Inference Engine<br/>Gemma/Llama via Candle]
C[Embeddings Engine<br/>FastEmbed]
D[Web Frontend<br/>Leptos WASM]
end
@@ -52,7 +52,7 @@ graph TB
## Workspace Structure
The project uses a 4-crate Rust workspace with TypeScript tooling, designed for maximum flexibility in deployment configurations.
The project uses a 7-crate Rust workspace with TypeScript tooling, designed for maximum flexibility in deployment configurations.
```mermaid
graph TD
@@ -62,24 +62,33 @@ graph TD
end
subgraph "AI Services"
B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Candle ML]
B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Multi-model orchestrator]
J[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
K[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
C[embeddings-engine<br/>Edition: 2024<br/>Port: 8080<br/>FastEmbed]
end
subgraph "Frontend"
D[leptos-app<br/>Edition: 2021<br/>Port: 3000/8788<br/>WASM/SSR]
end
subgraph "Tooling"
L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment]
end
end
subgraph "External Tooling"
E[cli.ts<br/>TypeScript/Bun<br/>OpenAI SDK]
E[scripts/cli.ts<br/>TypeScript/Bun<br/>OpenAI SDK]
end
subgraph "Dependencies"
A --> B
A --> C
A --> D
B -.-> F[Candle 0.9.1]
B --> J
B --> K
J -.-> F[Candle 0.9.1]
K -.-> F
C -.-> G[FastEmbed 4.x]
D -.-> H[Leptos 0.8.0]
E -.-> I[OpenAI SDK 5.16+]
@@ -87,9 +96,12 @@ graph TD
style A fill:#e1f5fe
style B fill:#f3e5f5
style J fill:#f3e5f5
style K fill:#f3e5f5
style C fill:#e8f5e8
style D fill:#fff3e0
style E fill:#fce4ec
style L fill:#fff9c4
```
## Deployment Configurations

30
scripts/run_llama.sh Normal file
View File

@@ -0,0 +1,30 @@
#!/usr/bin/env bash
set -euo pipefail
PROMPT=${1:-"Say hello in one short sentence."}
MODEL=${2:-"meta-llama/Llama-3.2-1B-Instruct"}
MAX_NEW=${3:-64}
FORCE_CPU=${FORCE_CPU:-0}
# Optional: keep HF cache local to repo if not already set
export HF_HOME=${HF_HOME:-"$PWD/.hf-cache"}
BIN="$(dirname "$0")/../target/release/llama_infer"
if [[ ! -x "$BIN" ]]; then
echo "Building llama-runner (release)..."
cargo build -p llama-runner --release
fi
echo "Running llama inference..." >&2
ARGS=(
--model-id "$MODEL"
--prompt "$PROMPT"
--max-new-tokens "$MAX_NEW"
)
if [[ "$FORCE_CPU" == "1" || "$FORCE_CPU" == "true" ]]; then
ARGS+=( --force-cpu )
fi
"$BIN" "${ARGS[@]}"