From e6c417bd8317ce341c39f6a3c7fd68b696661990 Mon Sep 17 00:00:00 2001 From: geoffsee <> Date: Sun, 31 Aug 2025 10:49:04 -0400 Subject: [PATCH] align dependencies across inference features --- Cargo.lock | 931 +------------- crates/inference-engine/Cargo.toml | 34 +- crates/inference-engine/src/lib.rs | 5 - crates/inference-engine/src/openai_types.rs | 1 + .../inference-engine/src/text_generation.rs | 1106 ----------------- .../src/token_output_stream.rs | 87 -- crates/inference-engine/src/utilities_lib.rs | 168 --- .../tests/text_generation_tests.rs | 554 --------- .../tests/token_output_stream_tests.rs | 135 -- crates/llama-runner/Cargo.toml | 5 - 10 files changed, 17 insertions(+), 3009 deletions(-) delete mode 100644 crates/inference-engine/src/text_generation.rs delete mode 100644 crates/inference-engine/src/token_output_stream.rs delete mode 100644 crates/inference-engine/src/utilities_lib.rs delete mode 100644 crates/inference-engine/tests/text_generation_tests.rs delete mode 100644 crates/inference-engine/tests/token_output_stream_tests.rs diff --git a/Cargo.lock b/Cargo.lock index 4d90019..575d6ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18,12 +18,6 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "366ffbaa4442f4684d91e2cd7c5ea7c4ed8add41959a31447066e279e432b618" -[[package]] -name = "accelerate-src" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "415ed64958754dbe991900f3940677e6a7eefb4d7367afd70d642677b0c7d19d" - [[package]] name = "addr2line" version = "0.24.2" @@ -46,7 +40,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" dependencies = [ "cfg-if", - "const-random", "getrandom 0.3.3", "once_cell", "serde", @@ -72,21 +65,6 @@ dependencies = [ "equator", ] -[[package]] -name = "alloc-no-stdlib" -version = "2.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" - -[[package]] -name = "alloc-stdlib" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" -dependencies = [ - "alloc-no-stdlib", -] - [[package]] name = "alsa" version = "0.9.1" @@ -109,21 +87,6 @@ dependencies = [ "pkg-config", ] -[[package]] -name = "android-tzdata" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" - -[[package]] -name = "android_system_properties" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" -dependencies = [ - "libc", -] - [[package]] name = "anstream" version = "0.6.20" @@ -230,105 +193,6 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" -[[package]] -name = "arrow-array" -version = "51.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8010572cf8c745e242d1b632bd97bd6d4f40fefed5ed1290a8f433abaa686fea" -dependencies = [ - "ahash", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "chrono", - "half", - "hashbrown 0.14.5", - "num", -] - -[[package]] -name = "arrow-buffer" -version = "51.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d0a2432f0cba5692bf4cb757469c66791394bac9ec7ce63c1afe74744c37b27" -dependencies = [ - "bytes", - "half", - "num", -] - -[[package]] -name = "arrow-cast" -version = "51.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9abc10cd7995e83505cc290df9384d6e5412b207b79ce6bdff89a10505ed2cba" -dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "arrow-select", - "atoi", - "base64 0.22.1", - "chrono", - "half", - "lexical-core", - "num", - "ryu", -] - -[[package]] -name = "arrow-data" -version = "51.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2742ac1f6650696ab08c88f6dd3f0eb68ce10f8c253958a18c943a68cd04aec5" -dependencies = [ - "arrow-buffer", - "arrow-schema", - "half", - "num", -] - -[[package]] -name = "arrow-ipc" -version = "51.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a42ea853130f7e78b9b9d178cb4cd01dee0f78e64d96c2949dc0a915d6d9e19d" -dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-cast", - "arrow-data", - "arrow-schema", - "flatbuffers", -] - -[[package]] -name = "arrow-schema" -version = "51.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02d9483aaabe910c4781153ae1b6ae0393f72d9ef757d38d09d450070cf2e528" - -[[package]] -name = "arrow-select" -version = "51.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "849524fa70e0e3c5ab58394c770cb8f514d0122d20de08475f7b472ed8075830" -dependencies = [ - "ahash", - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "num", -] - -[[package]] -name = "assert_float_eq" -version = "1.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10d2119f741b79fe9907f5396d19bffcb46568cfcc315e78677d731972ac7085" - [[package]] name = "async-lock" version = "3.4.1" @@ -418,15 +282,6 @@ dependencies = [ "syn 2.0.106", ] -[[package]] -name = "atoi" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" -dependencies = [ - "num-traits", -] - [[package]] name = "atomic-waker" version = "1.1.2" @@ -585,12 +440,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d27c3610c36aee21ce8ac510e6224498de4228ad772a171ed65643a24693a5a8" -[[package]] -name = "base16ct" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" - [[package]] name = "base64" version = "0.13.1" @@ -695,27 +544,6 @@ dependencies = [ "objc2", ] -[[package]] -name = "brotli" -version = "3.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", - "brotli-decompressor", -] - -[[package]] -name = "brotli-decompressor" -version = "2.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", -] - [[package]] name = "bstr" version = "1.12.0" @@ -738,12 +566,6 @@ version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" -[[package]] -name = "by_address" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64fa3c856b712db6612c019f14756e64e4bcea13337a6b33b696333a9eaa2d06" - [[package]] name = "bytemuck" version = "1.23.2" @@ -796,12 +618,10 @@ checksum = "a9f51e2ecf6efe9737af8f993433c839f956d2b6ed4fd2dd4a7c6d8b0fa667ff" dependencies = [ "byteorder", "candle-kernels 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", - "candle-metal-kernels 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", "cudarc", "gemm 0.17.1", "half", "memmap2", - "metal 0.27.0", "num-traits", "num_cpus", "rand 0.9.2", @@ -811,7 +631,6 @@ dependencies = [ "thiserror 1.0.69", "ug", "ug-cuda", - "ug-metal", "yoke 0.7.5", "zip", ] @@ -823,7 +642,7 @@ source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0 dependencies = [ "byteorder", "candle-kernels 0.9.1 (git+https://github.com/huggingface/candle.git)", - "candle-metal-kernels 0.9.1 (git+https://github.com/huggingface/candle.git)", + "candle-metal-kernels", "cudarc", "float8", "gemm 0.17.1", @@ -845,24 +664,6 @@ dependencies = [ "zip", ] -[[package]] -name = "candle-datasets" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0a7c351dd50cda83f00f17c4412e35c69d840e453edf06064974de1cc59343d" -dependencies = [ - "byteorder", - "candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", - "candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", - "hf-hub 0.4.3", - "image", - "memmap2", - "parquet", - "rand 0.9.2", - "thiserror 1.0.69", - "tokenizers 0.21.4", -] - [[package]] name = "candle-examples" version = "0.9.1" @@ -871,7 +672,7 @@ dependencies = [ "anyhow", "candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)", "candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)", - "candle-transformers 0.9.1 (git+https://github.com/huggingface/candle.git)", + "candle-transformers", "csv", "hf-hub 0.4.3", "image", @@ -912,19 +713,6 @@ dependencies = [ "bindgen_cuda", ] -[[package]] -name = "candle-metal-kernels" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a323ee9c813707f73b6e59300661b354a70410f31fe4135170c4eda8a061534" -dependencies = [ - "half", - "metal 0.27.0", - "once_cell", - "thiserror 1.0.69", - "tracing", -] - [[package]] name = "candle-metal-kernels" version = "0.9.1" @@ -960,7 +748,7 @@ version = "0.9.1" source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af" dependencies = [ "candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)", - "candle-metal-kernels 0.9.1 (git+https://github.com/huggingface/candle.git)", + "candle-metal-kernels", "half", "num-traits", "objc2-metal", @@ -982,25 +770,6 @@ dependencies = [ "prost-build", ] -[[package]] -name = "candle-transformers" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "186cb80045dbe47e0b387ea6d3e906f02fb3056297080d9922984c90e90a72b0" -dependencies = [ - "byteorder", - "candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", - "candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", - "fancy-regex", - "num-traits", - "rand 0.9.2", - "rayon", - "serde", - "serde_json", - "serde_plain", - "tracing", -] - [[package]] name = "candle-transformers" version = "0.9.1" @@ -1076,20 +845,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" -[[package]] -name = "chrono" -version = "0.4.41" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" -dependencies = [ - "android-tzdata", - "iana-time-zone", - "js-sys", - "num-traits", - "wasm-bindgen", - "windows-link", -] - [[package]] name = "clang-sys" version = "1.8.1" @@ -1240,26 +995,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "const-random" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" -dependencies = [ - "const-random-macro", -] - -[[package]] -name = "const-random-macro" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" -dependencies = [ - "getrandom 0.2.16", - "once_cell", - "tiny-keccak", -] - [[package]] name = "const-str" version = "0.6.4" @@ -1628,15 +1363,6 @@ dependencies = [ "crypto-common", ] -[[package]] -name = "directories" -version = "5.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a49173b84e034382284f27f1af4dcbbd231ffa358c0fe316541a7337f376a35" -dependencies = [ - "dirs-sys 0.4.1", -] - [[package]] name = "dirs" version = "5.0.1" @@ -1777,18 +1503,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "enterpolation" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fadf5c8cbf7c6765ff05ccbd8811cd7bc3a763e4671755204552bf8740d042a" -dependencies = [ - "assert_float_eq", - "num-traits", - "serde", - "topology-traits", -] - [[package]] name = "enum-as-inner" version = "0.6.1" @@ -1916,12 +1630,6 @@ dependencies = [ "regex-syntax 0.8.5", ] -[[package]] -name = "fast-srgb8" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd2e7510819d6fbf51a5545c8f922716ecfb14df168a3242f7d33e0239efe6a1" - [[package]] name = "fastembed" version = "4.9.1" @@ -1972,16 +1680,6 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" -[[package]] -name = "flatbuffers" -version = "23.5.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dac53e22462d78c16d64a1cd22371b54cc3fe94aa15e7886a2fa6e5d1ab8640" -dependencies = [ - "bitflags 1.3.2", - "rustc_version", -] - [[package]] name = "flate2" version = "1.1.2" @@ -2402,7 +2100,7 @@ dependencies = [ "candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)", "candle-examples", "candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)", - "candle-transformers 0.9.1 (git+https://github.com/huggingface/candle.git)", + "candle-transformers", "clap", "hf-hub 0.4.3", "serde_json", @@ -2449,18 +2147,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "getset" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf0fc11e47561d47397154977bc219f4cf809b2974facc3ccb3b89e2436f912" -dependencies = [ - "proc-macro-error2", - "proc-macro2", - "quote", - "syn 2.0.106", -] - [[package]] name = "gif" version = "0.13.3" @@ -2803,30 +2489,6 @@ dependencies = [ "windows-registry", ] -[[package]] -name = "iana-time-zone" -version = "0.1.63" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" -dependencies = [ - "android_system_properties", - "core-foundation-sys", - "iana-time-zone-haiku", - "js-sys", - "log", - "wasm-bindgen", - "windows-core 0.61.2", -] - -[[package]] -name = "iana-time-zone-haiku" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" -dependencies = [ - "cc", -] - [[package]] name = "icu_collections" version = "2.0.0" @@ -3021,52 +2683,31 @@ dependencies = [ "web-time", ] -[[package]] -name = "indoc" -version = "2.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" - [[package]] name = "inference-engine" version = "0.1.0" dependencies = [ "ab_glyph", - "accelerate-src", "anyhow", "axum", "bindgen_cuda", "byteorder", - "candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", - "candle-datasets", + "candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)", "candle-flash-attn", - "candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", + "candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)", "candle-onnx", - "candle-transformers 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", + "candle-transformers", "clap", "cpal", - "csv", - "cudarc", "either", - "enterpolation", "futures-util", "gemma-runner", - "half", - "hf-hub 0.4.3", - "image", "imageproc", - "intel-mkl-src", "llama-runner", "memmap2", - "num-traits", - "palette", "pdf2image", - "pyo3", "rand 0.9.2", - "rayon", "reborrow", - "rubato", - "safetensors", "serde", "serde_json", "symphonia", @@ -3091,34 +2732,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "integer-encoding" -version = "3.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" - -[[package]] -name = "intel-mkl-src" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ee70586cd5b3e772a8739a1bd43eaa90d4f4bf0fb2a4edc202e979937ee7f5e" -dependencies = [ - "anyhow", - "intel-mkl-tool", - "ocipkg", -] - -[[package]] -name = "intel-mkl-tool" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "887a16b4537d82227af54d3372971cfa5e0cde53322e60f57584056c16ada1b4" -dependencies = [ - "anyhow", - "log", - "walkdir", -] - [[package]] name = "interpolate_name" version = "0.2.4" @@ -3524,70 +3137,6 @@ dependencies = [ "tachys", ] -[[package]] -name = "lexical-core" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" -dependencies = [ - "lexical-parse-float", - "lexical-parse-integer", - "lexical-util", - "lexical-write-float", - "lexical-write-integer", -] - -[[package]] -name = "lexical-parse-float" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" -dependencies = [ - "lexical-parse-integer", - "lexical-util", - "static_assertions", -] - -[[package]] -name = "lexical-parse-integer" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" -dependencies = [ - "lexical-util", - "static_assertions", -] - -[[package]] -name = "lexical-util" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" -dependencies = [ - "static_assertions", -] - -[[package]] -name = "lexical-write-float" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" -dependencies = [ - "lexical-util", - "lexical-write-integer", - "static_assertions", -] - -[[package]] -name = "lexical-write-integer" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" -dependencies = [ - "lexical-util", - "static_assertions", -] - [[package]] name = "libc" version = "0.2.175" @@ -3656,7 +3205,7 @@ dependencies = [ "anyhow", "candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)", "candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)", - "candle-transformers 0.9.1 (git+https://github.com/huggingface/candle.git)", + "candle-transformers", "clap", "hf-hub 0.3.2", "serde_json", @@ -3694,15 +3243,6 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" -[[package]] -name = "lz4_flex" -version = "0.11.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08ab2867e3eeeca90e844d1940eab391c9dc5228783db2ed999acbc0a9ed375a" -dependencies = [ - "twox-hash 2.1.1", -] - [[package]] name = "mach2" version = "0.4.3" @@ -3811,30 +3351,6 @@ dependencies = [ "stable_deref_trait", ] -[[package]] -name = "memoffset" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" -dependencies = [ - "autocfg", -] - -[[package]] -name = "metal" -version = "0.27.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c43f73953f8cbe511f021b58f18c3ce1c3d1ae13fe953293e13345bf83217f25" -dependencies = [ - "bitflags 2.9.2", - "block", - "core-graphics-types", - "foreign-types 0.5.0", - "log", - "objc", - "paste", -] - [[package]] name = "metal" version = "0.29.0" @@ -4182,7 +3698,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" dependencies = [ "malloc_buf", - "objc_exception", ] [[package]] @@ -4238,15 +3753,6 @@ dependencies = [ "objc2-foundation", ] -[[package]] -name = "objc_exception" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad970fb455818ad6cba4c122ad012fae53ae8b4795f86378bce65e4f6bab2ca4" -dependencies = [ - "cc", -] - [[package]] name = "object" version = "0.36.7" @@ -4279,48 +3785,6 @@ dependencies = [ "cc", ] -[[package]] -name = "oci-spec" -version = "0.6.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdf88ddc01cc6bccbe1044adb6a29057333f523deadcb4953c011a73158cfa5e" -dependencies = [ - "derive_builder", - "getset", - "serde", - "serde_json", - "strum", - "strum_macros", - "thiserror 1.0.69", -] - -[[package]] -name = "ocipkg" -version = "0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9bb3293021f06540803301af45e7ab81693d50e89a7398a3420bdab139e7ba5e" -dependencies = [ - "base16ct", - "base64 0.22.1", - "chrono", - "directories", - "flate2", - "lazy_static", - "log", - "oci-spec", - "regex", - "serde", - "serde_json", - "sha2", - "tar", - "thiserror 1.0.69", - "toml 0.8.23", - "ureq", - "url", - "uuid", - "walkdir", -] - [[package]] name = "oco_ref" version = "0.2.1" @@ -4421,15 +3885,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c04f5d74368e4d0dfe06c45c8627c81bd7c317d52762d118fb9b3076f6420fd" -[[package]] -name = "ordered-float" -version = "2.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" -dependencies = [ - "num-traits", -] - [[package]] name = "ort" version = "2.0.0-rc.9" @@ -4469,30 +3924,6 @@ dependencies = [ "ttf-parser", ] -[[package]] -name = "palette" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cbf71184cc5ecc2e4e1baccdb21026c20e5fc3dcf63028a086131b3ab00b6e6" -dependencies = [ - "approx", - "fast-srgb8", - "palette_derive", - "phf", -] - -[[package]] -name = "palette_derive" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5030daf005bface118c096f510ffb781fc28f9ab6a32ab224d8631be6851d30" -dependencies = [ - "by_address", - "proc-macro2", - "quote", - "syn 2.0.106", -] - [[package]] name = "parking" version = "2.2.1" @@ -4522,38 +3953,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "parquet" -version = "51.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "096795d4f47f65fd3ee1ec5a98b77ab26d602f2cc785b0e4be5443add17ecc32" -dependencies = [ - "ahash", - "arrow-array", - "arrow-buffer", - "arrow-cast", - "arrow-data", - "arrow-ipc", - "arrow-schema", - "arrow-select", - "base64 0.22.1", - "brotli", - "bytes", - "chrono", - "flate2", - "half", - "hashbrown 0.14.5", - "lz4_flex", - "num", - "num-bigint", - "paste", - "seq-macro", - "snap", - "thrift", - "twox-hash 1.6.3", - "zstd", -] - [[package]] name = "paste" version = "1.0.15" @@ -4594,48 +3993,6 @@ dependencies = [ "indexmap", ] -[[package]] -name = "phf" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" -dependencies = [ - "phf_macros", - "phf_shared", -] - -[[package]] -name = "phf_generator" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" -dependencies = [ - "phf_shared", - "rand 0.8.5", -] - -[[package]] -name = "phf_macros" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216" -dependencies = [ - "phf_generator", - "phf_shared", - "proc-macro2", - "quote", - "syn 2.0.106", -] - -[[package]] -name = "phf_shared" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" -dependencies = [ - "siphasher", -] - [[package]] name = "pin-project" version = "1.1.10" @@ -4751,15 +4108,6 @@ dependencies = [ "syn 2.0.106", ] -[[package]] -name = "primal-check" -version = "0.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc0d895b311e3af9902528fbb8f928688abbd95872819320517cc24ca6b2bd08" -dependencies = [ - "num-integer", -] - [[package]] name = "proc-macro-crate" version = "3.3.0" @@ -4946,69 +4294,6 @@ dependencies = [ "version_check", ] -[[package]] -name = "pyo3" -version = "0.22.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884" -dependencies = [ - "cfg-if", - "indoc", - "libc", - "memoffset", - "once_cell", - "portable-atomic", - "pyo3-build-config", - "pyo3-ffi", - "pyo3-macros", - "unindent", -] - -[[package]] -name = "pyo3-build-config" -version = "0.22.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38" -dependencies = [ - "once_cell", - "target-lexicon", -] - -[[package]] -name = "pyo3-ffi" -version = "0.22.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636" -dependencies = [ - "libc", - "pyo3-build-config", -] - -[[package]] -name = "pyo3-macros" -version = "0.22.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453" -dependencies = [ - "proc-macro2", - "pyo3-macros-backend", - "quote", - "syn 2.0.106", -] - -[[package]] -name = "pyo3-macros-backend" -version = "0.22.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe" -dependencies = [ - "heck", - "proc-macro2", - "pyo3-build-config", - "quote", - "syn 2.0.106", -] - [[package]] name = "qoi" version = "0.4.1" @@ -5364,15 +4649,6 @@ dependencies = [ "syn 2.0.106", ] -[[package]] -name = "realfft" -version = "3.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f821338fddb99d089116342c46e9f1fbf3828dba077674613e734e01d6ea8677" -dependencies = [ - "rustfft", -] - [[package]] name = "reborrow" version = "0.5.5" @@ -5553,18 +4829,6 @@ dependencies = [ "thiserror 2.0.15", ] -[[package]] -name = "rubato" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5d18b486e7d29a408ef3f825bc1327d8f87af091c987ca2f5b734625940e234" -dependencies = [ - "num-complex", - "num-integer", - "num-traits", - "realfft", -] - [[package]] name = "rust-embed" version = "8.7.2" @@ -5621,20 +4885,6 @@ dependencies = [ "semver", ] -[[package]] -name = "rustfft" -version = "6.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6f140db74548f7c9d7cce60912c9ac414e74df5e718dc947d514b051b42f3f4" -dependencies = [ - "num-complex", - "num-integer", - "num-traits", - "primal-check", - "strength_reduce", - "transpose", -] - [[package]] name = "rustix" version = "1.0.8" @@ -6049,12 +5299,6 @@ dependencies = [ "quote", ] -[[package]] -name = "siphasher" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" - [[package]] name = "slab" version = "0.4.11" @@ -6076,12 +5320,6 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" -[[package]] -name = "snap" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" - [[package]] name = "socket2" version = "0.5.10" @@ -6143,37 +5381,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" -[[package]] -name = "strength_reduce" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" - [[package]] name = "strsim" version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" -[[package]] -name = "strum" -version = "0.26.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" - -[[package]] -name = "strum_macros" -version = "0.26.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "rustversion", - "syn 2.0.106", -] - [[package]] name = "subtle" version = "2.6.1" @@ -6603,17 +5816,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "thrift" -version = "0.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e54bc85fc7faa8bc175c4bab5b92ba8d9a3ce893d0e9f42cc455c8ab16a9e09" -dependencies = [ - "byteorder", - "integer-encoding", - "ordered-float", -] - [[package]] name = "throw_error" version = "0.3.0" @@ -6634,15 +5836,6 @@ dependencies = [ "weezl", ] -[[package]] -name = "tiny-keccak" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" -dependencies = [ - "crunchy", -] - [[package]] name = "tinystr" version = "0.8.1" @@ -6895,15 +6088,6 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" -[[package]] -name = "topology-traits" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0c8dab428531e30115d3bfd6e3092b55256a4a7b4f87cb3abe37a000b1f4032" -dependencies = [ - "num-traits", -] - [[package]] name = "tower" version = "0.5.2" @@ -7033,16 +6217,6 @@ dependencies = [ "tracing-log", ] -[[package]] -name = "transpose" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" -dependencies = [ - "num-integer", - "strength_reduce", -] - [[package]] name = "try-lock" version = "0.2.5" @@ -7072,22 +6246,6 @@ dependencies = [ "utf-8", ] -[[package]] -name = "twox-hash" -version = "1.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" -dependencies = [ - "cfg-if", - "static_assertions", -] - -[[package]] -name = "twox-hash" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b907da542cbced5261bd3256de1b3a1bf340a3d37f93425a07362a1d687de56" - [[package]] name = "typed-builder" version = "0.21.2" @@ -7155,7 +6313,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76daec3c7a32a1b4a0e3307b6b057fa067aa64e750713987410a2c402e5cd731" dependencies = [ "half", - "metal 0.29.0", + "metal", "objc", "serde", "thiserror 1.0.69", @@ -7207,12 +6365,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" -[[package]] -name = "unindent" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" - [[package]] name = "untrusted" version = "0.9.0" @@ -7592,7 +6744,7 @@ version = "0.54.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9252e5725dbed82865af151df558e754e4a3c2c30818359eb17465f1346a1b49" dependencies = [ - "windows-core 0.54.0", + "windows-core", "windows-targets 0.52.6", ] @@ -7606,41 +6758,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "windows-core" -version = "0.61.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" -dependencies = [ - "windows-implement", - "windows-interface", - "windows-link", - "windows-result 0.3.4", - "windows-strings", -] - -[[package]] -name = "windows-implement" -version = "0.60.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", -] - -[[package]] -name = "windows-interface" -version = "0.59.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", -] - [[package]] name = "windows-link" version = "0.1.3" @@ -8162,34 +7279,6 @@ dependencies = [ "thiserror 1.0.69", ] -[[package]] -name = "zstd" -version = "0.13.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" -dependencies = [ - "zstd-safe", -] - -[[package]] -name = "zstd-safe" -version = "7.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" -dependencies = [ - "zstd-sys", -] - -[[package]] -name = "zstd-sys" -version = "2.0.15+zstd.1.5.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" -dependencies = [ - "cc", - "pkg-config", -] - [[package]] name = "zune-core" version = "0.4.12" diff --git a/crates/inference-engine/Cargo.toml b/crates/inference-engine/Cargo.toml index ebaea5e..30739b7 100644 --- a/crates/inference-engine/Cargo.toml +++ b/crates/inference-engine/Cargo.toml @@ -4,26 +4,12 @@ version = "0.1.0" edition = "2021" [dependencies] -accelerate-src = { version = "0.3.2", optional = true } -candle-datasets = { version = "=0.9.1", optional = true } -candle-nn = { version = "=0.9.1" } -candle-transformers = { version = "=0.9.1" } +candle-core = { git = "https://github.com/huggingface/candle.git" } +candle-nn = { git = "https://github.com/huggingface/candle.git" } +candle-transformers = { git = "https://github.com/huggingface/candle.git" } candle-flash-attn = { version = "=0.9.1", optional = true } candle-onnx = { version = "=0.9.1", optional = true } -csv = "1.3.0" -cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false, optional = true } -half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"], optional = true } -hf-hub = { version = "0.4.1", features = ["tokio"] } -image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] } -intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"], optional = true } -num-traits = { version = "0.2.15" } -palette = { version = "0.7.6", optional = true } -enterpolation = { version = "0.2.1", optional = true} -pyo3 = { version = "0.22.0", features = ["auto-initialize", "abi3-py311"], optional = true } -rayon = "1.7.0" -rubato = { version = "0.15.0", optional = true } -safetensors = "0.4.1" serde = { version = "1.0.171", features = ["derive"] } serde_json = "1.0.99" symphonia = { version = "0.5.3", features = ["all"], optional = true } @@ -48,19 +34,11 @@ futures-util = "0.3.31" gemma-runner = { path = "../gemma-runner" } llama-runner = { path = "../llama-runner" } -# --- Add this section for conditional compilation --- [target.'cfg(target_os = "macos")'.dependencies] -# Use CPU backend for macOS to avoid Metal rotary-emb implementation issues -candle-core = { version = "=0.9.1", features = ["metal"], optional = false } +candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] } +candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] } +candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] } -[target.'cfg(not(target_os = "macos"))'.dependencies] -# For Linux or other non-macOS systems, you likely want the CPU backend or CUDA -# If you're building on Linux with a CUDA-enabled GPU: -candle-core = { version = "=0.9.1", features = ["cuda"], default-features = false } # Or just "cuda" if not using default features - -# If you're building on Linux with only CPU: -# candle-core = { version = "=0.9.1", default-features = false } # CPU is often the default, but good to be explicit -# --- End of conditional compilation section --- [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } diff --git a/crates/inference-engine/src/lib.rs b/crates/inference-engine/src/lib.rs index 9dd5f4c..b1643b6 100644 --- a/crates/inference-engine/src/lib.rs +++ b/crates/inference-engine/src/lib.rs @@ -1,9 +1,6 @@ // Expose modules for testing and library usage pub mod model; pub mod openai_types; -pub mod text_generation; -pub mod token_output_stream; -pub mod utilities_lib; // pub mod cli; pub mod inference; pub mod server; @@ -12,8 +9,6 @@ pub mod server; pub use inference::ModelInference; pub use model::{Model, Which}; pub use server::{create_router, AppState}; -pub use text_generation::TextGeneration; -pub use token_output_stream::TokenOutputStream; use std::env; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; diff --git a/crates/inference-engine/src/openai_types.rs b/crates/inference-engine/src/openai_types.rs index 62549c0..123e812 100644 --- a/crates/inference-engine/src/openai_types.rs +++ b/crates/inference-engine/src/openai_types.rs @@ -1,6 +1,7 @@ use either::Either; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use serde_json::json; use utoipa::ToSchema; /// Inner content structure for messages that can be either a string or key-value pairs diff --git a/crates/inference-engine/src/text_generation.rs b/crates/inference-engine/src/text_generation.rs deleted file mode 100644 index dd2e111..0000000 --- a/crates/inference-engine/src/text_generation.rs +++ /dev/null @@ -1,1106 +0,0 @@ -use anyhow::{Error as E, Result}; -use candle_core::{DType, Device, Tensor}; -use candle_transformers::generation::LogitsProcessor; -use std::collections::HashMap; -use tokenizers::Tokenizer; - -use crate::model::Model; -use crate::token_output_stream::TokenOutputStream; - -pub struct TextGeneration { - model: Model, - device: Device, - // CPU device for fallback when operations are unsupported on primary device - cpu_device: Option, - // Flag to indicate if we should try to use the primary device first - try_primary_device: bool, - tokenizer: TokenOutputStream, - logits_processor: LogitsProcessor, - repeat_penalty: f32, - repeat_last_n: usize, - // Cache for repeat penalty computation to avoid redundant calculations - penalty_cache: HashMap, - // Context window size for sliding window context (default: 64 tokens) - context_window_size: usize, -} - -impl TextGeneration { - #[allow(clippy::too_many_arguments)] - pub fn new( - model: Model, - tokenizer: Tokenizer, - seed: u64, - temp: Option, - top_p: Option, - repeat_penalty: f32, - repeat_last_n: usize, - device: &Device, - ) -> Self { - let logits_processor = LogitsProcessor::new(seed, temp, top_p); - - // Initialize CPU device only if the primary device is not already CPU - let (cpu_device, try_primary_device) = if device.is_cpu() { - // If already on CPU, no need for a fallback device - (None, false) - } else { - // Store CPU device for fallback and set flag to try primary device first - (Some(Device::Cpu), true) - }; - - Self { - model, - tokenizer: TokenOutputStream::new(tokenizer), - logits_processor, - repeat_penalty, - repeat_last_n, - device: device.clone(), - cpu_device, - try_primary_device, - penalty_cache: HashMap::new(), - context_window_size: 64, // Default sliding window size for better context preservation - } - } - - // Helper method for model execution with fallback to CPU for unsupported operations - fn execute_with_fallback(&mut self, input: &Tensor, start_pos: usize) -> Result { - // If we're not trying primary device anymore, go straight to CPU if available - if !self.try_primary_device { - if let Some(cpu_device) = &self.cpu_device { - let cpu_input = input.to_device(cpu_device).map_err(E::msg)?; - let cpu_result = self.model.forward(&cpu_input, start_pos).map_err(E::msg)?; - return cpu_result.to_device(&self.device).map_err(E::msg); - } else { - // No CPU fallback, use primary device - return self.model.forward(input, start_pos).map_err(E::msg); - } - } - - // Try running on the primary device first - match self.model.forward(input, start_pos) { - Ok(result) => Ok(result), - Err(err) => { - // Convert to string to check for unsupported operation - let err_string = err.to_string(); - - // Check if the error is about unsupported operations or shape mismatches - if (err_string.contains("no metal implementation for") - || err_string.contains("no cuda implementation for") - || err_string.contains("shape mismatch") - || err_string.contains("broadcast_add")) - && self.cpu_device.is_some() - { - // Extract operation name for better logging - let op_name = if let Some(idx) = err_string.find("for ") { - &err_string[(idx + 4)..] - } else if err_string.contains("shape mismatch") { - "shape mismatch operation" - } else { - "an operation" - }; - - // Log the fallback - tracing::warn!( - "The primary device does not support {}. Falling back to CPU.", - op_name - ); - - // Move input to CPU and try again - let cpu_device = self.cpu_device.as_ref().unwrap(); - let cpu_input = input.to_device(cpu_device).map_err(E::msg)?; - let cpu_result = self.model.forward(&cpu_input, start_pos).map_err(E::msg)?; - - // Don't try primary device for future operations - self.try_primary_device = false; - tracing::info!( - "Successfully executed on CPU. Will use CPU for subsequent operations." - ); - - // Move result back to original device - cpu_result.to_device(&self.device).map_err(E::msg) - } else { - // Not an unsupported operation error or no CPU fallback - Err(E::msg(err)) - } - } - } - } - - // Reset method to clear state between requests - pub fn reset_state(&mut self) { - // Reset the primary device flag so we try the primary device first for each new request - if !self.device.is_cpu() { - self.try_primary_device = true; - } - // Clear the penalty cache to avoid stale cached values from previous requests - self.penalty_cache.clear(); - } - - // Helper method to apply repeat penalty with caching for optimization - pub fn apply_cached_repeat_penalty( - &mut self, - logits: Tensor, - tokens: &[u32], - ) -> Result<(Tensor, std::time::Duration)> { - let repeat_start = std::time::Instant::now(); - - // If no penalty, return the original logits - if self.repeat_penalty == 1.0 { - return Ok((logits, repeat_start.elapsed())); - } - - // Get the tokens to penalize (the last n tokens) - let start_at = tokens.len().saturating_sub(self.repeat_last_n); - let penalty_tokens = &tokens[start_at..]; - - // Extract logits to a vector for modification - let mut logits_vec = logits.to_vec1::()?; - let cache_hits = std::cell::Cell::new(0); - - // Apply penalties with caching - for &token_id in penalty_tokens { - let token_id = token_id as usize; - if token_id < logits_vec.len() { - // Check if we've already calculated this token's penalty - if let Some(penalized_score) = self.penalty_cache.get(&token_id) { - // Use cached value - logits_vec[token_id] = *penalized_score; - cache_hits.set(cache_hits.get() + 1); - } else { - // Calculate and cache new value - let score = logits_vec[token_id]; - let sign = if score < 0.0 { -1.0 } else { 1.0 }; - let penalized_score = sign * score / self.repeat_penalty; - logits_vec[token_id] = penalized_score; - self.penalty_cache.insert(token_id, penalized_score); - } - } - } - - // Log cache efficiency statistics - if !penalty_tokens.is_empty() { - let cache_efficiency = (cache_hits.get() as f32 / penalty_tokens.len() as f32) * 100.0; - tracing::trace!( - "Repeat penalty cache hits: {}/{} ({:.1}%)", - cache_hits.get(), - penalty_tokens.len(), - cache_efficiency - ); - } - - // Create a new tensor with the modified logits (single tensor creation) - let device = logits.device().clone(); - let shape = logits.shape().clone(); - let new_logits = Tensor::new(&logits_vec[..], &device)?; - let result = new_logits.reshape(shape)?; - - let elapsed = repeat_start.elapsed(); - Ok((result, elapsed)) - } - - // Run text generation and print to stdout - pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { - use std::io::Write; - - // Track overall performance - let start_time = std::time::Instant::now(); - - // Keep penalty cache across generation for better repetition prevention - // Only clear cache if it becomes too large to prevent memory bloat - if self.penalty_cache.len() > 10000 { - self.penalty_cache.clear(); - tracing::debug!("Cleared penalty cache due to size limit"); - } else { - tracing::debug!( - "Maintaining penalty cache across generation for better repetition prevention" - ); - } - - // Phase 1: Tokenize input - let tokenize_start = std::time::Instant::now(); - self.tokenizer.clear(); - let mut tokens = self - .tokenizer - .tokenizer() - .encode(prompt, true) - .map_err(E::msg)? - .get_ids() - .to_vec(); - - let tokenize_time = tokenize_start.elapsed(); - tracing::debug!("Tokenization completed in {:.2?}", tokenize_time); - tracing::debug!("Input tokens: {}", tokens.len()); - - // Print tokenized prompt - for &t in tokens.iter() { - if let Some(t) = self.tokenizer.next_token(t)? { - print!("{t}") - } - } - std::io::stdout().flush()?; - - let mut generated_tokens = 0usize; - let eos_token = match self.tokenizer.get_token("") { - Some(token) => token, - None => anyhow::bail!("cannot find the token"), - }; - - let eot_token = match self.tokenizer.get_token("") { - Some(token) => token, - None => { - println!( - "Warning: token not found in tokenizer, using as a backup" - ); - eos_token - } - }; - - // Determine if we're using a Model2 (gemma-2) or Model3 (gemma-3) variant - // Both need special handling for shape compatibility - let needs_special_handling = match &self.model { - Model::V2(_) => true, - Model::V3(_) => true, - _ => false, - }; - - // Phase 2: Text generation - let start_gen = std::time::Instant::now(); - - // Track per-token generation timing for performance analysis - let mut token_times = Vec::new(); - let mut forward_times = Vec::new(); - let mut repeat_penalty_times = Vec::new(); - let mut sampling_times = Vec::new(); - - // For Model2 and Model3, we need to use a special approach for shape compatibility - if needs_special_handling { - // For gemma-2 and gemma-3 models, we'll generate one token at a time with the full context - tracing::debug!("Using special generation approach for gemma-2/gemma-3 models"); - - // Initial generation with the full prompt - let forward_start = std::time::Instant::now(); - let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; - - // Use execute_with_fallback which handles both device compatibility and shape mismatches - let mut logits = self.execute_with_fallback(&input, 0)?; - - logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; - let forward_time = forward_start.elapsed(); - forward_times.push(forward_time); - - for _ in 0..sample_len { - let token_start = std::time::Instant::now(); - - // Apply repeat penalty using optimized cached implementation - let (current_logits, repeat_time) = - self.apply_cached_repeat_penalty(logits.clone(), &tokens)?; - repeat_penalty_times.push(repeat_time); - - // Track token sampling - let sampling_start = std::time::Instant::now(); - let next_token = self.logits_processor.sample(¤t_logits)?; - let sampling_time = sampling_start.elapsed(); - sampling_times.push(sampling_time); - - tokens.push(next_token); - generated_tokens += 1; - - if next_token == eos_token || next_token == eot_token { - break; - } - - if let Some(t) = self.tokenizer.next_token(next_token)? { - print!("{t}"); - std::io::stdout().flush()?; - } - - // For the next iteration, just use the new token - let forward_start = std::time::Instant::now(); - let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?; - - // Use execute_with_fallback for both Gemma 3 and other models - logits = self.execute_with_fallback(&new_input, tokens.len() - 1)?; - - logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; - let forward_time = forward_start.elapsed(); - forward_times.push(forward_time); - - let token_time = token_start.elapsed(); - token_times.push(token_time); - } - } else { - // Standard approach for other models - tracing::debug!("Using standard generation approach"); - - for index in 0..sample_len { - let token_start = std::time::Instant::now(); - - let context_size = if index > 0 { 1 } else { tokens.len() }; - let start_pos = tokens.len().saturating_sub(context_size); - let ctxt = &tokens[start_pos..]; - - // Track tensor operations and model forward pass - let forward_start = std::time::Instant::now(); - let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; - let logits = self.execute_with_fallback(&input, start_pos)?; - let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; - let forward_time = forward_start.elapsed(); - forward_times.push(forward_time); - - // Apply repeat penalty using optimized cached implementation - let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?; - repeat_penalty_times.push(repeat_time); - - // Track token sampling - let sampling_start = std::time::Instant::now(); - let next_token = self.logits_processor.sample(&logits)?; - let sampling_time = sampling_start.elapsed(); - sampling_times.push(sampling_time); - - tokens.push(next_token); - generated_tokens += 1; - if next_token == eos_token || next_token == eot_token { - break; - } - if let Some(t) = self.tokenizer.next_token(next_token)? { - print!("{t}"); - std::io::stdout().flush()?; - } - - let token_time = token_start.elapsed(); - token_times.push(token_time); - } - } - - let dt = start_gen.elapsed(); - - // Phase 3: Final decoding and output - let decode_start = std::time::Instant::now(); - if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { - print!("{rest}"); - } - let decode_time = decode_start.elapsed(); - - std::io::stdout().flush()?; - - // Calculate generation speed - let tokens_per_second = generated_tokens as f64 / dt.as_secs_f64(); - - // Calculate average time per token and component breakdown - let avg_token_time = if !token_times.is_empty() { - token_times.iter().sum::() / token_times.len() as u32 - } else { - std::time::Duration::from_secs(0) - }; - - let avg_forward_time = if !forward_times.is_empty() { - forward_times.iter().sum::() / forward_times.len() as u32 - } else { - std::time::Duration::from_secs(0) - }; - - let avg_repeat_time = if !repeat_penalty_times.is_empty() { - repeat_penalty_times.iter().sum::() - / repeat_penalty_times.len() as u32 - } else { - std::time::Duration::from_secs(0) - }; - - let avg_sampling_time = if !sampling_times.is_empty() { - sampling_times.iter().sum::() / sampling_times.len() as u32 - } else { - std::time::Duration::from_secs(0) - }; - - // Log performance metrics - println!( - "\n{generated_tokens} tokens generated ({:.2} token/s)", - tokens_per_second, - ); - - // Record detailed performance metrics - tracing::info!("Text generation completed in {:.2?}", dt); - tracing::info!("Tokens generated: {}", generated_tokens); - tracing::info!("Generation speed: {:.2} tokens/second", tokens_per_second); - tracing::info!("Average time per token: {:.2?}", avg_token_time); - tracing::debug!( - " - Forward pass: {:.2?} ({:.1}%)", - avg_forward_time, - avg_forward_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0 - ); - tracing::debug!( - " - Repeat penalty: {:.2?} ({:.1}%)", - avg_repeat_time, - avg_repeat_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0 - ); - tracing::debug!( - " - Sampling: {:.2?} ({:.1}%)", - avg_sampling_time, - avg_sampling_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0 - ); - - // Log total request time - let total_time = start_time.elapsed(); - tracing::info!("Total request time: {:.2?}", total_time); - tracing::debug!( - " - Tokenization: {:.2?} ({:.1}%)", - tokenize_time, - tokenize_time.as_secs_f64() / total_time.as_secs_f64() * 100.0 - ); - tracing::debug!( - " - Generation: {:.2?} ({:.1}%)", - dt, - dt.as_secs_f64() / total_time.as_secs_f64() * 100.0 - ); - tracing::debug!( - " - Final decoding: {:.2?} ({:.1}%)", - decode_time, - decode_time.as_secs_f64() / total_time.as_secs_f64() * 100.0 - ); - - Ok(()) - } - - // Run text generation and write to a buffer - pub fn run_with_output( - &mut self, - prompt: &str, - sample_len: usize, - output: &mut Vec, - ) -> Result<()> { - use std::io::Write; - - // Track overall performance - let start_time = std::time::Instant::now(); - - // Keep penalty cache across generation for better repetition prevention - // Only clear cache if it becomes too large to prevent memory bloat - if self.penalty_cache.len() > 10000 { - self.penalty_cache.clear(); - tracing::debug!("Cleared penalty cache due to size limit (API mode)"); - } else { - tracing::debug!("Maintaining penalty cache across generation for better repetition prevention (API mode)"); - } - - // Phase 1: Tokenize input - let tokenize_start = std::time::Instant::now(); - self.tokenizer.clear(); - let mut tokens = self - .tokenizer - .tokenizer() - .encode(prompt, true) - .map_err(E::msg)? - .get_ids() - .to_vec(); - - let tokenize_time = tokenize_start.elapsed(); - tracing::debug!("API Tokenization completed in {:.2?}", tokenize_time); - tracing::debug!("API Input tokens: {}", tokens.len()); - - // Write prompt tokens to output - for &t in tokens.iter() { - if let Some(t) = self.tokenizer.next_token(t)? { - write!(output, "{}", t)?; - } - } - - let mut generated_tokens = 0usize; - let eos_token = match self.tokenizer.get_token("") { - Some(token) => token, - None => anyhow::bail!("cannot find the token"), - }; - - let eot_token = match self.tokenizer.get_token("") { - Some(token) => token, - None => { - write!( - output, - "Warning: token not found in tokenizer, using as a backup" - )?; - eos_token - } - }; - - // Determine if we're using a Model2 (gemma-2) or Model3 (gemma-3) variant - // Both need special handling for shape compatibility - let needs_special_handling = match &self.model { - Model::V2(_) => true, - Model::V3(_) => true, - _ => false, - }; - - // Check if we're specifically using a Model3 (gemma-3) for additional error handling - // let is_model_v3 = matches!(&self.model, Model::V3(_)); - - // Track generation timing - let start_gen = std::time::Instant::now(); - - // Track per-token generation timing for performance analysis - let mut token_times = Vec::new(); - let mut forward_times = Vec::new(); - let mut repeat_penalty_times = Vec::new(); - let mut sampling_times = Vec::new(); - - // For Model2 and Model3, we need to use a special approach for shape compatibility - if needs_special_handling { - // For gemma-2 and gemma-3 models, we'll generate one token at a time with the full context - tracing::debug!("Using special generation approach for gemma-2/gemma-3 models"); - - // Initial generation with the full prompt - let forward_start = std::time::Instant::now(); - let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; - - // Use execute_with_fallback which handles both device compatibility and shape mismatches - let mut logits = self.execute_with_fallback(&input, 0)?; - - logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; - let forward_time = forward_start.elapsed(); - forward_times.push(forward_time); - - for _ in 0..sample_len { - let token_start = std::time::Instant::now(); - - // Apply repeat penalty using optimized cached implementation - let (current_logits, repeat_time) = - self.apply_cached_repeat_penalty(logits.clone(), &tokens)?; - repeat_penalty_times.push(repeat_time); - - // Track token sampling - let sampling_start = std::time::Instant::now(); - let next_token = self.logits_processor.sample(¤t_logits)?; - let sampling_time = sampling_start.elapsed(); - sampling_times.push(sampling_time); - - tokens.push(next_token); - generated_tokens += 1; - - if next_token == eos_token || next_token == eot_token { - break; - } - - if let Some(t) = self.tokenizer.next_token(next_token)? { - write!(output, "{}", t)?; - } - - // For the next iteration, just use the new token - let forward_start = std::time::Instant::now(); - let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?; - - // Use execute_with_fallback for both Gemma 3 and other models - logits = self.execute_with_fallback(&new_input, tokens.len() - 1)?; - - logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; - let forward_time = forward_start.elapsed(); - forward_times.push(forward_time); - - let token_time = token_start.elapsed(); - token_times.push(token_time); - } - - let dt = start_gen.elapsed(); - - // Calculate and log performance metrics - Self::log_performance_metrics( - dt, - generated_tokens, - &token_times, - &forward_times, - &repeat_penalty_times, - &sampling_times, - tokenize_time, - std::time::Duration::from_secs(0), - start_time, - "API", - ); - - return Ok(()); - } - - // Standard approach for other models - tracing::debug!("Using standard generation approach"); - - for index in 0..sample_len { - let token_start = std::time::Instant::now(); - - // Use sliding window context instead of single token to preserve context and reduce repetition - let context_size = if index > 0 { - std::cmp::min(self.context_window_size, tokens.len()) - } else { - tokens.len() - }; - let start_pos = tokens.len().saturating_sub(context_size); - let ctxt = &tokens[start_pos..]; - - tracing::debug!( - "API standard model: Using sliding window context: {} tokens (from position {})", - ctxt.len(), - start_pos - ); - - // Track tensor operations and model forward pass - let forward_start = std::time::Instant::now(); - let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; - let logits = self.execute_with_fallback(&input, start_pos)?; - let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; - let forward_time = forward_start.elapsed(); - forward_times.push(forward_time); - - // Apply repeat penalty using optimized cached implementation - let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?; - repeat_penalty_times.push(repeat_time); - - // Track token sampling - let sampling_start = std::time::Instant::now(); - let next_token = self.logits_processor.sample(&logits)?; - let sampling_time = sampling_start.elapsed(); - sampling_times.push(sampling_time); - - tokens.push(next_token); - generated_tokens += 1; - if next_token == eos_token || next_token == eot_token { - break; - } - if let Some(t) = self.tokenizer.next_token(next_token)? { - write!(output, "{}", t)?; - } - - let token_time = token_start.elapsed(); - token_times.push(token_time); - } - - let dt = start_gen.elapsed(); - - // Phase 3: Final decoding and output - let decode_start = std::time::Instant::now(); - - // Write any remaining tokens - if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { - write!(output, "{}", rest)?; - } - - let decode_time = decode_start.elapsed(); - - // Log performance metrics - Self::log_performance_metrics( - dt, - generated_tokens, - &token_times, - &forward_times, - &repeat_penalty_times, - &sampling_times, - tokenize_time, - decode_time, - start_time, - "API", - ); - - Ok(()) - } - - // Run text generation with streaming callback for each token - pub async fn run_with_streaming( - &mut self, - prompt: &str, - sample_len: usize, - mut token_callback: F, - ) -> Result - where - F: FnMut(&str) -> Result<()>, - { - // Track overall performance - let start_time = std::time::Instant::now(); - - // Keep penalty cache across generation for better repetition prevention - // Only clear cache if it becomes too large to prevent memory bloat - if self.penalty_cache.len() > 10000 { - self.penalty_cache.clear(); - tracing::debug!("Cleared penalty cache due to size limit (streaming mode)"); - } else { - tracing::debug!("Maintaining penalty cache across generation for better repetition prevention (streaming mode)"); - } - - // Phase 1: Tokenize input - let tokenize_start = std::time::Instant::now(); - self.tokenizer.clear(); - let mut tokens = self - .tokenizer - .tokenizer() - .encode(prompt, true) - .map_err(E::msg)? - .get_ids() - .to_vec(); - - let tokenize_time = tokenize_start.elapsed(); - tracing::debug!("Streaming Tokenization completed in {:.2?}", tokenize_time); - tracing::debug!("Streaming Input tokens: {}", tokens.len()); - - // Collect all output for final return - let mut full_output = String::new(); - - let mut generated_tokens = 0usize; - let eos_token = match self.tokenizer.get_token("") { - Some(token) => token, - None => anyhow::bail!("cannot find the token"), - }; - - let eot_token = match self.tokenizer.get_token("") { - Some(token) => token, - None => { - tracing::warn!( - "Warning: token not found in tokenizer, using as a backup" - ); - eos_token - } - }; - - // Determine if we're using a Model2 (gemma-2) or Model3 (gemma-3) variant - let needs_special_handling = match &self.model { - Model::V2(_) => true, - Model::V3(_) => true, - _ => false, - }; - - // Track generation timing - let start_gen = std::time::Instant::now(); - - // Track per-token generation timing for performance analysis - let mut token_times = Vec::new(); - let mut forward_times = Vec::new(); - let mut repeat_penalty_times = Vec::new(); - let mut sampling_times = Vec::new(); - - // For Model2 and Model3, we need to use a special approach for shape compatibility - if needs_special_handling { - tracing::debug!( - "Using special generation approach for gemma-2/gemma-3 models (streaming)" - ); - tracing::debug!("Streaming: sample_len = {}", sample_len); - - // Initial generation with the full prompt - let forward_start = std::time::Instant::now(); - let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; - - let mut logits = self.execute_with_fallback(&input, 0)?; - - logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; - let forward_time = forward_start.elapsed(); - forward_times.push(forward_time); - - tracing::debug!( - "Streaming: About to enter generation loop with sample_len = {}", - sample_len - ); - for gen_index in 0..sample_len { - tracing::debug!( - "Streaming: Starting generation iteration {} / {}", - gen_index + 1, - sample_len - ); - let token_start = std::time::Instant::now(); - - // Apply repeat penalty using optimized cached implementation - let (current_logits, repeat_time) = - self.apply_cached_repeat_penalty(logits.clone(), &tokens)?; - repeat_penalty_times.push(repeat_time); - - // Track token sampling - let sampling_start = std::time::Instant::now(); - let next_token = self.logits_processor.sample(¤t_logits)?; - let sampling_time = sampling_start.elapsed(); - sampling_times.push(sampling_time); - - tokens.push(next_token); - generated_tokens += 1; - - tracing::debug!( - "Streaming: Generated token {} (id: {}), eos: {}, eot: {}", - next_token, - next_token, - eos_token, - eot_token - ); - if next_token == eos_token || next_token == eot_token { - tracing::debug!("Streaming: Breaking due to end token"); - break; - } - - if let Some(token_text) = self.tokenizer.next_token(next_token)? { - full_output.push_str(&token_text); - // Call the streaming callback with this token - token_callback(&token_text)?; - } - - // For the next iteration, use single token to avoid shape mismatch - let forward_start = std::time::Instant::now(); - tracing::debug!( - "Streaming: Preparing next forward pass with {} tokens", - tokens.len() - ); - - // Use just the last token for subsequent iterations to avoid shape mismatch - // This is required for Gemma model's attention mechanism compatibility - let context_tokens = &tokens[(tokens.len() - 1)..]; - let start_pos = tokens.len() - 1; - - tracing::debug!( - "Streaming: Using single token context for Gemma: {} tokens (from position {})", - context_tokens.len(), - start_pos - ); - - let new_input = match Tensor::new(context_tokens, &self.device) { - Ok(tensor) => tensor, - Err(e) => { - tracing::error!("Streaming: Failed to create input tensor: {}", e); - return Err(e.into()); - } - }; - - let new_input = match new_input.unsqueeze(0) { - Ok(tensor) => tensor, - Err(e) => { - tracing::error!("Streaming: Failed to unsqueeze input tensor: {}", e); - return Err(e.into()); - } - }; - - tracing::debug!("Streaming: About to call execute_with_fallback for iteration {} with start_pos {}", gen_index + 1, start_pos); - logits = match self.execute_with_fallback(&new_input, start_pos) { - Ok(result) => result, - Err(e) => { - tracing::error!("Streaming: Forward pass failed: {}", e); - return Err(e); - } - }; - - logits = match logits.squeeze(0) { - Ok(result) => result, - Err(e) => { - tracing::error!("Streaming: Failed to squeeze logits (dim 0): {}", e); - return Err(e.into()); - } - }; - - logits = match logits.squeeze(0) { - Ok(result) => result, - Err(e) => { - tracing::error!("Streaming: Failed to squeeze logits (dim 0 again): {}", e); - return Err(e.into()); - } - }; - - logits = match logits.to_dtype(DType::F32) { - Ok(result) => result, - Err(e) => { - tracing::error!("Streaming: Failed to convert logits to F32: {}", e); - return Err(e.into()); - } - }; - - let forward_time = forward_start.elapsed(); - forward_times.push(forward_time); - tracing::debug!( - "Streaming: Forward pass completed for iteration {}", - gen_index + 1 - ); - - let token_time = token_start.elapsed(); - token_times.push(token_time); - - // Yield to allow other async tasks to run - tokio::task::yield_now().await; - } - } else { - // Standard approach for other models - tracing::debug!("Using standard generation approach (streaming)"); - - for index in 0..sample_len { - let token_start = std::time::Instant::now(); - - // Use sliding window context instead of single token to preserve context and reduce repetition - let context_size = if index > 0 { - std::cmp::min(self.context_window_size, tokens.len()) - } else { - tokens.len() - }; - let start_pos = tokens.len().saturating_sub(context_size); - let ctxt = &tokens[start_pos..]; - - tracing::debug!( - "Standard model: Using sliding window context: {} tokens (from position {})", - ctxt.len(), - start_pos - ); - - // Track tensor operations and model forward pass - let forward_start = std::time::Instant::now(); - let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; - let logits = self.execute_with_fallback(&input, start_pos)?; - let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; - let forward_time = forward_start.elapsed(); - forward_times.push(forward_time); - - // Apply repeat penalty using optimized cached implementation - let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?; - repeat_penalty_times.push(repeat_time); - - // Track token sampling - let sampling_start = std::time::Instant::now(); - let next_token = self.logits_processor.sample(&logits)?; - let sampling_time = sampling_start.elapsed(); - sampling_times.push(sampling_time); - - tokens.push(next_token); - generated_tokens += 1; - if next_token == eos_token || next_token == eot_token { - break; - } - if let Some(token_text) = self.tokenizer.next_token(next_token)? { - full_output.push_str(&token_text); - // Call the streaming callback with this token - token_callback(&token_text)?; - } - - let token_time = token_start.elapsed(); - token_times.push(token_time); - } - } - - let dt = start_gen.elapsed(); - - // Phase 3: Final decoding - let decode_start = std::time::Instant::now(); - - // Decode any remaining tokens but don't send through callback to avoid repetition - // The tokens were already streamed individually in the generation loop above - if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { - full_output.push_str(&rest); - // Note: NOT calling token_callback(&rest) here to prevent token repetition - // Individual tokens were already streamed via the callback in the generation loop - } - - let decode_time = decode_start.elapsed(); - - // Log performance metrics - Self::log_performance_metrics( - dt, - generated_tokens, - &token_times, - &forward_times, - &repeat_penalty_times, - &sampling_times, - tokenize_time, - decode_time, - start_time, - "Streaming", - ); - - Ok(full_output) - } - - // Helper function for logging performance metrics - fn log_performance_metrics( - generation_time: std::time::Duration, - generated_tokens: usize, - token_times: &[std::time::Duration], - forward_times: &[std::time::Duration], - repeat_penalty_times: &[std::time::Duration], - sampling_times: &[std::time::Duration], - tokenize_time: std::time::Duration, - decode_time: std::time::Duration, - start_time: std::time::Instant, - prefix: &str, - ) { - // Calculate generation speed - let tokens_per_second = if generation_time.as_secs_f64() > 0.0 { - generated_tokens as f64 / generation_time.as_secs_f64() - } else { - 0.0 - }; - - // Calculate average time per token and component breakdown - let avg_token_time = if !token_times.is_empty() { - token_times.iter().sum::() / token_times.len() as u32 - } else { - std::time::Duration::from_secs(0) - }; - - let avg_forward_time = if !forward_times.is_empty() { - forward_times.iter().sum::() / forward_times.len() as u32 - } else { - std::time::Duration::from_secs(0) - }; - - let avg_repeat_time = if !repeat_penalty_times.is_empty() { - repeat_penalty_times.iter().sum::() - / repeat_penalty_times.len() as u32 - } else { - std::time::Duration::from_secs(0) - }; - - let avg_sampling_time = if !sampling_times.is_empty() { - sampling_times.iter().sum::() / sampling_times.len() as u32 - } else { - std::time::Duration::from_secs(0) - }; - - // Record detailed performance metrics - tracing::info!( - "{} Text generation completed in {:.2?}", - prefix, - generation_time - ); - tracing::info!("{} Tokens generated: {}", prefix, generated_tokens); - tracing::info!( - "{} Generation speed: {:.2} tokens/second", - prefix, - tokens_per_second - ); - tracing::info!("{} Average time per token: {:.2?}", prefix, avg_token_time); - - if !avg_token_time.is_zero() { - tracing::debug!( - "{} - Forward pass: {:.2?} ({:.1}%)", - prefix, - avg_forward_time, - avg_forward_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0 - ); - tracing::debug!( - "{} - Repeat penalty: {:.2?} ({:.1}%)", - prefix, - avg_repeat_time, - avg_repeat_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0 - ); - tracing::debug!( - "{} - Sampling: {:.2?} ({:.1}%)", - prefix, - avg_sampling_time, - avg_sampling_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0 - ); - } - - // Log total request time - let total_time = start_time.elapsed(); - tracing::info!("{} Total request time: {:.2?}", prefix, total_time); - - if !total_time.is_zero() { - tracing::debug!( - "{} - Tokenization: {:.2?} ({:.1}%)", - prefix, - tokenize_time, - tokenize_time.as_secs_f64() / total_time.as_secs_f64() * 100.0 - ); - tracing::debug!( - "{} - Generation: {:.2?} ({:.1}%)", - prefix, - generation_time, - generation_time.as_secs_f64() / total_time.as_secs_f64() * 100.0 - ); - tracing::debug!( - "{} - Final decoding: {:.2?} ({:.1}%)", - prefix, - decode_time, - decode_time.as_secs_f64() / total_time.as_secs_f64() * 100.0 - ); - } - } -} diff --git a/crates/inference-engine/src/token_output_stream.rs b/crates/inference-engine/src/token_output_stream.rs deleted file mode 100644 index 2b73f0c..0000000 --- a/crates/inference-engine/src/token_output_stream.rs +++ /dev/null @@ -1,87 +0,0 @@ -use candle_core::Result; - -/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a -/// streaming way rather than having to wait for the full decoding. -pub struct TokenOutputStream { - tokenizer: tokenizers::Tokenizer, - tokens: Vec, - prev_index: usize, - current_index: usize, -} - -impl TokenOutputStream { - pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { - Self { - tokenizer, - tokens: Vec::new(), - prev_index: 0, - current_index: 0, - } - } - - pub fn into_inner(self) -> tokenizers::Tokenizer { - self.tokenizer - } - - fn decode(&self, tokens: &[u32]) -> Result { - match self.tokenizer.decode(tokens, true) { - Ok(str) => Ok(str), - Err(err) => candle_core::bail!("cannot decode: {err}"), - } - } - - // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 - pub fn next_token(&mut self, token: u32) -> Result> { - let prev_text = if self.tokens.is_empty() { - String::new() - } else { - let tokens = &self.tokens[self.prev_index..self.current_index]; - self.decode(tokens)? - }; - self.tokens.push(token); - let text = self.decode(&self.tokens[self.prev_index..])?; - if text.len() > prev_text.len() { - // Modified to include all tokens, not just alphanumeric ones - let text = text.split_at(prev_text.len()); - self.prev_index = self.current_index; - self.current_index = self.tokens.len(); - Ok(Some(text.1.to_string())) - } else { - Ok(None) - } - } - - pub fn decode_rest(&self) -> Result> { - let prev_text = if self.tokens.is_empty() { - String::new() - } else { - let tokens = &self.tokens[self.prev_index..self.current_index]; - self.decode(tokens)? - }; - let text = self.decode(&self.tokens[self.prev_index..])?; - if text.len() > prev_text.len() { - let text = text.split_at(prev_text.len()); - Ok(Some(text.1.to_string())) - } else { - Ok(None) - } - } - - pub fn decode_all(&self) -> Result { - self.decode(&self.tokens) - } - - pub fn get_token(&self, token_s: &str) -> Option { - self.tokenizer.get_vocab(true).get(token_s).copied() - } - - pub fn tokenizer(&self) -> &tokenizers::Tokenizer { - &self.tokenizer - } - - pub fn clear(&mut self) { - self.tokens.clear(); - self.prev_index = 0; - self.current_index = 0; - } -} diff --git a/crates/inference-engine/src/utilities_lib.rs b/crates/inference-engine/src/utilities_lib.rs deleted file mode 100644 index 4abf5f9..0000000 --- a/crates/inference-engine/src/utilities_lib.rs +++ /dev/null @@ -1,168 +0,0 @@ -use candle_core::utils::{cuda_is_available, metal_is_available}; -use candle_core::{Device, Result, Tensor}; - -pub fn device(cpu: bool) -> Result { - if cpu { - Ok(Device::Cpu) - } else if cuda_is_available() { - Ok(Device::new_cuda(0)?) - } else if metal_is_available() { - Ok(Device::new_metal(0)?) - } else { - #[cfg(all(target_os = "macos", target_arch = "aarch64"))] - { - println!( - "Running on CPU, to run on GPU(metal), build this example with `--features metal`" - ); - } - #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] - { - println!("Running on CPU, to run on GPU, build this example with `--features cuda`"); - } - Ok(Device::Cpu) - } -} - -pub fn load_image>( - p: P, - resize_longest: Option, -) -> Result<(Tensor, usize, usize)> { - let img = image::ImageReader::open(p)? - .decode() - .map_err(candle_core::Error::wrap)?; - let (initial_h, initial_w) = (img.height() as usize, img.width() as usize); - let img = match resize_longest { - None => img, - Some(resize_longest) => { - let (height, width) = (img.height(), img.width()); - let resize_longest = resize_longest as u32; - let (height, width) = if height < width { - let h = (resize_longest * height) / width; - (h, resize_longest) - } else { - let w = (resize_longest * width) / height; - (resize_longest, w) - }; - img.resize_exact(width, height, image::imageops::FilterType::CatmullRom) - } - }; - let (height, width) = (img.height() as usize, img.width() as usize); - let img = img.to_rgb8(); - let data = img.into_raw(); - let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?; - Ok((data, initial_h, initial_w)) -} - -pub fn load_image_and_resize>( - p: P, - width: usize, - height: usize, -) -> Result { - let img = image::ImageReader::open(p)? - .decode() - .map_err(candle_core::Error::wrap)? - .resize_to_fill( - width as u32, - height as u32, - image::imageops::FilterType::Triangle, - ); - let img = img.to_rgb8(); - let data = img.into_raw(); - Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1)) -} - -/// Saves an image to disk using the image crate, this expects an input with shape -/// (c, height, width). -pub fn save_image>(img: &Tensor, p: P) -> Result<()> { - let p = p.as_ref(); - let (channel, height, width) = img.dims3()?; - if channel != 3 { - candle_core::bail!("save_image expects an input of shape (3, height, width)") - } - let img = img.permute((1, 2, 0))?.flatten_all()?; - let pixels = img.to_vec1::()?; - let image: image::ImageBuffer, Vec> = - match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { - Some(image) => image, - None => candle_core::bail!("error saving image {p:?}"), - }; - image.save(p).map_err(candle_core::Error::wrap)?; - Ok(()) -} - -pub fn save_image_resize>( - img: &Tensor, - p: P, - h: usize, - w: usize, -) -> Result<()> { - let p = p.as_ref(); - let (channel, height, width) = img.dims3()?; - if channel != 3 { - candle_core::bail!("save_image expects an input of shape (3, height, width)") - } - let img = img.permute((1, 2, 0))?.flatten_all()?; - let pixels = img.to_vec1::()?; - let image: image::ImageBuffer, Vec> = - match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { - Some(image) => image, - None => candle_core::bail!("error saving image {p:?}"), - }; - let image = image::DynamicImage::from(image); - let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom); - image.save(p).map_err(candle_core::Error::wrap)?; - Ok(()) -} - -/// Loads the safetensors files for a model from the hub based on a json index file. -pub fn hub_load_safetensors( - repo: &hf_hub::api::sync::ApiRepo, - json_file: &str, -) -> Result> { - let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?; - let json_file = std::fs::File::open(json_file)?; - let json: serde_json::Value = - serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?; - let weight_map = match json.get("weight_map") { - None => candle_core::bail!("no weight map in {json_file:?}"), - Some(serde_json::Value::Object(map)) => map, - Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"), - }; - let mut safetensors_files = std::collections::HashSet::new(); - for value in weight_map.values() { - if let Some(file) = value.as_str() { - safetensors_files.insert(file.to_string()); - } - } - let safetensors_files = safetensors_files - .iter() - .map(|v| repo.get(v).map_err(candle_core::Error::wrap)) - .collect::>>()?; - Ok(safetensors_files) -} - -pub fn hub_load_local_safetensors>( - path: P, - json_file: &str, -) -> Result> { - let path = path.as_ref(); - let jsfile = std::fs::File::open(path.join(json_file))?; - let json: serde_json::Value = - serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?; - let weight_map = match json.get("weight_map") { - None => candle_core::bail!("no weight map in {json_file:?}"), - Some(serde_json::Value::Object(map)) => map, - Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"), - }; - let mut safetensors_files = std::collections::HashSet::new(); - for value in weight_map.values() { - if let Some(file) = value.as_str() { - safetensors_files.insert(file); - } - } - let safetensors_files: Vec<_> = safetensors_files - .into_iter() - .map(|v| path.join(v)) - .collect(); - Ok(safetensors_files) -} diff --git a/crates/inference-engine/tests/text_generation_tests.rs b/crates/inference-engine/tests/text_generation_tests.rs deleted file mode 100644 index 6c836ac..0000000 --- a/crates/inference-engine/tests/text_generation_tests.rs +++ /dev/null @@ -1,554 +0,0 @@ -use anyhow::Result; -use candle_core::{Device, Tensor}; -use candle_transformers::generation::LogitsProcessor; -use inference_engine::model::Which; -use inference_engine::text_generation::TextGeneration; -use inference_engine::token_output_stream::TokenOutputStream; -use std::collections::HashMap; -use tokenizers::Tokenizer; - -#[cfg(test)] -mod tests { - use super::*; - - // Helper function to create a simple tokenizer for testing - fn create_test_tokenizer() -> Result { - // Create a simple tokenizer from the pretrained model - // This uses the tokenizer from the Hugging Face hub - let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap(); - Ok(tokenizer) - } - - // Test the Which enum's to_model_id method - #[test] - fn test_which_model_id() { - assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b"); - assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it"); - } - - // Test the Which enum's is_instruct_model method - #[test] - fn test_which_is_instruct() { - assert!(!Which::Base2B.is_instruct_model()); - assert!(Which::Instruct7B.is_instruct_model()); - } - - // Test the Which enum's is_v3_model method - #[test] - fn test_which_is_v3() { - assert!(!Which::Base2B.is_v3_model()); - assert!(Which::BaseV3_1B.is_v3_model()); - } - - // Test the TokenOutputStream functionality - #[test] - fn test_token_output_stream() -> Result<()> { - let tokenizer = create_test_tokenizer()?; - let mut token_stream = TokenOutputStream::new(tokenizer); - - // Test encoding and decoding - let text = "Hello, world!"; - let encoded = token_stream.tokenizer().encode(text, true).unwrap(); - let token_ids = encoded.get_ids(); - - // Add tokens one by one - for &token_id in token_ids { - token_stream.next_token(token_id)?; - } - - // Decode all and check - let decoded = token_stream.decode_all()?; - assert_eq!(decoded.trim(), text); - - Ok(()) - } - - // Test the LogitsProcessor - #[test] - fn test_logits_processor() -> Result<()> { - // Create a LogitsProcessor with default settings - let seed = 42; - let temp = Some(0.8); - let top_p = Some(0.9); - let logits_processor = LogitsProcessor::new(seed, temp, top_p); - - // Create a simple logits tensor - // In a real test, we would create a tensor with known values and verify - // that sampling produces expected results - - // For now, we'll just verify that the LogitsProcessor can be created - assert!(true); - Ok(()) - } - - // Test the TextGeneration constructor - #[test] - fn test_text_generation_constructor() -> Result<()> { - // We can't easily create a Model instance for testing, - // but we can test that the constructor compiles and the types are correct - - // In a real test with a mock Model, we would: - // 1. Create a mock model - // 2. Create a tokenizer - // 3. Call TextGeneration::new - // 4. Verify the properties of the created instance - - // For now, we'll just verify that the code compiles - assert!(true); - Ok(()) - } - - // Test apply_cached_repeat_penalty method with no penalty - #[test] - fn test_apply_cached_repeat_penalty_no_penalty() -> Result<()> { - // Create a simple test setup - let device = Device::Cpu; - let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; - let logits = Tensor::new(&logits_data[..], &device)?; - let tokens = vec![1u32, 2u32, 3u32]; - - // Create a mock TextGeneration instance - // Since we can't easily create a full TextGeneration instance without a model, - // we'll test the logic by creating a simple struct with the necessary fields - struct MockTextGeneration { - repeat_penalty: f32, - repeat_last_n: usize, - penalty_cache: HashMap, - } - - impl MockTextGeneration { - fn apply_cached_repeat_penalty( - &mut self, - logits: Tensor, - tokens: &[u32], - ) -> Result<(Tensor, std::time::Duration)> { - let repeat_start = std::time::Instant::now(); - - // If no penalty, return the original logits - if self.repeat_penalty == 1.0 { - return Ok((logits, repeat_start.elapsed())); - } - - // Get the tokens to penalize (the last n tokens) - let start_at = tokens.len().saturating_sub(self.repeat_last_n); - let penalty_tokens = &tokens[start_at..]; - - // Extract logits to a vector for modification - let mut logits_vec = logits.to_vec1::()?; - let cache_hits = std::cell::Cell::new(0); - - // Apply penalties with caching - for &token_id in penalty_tokens { - let token_id = token_id as usize; - if token_id < logits_vec.len() { - // Check if we've already calculated this token's penalty - if let Some(penalized_score) = self.penalty_cache.get(&token_id) { - // Use cached value - logits_vec[token_id] = *penalized_score; - cache_hits.set(cache_hits.get() + 1); - } else { - // Calculate and cache new value - let score = logits_vec[token_id]; - let sign = if score < 0.0 { -1.0 } else { 1.0 }; - let penalized_score = sign * score / self.repeat_penalty; - logits_vec[token_id] = penalized_score; - self.penalty_cache.insert(token_id, penalized_score); - } - } - } - - // Create a new tensor with the modified logits - let device = logits.device().clone(); - let shape = logits.shape().clone(); - let new_logits = Tensor::new(&logits_vec[..], &device)?; - let result = new_logits.reshape(shape)?; - - let elapsed = repeat_start.elapsed(); - Ok((result, elapsed)) - } - } - - let mut mock_gen = MockTextGeneration { - repeat_penalty: 1.0, // No penalty - repeat_last_n: 3, - penalty_cache: HashMap::new(), - }; - - let (result_logits, _duration) = - mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?; - let result_data = result_logits.to_vec1::()?; - - // With no penalty, logits should be unchanged - assert_eq!(result_data, logits_data); - Ok(()) - } - - // Test apply_cached_repeat_penalty method with penalty - #[test] - fn test_apply_cached_repeat_penalty_with_penalty() -> Result<()> { - let device = Device::Cpu; - let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; - let logits = Tensor::new(&logits_data[..], &device)?; - let tokens = vec![1u32, 2u32, 3u32]; - - struct MockTextGeneration { - repeat_penalty: f32, - repeat_last_n: usize, - penalty_cache: HashMap, - } - - impl MockTextGeneration { - fn apply_cached_repeat_penalty( - &mut self, - logits: Tensor, - tokens: &[u32], - ) -> Result<(Tensor, std::time::Duration)> { - let repeat_start = std::time::Instant::now(); - - if self.repeat_penalty == 1.0 { - return Ok((logits, repeat_start.elapsed())); - } - - let start_at = tokens.len().saturating_sub(self.repeat_last_n); - let penalty_tokens = &tokens[start_at..]; - let mut logits_vec = logits.to_vec1::()?; - let cache_hits = std::cell::Cell::new(0); - - for &token_id in penalty_tokens { - let token_id = token_id as usize; - if token_id < logits_vec.len() { - if let Some(penalized_score) = self.penalty_cache.get(&token_id) { - logits_vec[token_id] = *penalized_score; - cache_hits.set(cache_hits.get() + 1); - } else { - let score = logits_vec[token_id]; - let sign = if score < 0.0 { -1.0 } else { 1.0 }; - let penalized_score = sign * score / self.repeat_penalty; - logits_vec[token_id] = penalized_score; - self.penalty_cache.insert(token_id, penalized_score); - } - } - } - - let device = logits.device().clone(); - let shape = logits.shape().clone(); - let new_logits = Tensor::new(&logits_vec[..], &device)?; - let result = new_logits.reshape(shape)?; - - let elapsed = repeat_start.elapsed(); - Ok((result, elapsed)) - } - } - - let mut mock_gen = MockTextGeneration { - repeat_penalty: 2.0, // Apply penalty - repeat_last_n: 3, - penalty_cache: HashMap::new(), - }; - - let (result_logits, _duration) = - mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?; - let result_data = result_logits.to_vec1::()?; - - // Tokens 1, 2, 3 should be penalized (divided by 2.0) - let expected = vec![1.0f32, 1.0, 1.5, 2.0, 5.0]; // [1.0, 2.0/2.0, 3.0/2.0, 4.0/2.0, 5.0] - assert_eq!(result_data, expected); - Ok(()) - } - - // Test apply_cached_repeat_penalty caching behavior - #[test] - fn test_apply_cached_repeat_penalty_caching() -> Result<()> { - let device = Device::Cpu; - let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; - let logits = Tensor::new(&logits_data[..], &device)?; - let tokens = vec![1u32, 1u32, 1u32]; // Repeated token should use cache - - struct MockTextGeneration { - repeat_penalty: f32, - repeat_last_n: usize, - penalty_cache: HashMap, - } - - impl MockTextGeneration { - fn apply_cached_repeat_penalty( - &mut self, - logits: Tensor, - tokens: &[u32], - ) -> Result<(Tensor, std::time::Duration)> { - let repeat_start = std::time::Instant::now(); - - if self.repeat_penalty == 1.0 { - return Ok((logits, repeat_start.elapsed())); - } - - let start_at = tokens.len().saturating_sub(self.repeat_last_n); - let penalty_tokens = &tokens[start_at..]; - let mut logits_vec = logits.to_vec1::()?; - - for &token_id in penalty_tokens { - let token_id = token_id as usize; - if token_id < logits_vec.len() { - if let Some(penalized_score) = self.penalty_cache.get(&token_id) { - logits_vec[token_id] = *penalized_score; - } else { - let score = logits_vec[token_id]; - let sign = if score < 0.0 { -1.0 } else { 1.0 }; - let penalized_score = sign * score / self.repeat_penalty; - logits_vec[token_id] = penalized_score; - self.penalty_cache.insert(token_id, penalized_score); - } - } - } - - let device = logits.device().clone(); - let shape = logits.shape().clone(); - let new_logits = Tensor::new(&logits_vec[..], &device)?; - let result = new_logits.reshape(shape)?; - - let elapsed = repeat_start.elapsed(); - Ok((result, elapsed)) - } - } - - let mut mock_gen = MockTextGeneration { - repeat_penalty: 2.0, - repeat_last_n: 3, - penalty_cache: HashMap::new(), - }; - - // First call should cache the penalty for token 1 - let (_result_logits, _duration) = - mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?; - - // Cache should contain the penalized value for token 1 - assert!(mock_gen.penalty_cache.contains_key(&1)); - assert_eq!(mock_gen.penalty_cache.get(&1), Some(&1.0)); // 2.0 / 2.0 = 1.0 - - Ok(()) - } - - // Test edge case: empty tokens array - #[test] - fn test_apply_cached_repeat_penalty_empty_tokens() -> Result<()> { - let device = Device::Cpu; - let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; - let logits = Tensor::new(&logits_data[..], &device)?; - let tokens: Vec = vec![]; // Empty tokens - - struct MockTextGeneration { - repeat_penalty: f32, - repeat_last_n: usize, - penalty_cache: HashMap, - } - - impl MockTextGeneration { - fn apply_cached_repeat_penalty( - &mut self, - logits: Tensor, - tokens: &[u32], - ) -> Result<(Tensor, std::time::Duration)> { - let repeat_start = std::time::Instant::now(); - - if self.repeat_penalty == 1.0 { - return Ok((logits, repeat_start.elapsed())); - } - - let start_at = tokens.len().saturating_sub(self.repeat_last_n); - let penalty_tokens = &tokens[start_at..]; - let mut logits_vec = logits.to_vec1::()?; - - for &token_id in penalty_tokens { - let token_id = token_id as usize; - if token_id < logits_vec.len() { - if let Some(penalized_score) = self.penalty_cache.get(&token_id) { - logits_vec[token_id] = *penalized_score; - } else { - let score = logits_vec[token_id]; - let sign = if score < 0.0 { -1.0 } else { 1.0 }; - let penalized_score = sign * score / self.repeat_penalty; - logits_vec[token_id] = penalized_score; - self.penalty_cache.insert(token_id, penalized_score); - } - } - } - - let device = logits.device().clone(); - let shape = logits.shape().clone(); - let new_logits = Tensor::new(&logits_vec[..], &device)?; - let result = new_logits.reshape(shape)?; - - let elapsed = repeat_start.elapsed(); - Ok((result, elapsed)) - } - } - - let mut mock_gen = MockTextGeneration { - repeat_penalty: 2.0, - repeat_last_n: 3, - penalty_cache: HashMap::new(), - }; - - let (result_logits, _duration) = - mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?; - let result_data = result_logits.to_vec1::()?; - - // With empty tokens, logits should be unchanged - assert_eq!(result_data, logits_data); - Ok(()) - } - - // Test edge case: out-of-bounds token IDs - #[test] - fn test_apply_cached_repeat_penalty_out_of_bounds() -> Result<()> { - let device = Device::Cpu; - let logits_data = vec![1.0f32, 2.0, 3.0]; - let logits = Tensor::new(&logits_data[..], &device)?; - let tokens = vec![1u32, 5u32, 10u32]; // Token 5 and 10 are out of bounds - - struct MockTextGeneration { - repeat_penalty: f32, - repeat_last_n: usize, - penalty_cache: HashMap, - } - - impl MockTextGeneration { - fn apply_cached_repeat_penalty( - &mut self, - logits: Tensor, - tokens: &[u32], - ) -> Result<(Tensor, std::time::Duration)> { - let repeat_start = std::time::Instant::now(); - - if self.repeat_penalty == 1.0 { - return Ok((logits, repeat_start.elapsed())); - } - - let start_at = tokens.len().saturating_sub(self.repeat_last_n); - let penalty_tokens = &tokens[start_at..]; - let mut logits_vec = logits.to_vec1::()?; - - for &token_id in penalty_tokens { - let token_id = token_id as usize; - if token_id < logits_vec.len() { - if let Some(penalized_score) = self.penalty_cache.get(&token_id) { - logits_vec[token_id] = *penalized_score; - } else { - let score = logits_vec[token_id]; - let sign = if score < 0.0 { -1.0 } else { 1.0 }; - let penalized_score = sign * score / self.repeat_penalty; - logits_vec[token_id] = penalized_score; - self.penalty_cache.insert(token_id, penalized_score); - } - } - } - - let device = logits.device().clone(); - let shape = logits.shape().clone(); - let new_logits = Tensor::new(&logits_vec[..], &device)?; - let result = new_logits.reshape(shape)?; - - let elapsed = repeat_start.elapsed(); - Ok((result, elapsed)) - } - } - - let mut mock_gen = MockTextGeneration { - repeat_penalty: 2.0, - repeat_last_n: 3, - penalty_cache: HashMap::new(), - }; - - let (result_logits, _duration) = - mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?; - let result_data = result_logits.to_vec1::()?; - - // Only token 1 should be penalized, out-of-bounds tokens should be ignored - let expected = vec![1.0f32, 1.0, 3.0]; // [1.0, 2.0/2.0, 3.0] - assert_eq!(result_data, expected); - Ok(()) - } - - // Test the actual apply_cached_repeat_penalty method from TextGeneration - // This test creates a TextGeneration instance with minimal dependencies to test the real method - #[test] - fn test_actual_apply_cached_repeat_penalty_implementation() -> Result<()> { - // Since creating a real TextGeneration instance requires a Model which needs model weights, - // we'll create a test that demonstrates the method is now public and can be accessed. - // The comprehensive functionality testing is already covered by the mock tests above. - - // Test data setup - let device = Device::Cpu; - let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; - let logits = Tensor::new(&logits_data[..], &device)?; - let tokens = vec![1u32, 2u32, 3u32]; - - // Test that we can create the necessary components - let tokenizer = create_test_tokenizer()?; - - // The method is now public as confirmed by making it pub fn apply_cached_repeat_penalty - // This test verifies the method signature and that it's accessible from external code - - // We could create a TextGeneration instance if we had a way to mock the Model, - // but for now we confirm that the existing mock tests cover the functionality - // and the method is properly exposed as public - - println!("apply_cached_repeat_penalty method is now public and accessible for testing"); - assert!(true); - Ok(()) - } - - // Integration test that demonstrates the method usage pattern - #[test] - fn test_apply_cached_repeat_penalty_usage_pattern() -> Result<()> { - // This test demonstrates how the apply_cached_repeat_penalty method would be used - // in practice, even though we can't create a full TextGeneration instance in unit tests - - let device = Device::Cpu; - let logits_data = vec![1.5f32, 2.5, 3.5, 4.5, 5.5]; - let logits = Tensor::new(&logits_data[..], &device)?; - let tokens = vec![1u32, 2u32, 1u32, 3u32]; // Repeated token 1 to test caching - - // Test parameters that would be used with TextGeneration - let repeat_penalty = 1.2f32; - let repeat_last_n = 3usize; - let mut penalty_cache: HashMap = HashMap::new(); - - // Simulate the method's logic to verify it works as expected - let start_time = std::time::Instant::now(); - - if repeat_penalty != 1.0 { - let start_at = tokens.len().saturating_sub(repeat_last_n); - let penalty_tokens = &tokens[start_at..]; - let mut logits_vec = logits.to_vec1::()?; - - for &token_id in penalty_tokens { - let token_id = token_id as usize; - if token_id < logits_vec.len() { - if let Some(_cached_score) = penalty_cache.get(&token_id) { - // Cache hit simulation - } else { - let score = logits_vec[token_id]; - let sign = if score < 0.0 { -1.0 } else { 1.0 }; - let penalized_score = sign * score / repeat_penalty; - penalty_cache.insert(token_id, penalized_score); - } - } - } - } - - let _duration = start_time.elapsed(); - - // Verify that tokens were processed correctly - assert!(penalty_cache.contains_key(&1)); // Token 1 should be cached - assert!(penalty_cache.contains_key(&2)); // Token 2 should be cached - assert!(penalty_cache.contains_key(&3)); // Token 3 should be cached - - println!("Successfully demonstrated apply_cached_repeat_penalty usage pattern"); - Ok(()) - } - - // Note: Testing the actual text generation functionality would require - // integration tests with real models, which is beyond the scope of these unit tests. - // The tests above focus on the components that can be tested in isolation. -} diff --git a/crates/inference-engine/tests/token_output_stream_tests.rs b/crates/inference-engine/tests/token_output_stream_tests.rs deleted file mode 100644 index 1345fd4..0000000 --- a/crates/inference-engine/tests/token_output_stream_tests.rs +++ /dev/null @@ -1,135 +0,0 @@ -use anyhow::Result; -use inference_engine::token_output_stream::TokenOutputStream; -use std::path::PathBuf; -use tokenizers::Tokenizer; - -#[cfg(test)] -mod tests { - use super::*; - - // Helper function to create a simple tokenizer for testing - fn create_test_tokenizer() -> Result { - // Create a simple tokenizer from the pretrained model - // This uses the tokenizer from the Hugging Face hub - let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap(); - Ok(tokenizer) - } - - #[test] - fn test_new_token_output_stream() -> Result<()> { - let tokenizer = create_test_tokenizer()?; - let token_stream = TokenOutputStream::new(tokenizer); - - // Check that the token stream was created successfully - assert!(token_stream.tokenizer().get_vocab(true).len() > 0); - Ok(()) - } - - #[test] - fn test_clear() -> Result<()> { - let tokenizer = create_test_tokenizer()?; - let mut token_stream = TokenOutputStream::new(tokenizer); - - // Add a token - let token_id = token_stream.get_token("").unwrap(); - token_stream.next_token(token_id)?; - - // Clear the stream - token_stream.clear(); - - // Check that the stream is empty by trying to decode all - let decoded = token_stream.decode_all()?; - assert_eq!(decoded, ""); - - Ok(()) - } - - #[test] - fn test_get_token() -> Result<()> { - let tokenizer = create_test_tokenizer()?; - let token_stream = TokenOutputStream::new(tokenizer); - - // Get a token that should exist - let eos_token = token_stream.get_token(""); - assert!(eos_token.is_some()); - - // Get a token that shouldn't exist - let nonexistent_token = token_stream.get_token(""); - assert!(nonexistent_token.is_none()); - - Ok(()) - } - - #[test] - fn test_next_token_and_decode() -> Result<()> { - let tokenizer = create_test_tokenizer()?; - let mut token_stream = TokenOutputStream::new(tokenizer); - - // Get some tokens - let hello_tokens = token_stream - .tokenizer() - .encode("Hello world", true) - .unwrap(); - let token_ids = hello_tokens.get_ids(); - - // Add tokens one by one - let mut output = String::new(); - for &token_id in token_ids { - if let Some(text) = token_stream.next_token(token_id)? { - output.push_str(&text); - } - } - - // Get any remaining text - if let Some(rest) = token_stream.decode_rest()? { - output.push_str(&rest); - } - - // Check the output - assert!(!output.is_empty()); - assert_eq!(output.trim(), "Hello world"); - - Ok(()) - } - - #[test] - fn test_decode_all() -> Result<()> { - let tokenizer = create_test_tokenizer()?; - let mut token_stream = TokenOutputStream::new(tokenizer); - - // Get some tokens - let hello_tokens = token_stream - .tokenizer() - .encode("Hello world", true) - .unwrap(); - let token_ids = hello_tokens.get_ids(); - - // Add tokens one by one - for &token_id in token_ids { - token_stream.next_token(token_id)?; - } - - // Decode all - let decoded = token_stream.decode_all()?; - - // Check the output - assert_eq!(decoded.trim(), "Hello world"); - - Ok(()) - } - - #[test] - fn test_into_inner() -> Result<()> { - let tokenizer = create_test_tokenizer()?; - let token_stream = TokenOutputStream::new(tokenizer); - - // Get the inner tokenizer - let inner_tokenizer = token_stream.into_inner(); - - // Check that the inner tokenizer works - let encoded = inner_tokenizer.encode("Test", true).unwrap(); - assert!(encoded.get_ids().len() > 0); - - Ok(()) - } -} diff --git a/crates/llama-runner/Cargo.toml b/crates/llama-runner/Cargo.toml index d4d69f2..168f65a 100644 --- a/crates/llama-runner/Cargo.toml +++ b/crates/llama-runner/Cargo.toml @@ -18,11 +18,6 @@ candle-core = { git = "https://github.com/huggingface/candle.git", features = [" candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] } candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] } -[target.'cfg(not(target_os = "macos"))'.dependencies] -candle-core = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true } -candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true } -candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true } - [features] default = [] cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]