mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
supports small llama and gemma models
Refactor inference dedicated crates for llama and gemma inferencing, not integrated
This commit is contained in:
369
Cargo.lock
generated
369
Cargo.lock
generated
@@ -686,6 +686,15 @@ dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "block2"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "340d2f0bdb2a43c1d3cd40513185b2bd7def0aa1052f956455114bc98f82dcf2"
|
||||
dependencies = [
|
||||
"objc2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "brotli"
|
||||
version = "3.5.0"
|
||||
@@ -786,8 +795,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9f51e2ecf6efe9737af8f993433c839f956d2b6ed4fd2dd4a7c6d8b0fa667ff"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"candle-kernels",
|
||||
"candle-metal-kernels",
|
||||
"candle-kernels 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"candle-metal-kernels 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cudarc",
|
||||
"gemm 0.17.1",
|
||||
"half",
|
||||
@@ -807,6 +816,35 @@ dependencies = [
|
||||
"zip",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-core"
|
||||
version = "0.9.1"
|
||||
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"candle-kernels 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||
"candle-metal-kernels 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||
"cudarc",
|
||||
"float8",
|
||||
"gemm 0.17.1",
|
||||
"half",
|
||||
"memmap2",
|
||||
"num-traits",
|
||||
"num_cpus",
|
||||
"objc2-foundation",
|
||||
"objc2-metal",
|
||||
"rand 0.9.2",
|
||||
"rand_distr 0.5.1",
|
||||
"rayon",
|
||||
"safetensors",
|
||||
"thiserror 1.0.69",
|
||||
"ug",
|
||||
"ug-cuda",
|
||||
"ug-metal",
|
||||
"yoke 0.7.5",
|
||||
"zip",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-datasets"
|
||||
version = "0.9.1"
|
||||
@@ -814,15 +852,35 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0a7c351dd50cda83f00f17c4412e35c69d840e453edf06064974de1cc59343d"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"candle-core",
|
||||
"candle-nn",
|
||||
"hf-hub",
|
||||
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"hf-hub 0.4.3",
|
||||
"image",
|
||||
"memmap2",
|
||||
"parquet",
|
||||
"rand 0.9.2",
|
||||
"thiserror 1.0.69",
|
||||
"tokenizers",
|
||||
"tokenizers 0.21.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-examples"
|
||||
version = "0.9.1"
|
||||
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||
"candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||
"candle-transformers 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||
"csv",
|
||||
"hf-hub 0.4.3",
|
||||
"image",
|
||||
"num-traits",
|
||||
"rayon",
|
||||
"safetensors",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokenizers 0.21.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -833,7 +891,7 @@ checksum = "fb38a5bfae09c4ae73fd00039e5eaf97a7d6d9400cc35ee8e603fc4a5f9cb0a3"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bindgen_cuda",
|
||||
"candle-core",
|
||||
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"half",
|
||||
]
|
||||
|
||||
@@ -846,6 +904,14 @@ dependencies = [
|
||||
"bindgen_cuda",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-kernels"
|
||||
version = "0.9.1"
|
||||
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
|
||||
dependencies = [
|
||||
"bindgen_cuda",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.9.1"
|
||||
@@ -859,13 +925,27 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.9.1"
|
||||
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
|
||||
dependencies = [
|
||||
"half",
|
||||
"objc2",
|
||||
"objc2-foundation",
|
||||
"objc2-metal",
|
||||
"once_cell",
|
||||
"thiserror 1.0.69",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-nn"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1980d53280c8f9e2c6cbe1785855d7ff8010208b46e21252b978badf13ad69d"
|
||||
dependencies = [
|
||||
"candle-core",
|
||||
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"half",
|
||||
"num-traits",
|
||||
"rayon",
|
||||
@@ -874,14 +954,30 @@ dependencies = [
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-nn"
|
||||
version = "0.9.1"
|
||||
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
|
||||
dependencies = [
|
||||
"candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||
"candle-metal-kernels 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||
"half",
|
||||
"num-traits",
|
||||
"objc2-metal",
|
||||
"rayon",
|
||||
"safetensors",
|
||||
"serde",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-onnx"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a8fa227a8176fd9b8fb58d63c908c08ad3af1503ee6fcd058be072a598044d2"
|
||||
dependencies = [
|
||||
"candle-core",
|
||||
"candle-nn",
|
||||
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"prost",
|
||||
"prost-build",
|
||||
]
|
||||
@@ -893,8 +989,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "186cb80045dbe47e0b387ea6d3e906f02fb3056297080d9922984c90e90a72b0"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"candle-core",
|
||||
"candle-nn",
|
||||
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"fancy-regex",
|
||||
"num-traits",
|
||||
"rand 0.9.2",
|
||||
"rayon",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_plain",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-transformers"
|
||||
version = "0.9.1"
|
||||
source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||
"candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||
"fancy-regex",
|
||||
"num-traits",
|
||||
"rand 0.9.2",
|
||||
@@ -1523,6 +1637,15 @@ dependencies = [
|
||||
"dirs-sys 0.4.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs"
|
||||
version = "5.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225"
|
||||
dependencies = [
|
||||
"dirs-sys 0.4.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs"
|
||||
version = "6.0.0"
|
||||
@@ -1556,6 +1679,16 @@ dependencies = [
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dispatch2"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89a09f22a6c6069a18470eb92d2298acf25463f14256d24778e1230d789a2aec"
|
||||
dependencies = [
|
||||
"bitflags 2.9.2",
|
||||
"objc2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "displaydoc"
|
||||
version = "0.2.5"
|
||||
@@ -1715,6 +1848,9 @@ name = "esaxx-rs"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "event-listener"
|
||||
@@ -1793,14 +1929,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "04c269a76bfc6cea69553b7d040acb16c793119cebd97c756d21e08d0f075ff8"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"hf-hub",
|
||||
"hf-hub 0.4.3",
|
||||
"image",
|
||||
"ndarray",
|
||||
"ort",
|
||||
"ort-sys",
|
||||
"rayon",
|
||||
"serde_json",
|
||||
"tokenizers",
|
||||
"tokenizers 0.21.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1856,6 +1992,18 @@ dependencies = [
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "float8"
|
||||
version = "0.2.1"
|
||||
source = "git+https://github.com/zackangelo/float8?branch=cudarc_0_16#03c1f5fe7cdb2f9cb690823fdd40593be57c408f"
|
||||
dependencies = [
|
||||
"cudarc",
|
||||
"half",
|
||||
"num-traits",
|
||||
"rand 0.9.2",
|
||||
"rand_distr 0.5.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
@@ -2246,6 +2394,24 @@ dependencies = [
|
||||
"seq-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemma-runner"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||
"candle-examples",
|
||||
"candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||
"candle-transformers 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||
"clap",
|
||||
"hf-hub 0.4.3",
|
||||
"serde_json",
|
||||
"tokenizers 0.21.4",
|
||||
"tracing",
|
||||
"tracing-chrome",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.14.7"
|
||||
@@ -2421,19 +2587,48 @@ version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "helm-chart-tool"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"toml 0.8.23",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
|
||||
|
||||
[[package]]
|
||||
name = "hf-hub"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
|
||||
dependencies = [
|
||||
"dirs 5.0.1",
|
||||
"indicatif",
|
||||
"log",
|
||||
"native-tls",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 1.0.69",
|
||||
"ureq",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hf-hub"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97"
|
||||
dependencies = [
|
||||
"dirs",
|
||||
"dirs 6.0.0",
|
||||
"futures",
|
||||
"http",
|
||||
"indicatif",
|
||||
@@ -2842,12 +3037,12 @@ dependencies = [
|
||||
"axum",
|
||||
"bindgen_cuda",
|
||||
"byteorder",
|
||||
"candle-core",
|
||||
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"candle-datasets",
|
||||
"candle-flash-attn",
|
||||
"candle-nn",
|
||||
"candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"candle-onnx",
|
||||
"candle-transformers",
|
||||
"candle-transformers 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"clap",
|
||||
"cpal",
|
||||
"csv",
|
||||
@@ -2855,11 +3050,13 @@ dependencies = [
|
||||
"either",
|
||||
"enterpolation",
|
||||
"futures-util",
|
||||
"gemma-runner",
|
||||
"half",
|
||||
"hf-hub",
|
||||
"hf-hub 0.4.3",
|
||||
"image",
|
||||
"imageproc",
|
||||
"intel-mkl-src",
|
||||
"llama-runner",
|
||||
"memmap2",
|
||||
"num-traits",
|
||||
"palette",
|
||||
@@ -2873,7 +3070,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"symphonia",
|
||||
"tokenizers",
|
||||
"tokenizers 0.21.4",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tower",
|
||||
@@ -2981,6 +3178,15 @@ version = "1.70.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.12.1"
|
||||
@@ -3405,7 +3611,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets 0.53.3",
|
||||
"windows-targets 0.48.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3443,6 +3649,20 @@ version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956"
|
||||
|
||||
[[package]]
|
||||
name = "llama-runner"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||
"candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||
"candle-transformers 0.9.1 (git+https://github.com/huggingface/candle.git)",
|
||||
"clap",
|
||||
"hf-hub 0.3.2",
|
||||
"serde_json",
|
||||
"tokenizers 0.20.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.13"
|
||||
@@ -3965,6 +4185,59 @@ dependencies = [
|
||||
"objc_exception",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "objc2"
|
||||
version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "561f357ba7f3a2a61563a186a163d0a3a5247e1089524a3981d49adb775078bc"
|
||||
dependencies = [
|
||||
"objc2-encode",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "objc2-core-foundation"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166"
|
||||
dependencies = [
|
||||
"bitflags 2.9.2",
|
||||
"dispatch2",
|
||||
"objc2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "objc2-encode"
|
||||
version = "4.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33"
|
||||
|
||||
[[package]]
|
||||
name = "objc2-foundation"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "900831247d2fe1a09a683278e5384cfb8c80c79fe6b166f9d14bfdde0ea1b03c"
|
||||
dependencies = [
|
||||
"bitflags 2.9.2",
|
||||
"block2",
|
||||
"libc",
|
||||
"objc2",
|
||||
"objc2-core-foundation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "objc2-metal"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f246c183239540aab1782457b35ab2040d4259175bd1d0c58e46ada7b47a874"
|
||||
dependencies = [
|
||||
"bitflags 2.9.2",
|
||||
"block2",
|
||||
"dispatch2",
|
||||
"objc2",
|
||||
"objc2-core-foundation",
|
||||
"objc2-foundation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "objc_exception"
|
||||
version = "0.1.2"
|
||||
@@ -4803,7 +5076,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"socket2 0.5.10",
|
||||
"tracing",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5006,6 +5279,17 @@ dependencies = [
|
||||
"rayon-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-cond"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9"
|
||||
dependencies = [
|
||||
"either",
|
||||
"itertools 0.11.0",
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-cond"
|
||||
version = "0.4.0"
|
||||
@@ -6267,7 +6551,7 @@ dependencies = [
|
||||
"getrandom 0.3.3",
|
||||
"once_cell",
|
||||
"rustix",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6384,6 +6668,38 @@ version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.20.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3b08cc37428a476fc9e20ac850132a513a2e1ce32b6a31addf2b74fa7033b905"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"derive_builder",
|
||||
"esaxx-rs",
|
||||
"getrandom 0.2.16",
|
||||
"indicatif",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"macro_rules_attribute",
|
||||
"monostate",
|
||||
"onig",
|
||||
"paste",
|
||||
"rand 0.8.5",
|
||||
"rayon",
|
||||
"rayon-cond 0.3.0",
|
||||
"regex",
|
||||
"regex-syntax 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"spm_precompiled",
|
||||
"thiserror 1.0.69",
|
||||
"unicode-normalization-alignments",
|
||||
"unicode-segmentation",
|
||||
"unicode_categories",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.21.4"
|
||||
@@ -6397,7 +6713,8 @@ dependencies = [
|
||||
"derive_builder",
|
||||
"esaxx-rs",
|
||||
"getrandom 0.3.3",
|
||||
"hf-hub",
|
||||
"hf-hub 0.4.3",
|
||||
"indicatif",
|
||||
"itertools 0.14.0",
|
||||
"log",
|
||||
"macro_rules_attribute",
|
||||
@@ -6406,7 +6723,7 @@ dependencies = [
|
||||
"paste",
|
||||
"rand 0.9.2",
|
||||
"rayon",
|
||||
"rayon-cond",
|
||||
"rayon-cond 0.4.0",
|
||||
"regex",
|
||||
"regex-syntax 0.8.5",
|
||||
"serde",
|
||||
@@ -7260,7 +7577,7 @@ version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
|
||||
dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@@ -4,7 +4,9 @@ members = [
|
||||
"crates/inference-engine",
|
||||
"crates/embeddings-engine",
|
||||
"crates/leptos-app",
|
||||
"crates/helm-chart-tool"
|
||||
"crates/helm-chart-tool",
|
||||
"crates/llama-runner",
|
||||
"crates/gemma-runner"
|
||||
]
|
||||
default-members = ["crates/predict-otron-9000"]
|
||||
resolver = "2"
|
||||
|
28
README.md
28
README.md
@@ -10,7 +10,7 @@ Powerful local AI inference with OpenAI-compatible APIs
|
||||
|
||||
The predict-otron-9000 is a flexible AI platform that provides:
|
||||
|
||||
- **Local LLM Inference**: Run Gemma models locally with CPU or GPU acceleration
|
||||
- **Local LLM Inference**: Run Gemma and Llama models locally with CPU or GPU acceleration
|
||||
- **Embeddings Generation**: Create text embeddings with FastEmbed
|
||||
- **Web Interface**: Interact with models through a Leptos WASM chat interface
|
||||
- **TypeScript CLI**: Command-line client for testing and automation
|
||||
@@ -22,7 +22,7 @@ The system supports both CPU and GPU acceleration (CUDA/Metal), with intelligent
|
||||
|
||||
- **OpenAI Compatible**: API endpoints match OpenAI's format for easy integration
|
||||
- **Text Embeddings**: Generate high-quality text embeddings using FastEmbed
|
||||
- **Text Generation**: Chat completions with OpenAI-compatible API using Gemma models (1B, 2B, 7B variants including instruction-tuned models)
|
||||
- **Text Generation**: Chat completions with OpenAI-compatible API using Gemma and Llama models (various sizes including instruction-tuned variants)
|
||||
- **Performance Optimized**: Efficient caching and platform-specific optimizations for improved throughput
|
||||
- **Web Chat Interface**: Leptos-based WebAssembly (WASM) chat interface for browser-based interaction
|
||||
- **Flexible Deployment**: Run as monolithic service or microservices architecture
|
||||
@@ -31,15 +31,19 @@ The system supports both CPU and GPU acceleration (CUDA/Metal), with intelligent
|
||||
|
||||
### Workspace Structure
|
||||
|
||||
The project uses a 4-crate Rust workspace plus TypeScript components:
|
||||
The project uses a 7-crate Rust workspace plus TypeScript components:
|
||||
|
||||
```
|
||||
crates/
|
||||
├── predict-otron-9000/ # Main orchestration server (Rust 2024)
|
||||
├── inference-engine/ # Gemma inference via Candle (Rust 2021)
|
||||
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
|
||||
└── leptos-app/ # WASM web frontend (Rust 2021)
|
||||
cli.ts # TypeScript/Bun CLI client
|
||||
├── inference-engine/ # Multi-model inference orchestrator (Rust 2021)
|
||||
├── gemma-runner/ # Gemma model inference via Candle (Rust 2021)
|
||||
├── llama-runner/ # Llama model inference via Candle (Rust 2021)
|
||||
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
|
||||
├── leptos-app/ # WASM web frontend (Rust 2021)
|
||||
├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024)
|
||||
└── scripts/
|
||||
└── cli.ts # TypeScript/Bun CLI client
|
||||
```
|
||||
|
||||
### Service Architecture
|
||||
@@ -149,16 +153,16 @@ cd crates/leptos-app
|
||||
#### TypeScript CLI Client
|
||||
```bash
|
||||
# List available models
|
||||
bun run cli.ts --list-models
|
||||
bun run scripts/cli.ts --list-models
|
||||
|
||||
# Chat completion
|
||||
bun run cli.ts "What is the capital of France?"
|
||||
bun run scripts/cli.ts "What is the capital of France?"
|
||||
|
||||
# With specific model
|
||||
bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||
bun run scripts/cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||
|
||||
# Show help
|
||||
bun run cli.ts --help
|
||||
bun run scripts/cli.ts --help
|
||||
```
|
||||
|
||||
## API Usage
|
||||
@@ -454,7 +458,7 @@ curl -s http://localhost:8080/v1/models | jq
|
||||
|
||||
**CLI client test:**
|
||||
```bash
|
||||
bun run cli.ts "What is 2+2?"
|
||||
bun run scripts/cli.ts "What is 2+2?"
|
||||
```
|
||||
|
||||
**Web frontend:**
|
||||
|
28
crates/gemma-runner/Cargo.toml
Normal file
28
crates/gemma-runner/Cargo.toml
Normal file
@@ -0,0 +1,28 @@
|
||||
[package]
|
||||
name = "gemma-runner"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-examples = { git = "https://github.com/huggingface/candle.git" }
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
hf-hub = "0.4"
|
||||
tokenizers = "0.21"
|
||||
anyhow = "1.0"
|
||||
clap = { version = "4.0", features = ["derive", "string"] }
|
||||
serde_json = "1.0"
|
||||
tracing = "0.1"
|
||||
tracing-chrome = "0.7"
|
||||
tracing-subscriber = "0.3"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
|
137
crates/gemma-runner/README.md
Normal file
137
crates/gemma-runner/README.md
Normal file
@@ -0,0 +1,137 @@
|
||||
# Gemma Runner
|
||||
|
||||
Fast Gemma inference with Candle framework in Rust.
|
||||
|
||||
## Features
|
||||
|
||||
- Support for multiple Gemma model versions (v1, v2, v3)
|
||||
- GPU acceleration with CUDA and Metal
|
||||
- Configurable sampling parameters
|
||||
- Multiple model variants including instruct and code models
|
||||
|
||||
## Supported Models
|
||||
|
||||
### Gemma v1
|
||||
- `gemma-2b` - Base 2B model
|
||||
- `gemma-7b` - Base 7B model
|
||||
- `gemma-2b-it` - Instruct 2B model
|
||||
- `gemma-7b-it` - Instruct 7B model
|
||||
- `gemma-1.1-2b-it` - Instruct 2B v1.1 model
|
||||
- `gemma-1.1-7b-it` - Instruct 7B v1.1 model
|
||||
|
||||
### CodeGemma
|
||||
- `codegemma-2b` - Code base 2B model
|
||||
- `codegemma-7b` - Code base 7B model
|
||||
- `codegemma-2b-it` - Code instruct 2B model
|
||||
- `codegemma-7b-it` - Code instruct 7B model
|
||||
|
||||
### Gemma v2
|
||||
- `gemma-2-2b` - Base 2B v2 model (default)
|
||||
- `gemma-2-2b-it` - Instruct 2B v2 model
|
||||
- `gemma-2-9b` - Base 9B v2 model
|
||||
- `gemma-2-9b-it` - Instruct 9B v2 model
|
||||
|
||||
### Gemma v3
|
||||
- `gemma-3-1b` - Base 1B v3 model
|
||||
- `gemma-3-1b-it` - Instruct 1B v3 model
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
cd gemma-runner
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
For GPU support:
|
||||
```bash
|
||||
# CUDA
|
||||
cargo build --release --features cuda
|
||||
|
||||
# Metal (macOS)
|
||||
cargo build --release --features metal
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
# Run with default model (gemma-2-2b)
|
||||
cargo run -- --prompt "The capital of France is"
|
||||
|
||||
# Specify a different model
|
||||
cargo run -- --model gemma-2b-it --prompt "Explain quantum computing"
|
||||
|
||||
# Generate more tokens
|
||||
cargo run -- --model codegemma-2b-it --prompt "Write a Python function to sort a list" --max-tokens 200
|
||||
```
|
||||
|
||||
### Advanced Options
|
||||
|
||||
```bash
|
||||
# Use CPU instead of GPU
|
||||
cargo run -- --cpu --prompt "Hello world"
|
||||
|
||||
# Adjust sampling parameters
|
||||
cargo run -- --temperature 0.8 --top-p 0.9 --prompt "Write a story about"
|
||||
|
||||
# Use custom model from HuggingFace Hub
|
||||
cargo run -- --model-id "google/gemma-2-2b-it" --prompt "What is AI?"
|
||||
|
||||
# Enable tracing for performance analysis
|
||||
cargo run -- --tracing --prompt "Explain machine learning"
|
||||
```
|
||||
|
||||
### Command Line Arguments
|
||||
|
||||
- `--prompt, -p` - The prompt to generate text from (default: "The capital of France is")
|
||||
- `--model, -m` - The model to use (default: "gemma-2-2b")
|
||||
- `--cpu` - Run on CPU rather than GPU
|
||||
- `--temperature, -t` - Sampling temperature (optional)
|
||||
- `--top-p` - Nucleus sampling probability cutoff (optional)
|
||||
- `--seed` - Random seed (default: 299792458)
|
||||
- `--max-tokens, -n` - Maximum tokens to generate (default: 100)
|
||||
- `--model-id` - Custom model ID from HuggingFace Hub
|
||||
- `--revision` - Model revision (default: "main")
|
||||
- `--use-flash-attn` - Use flash attention
|
||||
- `--repeat-penalty` - Repetition penalty (default: 1.1)
|
||||
- `--repeat-last-n` - Context size for repeat penalty (default: 64)
|
||||
- `--dtype` - Data type (f16, bf16, f32)
|
||||
- `--tracing` - Enable performance tracing
|
||||
|
||||
## Examples
|
||||
|
||||
### Text Generation
|
||||
```bash
|
||||
cargo run -- --model gemma-2b-it --prompt "Explain the theory of relativity" --max-tokens 150
|
||||
```
|
||||
|
||||
### Code Generation
|
||||
```bash
|
||||
cargo run -- --model codegemma-7b-it --prompt "Write a Rust function to calculate factorial" --max-tokens 100
|
||||
```
|
||||
|
||||
### Creative Writing
|
||||
```bash
|
||||
cargo run -- --model gemma-7b-it --temperature 0.9 --prompt "Once upon a time in a magical forest" --max-tokens 200
|
||||
```
|
||||
|
||||
### Chat with Gemma 3 (Instruct format)
|
||||
```bash
|
||||
cargo run -- --model gemma-3-1b-it --prompt "How do I learn Rust programming?"
|
||||
```
|
||||
|
||||
## Performance Notes
|
||||
|
||||
- GPU acceleration is automatically detected and used when available
|
||||
- BF16 precision is used on CUDA for better performance
|
||||
- F32 precision is used on CPU
|
||||
- Flash attention can be enabled with `--use-flash-attn` for supported models
|
||||
- Model files are cached locally after first download
|
||||
|
||||
## Requirements
|
||||
|
||||
- Rust 1.70+
|
||||
- CUDA toolkit (for CUDA support)
|
||||
- Metal (automatically available on macOS)
|
||||
- Internet connection for first-time model download
|
389
crates/gemma-runner/src/gemma_api.rs
Normal file
389
crates/gemma-runner/src/gemma_api.rs
Normal file
@@ -0,0 +1,389 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::ValueEnum;
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
|
||||
// Removed gemma_cli import as it's not needed for the API
|
||||
use candle_core::{utils, DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use std::sync::mpsc::{self, Receiver, Sender};
|
||||
use std::thread;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
pub enum WhichModel {
|
||||
#[value(name = "gemma-2b")]
|
||||
Base2B,
|
||||
#[value(name = "gemma-7b")]
|
||||
Base7B,
|
||||
#[value(name = "gemma-2b-it")]
|
||||
Instruct2B,
|
||||
#[value(name = "gemma-7b-it")]
|
||||
Instruct7B,
|
||||
#[value(name = "gemma-1.1-2b-it")]
|
||||
InstructV1_1_2B,
|
||||
#[value(name = "gemma-1.1-7b-it")]
|
||||
InstructV1_1_7B,
|
||||
#[value(name = "codegemma-2b")]
|
||||
CodeBase2B,
|
||||
#[value(name = "codegemma-7b")]
|
||||
CodeBase7B,
|
||||
#[value(name = "codegemma-2b-it")]
|
||||
CodeInstruct2B,
|
||||
#[value(name = "codegemma-7b-it")]
|
||||
CodeInstruct7B,
|
||||
#[value(name = "gemma-2-2b")]
|
||||
BaseV2_2B,
|
||||
#[value(name = "gemma-2-2b-it")]
|
||||
InstructV2_2B,
|
||||
#[value(name = "gemma-2-9b")]
|
||||
BaseV2_9B,
|
||||
#[value(name = "gemma-2-9b-it")]
|
||||
InstructV2_9B,
|
||||
#[value(name = "gemma-3-1b")]
|
||||
BaseV3_1B,
|
||||
#[value(name = "gemma-3-1b-it")]
|
||||
InstructV3_1B,
|
||||
}
|
||||
|
||||
enum Model {
|
||||
V1(Model1),
|
||||
V2(Model2),
|
||||
V3(Model3),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> candle_core::Result<Tensor> {
|
||||
match self {
|
||||
Self::V1(m) => m.forward(input_ids, pos),
|
||||
Self::V2(m) => m.forward(input_ids, pos),
|
||||
Self::V3(m) => m.forward(input_ids, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn device(cpu: bool) -> Result<Device> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
} else if utils::cuda_is_available() {
|
||||
Ok(Device::new_cuda(0)?)
|
||||
} else if utils::metal_is_available() {
|
||||
Ok(Device::new_metal(0)?)
|
||||
} else {
|
||||
Ok(Device::Cpu)
|
||||
}
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream-only generation: sends freshly generated token strings over `tx`.
|
||||
/// (Does not send the prompt tokens; only newly generated model tokens.)
|
||||
fn run_stream(&mut self, prompt: &str, sample_len: usize, tx: Sender<Result<String>>) -> Result<()> {
|
||||
self.tokenizer.clear();
|
||||
|
||||
// Encode prompt (context only; do not emit prompt tokens to the stream).
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
// Warm the tokenizer's internal state with prompt tokens (so merges are correct),
|
||||
// but do not send them to the receiver.
|
||||
for &t in tokens.iter() {
|
||||
let _ = self.tokenizer.next_token(t)?;
|
||||
}
|
||||
// Make sure stdout isn't holding anything (if caller also prints).
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
eprintln!("Warning: <end_of_turn> token not found, using <eos> as backup");
|
||||
eos_token
|
||||
}
|
||||
};
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
// Best-effort send; ignore if receiver dropped.
|
||||
let _ = tx.send(Ok(t));
|
||||
}
|
||||
}
|
||||
|
||||
let _dt = start_gen.elapsed();
|
||||
|
||||
// Flush any remaining buffered bytes as one final chunk.
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
let _ = tx.send(Ok(rest));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GemmaInferenceConfig {
|
||||
pub tracing: bool,
|
||||
pub prompt: String,
|
||||
pub model: WhichModel,
|
||||
pub cpu: bool,
|
||||
pub dtype: Option<String>,
|
||||
pub model_id: Option<String>,
|
||||
pub revision: String,
|
||||
pub use_flash_attn: bool,
|
||||
pub seed: u64,
|
||||
pub temperature: f64,
|
||||
pub top_p: Option<f64>,
|
||||
pub repeat_penalty: f32,
|
||||
pub repeat_last_n: usize,
|
||||
pub max_tokens: usize,
|
||||
}
|
||||
|
||||
impl Default for GemmaInferenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tracing: false,
|
||||
prompt: "Hello".to_string(),
|
||||
model: WhichModel::InstructV2_2B,
|
||||
cpu: false,
|
||||
dtype: None,
|
||||
model_id: None,
|
||||
revision: "main".to_string(),
|
||||
use_flash_attn: false,
|
||||
seed: 299792458,
|
||||
temperature: 0.8,
|
||||
top_p: None,
|
||||
repeat_penalty: 1.1,
|
||||
repeat_last_n: 128,
|
||||
max_tokens: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Removed From<Args> implementation as Args is not available and not needed for API usage
|
||||
|
||||
/// Builds the model and returns a channel that streams generated token strings.
|
||||
/// If model setup fails, the `Result` is returned immediately.
|
||||
pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String>>> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let _guard = if cfg.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
utils::with_avx(),
|
||||
utils::with_neon(),
|
||||
utils::with_simd128(),
|
||||
utils::with_f16c()
|
||||
);
|
||||
|
||||
let device = device(cfg.cpu)?;
|
||||
println!("Device: {:?}", device);
|
||||
|
||||
let dtype = match cfg.dtype.as_deref() {
|
||||
Some("f16") => DType::F16,
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => anyhow::bail!("Unsupported dtype {dtype}"),
|
||||
None => {
|
||||
if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F16
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("Using dtype: {:?}", dtype);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
|
||||
let model_id = cfg.model_id.unwrap_or_else(|| {
|
||||
match cfg.model {
|
||||
WhichModel::Base2B => "google/gemma-2b",
|
||||
WhichModel::Base7B => "google/gemma-7b",
|
||||
WhichModel::Instruct2B => "google/gemma-2b-it",
|
||||
WhichModel::Instruct7B => "google/gemma-7b-it",
|
||||
WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it",
|
||||
WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it",
|
||||
WhichModel::CodeBase2B => "google/codegemma-2b",
|
||||
WhichModel::CodeBase7B => "google/codegemma-7b",
|
||||
WhichModel::CodeInstruct2B => "google/codegemma-2b-it",
|
||||
WhichModel::CodeInstruct7B => "google/codegemma-7b-it",
|
||||
WhichModel::BaseV2_2B => "google/gemma-2-2b",
|
||||
WhichModel::InstructV2_2B => "google/gemma-2-2b-it",
|
||||
WhichModel::BaseV2_9B => "google/gemma-2-9b",
|
||||
WhichModel::InstructV2_9B => "google/gemma-2-9b-it",
|
||||
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
|
||||
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
|
||||
}
|
||||
.to_string()
|
||||
});
|
||||
|
||||
println!("Loading model: {}", &model_id);
|
||||
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, cfg.revision));
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let filenames = match cfg.model {
|
||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
||||
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("Retrieved files in {:?}", start.elapsed());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
|
||||
let model: Model = match cfg.model {
|
||||
WhichModel::Base2B
|
||||
| WhichModel::Base7B
|
||||
| WhichModel::Instruct2B
|
||||
| WhichModel::Instruct7B
|
||||
| WhichModel::InstructV1_1_2B
|
||||
| WhichModel::InstructV1_1_7B
|
||||
| WhichModel::CodeBase2B
|
||||
| WhichModel::CodeBase7B
|
||||
| WhichModel::CodeInstruct2B
|
||||
| WhichModel::CodeInstruct7B => {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V1(model)
|
||||
}
|
||||
WhichModel::BaseV2_2B | WhichModel::InstructV2_2B | WhichModel::BaseV2_9B | WhichModel::InstructV2_9B => {
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V2(model)
|
||||
}
|
||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => {
|
||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V3(model)
|
||||
}
|
||||
};
|
||||
println!("Loaded model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
cfg.seed,
|
||||
cfg.temperature.into(),
|
||||
cfg.top_p,
|
||||
cfg.repeat_penalty,
|
||||
cfg.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
|
||||
let prompt = match cfg.model {
|
||||
WhichModel::InstructV3_1B => {
|
||||
format!(
|
||||
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
|
||||
cfg.prompt
|
||||
)
|
||||
}
|
||||
_ => cfg.prompt,
|
||||
};
|
||||
|
||||
println!("Starting inference...");
|
||||
|
||||
// Create the channel after successful setup.
|
||||
let (tx, rx) = mpsc::channel::<Result<String>>();
|
||||
|
||||
// Spawn generation thread; send tokens to the channel.
|
||||
thread::spawn(move || {
|
||||
// If generation fails, forward the error once.
|
||||
if let Err(e) = pipeline.run_stream(&prompt, cfg.max_tokens, tx.clone()) {
|
||||
let _ = tx.send(Err(e));
|
||||
}
|
||||
// Channel closes when tx is dropped.
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
}
|
97
crates/gemma-runner/src/gemma_cli.rs
Normal file
97
crates/gemma-runner/src/gemma_cli.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
use std::io::Write;
|
||||
use clap::Parser;
|
||||
use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)]
|
||||
pub struct Args {
|
||||
/// The prompt to generate text from
|
||||
#[arg(short, long, default_value = "The capital of France is")]
|
||||
pub(crate) prompt: String,
|
||||
|
||||
/// The model to use
|
||||
#[arg(short, long, default_value = "gemma-2-2b")]
|
||||
pub(crate) model: WhichModel,
|
||||
|
||||
/// Run on CPU rather than GPU
|
||||
#[arg(long)]
|
||||
pub(crate) cpu: bool,
|
||||
|
||||
/// The temperature used to generate samples
|
||||
#[arg(short, long)]
|
||||
pub(crate) temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff
|
||||
#[arg(long)]
|
||||
pub(crate) top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
pub(crate) seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens)
|
||||
#[arg(short = 'n', long, default_value_t = 100)]
|
||||
pub(crate) max_tokens: usize,
|
||||
|
||||
/// Use different dtype than default
|
||||
#[arg(long)]
|
||||
pub(crate) dtype: Option<String>,
|
||||
|
||||
/// Custom model ID from HuggingFace Hub
|
||||
#[arg(long)]
|
||||
pub(crate) model_id: Option<String>,
|
||||
|
||||
/// Model revision
|
||||
#[arg(long, default_value = "main")]
|
||||
pub(crate) revision: String,
|
||||
|
||||
/// Use flash attention
|
||||
#[arg(long)]
|
||||
pub(crate) use_flash_attn: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
pub(crate) repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty
|
||||
#[arg(long, default_value_t = 64)]
|
||||
pub(crate) repeat_last_n: usize,
|
||||
|
||||
/// Enable tracing
|
||||
#[arg(long)]
|
||||
pub(crate) tracing: bool,
|
||||
}
|
||||
|
||||
pub fn run_cli() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let cfg = GemmaInferenceConfig {
|
||||
tracing: args.tracing,
|
||||
prompt: args.prompt,
|
||||
model: args.model,
|
||||
cpu: args.cpu,
|
||||
dtype: args.dtype,
|
||||
model_id: args.model_id,
|
||||
revision: args.revision,
|
||||
use_flash_attn: args.use_flash_attn,
|
||||
seed: args.seed,
|
||||
temperature: args.temperature.unwrap_or(0.8),
|
||||
top_p: args.top_p,
|
||||
repeat_penalty: args.repeat_penalty,
|
||||
repeat_last_n: args.repeat_last_n,
|
||||
max_tokens: args.max_tokens,
|
||||
};
|
||||
let rx = run_gemma_api(cfg)?;
|
||||
for msg in rx {
|
||||
match msg {
|
||||
Ok(tok) => {
|
||||
print!("{tok}");
|
||||
let _ = std::io::stdout().flush(); // <- force it out now
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("generation error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
3
crates/gemma-runner/src/lib.rs
Normal file
3
crates/gemma-runner/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod gemma_api;
|
||||
|
||||
pub use gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
|
17
crates/gemma-runner/src/main.rs
Normal file
17
crates/gemma-runner/src/main.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
mod gemma_cli;
|
||||
mod gemma_api;
|
||||
|
||||
use anyhow::Error;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use crate::gemma_cli::run_cli;
|
||||
use std::io::Write;
|
||||
|
||||
/// just a placeholder, not used for anything
|
||||
fn main() -> std::result::Result<(), Error> {
|
||||
run_cli()
|
||||
}
|
@@ -3,8 +3,6 @@ name = "helm-chart-tool"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[workspace]
|
||||
|
||||
[[bin]]
|
||||
name = "helm-chart-tool"
|
||||
path = "src/main.rs"
|
||||
|
@@ -3,9 +3,16 @@ name = "inference-engine"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
|
||||
[[bin]]
|
||||
name="cli"
|
||||
path = "src/cli_main.rs"
|
||||
name="gemma_inference"
|
||||
path = "src/gemma_inference.rs"
|
||||
required-features = ["bin"]
|
||||
|
||||
[[bin]]
|
||||
name="llama_inference"
|
||||
path = "src/llama_inference.rs"
|
||||
required-features = ["bin"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
@@ -50,6 +57,8 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
reborrow = "0.5.5"
|
||||
futures-util = "0.3.31"
|
||||
gemma-runner = { path = "../gemma-runner" }
|
||||
llama-runner = { path = "../llama-runner" }
|
||||
|
||||
# --- Add this section for conditional compilation ---
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
@@ -83,6 +92,9 @@ tokio = "1.43.0"
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
bindgen_cuda = { version = "0.1.1", optional = true }
|
||||
|
||||
[features]
|
||||
bin = []
|
||||
|
||||
|
||||
|
||||
[package.metadata.compose]
|
||||
|
@@ -1,72 +0,0 @@
|
||||
use clap::Parser;
|
||||
use crate::model::Which;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
pub cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
pub tracing: bool,
|
||||
|
||||
/// Run in server mode with OpenAI compatible API
|
||||
#[arg(long)]
|
||||
pub server: bool,
|
||||
|
||||
/// Port to use for the server
|
||||
#[arg(long, default_value_t = 3777)]
|
||||
pub port: u16,
|
||||
|
||||
/// Prompt for text generation (not used in server mode)
|
||||
#[arg(long)]
|
||||
pub prompt: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
pub temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
pub top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
pub seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
pub sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
pub model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
pub revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
pub tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
pub config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
pub weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
pub repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
pub repeat_last_n: usize,
|
||||
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "3-1b-it")]
|
||||
pub which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
pub use_flash_attn: bool,
|
||||
}
|
@@ -1,912 +0,0 @@
|
||||
mod token_output_stream;
|
||||
mod utilities_lib;
|
||||
|
||||
#[cfg(feature = "intel-mkl-src")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate-src")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
extern crate metal_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use clap::Parser;
|
||||
use either::Either;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
|
||||
use tokio::sync::Mutex;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
|
||||
// OpenAI API compatible structs
|
||||
|
||||
/// Inner content structure for messages that can be either a string or key-value pairs
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MessageInnerContent(
|
||||
#[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
|
||||
);
|
||||
|
||||
impl ToSchema<'_> for MessageInnerContent {
|
||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||
(
|
||||
"MessageInnerContent",
|
||||
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Function for MessageInnerContent Schema generation to handle `Either`
|
||||
fn message_inner_content_schema() -> utoipa::openapi::Schema {
|
||||
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
||||
|
||||
Schema::OneOf(
|
||||
OneOfBuilder::new()
|
||||
// Either::Left - simple string
|
||||
.item(Schema::Object(
|
||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||
))
|
||||
// Either::Right - object with string values
|
||||
.item(Schema::Object(
|
||||
ObjectBuilder::new()
|
||||
.schema_type(SchemaType::Object)
|
||||
.additional_properties(Some(RefOr::T(Schema::Object(
|
||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||
))))
|
||||
.build(),
|
||||
))
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Message content that can be either simple text or complex structured content
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MessageContent(
|
||||
#[serde(with = "either::serde_untagged")]
|
||||
Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
||||
);
|
||||
|
||||
impl ToSchema<'_> for MessageContent {
|
||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
||||
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Function for MessageContent Schema generation to handle `Either`
|
||||
fn message_content_schema() -> utoipa::openapi::Schema {
|
||||
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
|
||||
|
||||
Schema::OneOf(
|
||||
OneOfBuilder::new()
|
||||
.item(Schema::Object(
|
||||
ObjectBuilder::new().schema_type(SchemaType::String).build(),
|
||||
))
|
||||
.item(Schema::Array(
|
||||
ArrayBuilder::new()
|
||||
.items(RefOr::T(Schema::Object(
|
||||
ObjectBuilder::new()
|
||||
.schema_type(SchemaType::Object)
|
||||
.additional_properties(Some(RefOr::Ref(
|
||||
utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
|
||||
)))
|
||||
.build(),
|
||||
)))
|
||||
.build(),
|
||||
))
|
||||
.build(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Represents a single message in a conversation
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub struct Message {
|
||||
/// The message content
|
||||
pub content: Option<MessageContent>,
|
||||
/// The role of the message sender ("user", "assistant", "system", "tool", etc.)
|
||||
pub role: String,
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
/// Stop token configuration for generation
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
#[serde(untagged)]
|
||||
pub enum StopTokens {
|
||||
/// Multiple possible stop sequences
|
||||
Multi(Vec<String>),
|
||||
/// Single stop sequence
|
||||
Single(String),
|
||||
}
|
||||
|
||||
/// Default value helper
|
||||
fn default_false() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Default value helper
|
||||
fn default_1usize() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
/// Default value helper
|
||||
fn default_model() -> String {
|
||||
"default".to_string()
|
||||
}
|
||||
|
||||
/// Chat completion request following OpenAI's specification
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub struct ChatCompletionRequest {
|
||||
#[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))]
|
||||
pub messages: Vec<Message>,
|
||||
#[schema(example = "gemma-3-1b-it")]
|
||||
#[serde(default = "default_model")]
|
||||
pub model: String,
|
||||
#[serde(default = "default_false")]
|
||||
#[schema(example = false)]
|
||||
pub logprobs: bool,
|
||||
#[schema(example = 256)]
|
||||
pub max_tokens: Option<usize>,
|
||||
#[serde(rename = "n")]
|
||||
#[serde(default = "default_1usize")]
|
||||
#[schema(example = 1)]
|
||||
pub n_choices: usize,
|
||||
#[schema(example = 0.7)]
|
||||
pub temperature: Option<f64>,
|
||||
#[schema(example = 0.9)]
|
||||
pub top_p: Option<f64>,
|
||||
#[schema(example = false)]
|
||||
pub stream: Option<bool>,
|
||||
}
|
||||
|
||||
/// Chat completion choice
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct ChatCompletionChoice {
|
||||
pub index: usize,
|
||||
pub message: Message,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
/// Chat completion response
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatCompletionChoice>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
/// Token usage information
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: usize,
|
||||
pub completion_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
// Application state shared between handlers
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
text_generation: Arc<Mutex<TextGeneration>>,
|
||||
model_id: String,
|
||||
}
|
||||
|
||||
// Chat completions endpoint handler
|
||||
async fn chat_completions(
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<ChatCompletionRequest>,
|
||||
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let mut prompt = String::new();
|
||||
|
||||
// Convert messages to a prompt string
|
||||
for message in &request.messages {
|
||||
let role = &message.role;
|
||||
let content = match &message.content {
|
||||
Some(content) => match &content.0 {
|
||||
Either::Left(text) => text.clone(),
|
||||
Either::Right(_) => "".to_string(), // Handle complex content if needed
|
||||
},
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
// Format based on role
|
||||
match role.as_str() {
|
||||
"system" => prompt.push_str(&format!("System: {}\n", content)),
|
||||
"user" => prompt.push_str(&format!("User: {}\n", content)),
|
||||
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
|
||||
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
|
||||
}
|
||||
}
|
||||
|
||||
// Add the assistant prefix for the response
|
||||
prompt.push_str("Assistant: ");
|
||||
|
||||
// Capture the output
|
||||
let mut output = Vec::new();
|
||||
{
|
||||
let mut text_gen = state.text_generation.lock().await;
|
||||
|
||||
// Buffer to capture the output
|
||||
let mut buffer = Vec::new();
|
||||
|
||||
// Run text generation
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
|
||||
|
||||
if let Err(e) = result {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin inference-engine -- --prompt \"Your prompt here\"",
|
||||
"type": "unsupported_api"
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
|
||||
// Convert buffer to string
|
||||
if let Ok(text) = String::from_utf8(buffer) {
|
||||
output.push(text);
|
||||
}
|
||||
}
|
||||
|
||||
// Create response
|
||||
let response = ChatCompletionResponse {
|
||||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4().to_string().replace("-", "")),
|
||||
object: "chat.completion".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
model: request.model,
|
||||
choices: vec![ChatCompletionChoice {
|
||||
index: 0,
|
||||
message: Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some(MessageContent(Either::Left(output.join("")))),
|
||||
name: None,
|
||||
},
|
||||
finish_reason: "stop".to_string(),
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: prompt.len() / 4, // Rough estimate
|
||||
completion_tokens: output.join("").len() / 4, // Rough estimate
|
||||
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
|
||||
},
|
||||
};
|
||||
|
||||
// Return the response as JSON
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
use candle_core::{DType, Device, MetalDevice, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{Repo, RepoType, api::sync::Api};
|
||||
use serde_json::json;
|
||||
use tokenizers::Tokenizer;
|
||||
use crate::token_output_stream::TokenOutputStream;
|
||||
use crate::utilities_lib::device;
|
||||
|
||||
// Create the router with the chat completions endpoint
|
||||
fn create_router(app_state: AppState) -> Router {
|
||||
// CORS layer to allow requests from any origin
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
Router::new()
|
||||
// OpenAI compatible endpoints
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
// Add more endpoints as needed
|
||||
.layer(cors)
|
||||
.with_state(app_state)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "2b")]
|
||||
Base2B,
|
||||
#[value(name = "7b")]
|
||||
Base7B,
|
||||
#[value(name = "2b-it")]
|
||||
Instruct2B,
|
||||
#[value(name = "7b-it")]
|
||||
Instruct7B,
|
||||
#[value(name = "1.1-2b-it")]
|
||||
InstructV1_1_2B,
|
||||
#[value(name = "1.1-7b-it")]
|
||||
InstructV1_1_7B,
|
||||
#[value(name = "code-2b")]
|
||||
CodeBase2B,
|
||||
#[value(name = "code-7b")]
|
||||
CodeBase7B,
|
||||
#[value(name = "code-2b-it")]
|
||||
CodeInstruct2B,
|
||||
#[value(name = "code-7b-it")]
|
||||
CodeInstruct7B,
|
||||
#[value(name = "2-2b")]
|
||||
BaseV2_2B,
|
||||
#[value(name = "2-2b-it")]
|
||||
InstructV2_2B,
|
||||
#[value(name = "2-9b")]
|
||||
BaseV2_9B,
|
||||
#[value(name = "2-9b-it")]
|
||||
InstructV2_9B,
|
||||
#[value(name = "3-1b")]
|
||||
BaseV3_1B,
|
||||
#[value(name = "3-1b-it")]
|
||||
InstructV3_1B,
|
||||
}
|
||||
|
||||
enum Model {
|
||||
V1(Model1),
|
||||
V2(Model2),
|
||||
V3(Model3),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
|
||||
match self {
|
||||
Self::V1(m) => m.forward(input_ids, pos),
|
||||
Self::V2(m) => m.forward(input_ids, pos),
|
||||
Self::V3(m) => m.forward(input_ids, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
// Run text generation and print to stdout
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
|
||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
println!(
|
||||
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
|
||||
);
|
||||
eos_token
|
||||
}
|
||||
};
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
|
||||
// Manual implementation of repeat penalty to avoid type conflicts
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in &tokens[start_at..] {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new tensor with the modified logits
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
new_logits.reshape(shape)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Run text generation and write to a buffer
|
||||
fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
// Write prompt tokens to output
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
write!(output, "{}", t)?;
|
||||
}
|
||||
}
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
|
||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
|
||||
eos_token
|
||||
}
|
||||
};
|
||||
|
||||
// Determine if we're using a Model3 (gemma-3) variant
|
||||
let is_model3 = match &self.model {
|
||||
Model::V3(_) => true,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
// For Model3, we need to use a different approach
|
||||
if is_model3 {
|
||||
// For gemma-3 models, we'll generate one token at a time with the full context
|
||||
let start_gen = std::time::Instant::now();
|
||||
|
||||
// Initial generation with the full prompt
|
||||
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
||||
let mut logits = self.model.forward(&input, 0)?;
|
||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
|
||||
for _ in 0..sample_len {
|
||||
// Apply repeat penalty if needed
|
||||
let current_logits = if self.repeat_penalty == 1. {
|
||||
logits.clone()
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
|
||||
// Manual implementation of repeat penalty to avoid type conflicts
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in &tokens[start_at..] {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new tensor with the modified logits
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
new_logits.reshape(shape)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(¤t_logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
write!(output, "{}", t)?;
|
||||
}
|
||||
|
||||
// For the next iteration, just use the new token
|
||||
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
|
||||
logits = self.model.forward(&new_input, tokens.len() - 1)?;
|
||||
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Standard approach for other models
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
|
||||
// Manual implementation of repeat penalty to avoid type conflicts
|
||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
||||
|
||||
for &token_id in &tokens[start_at..] {
|
||||
let token_id = token_id as usize;
|
||||
if token_id < logits_vec.len() {
|
||||
let score = logits_vec[token_id];
|
||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
||||
logits_vec[token_id] = sign * score / self.repeat_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new tensor with the modified logits
|
||||
let device = logits.device().clone();
|
||||
let shape = logits.shape().clone();
|
||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
||||
new_logits.reshape(shape)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
write!(output, "{}", t)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Write any remaining tokens
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
write!(output, "{}", rest)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Run in server mode with OpenAI compatible API
|
||||
#[arg(long)]
|
||||
server: bool,
|
||||
|
||||
/// Port to use for the server
|
||||
#[arg(long, default_value_t = 3777)]
|
||||
port: u16,
|
||||
|
||||
/// Prompt for text generation (not used in server mode)
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "3-1b-it")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle_core::utils::with_avx(),
|
||||
candle_core::utils::with_neon(),
|
||||
candle_core::utils::with_simd128(),
|
||||
candle_core::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match &args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => match args.which {
|
||||
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
||||
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
||||
Which::Base2B => "google/gemma-2b".to_string(),
|
||||
Which::Base7B => "google/gemma-7b".to_string(),
|
||||
Which::Instruct2B => "google/gemma-2b-it".to_string(),
|
||||
Which::Instruct7B => "google/gemma-7b-it".to_string(),
|
||||
Which::CodeBase2B => "google/codegemma-2b".to_string(),
|
||||
Which::CodeBase7B => "google/codegemma-7b".to_string(),
|
||||
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
||||
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
||||
Which::BaseV2_2B => "google/gemma-2-2b".to_string(),
|
||||
Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
|
||||
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
||||
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
||||
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
||||
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id.clone(),
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.which {
|
||||
Which::BaseV3_1B | Which::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
||||
_ => utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
},
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let initial_device = utilities_lib::device(args.cpu)?;
|
||||
|
||||
// Check if we're using a V3 model (Gemma 3) and if we're on Metal (macOS)
|
||||
let is_v3_model = matches!(args.which, Which::BaseV3_1B | Which::InstructV3_1B);
|
||||
let is_metal = !initial_device.is_cpu() && candle_core::utils::metal_is_available() && !args.cpu;
|
||||
|
||||
// Use CPU for V3 models on Metal due to missing implementations
|
||||
let device = if is_v3_model && is_metal {
|
||||
println!("Note: Using CPU for Gemma 3 model due to missing Metal implementations for required operations (e.g., rotary-emb).");
|
||||
Device::Cpu
|
||||
} else {
|
||||
initial_device
|
||||
};
|
||||
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
|
||||
// Use the selected device and dtype
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = match args.which {
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
| Which::Instruct2B
|
||||
| Which::Instruct7B
|
||||
| Which::InstructV1_1_2B
|
||||
| Which::InstructV1_1_7B
|
||||
| Which::CodeBase2B
|
||||
| Which::CodeBase7B
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B => {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model1::new(args.use_flash_attn, &config, vb)?;
|
||||
Model::V1(model)
|
||||
}
|
||||
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model2::new(args.use_flash_attn, &config, vb)?;
|
||||
Model::V2(model)
|
||||
}
|
||||
Which::BaseV3_1B | Which::InstructV3_1B => {
|
||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model3::new(args.use_flash_attn, &config, vb)?;
|
||||
Model::V3(model)
|
||||
}
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
|
||||
if args.server {
|
||||
// Start the server
|
||||
println!("Starting server on port {}", args.port);
|
||||
|
||||
// Create app state
|
||||
let app_state = AppState {
|
||||
text_generation: Arc::new(Mutex::new(pipeline)),
|
||||
model_id,
|
||||
};
|
||||
|
||||
// Create router
|
||||
let app = create_router(app_state);
|
||||
|
||||
// Run the server
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], args.port));
|
||||
|
||||
// Use tokio to run the server
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()?
|
||||
.block_on(async {
|
||||
axum::serve(tokio::net::TcpListener::bind(&addr).await?, app)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Server error: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
// Run in CLI mode
|
||||
if let Some(prompt_text) = &args.prompt {
|
||||
let prompt = match args.which {
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
| Which::Instruct2B
|
||||
| Which::Instruct7B
|
||||
| Which::InstructV1_1_2B
|
||||
| Which::InstructV1_1_7B
|
||||
| Which::CodeBase2B
|
||||
| Which::CodeBase7B
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B
|
||||
| Which::BaseV2_2B
|
||||
| Which::InstructV2_2B
|
||||
| Which::BaseV2_9B
|
||||
| Which::InstructV2_9B
|
||||
| Which::BaseV3_1B => prompt_text.clone(),
|
||||
Which::InstructV3_1B => {
|
||||
format!(
|
||||
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
|
||||
prompt_text
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let mut pipeline = pipeline;
|
||||
pipeline.run(&prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
} else {
|
||||
anyhow::bail!("Prompt is required in CLI mode. Use --prompt to specify a prompt or --server to run in server mode.")
|
||||
}
|
||||
}
|
||||
}
|
33
crates/inference-engine/src/inference.rs
Normal file
33
crates/inference-engine/src/inference.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use anyhow::Result;
|
||||
use candle_core::Tensor;
|
||||
|
||||
/// ModelInference trait defines the common interface for model inference operations
|
||||
///
|
||||
/// This trait serves as an abstraction for different model implementations (Gemma and Llama)
|
||||
/// to provide a unified interface for the inference engine.
|
||||
pub trait ModelInference {
|
||||
/// Perform model inference for the given input tensor starting at the specified position
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_ids` - The input tensor containing token IDs
|
||||
/// * `pos` - The position to start generation from
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor containing the logits for the next token prediction
|
||||
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> Result<Tensor>;
|
||||
|
||||
/// Reset the model's internal state, if applicable
|
||||
///
|
||||
/// This method can be used to clear any cached state between inference requests
|
||||
fn reset_state(&mut self) -> Result<()>;
|
||||
|
||||
/// Get the model type name
|
||||
///
|
||||
/// Returns a string identifier for the model type (e.g., "Gemma", "Llama")
|
||||
fn model_type(&self) -> &'static str;
|
||||
}
|
||||
|
||||
/// Factory function type for creating model inference implementations
|
||||
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;
|
@@ -4,14 +4,16 @@ pub mod model;
|
||||
pub mod text_generation;
|
||||
pub mod utilities_lib;
|
||||
pub mod openai_types;
|
||||
pub mod cli;
|
||||
// pub mod cli;
|
||||
pub mod server;
|
||||
pub mod inference;
|
||||
|
||||
// Re-export key components for easier access
|
||||
pub use model::{Model, Which};
|
||||
pub use text_generation::TextGeneration;
|
||||
pub use token_output_stream::TokenOutputStream;
|
||||
pub use server::{AppState, create_router};
|
||||
pub use inference::ModelInference;
|
||||
|
||||
use std::env;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
@@ -2,6 +2,7 @@
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
pub enum Which {
|
||||
@@ -37,12 +38,17 @@ pub enum Which {
|
||||
BaseV3_1B,
|
||||
#[value(name = "3-1b-it")]
|
||||
InstructV3_1B,
|
||||
#[value(name = "llama-3.2-1b-it")]
|
||||
LlamaInstruct3_2_1B,
|
||||
#[value(name = "llama-3.2-3b-it")]
|
||||
LlamaInstruct3_2_3B,
|
||||
}
|
||||
|
||||
pub enum Model {
|
||||
V1(Model1),
|
||||
V2(Model2),
|
||||
V3(Model3),
|
||||
Llama(LlamaModel),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@@ -51,6 +57,7 @@ impl Model {
|
||||
Self::V1(m) => m.forward(input_ids, pos),
|
||||
Self::V2(m) => m.forward(input_ids, pos),
|
||||
Self::V3(m) => m.forward(input_ids, pos),
|
||||
Self::Llama(m) => m.forward(input_ids, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -74,6 +81,8 @@ impl Which {
|
||||
Self::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
||||
Self::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
||||
Self::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
||||
Self::LlamaInstruct3_2_1B => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
|
||||
Self::LlamaInstruct3_2_3B => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,4 +96,8 @@ impl Which {
|
||||
pub fn is_v3_model(&self) -> bool {
|
||||
matches!(self, Self::BaseV3_1B | Self::InstructV3_1B)
|
||||
}
|
||||
|
||||
pub fn is_llama_model(&self) -> bool {
|
||||
matches!(self, Self::LlamaInstruct3_2_1B | Self::LlamaInstruct3_2_3B)
|
||||
}
|
||||
}
|
@@ -5,304 +5,85 @@ use axum::{
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use candle_core::DType;
|
||||
use candle_nn::VarBuilder;
|
||||
use futures_util::stream::{self, Stream};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use std::convert::Infallible;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
|
||||
use crate::text_generation::TextGeneration;
|
||||
use crate::{utilities_lib, Model as GemmaModel, Which};
|
||||
use crate::Which;
|
||||
use either::Either;
|
||||
use hf_hub::api::sync::{Api, ApiError};
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
use serde_json::Value;
|
||||
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
|
||||
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
|
||||
// -------------------------
|
||||
// Shared app state
|
||||
// -------------------------
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ModelType {
|
||||
Gemma,
|
||||
Llama,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub text_generation: Arc<Mutex<TextGeneration>>,
|
||||
pub model_type: ModelType,
|
||||
pub model_id: String,
|
||||
// Store build args to recreate TextGeneration when needed
|
||||
pub build_args: PipelineArgs,
|
||||
pub gemma_config: Option<GemmaInferenceConfig>,
|
||||
pub llama_config: Option<LlamaInferenceConfig>,
|
||||
}
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
let args = PipelineArgs::default();
|
||||
let text_generation = build_pipeline(args.clone());
|
||||
let gemma_config = GemmaInferenceConfig {
|
||||
model: gemma_runner::WhichModel::InstructV3_1B,
|
||||
..Default::default()
|
||||
};
|
||||
Self {
|
||||
text_generation: Arc::new(Mutex::new(text_generation)),
|
||||
model_id: args.model_id.clone(),
|
||||
build_args: args,
|
||||
model_type: ModelType::Gemma,
|
||||
model_id: "gemma-3-1b-it".to_string(),
|
||||
gemma_config: Some(gemma_config),
|
||||
llama_config: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------
|
||||
// Pipeline configuration
|
||||
// Helper functions
|
||||
// -------------------------
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PipelineArgs {
|
||||
pub model_id: String,
|
||||
pub which: Which,
|
||||
pub revision: Option<String>,
|
||||
pub tokenizer_path: Option<PathBuf>,
|
||||
pub config_path: Option<PathBuf>,
|
||||
pub weight_paths: Vec<PathBuf>,
|
||||
pub use_flash_attn: bool,
|
||||
pub force_cpu: bool,
|
||||
pub seed: u64,
|
||||
pub temperature: Option<f64>,
|
||||
pub top_p: Option<f64>,
|
||||
pub repeat_penalty: f32,
|
||||
pub repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl Default for PipelineArgs {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model_id: Which::InstructV3_1B.to_model_id().to_string(),
|
||||
which: Which::InstructV3_1B,
|
||||
revision: None,
|
||||
tokenizer_path: None,
|
||||
config_path: None,
|
||||
weight_paths: Vec::new(),
|
||||
use_flash_attn: false,
|
||||
force_cpu: false,
|
||||
seed: 299792458, // Speed of light in vacuum (m/s)
|
||||
temperature: Some(0.8), // Good balance between creativity and coherence
|
||||
top_p: Some(0.9), // Keep diverse but reasonable options
|
||||
repeat_penalty: 1.2, // Stronger penalty for repetition to prevent looping
|
||||
repeat_last_n: 64, // Consider last 64 tokens for repetition
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_model_id(model_id: &str) -> String {
|
||||
if model_id.contains('/') {
|
||||
model_id.to_string()
|
||||
} else {
|
||||
format!("google/{}", model_id)
|
||||
}
|
||||
}
|
||||
|
||||
fn ensure_repo_exists(api: &Api, model_id: &str, revision: &str) -> anyhow::Result<()> {
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id.to_string(),
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
));
|
||||
match repo.get("config.json") {
|
||||
Ok(_) => Ok(()),
|
||||
Err(e) => match e {
|
||||
ApiError::RequestError(resp) => {
|
||||
let error_str = resp.to_string();
|
||||
if error_str.contains("404") {
|
||||
anyhow::bail!(
|
||||
"Hugging Face model repo not found: '{model_id}' at revision '{revision}'."
|
||||
)
|
||||
}
|
||||
Err(anyhow::Error::new(ApiError::RequestError(resp)))
|
||||
}
|
||||
other => Err(anyhow::Error::new(other)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------
|
||||
// Pipeline builder
|
||||
// -------------------------
|
||||
|
||||
pub fn build_pipeline(mut args: PipelineArgs) -> TextGeneration {
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle_core::utils::with_avx(),
|
||||
candle_core::utils::with_neon(),
|
||||
candle_core::utils::with_simd128(),
|
||||
candle_core::utils::with_f16c()
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new().unwrap();
|
||||
let revision = args.revision.as_deref().unwrap_or("main");
|
||||
|
||||
if args.model_id.trim().is_empty() {
|
||||
panic!("No model ID specified.");
|
||||
}
|
||||
args.model_id = normalize_model_id(&args.model_id);
|
||||
|
||||
match ensure_repo_exists(&api, &args.model_id, revision) {
|
||||
Ok(_) => {}
|
||||
Err(e) => panic!("{}", e),
|
||||
};
|
||||
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id.clone(),
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
));
|
||||
|
||||
let tokenizer_path = args
|
||||
.tokenizer_path
|
||||
.unwrap_or_else(|| repo.get("tokenizer.json").unwrap());
|
||||
let config_path = args
|
||||
.config_path
|
||||
.unwrap_or_else(|| repo.get("config.json").unwrap());
|
||||
|
||||
if !matches!(
|
||||
args.which,
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
| Which::Instruct2B
|
||||
| Which::Instruct7B
|
||||
| Which::InstructV1_1_2B
|
||||
| Which::InstructV1_1_7B
|
||||
| Which::CodeBase2B
|
||||
| Which::CodeBase7B
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B
|
||||
| Which::BaseV2_2B
|
||||
| Which::InstructV2_2B
|
||||
| Which::BaseV2_9B
|
||||
| Which::InstructV2_9B
|
||||
| Which::BaseV3_1B
|
||||
| Which::InstructV3_1B
|
||||
) {
|
||||
if args.model_id.contains("gemma-2-2b-it") {
|
||||
args.which = Which::InstructV2_2B;
|
||||
} else if args.model_id.contains("gemma-3-1b-it") {
|
||||
args.which = Which::InstructV3_1B;
|
||||
} else if let Ok(file) = std::fs::File::open(config_path.clone()) {
|
||||
if let Ok(cfg_val) = serde_json::from_reader::<_, serde_json::Value>(file) {
|
||||
if let Some(model_type) = cfg_val.get("model_type").and_then(|v| v.as_str()) {
|
||||
if model_type.contains("gemma3") {
|
||||
args.which = Which::InstructV3_1B;
|
||||
} else if model_type.contains("gemma2") {
|
||||
args.which = Which::InstructV2_2B;
|
||||
} else {
|
||||
args.which = Which::Instruct2B;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let weight_paths = if !args.weight_paths.is_empty() {
|
||||
args.weight_paths
|
||||
} else {
|
||||
match repo.get("model.safetensors") {
|
||||
Ok(single) => vec![single],
|
||||
Err(_) => match utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json") {
|
||||
Ok(paths) => paths,
|
||||
Err(e) => {
|
||||
panic!("Unable to locate model weights: {}", e);
|
||||
}
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).unwrap();
|
||||
|
||||
let initial_device = utilities_lib::device(args.force_cpu).unwrap();
|
||||
let is_v3_model = args.which.is_v3_model();
|
||||
let is_metal = !initial_device.is_cpu()
|
||||
&& candle_core::utils::metal_is_available()
|
||||
&& !args.force_cpu;
|
||||
|
||||
let device = if is_v3_model && is_metal {
|
||||
candle_core::Device::Cpu
|
||||
} else {
|
||||
initial_device
|
||||
};
|
||||
|
||||
let dtype = if device.is_cuda() { DType::BF16 } else { DType::F32 };
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_paths, dtype, &device).unwrap() };
|
||||
|
||||
let model = match args.which {
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
| Which::Instruct2B
|
||||
| Which::Instruct7B
|
||||
| Which::InstructV1_1_2B
|
||||
| Which::InstructV1_1_7B
|
||||
| Which::CodeBase2B
|
||||
| Which::CodeBase7B
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B => {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
||||
GemmaModel::V1(Model1::new(args.use_flash_attn, &config, vb).unwrap())
|
||||
}
|
||||
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_path.clone()).unwrap()).unwrap();
|
||||
GemmaModel::V2(Model2::new(args.use_flash_attn, &config, vb).unwrap())
|
||||
}
|
||||
Which::BaseV3_1B | Which::InstructV3_1B => {
|
||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_path).unwrap()).unwrap();
|
||||
GemmaModel::V3(Model3::new(args.use_flash_attn, &config, vb).unwrap())
|
||||
}
|
||||
};
|
||||
|
||||
TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
)
|
||||
model_id.to_lowercase().replace("_", "-")
|
||||
}
|
||||
|
||||
fn build_gemma_prompt(messages: &[Message]) -> String {
|
||||
let mut prompt = String::new();
|
||||
let mut system_prompt: Option<String> = None;
|
||||
|
||||
|
||||
for message in messages {
|
||||
let content = match &message.content {
|
||||
Some(content) => match &content.0 {
|
||||
Either::Left(text) => text.clone(),
|
||||
Either::Right(_) => "".to_string(),
|
||||
},
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
match message.role.as_str() {
|
||||
"system" => system_prompt = Some(content),
|
||||
"user" => {
|
||||
prompt.push_str("<start_of_turn>user\n");
|
||||
if let Some(sys_prompt) = system_prompt.take() {
|
||||
prompt.push_str(&sys_prompt);
|
||||
prompt.push_str("\n\n");
|
||||
"system" => {
|
||||
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
||||
prompt.push_str(&format!("<start_of_turn>system\n{}<end_of_turn>\n", content));
|
||||
}
|
||||
}
|
||||
"user" => {
|
||||
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
||||
prompt.push_str(&format!("<start_of_turn>user\n{}<end_of_turn>\n", content));
|
||||
}
|
||||
prompt.push_str(&content);
|
||||
prompt.push_str("<end_of_turn>\n");
|
||||
}
|
||||
"assistant" => {
|
||||
prompt.push_str("<start_of_turn>model\n");
|
||||
prompt.push_str(&content);
|
||||
prompt.push_str("<end_of_turn>\n");
|
||||
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
||||
prompt.push_str(&format!("<start_of_turn>model\n{}<end_of_turn>\n", content));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
prompt.push_str("<start_of_turn>model\n");
|
||||
prompt
|
||||
}
|
||||
@@ -325,14 +106,13 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
state: AppState,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||
let prompt = build_gemma_prompt(&request.messages);
|
||||
|
||||
// Enforce model selection behavior: reject if a different model is requested
|
||||
let configured_model = state.build_args.model_id.clone();
|
||||
let configured_model = state.model_id.clone();
|
||||
let requested_model = request.model.clone();
|
||||
if requested_model.to_lowercase() != "default" {
|
||||
let normalized_requested = normalize_model_id(&requested_model);
|
||||
if normalized_requested != configured_model {
|
||||
let normalized_configured = normalize_model_id(&configured_model);
|
||||
if normalized_requested != normalized_configured {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
@@ -349,35 +129,81 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
}
|
||||
|
||||
let model_id = state.model_id.clone();
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
{
|
||||
let mut text_gen = state.text_generation.lock().await;
|
||||
// Reset per-request state without rebuilding the whole pipeline
|
||||
text_gen.reset_state();
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
if let Err(e) = text_gen.run_with_output(&prompt, max_tokens, &mut buffer) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error generating text: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let completion = match String::from_utf8(buffer) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("UTF-8 conversion error: {}", e) }
|
||||
})),
|
||||
));
|
||||
// Build prompt based on model type
|
||||
let prompt = match state.model_type {
|
||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||
ModelType::Llama => {
|
||||
// For Llama, just use the last user message for now
|
||||
request.messages.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
};
|
||||
|
||||
// Get streaming receiver based on model type
|
||||
let rx = match state.model_type {
|
||||
ModelType::Gemma => {
|
||||
if let Some(mut config) = state.gemma_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_gemma_api(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
}))
|
||||
));
|
||||
}
|
||||
}
|
||||
ModelType::Llama => {
|
||||
if let Some(mut config) = state.llama_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
}))
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Collect all tokens from the stream
|
||||
let mut completion = String::new();
|
||||
while let Ok(token_result) = rx.recv() {
|
||||
match token_result {
|
||||
Ok(token) => completion.push_str(&token),
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error generating text: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response = ChatCompletionResponse {
|
||||
id: format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', "")),
|
||||
object: "chat.completion".to_string(),
|
||||
@@ -420,11 +246,12 @@ async fn handle_streaming_request(
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
|
||||
// Validate requested model vs configured model
|
||||
let configured_model = state.build_args.model_id.clone();
|
||||
let configured_model = state.model_id.clone();
|
||||
let requested_model = request.model.clone();
|
||||
if requested_model.to_lowercase() != "default" {
|
||||
let normalized_requested = normalize_model_id(&requested_model);
|
||||
if normalized_requested != configured_model {
|
||||
let normalized_configured = normalize_model_id(&configured_model);
|
||||
if normalized_requested != normalized_configured {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
@@ -447,9 +274,22 @@ async fn handle_streaming_request(
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
let model_id = state.model_id.clone();
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
|
||||
// Build prompt
|
||||
let prompt = build_gemma_prompt(&request.messages);
|
||||
// Build prompt based on model type
|
||||
let prompt = match state.model_type {
|
||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||
ModelType::Llama => {
|
||||
// For Llama, just use the last user message for now
|
||||
request.messages.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
};
|
||||
tracing::debug!("Formatted prompt: {}", prompt);
|
||||
|
||||
// Channel for streaming SSE events
|
||||
@@ -471,80 +311,121 @@ async fn handle_streaming_request(
|
||||
let _ = tx.send(Ok(Event::default().data(json)));
|
||||
}
|
||||
|
||||
// Spawn generation task that streams tokens as they are generated
|
||||
let state_clone = state.clone();
|
||||
let response_id_clone = response_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
let mut text_gen = state_clone.text_generation.lock().await;
|
||||
text_gen.reset_state();
|
||||
// Get streaming receiver based on model type
|
||||
let model_rx = match state.model_type {
|
||||
ModelType::Gemma => {
|
||||
if let Some(mut config) = state.gemma_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
match run_gemma_api(config) {
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
}))
|
||||
));
|
||||
}
|
||||
}
|
||||
ModelType::Llama => {
|
||||
if let Some(mut config) = state.llama_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
match run_llama_inference(config) {
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
}))
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Stream tokens via callback with repetition detection
|
||||
// Spawn task to receive tokens from model and forward as SSE events
|
||||
let response_id_clone = response_id.clone();
|
||||
let model_id_clone = model_id.clone();
|
||||
tokio::spawn(async move {
|
||||
// Stream tokens with repetition detection
|
||||
let mut recent_tokens = Vec::new();
|
||||
let mut repetition_count = 0;
|
||||
const MAX_REPETITION_COUNT: usize = 5; // Stop after 5 consecutive repetitions
|
||||
const REPETITION_WINDOW: usize = 8; // Look at last 8 tokens for patterns
|
||||
|
||||
let result = text_gen.run_with_streaming(&prompt, max_tokens, |token| {
|
||||
// Debug log to verify token content
|
||||
tracing::debug!("Streaming token: '{}'", token);
|
||||
|
||||
// Skip sending empty tokens
|
||||
if token.is_empty() {
|
||||
tracing::debug!("Skipping empty token");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Add token to recent history for repetition detection
|
||||
recent_tokens.push(token.to_string());
|
||||
if recent_tokens.len() > REPETITION_WINDOW {
|
||||
recent_tokens.remove(0);
|
||||
}
|
||||
|
||||
// Check for repetitive patterns
|
||||
if recent_tokens.len() >= 4 {
|
||||
let last_token = &recent_tokens[recent_tokens.len() - 1];
|
||||
let second_last = &recent_tokens[recent_tokens.len() - 2];
|
||||
|
||||
// Check if we're repeating the same token or pattern
|
||||
if last_token == second_last ||
|
||||
(last_token.trim() == "plus" && second_last.trim() == "plus") ||
|
||||
(recent_tokens.len() >= 6 &&
|
||||
recent_tokens[recent_tokens.len()-3..].iter().all(|t| t.trim() == "plus" || t.trim().is_empty())) {
|
||||
repetition_count += 1;
|
||||
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
|
||||
|
||||
if repetition_count >= MAX_REPETITION_COUNT {
|
||||
tracing::info!("Stopping generation due to excessive repetition");
|
||||
return Err(anyhow::Error::msg("Repetition detected - stopping generation"));
|
||||
const MAX_REPETITION_COUNT: usize = 5;
|
||||
const REPETITION_WINDOW: usize = 8;
|
||||
|
||||
while let Ok(token_result) = model_rx.recv() {
|
||||
match token_result {
|
||||
Ok(token) => {
|
||||
// Skip sending empty tokens
|
||||
if token.is_empty() {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
repetition_count = 0; // Reset counter if pattern breaks
|
||||
|
||||
// Add token to recent history for repetition detection
|
||||
recent_tokens.push(token.clone());
|
||||
if recent_tokens.len() > REPETITION_WINDOW {
|
||||
recent_tokens.remove(0);
|
||||
}
|
||||
|
||||
// Check for repetitive patterns
|
||||
if recent_tokens.len() >= 4 {
|
||||
let last_token = &recent_tokens[recent_tokens.len() - 1];
|
||||
let second_last = &recent_tokens[recent_tokens.len() - 2];
|
||||
|
||||
if last_token == second_last {
|
||||
repetition_count += 1;
|
||||
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
|
||||
|
||||
if repetition_count >= MAX_REPETITION_COUNT {
|
||||
tracing::info!("Stopping generation due to excessive repetition");
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
repetition_count = 0;
|
||||
}
|
||||
}
|
||||
|
||||
let chunk = ChatCompletionChunk {
|
||||
id: response_id_clone.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model_id_clone.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: None, content: Some(token) },
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&chunk) {
|
||||
let _ = tx.send(Ok(Event::default().data(json)));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::info!("Text generation stopped: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let chunk = ChatCompletionChunk {
|
||||
id: response_id_clone.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model_id.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: None, content: Some(token.to_string()) },
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&chunk) {
|
||||
tracing::debug!("Sending chunk with content: '{}'", token);
|
||||
let _ = tx.send(Ok(Event::default().data(json)));
|
||||
}
|
||||
Ok(())
|
||||
}).await;
|
||||
|
||||
// Log result of generation
|
||||
match result {
|
||||
Ok(_) => tracing::debug!("Text generation completed successfully"),
|
||||
Err(e) => tracing::info!("Text generation stopped: {}", e),
|
||||
}
|
||||
|
||||
// Send final stop chunk and DONE marker
|
||||
@@ -552,7 +433,7 @@ async fn handle_streaming_request(
|
||||
id: response_id_clone.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created,
|
||||
model: model_id.clone(),
|
||||
model: model_id_clone.clone(),
|
||||
choices: vec![ChatCompletionChunkChoice {
|
||||
index: 0,
|
||||
delta: Delta { role: None, content: None },
|
||||
@@ -594,6 +475,7 @@ pub fn create_router(app_state: AppState) -> Router {
|
||||
pub async fn list_models() -> Json<ModelListResponse> {
|
||||
// Get all available model variants from the Which enum
|
||||
let models = vec![
|
||||
// Gemma models
|
||||
Model {
|
||||
id: "gemma-2b".to_string(),
|
||||
object: "model".to_string(),
|
||||
@@ -690,6 +572,73 @@ pub async fn list_models() -> Json<ModelListResponse> {
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
// Llama models
|
||||
Model {
|
||||
id: "llama-3.2-1b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "llama-3.2-1b-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "llama-3.2-3b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "llama-3.2-3b-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-135m".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-135m-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-360m".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-360m-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-1.7b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-1.7b-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "tinyllama-1.1b-chat".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "tinyllama".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
Json(ModelListResponse {
|
||||
|
24
crates/llama-runner/Cargo.toml
Normal file
24
crates/llama-runner/Cargo.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[package]
|
||||
name = "llama-runner"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
||||
hf-hub = "0.3"
|
||||
tokenizers = "0.20"
|
||||
anyhow = "1.0"
|
||||
clap = { version = "4.0", features = ["derive", "string"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
|
188
crates/llama-runner/README.md
Normal file
188
crates/llama-runner/README.md
Normal file
@@ -0,0 +1,188 @@
|
||||
# Llama Runner
|
||||
|
||||
A fast Rust implementation for running Llama and other language models using the Candle deep learning framework. Built on the official Candle examples with optimizations for speed and usability.
|
||||
|
||||
## Features
|
||||
|
||||
- 🚀 **High Performance**: Metal GPU acceleration on macOS, CUDA support on Linux/Windows
|
||||
- 🤖 **Multiple Models**: Supports Llama 3.2, SmolLM2, TinyLlama, and more
|
||||
- ⚡ **Fast Inference**: Optimized with F16 precision and KV caching
|
||||
- 🎯 **Advanced Sampling**: Top-k, top-p, temperature, and repeat penalty controls
|
||||
- 📊 **Performance Metrics**: Real-time tokens/second reporting
|
||||
- 🔧 **Easy CLI**: Simple command-line interface with sensible defaults
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Size | Command | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| SmolLM2-135M | 135M | `smollm2-135m` | Tiny, fast model for testing |
|
||||
| SmolLM2-360M | 360M | `smollm2-360m` | Small, efficient model |
|
||||
| SmolLM2-1.7B | 1.7B | `smollm2-1.7b` | Balanced performance/speed |
|
||||
| Llama-3.2-1B | 1B | `llama-3.2-1b` | Meta's compact model |
|
||||
| Llama-3.2-3B | 3B | `llama-3.2-3b` | Larger Llama model |
|
||||
| TinyLlama-1.1B | 1.1B | `tinyllama-1.1b-chat` | Chat-optimized small model |
|
||||
|
||||
Add `-instruct` suffix for instruction-tuned variants (e.g., `smollm2-135m-instruct`).
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone <repository-url>
|
||||
cd llama-runner
|
||||
|
||||
# Build with GPU acceleration (recommended)
|
||||
cargo build --release --features metal # macOS
|
||||
cargo build --release --features cuda # Linux/Windows with NVIDIA GPU
|
||||
|
||||
# CPU-only build
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Fast inference with GPU acceleration
|
||||
cargo run --features metal -- --prompt "What is quantum computing?"
|
||||
|
||||
# Specify a model and parameters
|
||||
cargo run --features metal -- \
|
||||
--prompt "Write a short story about space exploration" \
|
||||
--model smollm2-360m \
|
||||
--max-tokens 100 \
|
||||
--temperature 0.8
|
||||
|
||||
# Use CPU (slower but works everywhere)
|
||||
cargo run -- --prompt "Hello, world!" --model smollm2-135m --cpu
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Text Generation
|
||||
```bash
|
||||
# Simple completion
|
||||
cargo run --features metal -- --prompt "The capital of France is"
|
||||
|
||||
# Creative writing with higher temperature
|
||||
cargo run --features metal -- \
|
||||
--prompt "Once upon a time" \
|
||||
--temperature 1.0 \
|
||||
--max-tokens 200
|
||||
```
|
||||
|
||||
### Advanced Sampling
|
||||
```bash
|
||||
# Top-k and top-p sampling
|
||||
cargo run --features metal -- \
|
||||
--prompt "Explain artificial intelligence" \
|
||||
--top-k 40 \
|
||||
--top-p 0.9 \
|
||||
--temperature 0.7
|
||||
|
||||
# Reduce repetition
|
||||
cargo run --features metal -- \
|
||||
--prompt "List the benefits of renewable energy" \
|
||||
--repeat-penalty 1.2 \
|
||||
--repeat-last-n 64
|
||||
```
|
||||
|
||||
### Different Models
|
||||
```bash
|
||||
# Ultra-fast with tiny model
|
||||
cargo run --features metal -- \
|
||||
--prompt "Quick test" \
|
||||
--model smollm2-135m
|
||||
|
||||
# Better quality with larger model
|
||||
cargo run --features metal -- \
|
||||
--prompt "Explain quantum physics" \
|
||||
--model llama-3.2-1b \
|
||||
--max-tokens 150
|
||||
```
|
||||
|
||||
## Command-Line Options
|
||||
|
||||
| Option | Short | Default | Description |
|
||||
|--------|-------|---------|-------------|
|
||||
| `--prompt` | `-p` | "The capital of France is" | Input prompt |
|
||||
| `--model` | `-m` | `smollm2-135m` | Model to use |
|
||||
| `--max-tokens` | `-n` | 100 | Maximum tokens to generate |
|
||||
| `--temperature` | `-t` | 0.8 | Sampling temperature (0.0 = deterministic) |
|
||||
| `--top-k` | | None | Top-k sampling |
|
||||
| `--top-p` | | None | Top-p (nucleus) sampling |
|
||||
| `--seed` | | 299792458 | Random seed for reproducibility |
|
||||
| `--repeat-penalty` | | 1.1 | Repetition penalty (1.0 = no penalty) |
|
||||
| `--repeat-last-n` | | 128 | Context window for repeat penalty |
|
||||
| `--cpu` | | false | Force CPU usage |
|
||||
| `--dtype` | | f16 | Data type: f16, bf16, f32 |
|
||||
| `--no-kv-cache` | | false | Disable key-value caching |
|
||||
|
||||
## Performance
|
||||
|
||||
Typical performance on Apple M2 with Metal acceleration:
|
||||
|
||||
| Model | Size | Speed | Memory |
|
||||
|-------|------|-------|--------|
|
||||
| SmolLM2-135M | 135M | ~100 tok/s | ~500MB |
|
||||
| SmolLM2-360M | 360M | ~80 tok/s | ~1GB |
|
||||
| SmolLM2-1.7B | 1.7B | ~50 tok/s | ~3GB |
|
||||
| Llama-3.2-1B | 1B | ~40 tok/s | ~2GB |
|
||||
|
||||
## Requirements
|
||||
|
||||
- **Rust**: 1.70+ (latest stable recommended)
|
||||
- **Memory**: 2-8GB RAM depending on model size
|
||||
- **Storage**: 1-10GB for model weights
|
||||
- **Network**: Internet connection for first-time model download
|
||||
- **GPU** (optional): Metal on macOS, CUDA on Linux/Windows
|
||||
|
||||
## GPU Support
|
||||
|
||||
### macOS (Metal)
|
||||
```bash
|
||||
cargo run --features metal -- [options]
|
||||
```
|
||||
|
||||
### Linux/Windows (CUDA)
|
||||
```bash
|
||||
cargo run --features cuda -- [options]
|
||||
```
|
||||
|
||||
### CPU Only
|
||||
```bash
|
||||
cargo run -- --cpu [options]
|
||||
```
|
||||
|
||||
## Model Downloads
|
||||
|
||||
Models are automatically downloaded from HuggingFace Hub on first use and cached locally. Download times:
|
||||
|
||||
- SmolLM2-135M: ~1 minute
|
||||
- SmolLM2-360M: ~2 minutes
|
||||
- Llama-3.2-1B: ~5 minutes
|
||||
- Larger models: 10+ minutes
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Slow Performance
|
||||
- Use `--features metal` on macOS or `--features cuda` on Linux/Windows
|
||||
- Try smaller models like `smollm2-135m` for faster inference
|
||||
- Ensure sufficient RAM for your chosen model
|
||||
|
||||
### Out of Memory
|
||||
- Use `--cpu` to use system RAM instead of GPU memory
|
||||
- Try smaller models or reduce `--max-tokens`
|
||||
- Use `--dtype f32` if f16 causes issues
|
||||
|
||||
### Model Download Issues
|
||||
- Check internet connection
|
||||
- Some models may require HuggingFace Hub authentication
|
||||
- Verify sufficient disk space in `~/.cache/huggingface/`
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions welcome! This project is based on the [Candle](https://github.com/huggingface/candle) framework by HuggingFace.
|
||||
|
||||
## License
|
||||
|
||||
MIT License - see LICENSE file for details.
|
8
crates/llama-runner/src/lib.rs
Normal file
8
crates/llama-runner/src/lib.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
pub mod llama_api;
|
||||
|
||||
use clap::ValueEnum;
|
||||
pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
||||
|
||||
// Re-export constants and types that might be needed
|
||||
pub const EOS_TOKEN: &str = "</s>";
|
||||
|
337
crates/llama-runner/src/llama_api.rs
Normal file
337
crates/llama-runner/src/llama_api.rs
Normal file
@@ -0,0 +1,337 @@
|
||||
use anyhow::{bail, Error as E};
|
||||
use candle_core::{utils, DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use candle_transformers::models::llama::{Llama, LlamaConfig};
|
||||
use candle_transformers::models::llama as model;
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use std::sync::mpsc::{self, Receiver};
|
||||
use clap::ValueEnum;
|
||||
use crate::{EOS_TOKEN};
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
|
||||
pub enum WhichModel {
|
||||
#[value(name = "llama-3.2-1b")]
|
||||
#[default]
|
||||
Llama32_1B,
|
||||
#[value(name = "llama-3.2-1b-instruct")]
|
||||
Llama32_1BInstruct,
|
||||
#[value(name = "llama-3.2-3b")]
|
||||
Llama32_3B,
|
||||
#[value(name = "llama-3.2-3b-instruct")]
|
||||
Llama32_3BInstruct,
|
||||
#[value(name = "smollm2-135m")]
|
||||
SmolLM2_135M,
|
||||
#[value(name = "smollm2-135m-instruct")]
|
||||
SmolLM2_135MInstruct,
|
||||
#[value(name = "smollm2-360m")]
|
||||
SmolLM2_360M,
|
||||
#[value(name = "smollm2-360m-instruct")]
|
||||
SmolLM2_360MInstruct,
|
||||
#[value(name = "smollm2-1.7b")]
|
||||
SmolLM2_1_7B,
|
||||
#[value(name = "smollm2-1.7b-instruct")]
|
||||
SmolLM2_1_7BInstruct,
|
||||
#[value(name = "tinyllama-1.1b-chat")]
|
||||
TinyLlama1_1BChat,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LlamaInferenceConfig {
|
||||
pub prompt: String,
|
||||
|
||||
pub model: WhichModel,
|
||||
pub cpu: bool,
|
||||
pub temperature: f64,
|
||||
pub top_p: Option<f64>,
|
||||
pub top_k: Option<usize>,
|
||||
pub seed: u64,
|
||||
pub max_tokens: usize,
|
||||
pub no_kv_cache: bool,
|
||||
pub dtype: Option<String>,
|
||||
pub model_id: Option<String>,
|
||||
pub revision: Option<String>,
|
||||
pub use_flash_attn: bool,
|
||||
pub repeat_penalty: f32,
|
||||
pub repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl Default for LlamaInferenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
// Leave prompt empty by default; let call sites set it.
|
||||
prompt: String::new(),
|
||||
|
||||
// Keep your existing model choice; swap at call-site if needed.
|
||||
model: WhichModel::Llama32_1BInstruct,
|
||||
|
||||
// Prefer GPU if available.
|
||||
cpu: false,
|
||||
|
||||
// Sampling: balanced + stable
|
||||
temperature: 0.7,
|
||||
top_p: Some(0.95),
|
||||
top_k: Some(50),
|
||||
|
||||
// Reproducible by default; override for variability.
|
||||
seed: 42,
|
||||
|
||||
// Don’t run unbounded generations.
|
||||
max_tokens: 512,
|
||||
|
||||
// Performance flags
|
||||
no_kv_cache: false, // keep cache ON for speed
|
||||
use_flash_attn: true, // great speed boost if supported
|
||||
|
||||
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
|
||||
dtype: Some("bf16".to_string()),
|
||||
|
||||
// Optional model source pinning (None = app defaults)
|
||||
model_id: None,
|
||||
revision: None,
|
||||
|
||||
// Anti-repeat heuristics
|
||||
repeat_penalty: 1.15,
|
||||
repeat_last_n: 128,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
fn device(cpu: bool) -> anyhow::Result<Device> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
} else if utils::cuda_is_available() {
|
||||
Ok(Device::new_cuda(0)?)
|
||||
} else if utils::metal_is_available() {
|
||||
Ok(Device::new_metal(0)?)
|
||||
} else {
|
||||
Ok(Device::Cpu)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn hub_load_safetensors(
|
||||
api: &hf_hub::api::sync::ApiRepo,
|
||||
json_file: &str,
|
||||
) -> anyhow::Result<Vec<std::path::PathBuf>> {
|
||||
let json_file = api.get(json_file)?;
|
||||
let json_file = std::fs::File::open(json_file)?;
|
||||
let json: serde_json::Value = serde_json::from_reader(&json_file)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
None => bail!("no weight map in {json_file:?}"),
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
Some(_) => bail!("weight map in {json_file:?} is not a map"),
|
||||
};
|
||||
let mut safetensors_files = std::collections::HashSet::new();
|
||||
for value in weight_map.values() {
|
||||
if let Some(file) = value.as_str() {
|
||||
safetensors_files.insert(file.to_string());
|
||||
}
|
||||
}
|
||||
let safetensors_files = safetensors_files
|
||||
.iter()
|
||||
.map(|v| api.get(v))
|
||||
.collect::<anyhow::Result<Vec<_>, _>>()?;
|
||||
Ok(safetensors_files)
|
||||
}
|
||||
|
||||
pub fn run_llama_inference(
|
||||
cfg: LlamaInferenceConfig,
|
||||
) -> anyhow::Result<Receiver<anyhow::Result<String>>, anyhow::Error> {
|
||||
// ---- Device & dtype -----------------------------------------------------
|
||||
let device = device(cfg.cpu)?;
|
||||
println!("Device: {:?}", device);
|
||||
|
||||
let dtype = match cfg.dtype.as_deref() {
|
||||
Some("f16") => DType::F16,
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
println!("Using dtype: {:?}", dtype);
|
||||
|
||||
// ---- Load model & tokenizer --------------------------------------------
|
||||
let (llama, tokenizer, mut cache) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = cfg.model_id.clone().unwrap_or_else(|| {
|
||||
match cfg.model {
|
||||
WhichModel::Llama32_1B => "meta-llama/Llama-3.2-1B",
|
||||
WhichModel::Llama32_1BInstruct => "meta-llama/Llama-3.2-1B-Instruct",
|
||||
WhichModel::Llama32_3B => "meta-llama/Llama-3.2-3B",
|
||||
WhichModel::Llama32_3BInstruct => "meta-llama/Llama-3.2-3B-Instruct",
|
||||
WhichModel::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
|
||||
WhichModel::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||
WhichModel::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
|
||||
WhichModel::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
|
||||
WhichModel::SmolLM2_1_7B => "HuggingFaceTB/SmolLM2-1.7B",
|
||||
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
}
|
||||
.to_string()
|
||||
});
|
||||
println!("Loading model: {}", model_id);
|
||||
let revision = cfg.revision.clone().unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let config_filename = api.get("config.json")?;
|
||||
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let config = config.into_config(cfg.use_flash_attn);
|
||||
|
||||
let filenames = match cfg.model {
|
||||
WhichModel::Llama32_3B | WhichModel::Llama32_3BInstruct => {
|
||||
hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
}
|
||||
_ => vec![api.get("model.safetensors")?],
|
||||
};
|
||||
|
||||
let cache = model::Cache::new(!cfg.no_kv_cache, dtype, &config, &device)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let llama = Llama::load(vb, &config)?;
|
||||
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
(llama, tokenizer, cache)
|
||||
};
|
||||
|
||||
// ---- Prepare prompt & sampler ------------------------------------------
|
||||
let eos_token_id = tokenizer
|
||||
.token_to_id(EOS_TOKEN)
|
||||
.map(model::LlamaEosToks::Single);
|
||||
|
||||
let mut tokens = tokenizer
|
||||
.encode(cfg.prompt.as_str(), true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
println!("Starting inference...");
|
||||
|
||||
let mut logits_processor = {
|
||||
let temperature = cfg.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (cfg.top_k, cfg.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(cfg.seed, sampling)
|
||||
};
|
||||
|
||||
// Channel for streaming decoded fragments to the caller.
|
||||
let (tx, rx) = mpsc::channel::<anyhow::Result<String>>();
|
||||
|
||||
// ---- Spawn generation thread -------------------------------------------
|
||||
std::thread::spawn(move || {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0usize;
|
||||
let mut token_generated = 0usize;
|
||||
|
||||
for index in 0..cfg.max_tokens {
|
||||
// Use KV-cache for single-token step after the first pass.
|
||||
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
|
||||
(1, index_pos)
|
||||
} else {
|
||||
(tokens.len(), 0)
|
||||
};
|
||||
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = match Tensor::new(ctxt, &device).and_then(|t| t.unsqueeze(0)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let logits = match llama.forward(&input, context_index, &mut cache) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
let logits = match logits.squeeze(0) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let logits = if cfg.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(cfg.repeat_last_n);
|
||||
match candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
cfg.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = match logits_processor.sample(&logits) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
token_generated += 1;
|
||||
tokens.push(next_token);
|
||||
|
||||
// Early stop on EOS.
|
||||
let stop = match eos_token_id {
|
||||
Some(model::LlamaEosToks::Single(eos_tok_id)) => next_token == eos_tok_id,
|
||||
Some(model::LlamaEosToks::Multiple(ref eos_ids)) => eos_ids.contains(&next_token),
|
||||
None => false,
|
||||
};
|
||||
if stop {
|
||||
break;
|
||||
}
|
||||
|
||||
// Decode this token's text and stream it out.
|
||||
match tokenizer.decode(&[next_token], false) {
|
||||
Ok(text) => {
|
||||
if !text.is_empty() {
|
||||
// Best-effort send; if receiver is gone, just stop.
|
||||
if tx.send(Ok(text)).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(anyhow::anyhow!("{}", e)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Optional: final stats as a debug line (not sent through the stream).
|
||||
let dt = start_gen.elapsed();
|
||||
eprintln!(
|
||||
"[llama-runner] {} tokens generated ({:.2} tokens/s)",
|
||||
token_generated,
|
||||
token_generated as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
// Dropping tx closes the stream.
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
}
|
||||
|
109
crates/llama-runner/src/llama_cli.rs
Normal file
109
crates/llama-runner/src/llama_cli.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
use crate::llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
||||
use clap::Parser;
|
||||
use std::io::Write;
|
||||
|
||||
#[derive(Parser, Debug, Default)]
|
||||
#[command(author, version, about = "Fast Llama inference with Candle", long_about = None)]
|
||||
struct Args {
|
||||
/// The prompt to generate text from
|
||||
#[arg(short, long, default_value = "The capital of France is")]
|
||||
prompt: String,
|
||||
|
||||
/// The model to use
|
||||
#[arg(short, long, default_value = "llama-3.2-1b-instruct")]
|
||||
model: WhichModel,
|
||||
|
||||
/// Run on CPU rather than GPU
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The temperature used to generate samples
|
||||
#[arg(short, long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens)
|
||||
#[arg(short = 'n', long, default_value_t = 100)]
|
||||
max_tokens: usize,
|
||||
|
||||
/// Disable the key-value cache
|
||||
#[arg(long)]
|
||||
no_kv_cache: bool,
|
||||
|
||||
/// Use different dtype than f16
|
||||
#[arg(long)]
|
||||
dtype: Option<String>,
|
||||
|
||||
/// Custom model ID from HuggingFace Hub
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
/// Model revision
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// Use flash attention
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty
|
||||
#[arg(long, default_value_t = 128)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl Into<LlamaInferenceConfig> for Args {
|
||||
fn into(self) -> LlamaInferenceConfig {
|
||||
LlamaInferenceConfig {
|
||||
prompt: self.prompt,
|
||||
model: self.model,
|
||||
cpu: self.cpu,
|
||||
temperature: self.temperature,
|
||||
top_p: self.top_p,
|
||||
top_k: self.top_k,
|
||||
seed: self.seed,
|
||||
max_tokens: self.max_tokens,
|
||||
no_kv_cache: self.no_kv_cache,
|
||||
dtype: self.dtype,
|
||||
model_id: self.model_id,
|
||||
revision: self.revision,
|
||||
use_flash_attn: self.use_flash_attn,
|
||||
repeat_penalty: self.repeat_penalty,
|
||||
repeat_last_n: self.repeat_last_n,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn run_cli() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let cfg = args.into();
|
||||
let rx = run_llama_inference(cfg)?;
|
||||
for msg in rx {
|
||||
match msg {
|
||||
Ok(tok) => {
|
||||
print!("{tok}");
|
||||
let _ = std::io::stdout().flush(); // <- force it out now
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("generation error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
20
crates/llama-runner/src/main.rs
Normal file
20
crates/llama-runner/src/main.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
mod llama_cli;
|
||||
mod llama_api;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use std::io::Write;
|
||||
|
||||
use crate::llama_cli::run_cli;
|
||||
|
||||
const EOS_TOKEN: &str = "</s>";
|
||||
|
||||
|
||||
fn main() -> Result<()> {
|
||||
run_cli()
|
||||
}
|
@@ -67,18 +67,7 @@ async fn main() {
|
||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||
|
||||
// Create AppState with correct model configuration
|
||||
use inference_engine::Which;
|
||||
use inference_engine::server::{PipelineArgs, build_pipeline};
|
||||
let mut pipeline_args = PipelineArgs::default();
|
||||
pipeline_args.model_id = "google/gemma-3-1b-it".to_string();
|
||||
pipeline_args.which = Which::InstructV3_1B;
|
||||
|
||||
let text_generation = build_pipeline(pipeline_args.clone());
|
||||
let app_state = AppState {
|
||||
text_generation: std::sync::Arc::new(tokio::sync::Mutex::new(text_generation)),
|
||||
model_id: "google/gemma-3-1b-it".to_string(),
|
||||
build_args: pipeline_args,
|
||||
};
|
||||
let app_state = AppState::default();
|
||||
|
||||
// Get the inference router directly from the inference engine
|
||||
let inference_router = inference_engine::create_router(app_state);
|
||||
|
@@ -22,7 +22,7 @@ The Predict-Otron-9000 is a comprehensive multi-service AI platform built around
|
||||
graph TB
|
||||
subgraph "Core Components"
|
||||
A[Main Server<br/>predict-otron-9000]
|
||||
B[Inference Engine<br/>Gemma via Candle]
|
||||
B[Inference Engine<br/>Gemma/Llama via Candle]
|
||||
C[Embeddings Engine<br/>FastEmbed]
|
||||
D[Web Frontend<br/>Leptos WASM]
|
||||
end
|
||||
@@ -52,7 +52,7 @@ graph TB
|
||||
|
||||
## Workspace Structure
|
||||
|
||||
The project uses a 4-crate Rust workspace with TypeScript tooling, designed for maximum flexibility in deployment configurations.
|
||||
The project uses a 7-crate Rust workspace with TypeScript tooling, designed for maximum flexibility in deployment configurations.
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
@@ -62,24 +62,33 @@ graph TD
|
||||
end
|
||||
|
||||
subgraph "AI Services"
|
||||
B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Candle ML]
|
||||
B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Multi-model orchestrator]
|
||||
J[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
|
||||
K[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
|
||||
C[embeddings-engine<br/>Edition: 2024<br/>Port: 8080<br/>FastEmbed]
|
||||
end
|
||||
|
||||
subgraph "Frontend"
|
||||
D[leptos-app<br/>Edition: 2021<br/>Port: 3000/8788<br/>WASM/SSR]
|
||||
end
|
||||
|
||||
subgraph "Tooling"
|
||||
L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment]
|
||||
end
|
||||
end
|
||||
|
||||
subgraph "External Tooling"
|
||||
E[cli.ts<br/>TypeScript/Bun<br/>OpenAI SDK]
|
||||
E[scripts/cli.ts<br/>TypeScript/Bun<br/>OpenAI SDK]
|
||||
end
|
||||
|
||||
subgraph "Dependencies"
|
||||
A --> B
|
||||
A --> C
|
||||
A --> D
|
||||
B -.-> F[Candle 0.9.1]
|
||||
B --> J
|
||||
B --> K
|
||||
J -.-> F[Candle 0.9.1]
|
||||
K -.-> F
|
||||
C -.-> G[FastEmbed 4.x]
|
||||
D -.-> H[Leptos 0.8.0]
|
||||
E -.-> I[OpenAI SDK 5.16+]
|
||||
@@ -87,9 +96,12 @@ graph TD
|
||||
|
||||
style A fill:#e1f5fe
|
||||
style B fill:#f3e5f5
|
||||
style J fill:#f3e5f5
|
||||
style K fill:#f3e5f5
|
||||
style C fill:#e8f5e8
|
||||
style D fill:#fff3e0
|
||||
style E fill:#fce4ec
|
||||
style L fill:#fff9c4
|
||||
```
|
||||
|
||||
## Deployment Configurations
|
||||
|
30
scripts/run_llama.sh
Normal file
30
scripts/run_llama.sh
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
PROMPT=${1:-"Say hello in one short sentence."}
|
||||
MODEL=${2:-"meta-llama/Llama-3.2-1B-Instruct"}
|
||||
MAX_NEW=${3:-64}
|
||||
FORCE_CPU=${FORCE_CPU:-0}
|
||||
|
||||
# Optional: keep HF cache local to repo if not already set
|
||||
export HF_HOME=${HF_HOME:-"$PWD/.hf-cache"}
|
||||
|
||||
BIN="$(dirname "$0")/../target/release/llama_infer"
|
||||
|
||||
if [[ ! -x "$BIN" ]]; then
|
||||
echo "Building llama-runner (release)..."
|
||||
cargo build -p llama-runner --release
|
||||
fi
|
||||
|
||||
echo "Running llama inference..." >&2
|
||||
ARGS=(
|
||||
--model-id "$MODEL"
|
||||
--prompt "$PROMPT"
|
||||
--max-new-tokens "$MAX_NEW"
|
||||
)
|
||||
|
||||
if [[ "$FORCE_CPU" == "1" || "$FORCE_CPU" == "true" ]]; then
|
||||
ARGS+=( --force-cpu )
|
||||
fi
|
||||
|
||||
"$BIN" "${ARGS[@]}"
|
Reference in New Issue
Block a user