From 2deecb5e51124d9623e9d4981288e8e8dd8d71ce Mon Sep 17 00:00:00 2001 From: geoffsee <> Date: Mon, 1 Sep 2025 22:29:54 -0400 Subject: [PATCH] chat client only displays available models --- Cargo.lock | 902 ++++++++++++-- Cargo.toml | 2 +- README.md | 5 + crates/gemma-runner/Cargo.toml | 4 +- crates/gemma-runner/src/gemma_api.rs | 23 +- crates/inference-engine/Cargo.toml | 4 +- crates/inference-engine/src/model.rs | 183 +-- crates/inference-engine/src/server.rs | 496 ++++---- crates/llama-runner/Cargo.toml | 4 +- crates/llama-runner/src/llama_api.rs | 2 +- .../predict-otron-9000/src/standalone_mode.rs | 3 +- crates/utils/Cargo.toml | 88 ++ crates/utils/src/audio.rs | 138 +++ crates/utils/src/bs1770.rs | 506 ++++++++ crates/utils/src/coco_classes.rs | 82 ++ crates/utils/src/imagenet.rs | 1056 +++++++++++++++++ crates/utils/src/lib.rs | 156 +++ crates/utils/src/main.rs | 3 + crates/utils/src/token_output_stream.rs | 85 ++ crates/utils/src/wav.rs | 56 + 20 files changed, 3314 insertions(+), 484 deletions(-) create mode 100644 crates/utils/Cargo.toml create mode 100644 crates/utils/src/audio.rs create mode 100644 crates/utils/src/bs1770.rs create mode 100644 crates/utils/src/coco_classes.rs create mode 100644 crates/utils/src/imagenet.rs create mode 100644 crates/utils/src/lib.rs create mode 100644 crates/utils/src/main.rs create mode 100644 crates/utils/src/token_output_stream.rs create mode 100644 crates/utils/src/wav.rs diff --git a/Cargo.lock b/Cargo.lock index ddfafa3..b195186 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18,6 +18,12 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "366ffbaa4442f4684d91e2cd7c5ea7c4ed8add41959a31447066e279e432b618" +[[package]] +name = "accelerate-src" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "415ed64958754dbe991900f3940677e6a7eefb4d7367afd70d642677b0c7d19d" + [[package]] name = "addr2line" version = "0.24.2" @@ -87,6 +93,21 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "0.6.20" @@ -193,6 +214,12 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "assert_float_eq" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10d2119f741b79fe9907f5396d19bffcb46568cfcc315e78677d731972ac7085" + [[package]] name = "async-lock" version = "3.4.1" @@ -416,6 +443,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d27c3610c36aee21ce8ac510e6224498de4228ad772a171ed65643a24693a5a8" +[[package]] +name = "base16ct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" + [[package]] name = "base64" version = "0.13.1" @@ -441,7 +474,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "rustc-hash", + "rustc-hash 2.1.1", "shlex", "syn 2.0.106", ] @@ -527,6 +560,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" dependencies = [ "memchr", + "regex-automata", "serde", ] @@ -542,6 +576,12 @@ version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" +[[package]] +name = "by_address" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64fa3c856b712db6612c019f14756e64e4bcea13337a6b33b696333a9eaa2d06" + [[package]] name = "bytemuck" version = "1.23.2" @@ -592,21 +632,27 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a9f51e2ecf6efe9737af8f993433c839f956d2b6ed4fd2dd4a7c6d8b0fa667ff" dependencies = [ + "accelerate-src", "byteorder", "candle-kernels 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", - "cudarc", + "candle-metal-kernels 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", + "cudarc 0.16.6", "gemm 0.17.1", "half", + "intel-mkl-src", + "libc", "memmap2", + "metal 0.27.0", "num-traits", "num_cpus", "rand 0.9.2", "rand_distr 0.5.1", "rayon", - "safetensors", + "safetensors 0.4.5", "thiserror 1.0.69", "ug", "ug-cuda", + "ug-metal", "yoke 0.7.5", "zip", ] @@ -618,8 +664,8 @@ source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0 dependencies = [ "byteorder", "candle-kernels 0.9.1 (git+https://github.com/huggingface/candle.git)", - "candle-metal-kernels", - "cudarc", + "candle-metal-kernels 0.9.1 (git+https://github.com/huggingface/candle.git)", + "cudarc 0.16.6", "float8", "gemm 0.17.1", "half", @@ -631,7 +677,7 @@ dependencies = [ "rand 0.9.2", "rand_distr 0.5.1", "rayon", - "safetensors", + "safetensors 0.4.5", "thiserror 1.0.69", "ug", "ug-cuda", @@ -640,26 +686,6 @@ dependencies = [ "zip", ] -[[package]] -name = "candle-examples" -version = "0.9.1" -source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af" -dependencies = [ - "anyhow", - "candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)", - "candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)", - "candle-transformers", - "csv", - "hf-hub 0.4.3", - "image", - "num-traits", - "rayon", - "safetensors", - "serde", - "serde_json", - "tokenizers 0.21.4", -] - [[package]] name = "candle-flash-attn" version = "0.9.1" @@ -689,6 +715,19 @@ dependencies = [ "bindgen_cuda", ] +[[package]] +name = "candle-metal-kernels" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a323ee9c813707f73b6e59300661b354a70410f31fe4135170c4eda8a061534" +dependencies = [ + "half", + "metal 0.27.0", + "once_cell", + "thiserror 1.0.69", + "tracing", +] + [[package]] name = "candle-metal-kernels" version = "0.9.1" @@ -709,11 +748,15 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1980d53280c8f9e2c6cbe1785855d7ff8010208b46e21252b978badf13ad69d" dependencies = [ + "accelerate-src", "candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", + "candle-metal-kernels 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", "half", + "intel-mkl-src", + "metal 0.27.0", "num-traits", "rayon", - "safetensors", + "safetensors 0.4.5", "serde", "thiserror 1.0.69", ] @@ -724,12 +767,12 @@ version = "0.9.1" source = "git+https://github.com/huggingface/candle.git#06387ae55d8db4b5d29564d0e1e350246bc458af" dependencies = [ "candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)", - "candle-metal-kernels", + "candle-metal-kernels 0.9.1 (git+https://github.com/huggingface/candle.git)", "half", "num-traits", "objc2-metal", "rayon", - "safetensors", + "safetensors 0.4.5", "serde", "thiserror 1.0.69", ] @@ -746,6 +789,28 @@ dependencies = [ "prost-build", ] +[[package]] +name = "candle-transformers" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "186cb80045dbe47e0b387ea6d3e906f02fb3056297080d9922984c90e90a72b0" +dependencies = [ + "accelerate-src", + "byteorder", + "candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", + "candle-flash-attn", + "candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", + "fancy-regex", + "intel-mkl-src", + "num-traits", + "rand 0.9.2", + "rayon", + "serde", + "serde_json", + "serde_plain", + "tracing", +] + [[package]] name = "candle-transformers" version = "0.9.1" @@ -840,6 +905,20 @@ dependencies = [ "web-sys", ] +[[package]] +name = "chrono" +version = "0.4.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-link", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -853,9 +932,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.45" +version = "4.5.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fc0e74a703892159f5ae7d3aac52c8e6c392f5ae5f359c70b5881d60aaac318" +checksum = "2c5e4fcf9c21d2e544ca1ee9d8552de13019a42aa7dbf32747fa7aaf1df76e57" dependencies = [ "clap_builder", "clap_derive", @@ -863,9 +942,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.44" +version = "4.5.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3e7f4214277f3c7aa526a59dd3fbe306a370daee1f8b7b8c987069cd8e888a8" +checksum = "fecb53a0e6fcfb055f686001bc2e2592fa527efaf38dbe81a6a9563562e57d41" dependencies = [ "anstream", "anstyle", @@ -1223,6 +1302,15 @@ dependencies = [ "libloading", ] +[[package]] +name = "cudarc" +version = "0.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72ba848ae5c6f3cb36e71eab5f268763e3fabcabe3f7bc683e16f7fa3d46281e" +dependencies = [ + "libloading", +] + [[package]] name = "custom_derive" version = "0.1.7" @@ -1362,6 +1450,15 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "directories" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a49173b84e034382284f27f1af4dcbbd231ffa358c0fe316541a7337f376a35" +dependencies = [ + "dirs-sys 0.4.1", +] + [[package]] name = "dirs" version = "5.0.1" @@ -1502,6 +1599,18 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "enterpolation" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fadf5c8cbf7c6765ff05ccbd8811cd7bc3a763e4671755204552bf8740d042a" +dependencies = [ + "assert_float_eq", + "num-traits", + "serde", + "topology-traits", +] + [[package]] name = "enum-as-inner" version = "0.6.1" @@ -1514,6 +1623,29 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "env_filter" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "jiff", + "log", +] + [[package]] name = "equator" version = "0.4.2" @@ -1625,10 +1757,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" dependencies = [ "bit-set", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", + "regex-automata", + "regex-syntax", ] +[[package]] +name = "fast-srgb8" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd2e7510819d6fbf51a5545c8f922716ecfb14df168a3242f7d33e0239efe6a1" + [[package]] name = "fastembed" version = "4.9.1" @@ -1694,7 +1832,7 @@ name = "float8" version = "0.2.1" source = "git+https://github.com/zackangelo/float8?branch=cudarc_0_16#03c1f5fe7cdb2f9cb690823fdd40593be57c408f" dependencies = [ - "cudarc", + "cudarc 0.16.6", "half", "num-traits", "rand 0.9.2", @@ -2097,16 +2235,16 @@ version = "0.1.3" dependencies = [ "anyhow", "candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)", - "candle-examples", "candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)", - "candle-transformers", + "candle-transformers 0.9.1 (git+https://github.com/huggingface/candle.git)", "clap", "hf-hub 0.4.3", "serde_json", - "tokenizers 0.21.4", + "tokenizers 0.22.0", "tracing", "tracing-chrome", "tracing-subscriber", + "utils", ] [[package]] @@ -2146,6 +2284,18 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "getset" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf0fc11e47561d47397154977bc219f4cf809b2974facc3ccb3b89e2436f912" +dependencies = [ + "proc-macro-error2", + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "gif" version = "0.13.3" @@ -2177,8 +2327,8 @@ dependencies = [ "aho-corasick", "bstr", "log", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", + "regex-automata", + "regex-syntax", ] [[package]] @@ -2331,6 +2481,12 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "hound" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" + [[package]] name = "html-escape" version = "0.2.13" @@ -2488,6 +2644,30 @@ dependencies = [ "windows-registry", ] +[[package]] +name = "iana-time-zone" +version = "0.1.63" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core 0.61.2", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "icu_collections" version = "2.0.0" @@ -2652,6 +2832,24 @@ dependencies = [ "rand_distr 0.4.3", ] +[[package]] +name = "imageproc" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2393fb7808960751a52e8a154f67e7dd3f8a2ef9bd80d1553078a7b4e8ed3f0d" +dependencies = [ + "ab_glyph", + "approx", + "getrandom 0.2.16", + "image", + "itertools 0.12.1", + "nalgebra", + "num", + "rand 0.8.5", + "rand_distr 0.4.3", + "rayon", +] + [[package]] name = "imgref" version = "1.11.0" @@ -2682,6 +2880,12 @@ dependencies = [ "web-time", ] +[[package]] +name = "indoc" +version = "2.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" + [[package]] name = "inference-engine" version = "0.1.3" @@ -2695,13 +2899,13 @@ dependencies = [ "candle-flash-attn", "candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)", "candle-onnx", - "candle-transformers", + "candle-transformers 0.9.1 (git+https://github.com/huggingface/candle.git)", "clap", "cpal", "either", "futures-util", "gemma-runner", - "imageproc", + "imageproc 0.24.0", "llama-runner", "memmap2", "pdf2image", @@ -2731,6 +2935,28 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "intel-mkl-src" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ee70586cd5b3e772a8739a1bd43eaa90d4f4bf0fb2a4edc202e979937ee7f5e" +dependencies = [ + "anyhow", + "intel-mkl-tool", + "ocipkg", +] + +[[package]] +name = "intel-mkl-tool" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "887a16b4537d82227af54d3372971cfa5e0cde53322e60f57584056c16ada1b4" +dependencies = [ + "anyhow", + "log", + "walkdir", +] + [[package]] name = "interpolate_name" version = "0.2.4" @@ -2832,6 +3058,30 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +[[package]] +name = "jiff" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be1f93b8b1eb69c77f24bbb0afdf66f54b632ee39af40ca21c4365a1d7347e49" +dependencies = [ + "jiff-static", + "log", + "portable-atomic", + "portable-atomic-util", + "serde", +] + +[[package]] +name = "jiff-static" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03343451ff899767262ec32146f6d559dd759fdadf42ff0e227c7c48f72594b4" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "jni" version = "0.21.1" @@ -2915,7 +3165,7 @@ dependencies = [ "paste", "rand 0.9.2", "reactive_graph", - "rustc-hash", + "rustc-hash 2.1.1", "rustc_version", "send_wrapper", "serde", @@ -3137,7 +3387,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.53.3", ] [[package]] @@ -3182,7 +3432,7 @@ dependencies = [ "anyhow", "candle-core 0.9.1 (git+https://github.com/huggingface/candle.git)", "candle-nn 0.9.1 (git+https://github.com/huggingface/candle.git)", - "candle-transformers", + "candle-transformers 0.9.1 (git+https://github.com/huggingface/candle.git)", "clap", "hf-hub 0.3.2", "serde_json", @@ -3279,11 +3529,11 @@ dependencies = [ [[package]] name = "matchers" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" dependencies = [ - "regex-automata 0.1.10", + "regex-automata", ] [[package]] @@ -3320,14 +3570,38 @@ checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" [[package]] name = "memmap2" -version = "0.9.7" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "483758ad303d734cec05e5c12b41d7e93e6a6390c5e9dae6bdeb7c1259012d28" +checksum = "843a98750cd611cc2965a8213b53b43e715f13c37a9e096c6408e69990961db7" dependencies = [ "libc", "stable_deref_trait", ] +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "metal" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c43f73953f8cbe511f021b58f18c3ce1c3d1ae13fe953293e13345bf83217f25" +dependencies = [ + "bitflags 2.9.2", + "block", + "core-graphics-types", + "foreign-types 0.5.0", + "log", + "objc", + "paste", +] + [[package]] name = "metal" version = "0.29.0" @@ -3536,12 +3810,11 @@ checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "d4a28e057d01f97e61255210fcff094d74ed0466038633e95017f5beb68e4399" dependencies = [ - "overload", - "winapi", + "windows-sys 0.52.0", ] [[package]] @@ -3675,6 +3948,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" dependencies = [ "malloc_buf", + "objc_exception", ] [[package]] @@ -3730,6 +4004,15 @@ dependencies = [ "objc2-foundation", ] +[[package]] +name = "objc_exception" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad970fb455818ad6cba4c122ad012fae53ae8b4795f86378bce65e4f6bab2ca4" +dependencies = [ + "cc", +] + [[package]] name = "object" version = "0.36.7" @@ -3762,6 +4045,48 @@ dependencies = [ "cc", ] +[[package]] +name = "oci-spec" +version = "0.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdf88ddc01cc6bccbe1044adb6a29057333f523deadcb4953c011a73158cfa5e" +dependencies = [ + "derive_builder", + "getset", + "serde", + "serde_json", + "strum", + "strum_macros", + "thiserror 1.0.69", +] + +[[package]] +name = "ocipkg" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bb3293021f06540803301af45e7ab81693d50e89a7398a3420bdab139e7ba5e" +dependencies = [ + "base16ct", + "base64 0.22.1", + "chrono", + "directories", + "flate2", + "lazy_static", + "log", + "oci-spec", + "regex", + "serde", + "serde_json", + "sha2", + "tar", + "thiserror 1.0.69", + "toml 0.8.23", + "ureq", + "url", + "uuid", + "walkdir", +] + [[package]] name = "oco_ref" version = "0.2.1" @@ -3886,12 +4211,6 @@ dependencies = [ "ureq", ] -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "owned_ttf_parser" version = "0.25.1" @@ -3901,6 +4220,30 @@ dependencies = [ "ttf-parser", ] +[[package]] +name = "palette" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cbf71184cc5ecc2e4e1baccdb21026c20e5fc3dcf63028a086131b3ab00b6e6" +dependencies = [ + "approx", + "fast-srgb8", + "palette_derive", + "phf", +] + +[[package]] +name = "palette_derive" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5030daf005bface118c096f510ffb781fc28f9ab6a32ab224d8631be6851d30" +dependencies = [ + "by_address", + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "parking" version = "2.2.1" @@ -3970,6 +4313,48 @@ dependencies = [ "indexmap", ] +[[package]] +name = "phf" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +dependencies = [ + "phf_macros", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" +dependencies = [ + "phf_shared", + "rand 0.8.5", +] + +[[package]] +name = "phf_macros" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216" +dependencies = [ + "phf_generator", + "phf_shared", + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "phf_shared" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project" version = "1.1.10" @@ -4086,6 +4471,15 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "primal-check" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0d895b311e3af9902528fbb8f928688abbd95872819320517cc24ca6b2bd08" +dependencies = [ + "num-integer", +] + [[package]] name = "proc-macro-crate" version = "3.3.0" @@ -4272,6 +4666,69 @@ dependencies = [ "version_check", ] +[[package]] +name = "pyo3" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.22.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.106", +] + [[package]] name = "qoi" version = "0.4.1" @@ -4298,7 +4755,7 @@ dependencies = [ "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash", + "rustc-hash 2.1.1", "rustls", "socket2 0.5.10", "thiserror 2.0.15", @@ -4318,7 +4775,7 @@ dependencies = [ "lru-slab", "rand 0.9.2", "ring", - "rustc-hash", + "rustc-hash 2.1.1", "rustls", "rustls-pki-types", "slab", @@ -4339,7 +4796,7 @@ dependencies = [ "once_cell", "socket2 0.5.10", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4588,7 +5045,7 @@ dependencies = [ "indexmap", "or_poisoned", "pin-project-lite", - "rustc-hash", + "rustc-hash 2.1.1", "rustc_version", "send_wrapper", "serde", @@ -4610,7 +5067,7 @@ dependencies = [ "paste", "reactive_graph", "reactive_stores_macro", - "rustc-hash", + "rustc-hash 2.1.1", "send_wrapper", ] @@ -4627,6 +5084,15 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "realfft" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f821338fddb99d089116342c46e9f1fbf3828dba077674613e734e01d6ea8677" +dependencies = [ + "rustfft", +] + [[package]] name = "reborrow" version = "0.5.5" @@ -4672,17 +5138,8 @@ checksum = "23d7fd106d8c02486a8d64e778353d1cffe08ce79ac2e82f540c86d0facf6912" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", + "regex-automata", + "regex-syntax", ] [[package]] @@ -4693,15 +5150,9 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.5", + "regex-syntax", ] -[[package]] -name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - [[package]] name = "regex-syntax" version = "0.8.5" @@ -4807,6 +5258,30 @@ dependencies = [ "thiserror 2.0.15", ] +[[package]] +name = "rubato" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5d18b486e7d29a408ef3f825bc1327d8f87af091c987ca2f5b734625940e234" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "realfft", +] + +[[package]] +name = "rubato" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5258099699851cfd0082aeb645feb9c084d9a5e1f1b8d5372086b989fc5e56a1" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "realfft", +] + [[package]] name = "rust-embed" version = "8.7.2" @@ -4849,6 +5324,12 @@ version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc-hash" version = "2.1.1" @@ -4864,6 +5345,20 @@ dependencies = [ "semver", ] +[[package]] +name = "rustfft" +version = "6.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6f140db74548f7c9d7cce60912c9ac414e74df5e718dc947d514b051b42f3f4" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "primal-check", + "strength_reduce", + "transpose", +] + [[package]] name = "rustix" version = "1.0.8" @@ -4956,6 +5451,16 @@ dependencies = [ "serde_json", ] +[[package]] +name = "safetensors" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "172dd94c5a87b5c79f945c863da53b2ebc7ccef4eca24ac63cca66a41aab2178" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "same-file" version = "1.0.6" @@ -5278,6 +5783,12 @@ dependencies = [ "quote", ] +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + [[package]] name = "slab" version = "0.4.11" @@ -5360,12 +5871,37 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + [[package]] name = "strsim" version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.106", +] + [[package]] name = "subtle" version = "2.6.1" @@ -5707,7 +6243,7 @@ dependencies = [ "paste", "reactive_graph", "reactive_stores", - "rustc-hash", + "rustc-hash 2.1.1", "rustc_version", "send_wrapper", "slotmap", @@ -5733,6 +6269,28 @@ version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +[[package]] +name = "tekken-rs" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49623843103837f53f7ebe8cfafc19ccff28ff0e15e7c4b9f6ad21e36fbfde3a" +dependencies = [ + "anyhow", + "base64 0.22.1", + "env_logger", + "hound", + "log", + "ndarray", + "regex", + "rubato 0.16.2", + "rustc-hash 1.1.0", + "rustfft", + "serde", + "serde_json", + "thiserror 2.0.15", + "tiktoken-rs", +] + [[package]] name = "tempfile" version = "3.20.0" @@ -5743,7 +6301,7 @@ dependencies = [ "getrandom 0.3.3", "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -5815,6 +6373,21 @@ dependencies = [ "weezl", ] +[[package]] +name = "tiktoken-rs" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25563eeba904d770acf527e8b370fe9a5547bacd20ff84a0b6c3bc41288e5625" +dependencies = [ + "anyhow", + "base64 0.22.1", + "bstr", + "fancy-regex", + "lazy_static", + "regex", + "rustc-hash 1.1.0", +] + [[package]] name = "tinystr" version = "0.8.1" @@ -5862,7 +6435,7 @@ dependencies = [ "rayon", "rayon-cond 0.3.0", "regex", - "regex-syntax 0.8.5", + "regex-syntax", "serde", "serde_json", "spm_precompiled", @@ -5886,6 +6459,39 @@ dependencies = [ "esaxx-rs", "getrandom 0.3.3", "hf-hub 0.4.3", + "itertools 0.14.0", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand 0.9.2", + "rayon", + "rayon-cond 0.4.0", + "regex", + "regex-syntax", + "serde", + "serde_json", + "spm_precompiled", + "thiserror 2.0.15", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + +[[package]] +name = "tokenizers" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af10f51be57162b69d90a15cb226eef12c9e4faecbd5e3ea98a86bfb920b3d71" +dependencies = [ + "ahash", + "aho-corasick", + "compact_str", + "dary_heap", + "derive_builder", + "esaxx-rs", + "getrandom 0.3.3", "indicatif", "itertools 0.14.0", "log", @@ -5897,7 +6503,7 @@ dependencies = [ "rayon", "rayon-cond 0.4.0", "regex", - "regex-syntax 0.8.5", + "regex-syntax", "serde", "serde_json", "spm_precompiled", @@ -6067,6 +6673,15 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" +[[package]] +name = "topology-traits" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0c8dab428531e30115d3bfd6e3092b55256a4a7b4f87cb3abe37a000b1f4032" +dependencies = [ + "num-traits", +] + [[package]] name = "tower" version = "0.5.2" @@ -6180,14 +6795,14 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "matchers", "nu-ansi-term", "once_cell", - "regex", + "regex-automata", "sharded-slab", "smallvec", "thread_local", @@ -6196,6 +6811,16 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -6265,7 +6890,7 @@ dependencies = [ "num-traits", "num_cpus", "rayon", - "safetensors", + "safetensors 0.4.5", "serde", "thiserror 1.0.69", "tracing", @@ -6278,7 +6903,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14053653d0b7fa7b21015aa9a62edc8af2f60aa6f9c54e66386ecce55f22ed29" dependencies = [ - "cudarc", + "cudarc 0.16.6", "half", "serde", "thiserror 1.0.69", @@ -6292,7 +6917,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76daec3c7a32a1b4a0e3307b6b057fa067aa64e750713987410a2c402e5cd731" dependencies = [ "half", - "metal", + "metal 0.29.0", "objc", "serde", "thiserror 1.0.69", @@ -6344,6 +6969,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + [[package]] name = "untrusted" version = "0.9.0" @@ -6405,6 +7036,50 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "utils" +version = "0.0.0" +dependencies = [ + "ab_glyph", + "accelerate-src", + "anyhow", + "bindgen_cuda", + "byteorder", + "candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", + "candle-flash-attn", + "candle-nn 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", + "candle-onnx", + "candle-transformers 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", + "clap", + "cpal", + "csv", + "cudarc 0.17.3", + "enterpolation", + "half", + "hf-hub 0.4.3", + "image", + "imageproc 0.25.0", + "intel-mkl-src", + "memmap2", + "num-traits", + "palette", + "pdf2image", + "pyo3", + "rand 0.9.2", + "rayon", + "rubato 0.15.0", + "safetensors 0.6.2", + "serde", + "serde_json", + "symphonia", + "tekken-rs", + "tokenizers 0.22.0", + "tokio", + "tracing", + "tracing-chrome", + "tracing-subscriber", +] + [[package]] name = "utoipa" version = "4.2.3" @@ -6695,7 +7370,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] @@ -6710,7 +7385,7 @@ version = "0.54.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9252e5725dbed82865af151df558e754e4a3c2c30818359eb17465f1346a1b49" dependencies = [ - "windows-core", + "windows-core 0.54.0", "windows-targets 0.52.6", ] @@ -6724,6 +7399,41 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-core" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result 0.3.4", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "windows-interface" +version = "0.59.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "windows-link" version = "0.1.3" diff --git a/Cargo.toml b/Cargo.toml index 60cfab9..6611cfd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ members = [ "crates/gemma-runner", "crates/cli", "crates/chat-ui" -] +, "crates/utils"] default-members = ["crates/predict-otron-9000"] resolver = "2" diff --git a/README.md b/README.md index 220e17a..091a908 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,11 @@ Stability is currently best effort. Many models require unique configuration. Wh A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces. +~~~shell +./scripts/run.sh +~~~ + + ## Project Overview The predict-otron-9000 is a flexible AI platform that provides: diff --git a/crates/gemma-runner/Cargo.toml b/crates/gemma-runner/Cargo.toml index fce2a32..57154db 100644 --- a/crates/gemma-runner/Cargo.toml +++ b/crates/gemma-runner/Cargo.toml @@ -10,15 +10,15 @@ edition = "2021" candle-core = { git = "https://github.com/huggingface/candle.git" } candle-nn = { git = "https://github.com/huggingface/candle.git" } candle-transformers = { git = "https://github.com/huggingface/candle.git" } -candle-examples = { git = "https://github.com/huggingface/candle.git" } hf-hub = "0.4" -tokenizers = "0.21" +tokenizers = "0.22.0" anyhow = "1.0" clap = { version = "4.0", features = ["derive", "string"] } serde_json = "1.0" tracing = "0.1" tracing-chrome = "0.7" tracing-subscriber = "0.3" +utils = {path = "../utils"} [target.'cfg(target_os = "macos")'.dependencies] candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] } diff --git a/crates/gemma-runner/src/gemma_api.rs b/crates/gemma-runner/src/gemma_api.rs index 1c524ac..5f91ee5 100644 --- a/crates/gemma-runner/src/gemma_api.rs +++ b/crates/gemma-runner/src/gemma_api.rs @@ -10,16 +10,17 @@ use candle_transformers::models::gemma3::{Config as Config3, Model as Model3}; use clap::ValueEnum; // Removed gemma_cli import as it's not needed for the API -use candle_core::{utils, DType, Device, Tensor}; -use candle_examples::token_output_stream::TokenOutputStream; +use candle_core::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use std::io::Write; -use tokenizers::Tokenizer; use std::sync::mpsc::{self, Receiver, Sender}; use std::thread; +use tokenizers::Tokenizer; +use utils::hub_load_safetensors; +use utils::token_output_stream::TokenOutputStream; #[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] pub enum WhichModel { @@ -85,9 +86,9 @@ pub struct TextGeneration { fn device(cpu: bool) -> Result { if cpu { Ok(Device::Cpu) - } else if utils::cuda_is_available() { + } else if candle_core::utils::cuda_is_available() { Ok(Device::new_cuda(0)?) - } else if utils::metal_is_available() { + } else if candle_core::utils::metal_is_available() { Ok(Device::new_metal(0)?) } else { Ok(Device::Cpu) @@ -98,7 +99,7 @@ impl TextGeneration { #[allow(clippy::too_many_arguments)] fn new( model: Model, - tokenizer: Tokenizer, + tokenizer: tokenizers::Tokenizer, seed: u64, temp: Option, top_p: Option, @@ -262,10 +263,10 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result Result vec![repo.get("model.safetensors")?], - _ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + _ => hub_load_safetensors(&repo, "model.safetensors.index.json")?, }; println!("Retrieved files in {:?}", start.elapsed()); diff --git a/crates/inference-engine/Cargo.toml b/crates/inference-engine/Cargo.toml index e5eed78..1857d59 100644 --- a/crates/inference-engine/Cargo.toml +++ b/crates/inference-engine/Cargo.toml @@ -31,8 +31,8 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] } uuid = { version = "1.7.0", features = ["v4"] } reborrow = "0.5.5" futures-util = "0.3.31" -gemma-runner = { path = "../gemma-runner" } -llama-runner = { path = "../llama-runner" } +gemma-runner = { path = "../gemma-runner", features = ["metal"] } +llama-runner = { path = "../llama-runner", features = ["metal"]} [target.'cfg(target_os = "macos")'.dependencies] candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] } diff --git a/crates/inference-engine/src/model.rs b/crates/inference-engine/src/model.rs index 283e63d..89270ff 100644 --- a/crates/inference-engine/src/model.rs +++ b/crates/inference-engine/src/model.rs @@ -1,49 +1,9 @@ -// use candle_core::Tensor; use candle_transformers::models::csm::{LlamaConfig, LlamaModel}; use candle_transformers::models::gemma::{Config as Config1, Model as Model1}; use candle_transformers::models::gemma2::{Config as Config2, Model as Model2}; use candle_transformers::models::gemma3::{Config as Config3, Model as Model3}; -#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] -pub enum Which { - #[value(name = "2b")] - Base2B, - #[value(name = "7b")] - Base7B, - #[value(name = "2b-it")] - Instruct2B, - #[value(name = "7b-it")] - Instruct7B, - #[value(name = "1.1-2b-it")] - InstructV1_1_2B, - #[value(name = "1.1-7b-it")] - InstructV1_1_7B, - #[value(name = "code-2b")] - CodeBase2B, - #[value(name = "code-7b")] - CodeBase7B, - #[value(name = "code-2b-it")] - CodeInstruct2B, - #[value(name = "code-7b-it")] - CodeInstruct7B, - #[value(name = "2-2b")] - BaseV2_2B, - #[value(name = "2-2b-it")] - InstructV2_2B, - #[value(name = "2-9b")] - BaseV2_9B, - #[value(name = "2-9b-it")] - InstructV2_9B, - #[value(name = "3-1b")] - BaseV3_1B, - #[value(name = "3-1b-it")] - InstructV3_1B, - #[value(name = "llama-3.2-1b-it")] - LlamaInstruct3_2_1B, - #[value(name = "llama-3.2-3b-it")] - LlamaInstruct3_2_3B, -} - +#[derive(Clone, Debug)] pub enum Model { V1(Model1), V2(Model2), @@ -66,48 +26,127 @@ impl Model { } } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Family { + GemmaV1, + GemmaV2, + GemmaV3, + Llama, +} + +#[derive(Clone, Copy, Debug)] +pub struct ModelMeta { + pub id: &'static str, + pub family: Family, + pub instruct: bool, +} + +const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta { + ModelMeta { id, family, instruct } +} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +pub enum Which { + // Gemma 1.x + #[value(name = "2b")] + Base2B, + #[value(name = "7b")] + Base7B, + #[value(name = "2b-it")] + Instruct2B, + #[value(name = "7b-it")] + Instruct7B, + #[value(name = "1.1-2b-it")] + InstructV1_1_2B, + #[value(name = "1.1-7b-it")] + InstructV1_1_7B, + + // CodeGemma + #[value(name = "code-2b")] + CodeBase2B, + #[value(name = "code-7b")] + CodeBase7B, + #[value(name = "code-2b-it")] + CodeInstruct2B, + #[value(name = "code-7b-it")] + CodeInstruct7B, + + // Gemma 2 + #[value(name = "2-2b")] + BaseV2_2B, + #[value(name = "2-2b-it")] + InstructV2_2B, + #[value(name = "2-9b")] + BaseV2_9B, + #[value(name = "2-9b-it")] + InstructV2_9B, + + // Gemma 3 + #[value(name = "3-1b")] + BaseV3_1B, + #[value(name = "3-1b-it")] + InstructV3_1B, + + // Llama 3.2 (use aliases instead of duplicate variants) + #[value(name = "llama-3.2-1b")] + Llama32_1B, + #[value(name = "llama-3.2-1b-it", alias = "llama-3.2-1b-instruct")] + Llama32_1BInstruct, + #[value(name = "llama-3.2-3b")] + Llama32_3B, + #[value(name = "llama-3.2-3b-it", alias = "llama-3.2-3b-instruct")] + Llama32_3BInstruct, +} + impl Which { - pub fn to_model_id(&self) -> String { + pub const fn meta(&self) -> ModelMeta { + use Family::*; match self { - Self::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(), - Self::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(), - Self::Base2B => "google/gemma-2b".to_string(), - Self::Base7B => "google/gemma-7b".to_string(), - Self::Instruct2B => "google/gemma-2b-it".to_string(), - Self::Instruct7B => "google/gemma-7b-it".to_string(), - Self::CodeBase2B => "google/codegemma-2b".to_string(), - Self::CodeBase7B => "google/codegemma-7b".to_string(), - Self::CodeInstruct2B => "google/codegemma-2b-it".to_string(), - Self::CodeInstruct7B => "google/codegemma-7b-it".to_string(), - Self::BaseV2_2B => "google/gemma-2-2b".to_string(), - Self::InstructV2_2B => "google/gemma-2-2b-it".to_string(), - Self::BaseV2_9B => "google/gemma-2-9b".to_string(), - Self::InstructV2_9B => "google/gemma-2-9b-it".to_string(), - Self::BaseV3_1B => "google/gemma-3-1b-pt".to_string(), - Self::InstructV3_1B => "google/gemma-3-1b-it".to_string(), - Self::LlamaInstruct3_2_1B => "meta-llama/Llama-3.2-1B-Instruct".to_string(), - Self::LlamaInstruct3_2_3B => "meta-llama/Llama-3.2-3B-Instruct".to_string(), + // Gemma 1.x + Self::Base2B => m("google/gemma-2b", GemmaV1, false), + Self::Base7B => m("google/gemma-7b", GemmaV1, false), + Self::Instruct2B => m("google/gemma-2b-it", GemmaV1, true), + Self::Instruct7B => m("google/gemma-7b-it", GemmaV1, true), + Self::InstructV1_1_2B => m("google/gemma-1.1-2b-it", GemmaV1, true), + Self::InstructV1_1_7B => m("google/gemma-1.1-7b-it", GemmaV1, true), + + // CodeGemma + Self::CodeBase2B => m("google/codegemma-2b", GemmaV1, false), + Self::CodeBase7B => m("google/codegemma-7b", GemmaV1, false), + Self::CodeInstruct2B => m("google/codegemma-2b-it", GemmaV1, true), + Self::CodeInstruct7B => m("google/codegemma-7b-it", GemmaV1, true), + + // Gemma 2 + Self::BaseV2_2B => m("google/gemma-2-2b", GemmaV2, false), + Self::InstructV2_2B => m("google/gemma-2-2b-it", GemmaV2, true), + Self::BaseV2_9B => m("google/gemma-2-9b", GemmaV2, false), + Self::InstructV2_9B => m("google/gemma-2-9b-it", GemmaV2, true), + + // Gemma 3 + Self::BaseV3_1B => m("google/gemma-3-1b-pt", GemmaV3, false), + Self::InstructV3_1B => m("google/gemma-3-1b-it", GemmaV3, true), + + // Llama 3.2 + Self::Llama32_1B => m("meta-llama/Llama-3.2-1B", Llama, false), + Self::Llama32_1BInstruct => m("meta-llama/Llama-3.2-1B-Instruct", Llama, true), + Self::Llama32_3B => m("meta-llama/Llama-3.2-3B", Llama, false), + Self::Llama32_3BInstruct => m("meta-llama/Llama-3.2-3B-Instruct", Llama, true), } } + pub fn to_model_id(&self) -> String { + self.meta().id.to_string() + } + pub fn is_instruct_model(&self) -> bool { - match self { - Self::Base2B - | Self::Base7B - | Self::CodeBase2B - | Self::CodeBase7B - | Self::BaseV2_2B - | Self::BaseV2_9B - | Self::BaseV3_1B => false, - _ => true, - } + self.meta().instruct } pub fn is_v3_model(&self) -> bool { - matches!(self, Self::BaseV3_1B | Self::InstructV3_1B) + matches!(self.meta().family, Family::GemmaV3) } pub fn is_llama_model(&self) -> bool { - matches!(self, Self::LlamaInstruct3_2_1B | Self::LlamaInstruct3_2_3B) + matches!(self.meta().family, Family::Llama) } } diff --git a/crates/inference-engine/src/server.rs b/crates/inference-engine/src/server.rs index c2c6d2b..79e87af 100644 --- a/crates/inference-engine/src/server.rs +++ b/crates/inference-engine/src/server.rs @@ -42,13 +42,18 @@ pub struct AppState { impl Default for AppState { fn default() -> Self { + // Configure a default model to prevent 503 errors from the chat-ui + // This can be overridden by environment variables if needed + let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string()); + let gemma_config = GemmaInferenceConfig { model: gemma_runner::WhichModel::InstructV3_1B, ..Default::default() }; + Self { model_type: ModelType::Gemma, - model_id: "gemma-3-1b-it".to_string(), + model_id: default_model_id, gemma_config: Some(gemma_config), llama_config: None, } @@ -59,6 +64,34 @@ impl Default for AppState { // Helper functions // ------------------------- +fn model_id_to_which(model_id: &str) -> Option { + let normalized = normalize_model_id(model_id); + match normalized.as_str() { + "gemma-2b" => Some(Which::Base2B), + "gemma-7b" => Some(Which::Base7B), + "gemma-2b-it" => Some(Which::Instruct2B), + "gemma-7b-it" => Some(Which::Instruct7B), + "gemma-1.1-2b-it" => Some(Which::InstructV1_1_2B), + "gemma-1.1-7b-it" => Some(Which::InstructV1_1_7B), + "codegemma-2b" => Some(Which::CodeBase2B), + "codegemma-7b" => Some(Which::CodeBase7B), + "codegemma-2b-it" => Some(Which::CodeInstruct2B), + "codegemma-7b-it" => Some(Which::CodeInstruct7B), + "gemma-2-2b" => Some(Which::BaseV2_2B), + "gemma-2-2b-it" => Some(Which::InstructV2_2B), + "gemma-2-9b" => Some(Which::BaseV2_9B), + "gemma-2-9b-it" => Some(Which::InstructV2_9B), + "gemma-3-1b" => Some(Which::BaseV3_1B), + "gemma-3-1b-it" => Some(Which::InstructV3_1B), + "llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct), + "llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct), + _ => None, + } +} + + + + fn normalize_model_id(model_id: &str) -> String { model_id.to_lowercase().replace("_", "-") } @@ -116,90 +149,76 @@ pub async fn chat_completions_non_streaming_proxy( state: AppState, request: ChatCompletionRequest, ) -> Result)> { - // Enforce model selection behavior: reject if a different model is requested - let configured_model = state.model_id.clone(); - let requested_model = request.model.clone(); - if requested_model.to_lowercase() != "default" { - let normalized_requested = normalize_model_id(&requested_model); - let normalized_configured = normalize_model_id(&configured_model); - if normalized_requested != normalized_configured { + // Use the model specified in the request + let model_id = request.model.clone(); + let which_model = model_id_to_which(&model_id); + + // Validate that the requested model is supported + let which_model = match which_model { + Some(model) => model, + None => { return Err(( StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": { - "message": format!( - "Requested model '{}' is not available. This server is running '{}' only.", - requested_model, configured_model - ), - "type": "model_mismatch" + "message": format!("Unsupported model: {}", model_id), + "type": "model_not_supported" } })), )); } - } - - let model_id = state.model_id.clone(); + }; let max_tokens = request.max_tokens.unwrap_or(1000); // Build prompt based on model type - let prompt = match state.model_type { - ModelType::Gemma => build_gemma_prompt(&request.messages), - ModelType::Llama => { - // For Llama, just use the last user message for now - request - .messages - .last() - .and_then(|m| m.content.as_ref()) - .and_then(|c| match c { - MessageContent(Either::Left(text)) => Some(text.clone()), - _ => None, - }) - .unwrap_or_default() - } + let prompt = if which_model.is_llama_model() { + // For Llama, just use the last user message for now + request + .messages + .last() + .and_then(|m| m.content.as_ref()) + .and_then(|c| match c { + MessageContent(Either::Left(text)) => Some(text.clone()), + _ => None, + }) + .unwrap_or_default() + } else { + build_gemma_prompt(&request.messages) }; // Get streaming receiver based on model type - let rx = - match state.model_type { - ModelType::Gemma => { - if let Some(mut config) = state.gemma_config { - config.prompt = prompt.clone(); - config.max_tokens = max_tokens; - run_gemma_api(config).map_err(|e| ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": { "message": format!("Error initializing Gemma model: {}", e) } - })) - ))? - } else { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": { "message": "Gemma configuration not available" } - })), - )); - } - } - ModelType::Llama => { - if let Some(mut config) = state.llama_config { - config.prompt = prompt.clone(); - config.max_tokens = max_tokens; - run_llama_inference(config).map_err(|e| ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": { "message": format!("Error initializing Llama model: {}", e) } - })) - ))? - } else { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": { "message": "Llama configuration not available" } - })), - )); - } - } + let rx = if which_model.is_llama_model() { + // Create Llama configuration dynamically + let mut config = LlamaInferenceConfig::default(); + config.prompt = prompt.clone(); + config.max_tokens = max_tokens; + run_llama_inference(config).map_err(|e| ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": { "message": format!("Error initializing Llama model: {}", e) } + })) + ))? + } else { + // Create Gemma configuration dynamically + let gemma_model = if which_model.is_v3_model() { + gemma_runner::WhichModel::InstructV3_1B + } else { + gemma_runner::WhichModel::InstructV3_1B // Default fallback }; + + let mut config = GemmaInferenceConfig { + model: gemma_model, + ..Default::default() + }; + config.prompt = prompt.clone(); + config.max_tokens = max_tokens; + run_gemma_api(config).map_err(|e| ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": { "message": format!("Error initializing Gemma model: {}", e) } + })) + ))? + }; // Collect all tokens from the stream let mut completion = String::new(); @@ -258,27 +277,25 @@ async fn handle_streaming_request( state: AppState, request: ChatCompletionRequest, ) -> Result>>, (StatusCode, Json)> { - // Validate requested model vs configured model - let configured_model = state.model_id.clone(); - let requested_model = request.model.clone(); - if requested_model.to_lowercase() != "default" { - let normalized_requested = normalize_model_id(&requested_model); - let normalized_configured = normalize_model_id(&configured_model); - if normalized_requested != normalized_configured { + // Use the model specified in the request + let model_id = request.model.clone(); + let which_model = model_id_to_which(&model_id); + + // Validate that the requested model is supported + let which_model = match which_model { + Some(model) => model, + None => { return Err(( StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": { - "message": format!( - "Requested model '{}' is not available. This server is running '{}' only.", - requested_model, configured_model - ), - "type": "model_mismatch" + "message": format!("Unsupported model: {}", model_id), + "type": "model_not_supported" } })), )); } - } + }; // Generate a unique ID and metadata let response_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', "")); @@ -286,24 +303,22 @@ async fn handle_streaming_request( .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(); - let model_id = state.model_id.clone(); let max_tokens = request.max_tokens.unwrap_or(1000); // Build prompt based on model type - let prompt = match state.model_type { - ModelType::Gemma => build_gemma_prompt(&request.messages), - ModelType::Llama => { - // For Llama, just use the last user message for now - request - .messages - .last() - .and_then(|m| m.content.as_ref()) - .and_then(|c| match c { - MessageContent(Either::Left(text)) => Some(text.clone()), - _ => None, - }) - .unwrap_or_default() - } + let prompt = if which_model.is_llama_model() { + // For Llama, just use the last user message for now + request + .messages + .last() + .and_then(|m| m.content.as_ref()) + .and_then(|c| match c { + MessageContent(Either::Left(text)) => Some(text.clone()), + _ => None, + }) + .unwrap_or_default() + } else { + build_gemma_prompt(&request.messages) }; tracing::debug!("Formatted prompt: {}", prompt); @@ -330,51 +345,43 @@ async fn handle_streaming_request( } // Get streaming receiver based on model type - let model_rx = match state.model_type { - ModelType::Gemma => { - if let Some(mut config) = state.gemma_config { - config.prompt = prompt.clone(); - config.max_tokens = max_tokens; - match run_gemma_api(config) { - Ok(rx) => rx, - Err(e) => { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": { "message": format!("Error initializing Gemma model: {}", e) } - })), - )); - } - } - } else { + let model_rx = if which_model.is_llama_model() { + // Create Llama configuration dynamically + let mut config = LlamaInferenceConfig::default(); + config.prompt = prompt.clone(); + config.max_tokens = max_tokens; + match run_llama_inference(config) { + Ok(rx) => rx, + Err(e) => { return Err(( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ - "error": { "message": "Gemma configuration not available" } + "error": { "message": format!("Error initializing Llama model: {}", e) } })), )); } } - ModelType::Llama => { - if let Some(mut config) = state.llama_config { - config.prompt = prompt.clone(); - config.max_tokens = max_tokens; - match run_llama_inference(config) { - Ok(rx) => rx, - Err(e) => { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": { "message": format!("Error initializing Llama model: {}", e) } - })), - )); - } - } - } else { + } else { + // Create Gemma configuration dynamically + let gemma_model = if which_model.is_v3_model() { + gemma_runner::WhichModel::InstructV3_1B + } else { + gemma_runner::WhichModel::InstructV3_1B // Default fallback + }; + + let mut config = GemmaInferenceConfig { + model: gemma_model, + ..Default::default() + }; + config.prompt = prompt.clone(); + config.max_tokens = max_tokens; + match run_gemma_api(config) { + Ok(rx) => rx, + Err(e) => { return Err(( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ - "error": { "message": "Llama configuration not available" } + "error": { "message": format!("Error initializing Gemma model: {}", e) } })), )); } @@ -500,172 +507,69 @@ pub fn create_router(app_state: AppState) -> Router { /// Handler for GET /v1/models - returns list of available models pub async fn list_models() -> Json { // Get all available model variants from the Which enum - let models = vec![ - // Gemma models + let which_variants = vec![ + Which::Base2B, + Which::Base7B, + Which::Instruct2B, + Which::Instruct7B, + Which::InstructV1_1_2B, + Which::InstructV1_1_7B, + Which::CodeBase2B, + Which::CodeBase7B, + Which::CodeInstruct2B, + Which::CodeInstruct7B, + Which::BaseV2_2B, + Which::InstructV2_2B, + Which::BaseV2_9B, + Which::InstructV2_9B, + Which::BaseV3_1B, + Which::InstructV3_1B, + Which::Llama32_1B, + Which::Llama32_1BInstruct, + Which::Llama32_3B, + Which::Llama32_3BInstruct, + ]; + + let models: Vec = which_variants.into_iter().map(|which| { + let meta = which.meta(); + let model_id = match which { + Which::Base2B => "gemma-2b", + Which::Base7B => "gemma-7b", + Which::Instruct2B => "gemma-2b-it", + Which::Instruct7B => "gemma-7b-it", + Which::InstructV1_1_2B => "gemma-1.1-2b-it", + Which::InstructV1_1_7B => "gemma-1.1-7b-it", + Which::CodeBase2B => "codegemma-2b", + Which::CodeBase7B => "codegemma-7b", + Which::CodeInstruct2B => "codegemma-2b-it", + Which::CodeInstruct7B => "codegemma-7b-it", + Which::BaseV2_2B => "gemma-2-2b", + Which::InstructV2_2B => "gemma-2-2b-it", + Which::BaseV2_9B => "gemma-2-9b", + Which::InstructV2_9B => "gemma-2-9b-it", + Which::BaseV3_1B => "gemma-3-1b", + Which::InstructV3_1B => "gemma-3-1b-it", + Which::Llama32_1B => "llama-3.2-1b", + Which::Llama32_1BInstruct => "llama-3.2-1b-instruct", + Which::Llama32_3B => "llama-3.2-3b", + Which::Llama32_3BInstruct => "llama-3.2-3b-instruct", + }; + + let owned_by = if meta.id.starts_with("google/") { + "google" + } else if meta.id.starts_with("meta-llama/") { + "meta" + } else { + "unknown" + }; + Model { - id: "gemma-2b".to_string(), + id: model_id.to_string(), object: "model".to_string(), created: 1686935002, // Using same timestamp as OpenAI example - owned_by: "google".to_string(), - }, - Model { - id: "gemma-7b".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "google".to_string(), - }, - Model { - id: "gemma-2b-it".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "google".to_string(), - }, - Model { - id: "gemma-7b-it".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "google".to_string(), - }, - Model { - id: "gemma-1.1-2b-it".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "google".to_string(), - }, - Model { - id: "gemma-1.1-7b-it".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "google".to_string(), - }, - Model { - id: "codegemma-2b".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "google".to_string(), - }, - Model { - id: "codegemma-7b".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "google".to_string(), - }, - Model { - id: "codegemma-2b-it".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "google".to_string(), - }, - Model { - id: "codegemma-7b-it".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "google".to_string(), - }, - Model { - id: "gemma-2-2b".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "google".to_string(), - }, - Model { - id: "gemma-2-2b-it".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "google".to_string(), - }, - Model { - id: "gemma-2-9b".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "google".to_string(), - }, - Model { - id: "gemma-2-9b-it".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "google".to_string(), - }, - Model { - id: "gemma-3-1b".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "google".to_string(), - }, - Model { - id: "gemma-3-1b-it".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "google".to_string(), - }, - // Llama models - Model { - id: "llama-3.2-1b".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "meta".to_string(), - }, - Model { - id: "llama-3.2-1b-instruct".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "meta".to_string(), - }, - Model { - id: "llama-3.2-3b".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "meta".to_string(), - }, - Model { - id: "llama-3.2-3b-instruct".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "meta".to_string(), - }, - Model { - id: "smollm2-135m".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "huggingface".to_string(), - }, - Model { - id: "smollm2-135m-instruct".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "huggingface".to_string(), - }, - Model { - id: "smollm2-360m".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "huggingface".to_string(), - }, - Model { - id: "smollm2-360m-instruct".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "huggingface".to_string(), - }, - Model { - id: "smollm2-1.7b".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "huggingface".to_string(), - }, - Model { - id: "smollm2-1.7b-instruct".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "huggingface".to_string(), - }, - Model { - id: "tinyllama-1.1b-chat".to_string(), - object: "model".to_string(), - created: 1686935002, - owned_by: "tinyllama".to_string(), - }, - ]; + owned_by: owned_by.to_string(), + } + }).collect(); Json(ModelListResponse { object: "list".to_string(), diff --git a/crates/llama-runner/Cargo.toml b/crates/llama-runner/Cargo.toml index 4927a22..1dd90f7 100644 --- a/crates/llama-runner/Cargo.toml +++ b/crates/llama-runner/Cargo.toml @@ -5,8 +5,8 @@ edition = "2021" [dependencies] candle-core = { git = "https://github.com/huggingface/candle.git" } -candle-nn = { git = "https://github.com/huggingface/candle.git" } -candle-transformers = { git = "https://github.com/huggingface/candle.git" } +candle-nn = { git = "https://github.com/huggingface/candle.git" } +candle-transformers = { git = "https://github.com/huggingface/candle.git"} hf-hub = "0.3" tokenizers = "0.20" anyhow = "1.0" diff --git a/crates/llama-runner/src/llama_api.rs b/crates/llama-runner/src/llama_api.rs index 474bbe2..59024c9 100644 --- a/crates/llama-runner/src/llama_api.rs +++ b/crates/llama-runner/src/llama_api.rs @@ -82,7 +82,7 @@ impl Default for LlamaInferenceConfig { // Performance flags no_kv_cache: false, // keep cache ON for speed - use_flash_attn: true, // great speed boost if supported + use_flash_attn: false, // great speed boost if supported // Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed. dtype: Some("bf16".to_string()), diff --git a/crates/predict-otron-9000/src/standalone_mode.rs b/crates/predict-otron-9000/src/standalone_mode.rs index 91cb903..0881d94 100644 --- a/crates/predict-otron-9000/src/standalone_mode.rs +++ b/crates/predict-otron-9000/src/standalone_mode.rs @@ -6,7 +6,8 @@ pub fn create_standalone_router(server_config: ServerConfig) -> Router { // Create unified router by merging embeddings and inference routers (existing behavior) let embeddings_router = embeddings_engine::create_embeddings_router(); - // Create AppState with correct model configuration + // Create AppState - no default model, must be configured explicitly + // This removes the hardcoded gemma-3-1b-it default behavior let app_state = AppState::default(); // Get the inference router directly from the inference engine diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml new file mode 100644 index 0000000..fa474c5 --- /dev/null +++ b/crates/utils/Cargo.toml @@ -0,0 +1,88 @@ +[package] +name = "utils" + +[lib] +path = "src/lib.rs" + +[dependencies] +accelerate-src = {version = "0.3.2", optional = true } +candle-nn = {version = "0.9.1" } +candle-transformers = {version = "0.9.1" } + +candle-flash-attn = {version = "0.9.1", optional = true } +candle-onnx = {version = "0.9.1", optional = true } +candle-core="0.9.1" +csv = "1.3.0" +anyhow = "1.0.99" +cudarc = {version = "0.17.3", optional = true } +half = {version = "2.6.0", optional = true } +hf-hub = {version = "0.4.3", features = ["tokio"] } +image = {version = "0.25.6" } +intel-mkl-src = {version = "0.8.1", optional = true } +num-traits = {version = "0.2.19" } +palette = { version = "0.7.6", optional = true } +enterpolation = { version = "0.2.1", optional = true } +pyo3 = { version = "0.22.0", features = [ + "auto-initialize", + "abi3-py311", +], optional = true } +rayon = {version = "1.11.0" } +rubato = { version = "0.15.0", optional = true } +safetensors = {version = "0.6.2" } +serde = {version = "1.0.219" } +serde_json = {version = "1.0.143" } +symphonia = { version = "0.5.3", features = ["all"], optional = true } +tokenizers = {version = "0.22.0", features = ["onig"] } +cpal = { version = "0.15.2", optional = true } +pdf2image = { version = "0.1.2", optional = true } +tekken-rs = { version = "0.1.1", optional = true } + +[dev-dependencies] +anyhow = {version = "1.0.99" } +byteorder = {version = "1.5.0" } +clap = {version = "4.5.46" } +imageproc = {version = "0.25.0" } +memmap2 = {version = "0.9.8" } +rand = {version = "0.9.2" } +ab_glyph = {version = "0.2.31" } +tracing = {version = "0.1.41" } +tracing-chrome = {version = "0.7.2" } +tracing-subscriber = {version = "0.3.20" } +# Necessary to disambiguate with tokio in wasm examples which are 1.28.1 +tokio = "1.43.0" + +[build-dependencies] +anyhow = {version = "1.0.99" } +bindgen_cuda = { version = "0.1.1", optional = true } +# +[features] +default = [] +accelerate = [ + "dep:accelerate-src", + "candle-core/accelerate", + "candle-nn/accelerate", + "candle-transformers/accelerate", +] +cuda = [ + "candle-core/cuda", + "candle-nn/cuda", + "candle-transformers/cuda", + "dep:bindgen_cuda", +] +cudnn = ["candle-core/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"] +flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] +mkl = [ + "dep:intel-mkl-src", + "candle-core/mkl", + "candle-nn/mkl", + "candle-transformers/mkl", +] +nccl = ["cuda", "cudarc/nccl", "dep:half"] +onnx = ["candle-onnx"] +metal = ["candle-core/metal", "candle-nn/metal"] +microphone = ["cpal", "rubato"] +encodec = ["cpal", "symphonia", "rubato"] +mimi = ["cpal", "symphonia", "rubato"] +snac = ["cpal", "symphonia", "rubato"] +depth_anything_v2 = ["palette", "enterpolation"] +tekken = ["tekken-rs"] \ No newline at end of file diff --git a/crates/utils/src/audio.rs b/crates/utils/src/audio.rs new file mode 100644 index 0000000..0b17b94 --- /dev/null +++ b/crates/utils/src/audio.rs @@ -0,0 +1,138 @@ +use candle_core::{Result, Tensor}; + +// https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/data/audio_utils.py#L57 +pub fn normalize_loudness( + wav: &Tensor, + sample_rate: u32, + loudness_compressor: bool, +) -> Result { + let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::()?; + if energy < 2e-3 { + return Ok(wav.clone()); + } + let wav_array = wav.to_vec1::()?; + let mut meter = crate::bs1770::ChannelLoudnessMeter::new(sample_rate); + meter.push(wav_array.into_iter()); + let power = meter.as_100ms_windows(); + let loudness = match crate::bs1770::gated_mean(power) { + None => return Ok(wav.clone()), + Some(gp) => gp.loudness_lkfs() as f64, + }; + let delta_loudness = -14. - loudness; + let gain = 10f64.powf(delta_loudness / 20.); + let wav = (wav * gain)?; + if loudness_compressor { + wav.tanh() + } else { + Ok(wav) + } +} + +#[cfg(feature = "symphonia")] +pub fn pcm_decode>(path: P) -> Result<(Vec, u32)> { + use symphonia::core::audio::{AudioBufferRef, Signal}; + use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL}; + use symphonia::core::conv::FromSample; + + fn conv( + samples: &mut Vec, + data: std::borrow::Cow>, + ) where + T: symphonia::core::sample::Sample, + f32: symphonia::core::conv::FromSample, + { + samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v))) + } + + // Open the media source. + let src = std::fs::File::open(path).map_err(candle::Error::wrap)?; + + // Create the media source stream. + let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default()); + + // Create a probe hint using the file's extension. [Optional] + let hint = symphonia::core::probe::Hint::new(); + + // Use the default options for metadata and format readers. + let meta_opts: symphonia::core::meta::MetadataOptions = Default::default(); + let fmt_opts: symphonia::core::formats::FormatOptions = Default::default(); + + // Probe the media source. + let probed = symphonia::default::get_probe() + .format(&hint, mss, &fmt_opts, &meta_opts) + .map_err(candle::Error::wrap)?; + // Get the instantiated format reader. + let mut format = probed.format; + + // Find the first audio track with a known (decodeable) codec. + let track = format + .tracks() + .iter() + .find(|t| t.codec_params.codec != CODEC_TYPE_NULL) + .ok_or_else(|| candle::Error::Msg("no supported audio tracks".to_string()))?; + + // Use the default options for the decoder. + let dec_opts: DecoderOptions = Default::default(); + + // Create a decoder for the track. + let mut decoder = symphonia::default::get_codecs() + .make(&track.codec_params, &dec_opts) + .map_err(|_| candle::Error::Msg("unsupported codec".to_string()))?; + let track_id = track.id; + let sample_rate = track.codec_params.sample_rate.unwrap_or(0); + let mut pcm_data = Vec::new(); + // The decode loop. + while let Ok(packet) = format.next_packet() { + // Consume any new metadata that has been read since the last packet. + while !format.metadata().is_latest() { + format.metadata().pop(); + } + + // If the packet does not belong to the selected track, skip over it. + if packet.track_id() != track_id { + continue; + } + match decoder.decode(&packet).map_err(candle::Error::wrap)? { + AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)), + AudioBufferRef::U8(data) => conv(&mut pcm_data, data), + AudioBufferRef::U16(data) => conv(&mut pcm_data, data), + AudioBufferRef::U24(data) => conv(&mut pcm_data, data), + AudioBufferRef::U32(data) => conv(&mut pcm_data, data), + AudioBufferRef::S8(data) => conv(&mut pcm_data, data), + AudioBufferRef::S16(data) => conv(&mut pcm_data, data), + AudioBufferRef::S24(data) => conv(&mut pcm_data, data), + AudioBufferRef::S32(data) => conv(&mut pcm_data, data), + AudioBufferRef::F64(data) => conv(&mut pcm_data, data), + } + } + Ok((pcm_data, sample_rate)) +} + +#[cfg(feature = "rubato")] +pub fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result> { + use rubato::Resampler; + + let mut pcm_out = + Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024); + + let mut resampler = rubato::FftFixedInOut::::new(sr_in as usize, sr_out as usize, 1024, 1) + .map_err(candle::Error::wrap)?; + let mut output_buffer = resampler.output_buffer_allocate(true); + let mut pos_in = 0; + while pos_in + resampler.input_frames_next() < pcm_in.len() { + let (in_len, out_len) = resampler + .process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None) + .map_err(candle::Error::wrap)?; + pos_in += in_len; + pcm_out.extend_from_slice(&output_buffer[0][..out_len]); + } + + if pos_in < pcm_in.len() { + let (_in_len, out_len) = resampler + .process_partial_into_buffer(Some(&[&pcm_in[pos_in..]]), &mut output_buffer, None) + .map_err(candle::Error::wrap)?; + pcm_out.extend_from_slice(&output_buffer[0][..out_len]); + } + + Ok(pcm_out) +} diff --git a/crates/utils/src/bs1770.rs b/crates/utils/src/bs1770.rs new file mode 100644 index 0000000..fbda6df --- /dev/null +++ b/crates/utils/src/bs1770.rs @@ -0,0 +1,506 @@ +// Copied from https://github.com/ruuda/bs1770/blob/master/src/lib.rs +// BS1770 -- Loudness analysis library conforming to ITU-R BS.1770 +// Copyright 2020 Ruud van Asseldonk + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// A copy of the License has been included in the root of the repository. + +//! Loudness analysis conforming to [ITU-R BS.1770-4][bs17704]. +//! +//! This library offers the building blocks to perform BS.1770 loudness +//! measurements, but you need to put the pieces together yourself. +//! +//! [bs17704]: https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en +//! +//! # Stereo integrated loudness example +//! +//! ```ignore +//! # fn load_stereo_audio() -> [Vec; 2] { +//! # [vec![0; 48_000], vec![0; 48_000]] +//! # } +//! # +//! let sample_rate_hz = 44_100; +//! let bits_per_sample = 16; +//! let channel_samples: [Vec; 2] = load_stereo_audio(); +//! +//! // When converting integer samples to float, note that the maximum amplitude +//! // is `1 << (bits_per_sample - 1)`, one bit is the sign bit. +//! let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32; +//! +//! let channel_power: Vec<_> = channel_samples.iter().map(|samples| { +//! let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate_hz); +//! meter.push(samples.iter().map(|&s| s as f32 * normalizer)); +//! meter.into_100ms_windows() +//! }).collect(); +//! +//! let stereo_power = bs1770::reduce_stereo( +//! channel_power[0].as_ref(), +//! channel_power[1].as_ref(), +//! ); +//! +//! let gated_power = bs1770::gated_mean( +//! stereo_power.as_ref() +//! ).unwrap_or(bs1770::Power(0.0)); +//! println!("Integrated loudness: {:.1} LUFS", gated_power.loudness_lkfs()); +//! ``` + +use std::f32; + +/// Coefficients for a 2nd-degree infinite impulse response filter. +/// +/// Coefficient a0 is implicitly 1.0. +#[derive(Clone)] +struct Filter { + a1: f32, + a2: f32, + b0: f32, + b1: f32, + b2: f32, + + // The past two input and output samples. + x1: f32, + x2: f32, + y1: f32, + y2: f32, +} + +impl Filter { + /// Stage 1 of th BS.1770-4 pre-filter. + pub fn high_shelf(sample_rate_hz: f32) -> Filter { + // Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/ + // 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136. + let gain_db = 3.999_843_8; + let q = 0.707_175_25; + let center_hz = 1_681.974_5; + + // Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/ + // 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L134-L143. + let k = (f32::consts::PI * center_hz / sample_rate_hz).tan(); + let vh = 10.0_f32.powf(gain_db / 20.0); + let vb = vh.powf(0.499_666_78); + let a0 = 1.0 + k / q + k * k; + Filter { + b0: (vh + vb * k / q + k * k) / a0, + b1: 2.0 * (k * k - vh) / a0, + b2: (vh - vb * k / q + k * k) / a0, + a1: 2.0 * (k * k - 1.0) / a0, + a2: (1.0 - k / q + k * k) / a0, + + x1: 0.0, + x2: 0.0, + y1: 0.0, + y2: 0.0, + } + } + + /// Stage 2 of th BS.1770-4 pre-filter. + pub fn high_pass(sample_rate_hz: f32) -> Filter { + // Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/ + // 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136. + let q = 0.500_327_05; + let center_hz = 38.135_47; + + // Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/ + // 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L145-L151 + let k = (f32::consts::PI * center_hz / sample_rate_hz).tan(); + Filter { + a1: 2.0 * (k * k - 1.0) / (1.0 + k / q + k * k), + a2: (1.0 - k / q + k * k) / (1.0 + k / q + k * k), + b0: 1.0, + b1: -2.0, + b2: 1.0, + + x1: 0.0, + x2: 0.0, + y1: 0.0, + y2: 0.0, + } + } + + /// Feed the next input sample, get the next output sample. + #[inline(always)] + pub fn apply(&mut self, x0: f32) -> f32 { + let y0 = 0.0 + self.b0 * x0 + self.b1 * self.x1 + self.b2 * self.x2 + - self.a1 * self.y1 + - self.a2 * self.y2; + + self.x2 = self.x1; + self.x1 = x0; + self.y2 = self.y1; + self.y1 = y0; + + y0 + } +} + +/// Compensated sum, for summing many values of different orders of magnitude +/// accurately. +#[derive(Copy, Clone, PartialEq)] +struct Sum { + sum: f32, + residue: f32, +} + +impl Sum { + #[inline(always)] + fn zero() -> Sum { + Sum { + sum: 0.0, + residue: 0.0, + } + } + + #[inline(always)] + fn add(&mut self, x: f32) { + let sum = self.sum + (self.residue + x); + self.residue = (self.residue + x) - (sum - self.sum); + self.sum = sum; + } +} + +/// The mean of the squares of the K-weighted samples in a window of time. +/// +/// K-weighted power is equivalent to K-weighted loudness, the only difference +/// is one of scale: power is quadratic in sample amplitudes, whereas loudness +/// units are logarithmic. `loudness_lkfs` and `from_lkfs` convert between power, +/// and K-weighted Loudness Units relative to nominal Full Scale (LKFS). +/// +/// The term “LKFS” (Loudness Units, K-Weighted, relative to nominal Full Scale) +/// is used in BS.1770-4 to emphasize K-weighting, but the term is otherwise +/// interchangeable with the more widespread term “LUFS” (Loudness Units, +/// relative to Full Scale). Loudness units are related to decibels in the +/// following sense: boosting a signal that has a loudness of +/// -LK LUFS by LK dB (by +/// multiplying the amplitude by 10LK/20) will +/// bring the loudness to 0 LUFS. +/// +/// K-weighting refers to a high-shelf and high-pass filter that model the +/// effect that humans perceive a certain amount of power in low frequencies to +/// be less loud than the same amount of power in higher frequencies. In this +/// library the `Power` type is used exclusively to refer to power after applying K-weighting. +/// +/// The nominal “full scale” is the range [-1.0, 1.0]. Because the power is the +/// mean square of the samples, if no input samples exceeded the full scale, the +/// power will be in the range [0.0, 1.0]. However, the power delivered by +/// multiple channels, which is a weighted sum over individual channel powers, +/// can exceed this range, because the weighted sum is not normalized. +#[derive(Copy, Clone, PartialEq, PartialOrd)] +pub struct Power(pub f32); + +impl Power { + /// Convert Loudness Units relative to Full Scale into a squared sample amplitude. + /// + /// This is the inverse of `loudness_lkfs`. + pub fn from_lkfs(lkfs: f32) -> Power { + // The inverse of the formula below. + Power(10.0_f32.powf((lkfs + 0.691) * 0.1)) + } + + /// Return the loudness of this window in Loudness Units, K-weighted, relative to Full Scale. + /// + /// This is the inverse of `from_lkfs`. + pub fn loudness_lkfs(&self) -> f32 { + // Equation 2 (p.5) of BS.1770-4. + -0.691 + 10.0 * self.0.log10() + } +} + +/// A `T` value for non-overlapping windows of audio, 100ms in length. +/// +/// The `ChannelLoudnessMeter` applies K-weighting and then produces the power +/// for non-overlapping windows of 100ms duration. +/// +/// These non-overlapping 100ms windows can later be combined into overlapping +/// windows of 400ms, spaced 100ms apart, to compute instantaneous loudness or +/// to perform a gated measurement, or they can be combined into even larger +/// windows for a momentary loudness measurement. +#[derive(Copy, Clone, Debug)] +pub struct Windows100ms { + pub inner: T, +} + +impl Windows100ms { + /// Wrap a new empty vector. + pub fn new() -> Windows100ms> { + Windows100ms { inner: Vec::new() } + } + + /// Apply `as_ref` to the inner value. + pub fn as_ref(&self) -> Windows100ms<&[Power]> + where + T: AsRef<[Power]>, + { + Windows100ms { + inner: self.inner.as_ref(), + } + } + + /// Apply `as_mut` to the inner value. + pub fn as_mut(&mut self) -> Windows100ms<&mut [Power]> + where + T: AsMut<[Power]>, + { + Windows100ms { + inner: self.inner.as_mut(), + } + } + + #[allow(clippy::len_without_is_empty)] + /// Apply `len` to the inner value. + pub fn len(&self) -> usize + where + T: AsRef<[Power]>, + { + self.inner.as_ref().len() + } +} + +/// Measures K-weighted power of non-overlapping 100ms windows of a single channel of audio. +/// +/// # Output +/// +/// The output of the meter is an intermediate result in the form of power for +/// 100ms non-overlapping windows. The windows need to be processed further to +/// get one of the instantaneous, momentary, and integrated loudness +/// measurements defined in BS.1770. +/// +/// The windows can also be inspected directly; the data is meaningful +/// on its own (the K-weighted power delivered in that window of time), but it +/// is not something that BS.1770 defines a term for. +/// +/// # Multichannel audio +/// +/// To perform a loudness measurement of multichannel audio, construct a +/// `ChannelLoudnessMeter` per channel, and later combine the measured power +/// with e.g. `reduce_stereo`. +/// +/// # Instantaneous loudness +/// +/// The instantaneous loudness is the power over a 400ms window, so you can +/// average four 100ms windows. No special functionality is implemented to help +/// with that at this time. ([Pull requests would be accepted.][contribute]) +/// +/// # Momentary loudness +/// +/// The momentary loudness is the power over a 3-second window, so you can +/// average thirty 100ms windows. No special functionality is implemented to +/// help with that at this time. ([Pull requests would be accepted.][contribute]) +/// +/// # Integrated loudness +/// +/// Use `gated_mean` to perform an integrated loudness measurement: +/// +/// ```ignore +/// # use std::iter; +/// # use bs1770::{ChannelLoudnessMeter, gated_mean}; +/// # let sample_rate_hz = 44_100; +/// # let samples_per_100ms = sample_rate_hz / 10; +/// # let mut meter = ChannelLoudnessMeter::new(sample_rate_hz); +/// # meter.push((0..44_100).map(|i| (i as f32 * 0.01).sin())); +/// let integrated_loudness_lkfs = gated_mean(meter.as_100ms_windows()) +/// .unwrap_or(bs1770::Power(0.0)) +/// .loudness_lkfs(); +/// ``` +/// +/// [contribute]: https://github.com/ruuda/bs1770/blob/master/CONTRIBUTING.md +#[derive(Clone)] +pub struct ChannelLoudnessMeter { + /// The number of samples that fit in 100ms of audio. + samples_per_100ms: u32, + + /// Stage 1 filter (head effects, high shelf). + filter_stage1: Filter, + + /// Stage 2 filter (high-pass). + filter_stage2: Filter, + + /// Sum of the squares over non-overlapping windows of 100ms. + windows: Windows100ms>, + + /// The number of samples in the current unfinished window. + count: u32, + + /// The sum of the squares of the samples in the current unfinished window. + square_sum: Sum, +} + +impl ChannelLoudnessMeter { + /// Construct a new loudness meter for the given sample rate. + pub fn new(sample_rate_hz: u32) -> ChannelLoudnessMeter { + ChannelLoudnessMeter { + samples_per_100ms: sample_rate_hz / 10, + filter_stage1: Filter::high_shelf(sample_rate_hz as f32), + filter_stage2: Filter::high_pass(sample_rate_hz as f32), + windows: Windows100ms::new(), + count: 0, + square_sum: Sum::zero(), + } + } + + /// Feed input samples for loudness analysis. + /// + /// # Full scale + /// + /// Full scale for the input samples is the interval [-1.0, 1.0]. If your + /// input consists of signed integer samples, you can convert as follows: + /// + /// ```ignore + /// # let mut meter = bs1770::ChannelLoudnessMeter::new(44_100); + /// # let bits_per_sample = 16_usize; + /// # let samples = &[0_i16]; + /// // Note that the maximum amplitude is `1 << (bits_per_sample - 1)`, + /// // one bit is the sign bit. + /// let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32; + /// meter.push(samples.iter().map(|&s| s as f32 * normalizer)); + /// ``` + /// + /// # Repeated calls + /// + /// You can call `push` multiple times to feed multiple batches of samples. + /// This is equivalent to feeding a single chained iterator. The leftover of + /// samples that did not fill a full 100ms window is not discarded: + /// + /// ```ignore + /// # use std::iter; + /// # use bs1770::ChannelLoudnessMeter; + /// let sample_rate_hz = 44_100; + /// let samples_per_100ms = sample_rate_hz / 10; + /// let mut meter = ChannelLoudnessMeter::new(sample_rate_hz); + /// + /// meter.push(iter::repeat(0.0).take(samples_per_100ms as usize - 1)); + /// assert_eq!(meter.as_100ms_windows().len(), 0); + /// + /// meter.push(iter::once(0.0)); + /// assert_eq!(meter.as_100ms_windows().len(), 1); + /// ``` + pub fn push>(&mut self, samples: I) { + let normalizer = 1.0 / self.samples_per_100ms as f32; + + // LLVM, if you could go ahead and inline those apply calls, and then + // unroll and vectorize the loop, that'd be terrific. + for x in samples { + let y = self.filter_stage1.apply(x); + let z = self.filter_stage2.apply(y); + + self.square_sum.add(z * z); + self.count += 1; + + // TODO: Should this branch be marked cold? + if self.count == self.samples_per_100ms { + let mean_squares = Power(self.square_sum.sum * normalizer); + self.windows.inner.push(mean_squares); + // We intentionally do not reset the residue. That way, leftover + // energy from this window is not lost, so for the file overall, + // the sum remains more accurate. + self.square_sum.sum = 0.0; + self.count = 0; + } + } + } + + /// Return a reference to the 100ms windows analyzed so far. + pub fn as_100ms_windows(&self) -> Windows100ms<&[Power]> { + self.windows.as_ref() + } + + /// Return all 100ms windows analyzed so far. + pub fn into_100ms_windows(self) -> Windows100ms> { + self.windows + } +} + +/// Combine power for multiple channels by taking a weighted sum. +/// +/// Note that BS.1770-4 defines power for a multi-channel signal as a weighted +/// sum over channels which is not normalized. This means that a stereo signal +/// is inherently louder than a mono signal. For a mono signal played back on +/// stereo speakers, you should therefore still apply `reduce_stereo`, passing +/// in the same signal for both channels. +pub fn reduce_stereo( + left: Windows100ms<&[Power]>, + right: Windows100ms<&[Power]>, +) -> Windows100ms> { + assert_eq!( + left.len(), + right.len(), + "Channels must have the same length." + ); + let mut result = Vec::with_capacity(left.len()); + for (l, r) in left.inner.iter().zip(right.inner) { + result.push(Power(l.0 + r.0)); + } + Windows100ms { inner: result } +} + +/// In-place version of `reduce_stereo` that stores the result in the former left channel. +pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows100ms<&[Power]>) { + assert_eq!( + left.len(), + right.len(), + "Channels must have the same length." + ); + for (l, r) in left.inner.iter_mut().zip(right.inner) { + l.0 += r.0; + } +} + +/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement. +/// +/// The integrated loudness measurement is not just the average power over the +/// entire signal. BS.1770-4 defines two stages of gating that exclude +/// parts of the signal, to ensure that silent parts do not contribute to the +/// loudness measurement. This function performs that gating, and returns the +/// average power over the windows that were not excluded. +/// +/// The result of this function is the integrated loudness measurement. +/// +/// When no signal remains after applying the gate, this function returns +/// `None`. In particular, this happens when all of the signal is softer than +/// -70 LKFS, including a signal that consists of pure silence. +pub fn gated_mean(windows_100ms: Windows100ms<&[Power]>) -> Option { + let mut gating_blocks = Vec::with_capacity(windows_100ms.len()); + + // Stage 1: an absolute threshold of -70 LKFS. (Equation 6, p.6.) + let absolute_threshold = Power::from_lkfs(-70.0); + + // Iterate over all 400ms windows. + for window in windows_100ms.inner.windows(4) { + // Note that the sum over channels has already been performed at this point. + let gating_block_power = Power(0.25 * window.iter().map(|mean| mean.0).sum::()); + + if gating_block_power > absolute_threshold { + gating_blocks.push(gating_block_power); + } + } + + if gating_blocks.is_empty() { + return None; + } + + // Compute the loudness after applying the absolute gate, in order to + // determine the threshold for the relative gate. + let mut sum_power = Sum::zero(); + for &gating_block_power in &gating_blocks { + sum_power.add(gating_block_power.0); + } + let absolute_gated_power = Power(sum_power.sum / (gating_blocks.len() as f32)); + + // Stage 2: Apply the relative gate. + let relative_threshold = Power::from_lkfs(absolute_gated_power.loudness_lkfs() - 10.0); + let mut sum_power = Sum::zero(); + let mut n_blocks = 0_usize; + for &gating_block_power in &gating_blocks { + if gating_block_power > relative_threshold { + sum_power.add(gating_block_power.0); + n_blocks += 1; + } + } + + if n_blocks == 0 { + return None; + } + + let relative_gated_power = Power(sum_power.sum / n_blocks as f32); + Some(relative_gated_power) +} diff --git a/crates/utils/src/coco_classes.rs b/crates/utils/src/coco_classes.rs new file mode 100644 index 0000000..0075352 --- /dev/null +++ b/crates/utils/src/coco_classes.rs @@ -0,0 +1,82 @@ +pub const NAMES: [&str; 80] = [ + "person", + "bicycle", + "car", + "motorbike", + "aeroplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "sofa", + "pottedplant", + "bed", + "diningtable", + "toilet", + "tvmonitor", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", +]; diff --git a/crates/utils/src/imagenet.rs b/crates/utils/src/imagenet.rs new file mode 100644 index 0000000..3dcb312 --- /dev/null +++ b/crates/utils/src/imagenet.rs @@ -0,0 +1,1056 @@ +use candle_transformers::models::mimi::candle; +use candle_core::{Device, Result, Tensor}; + +pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406]; +pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225]; + +/// Loads an image from disk using the image crate at the requested resolution, +/// using the given std and mean parameters. +/// This returns a tensor with shape (3, res, res). imagenet normalization is applied. +pub fn load_image_with_std_mean>( + p: P, + res: usize, + mean: &[f32; 3], + std: &[f32; 3], +) -> Result { + let img = image::ImageReader::open(p)? + .decode() + .map_err(candle::Error::wrap)? + .resize_to_fill( + res as u32, + res as u32, + image::imageops::FilterType::Triangle, + ); + let img = img.to_rgb8(); + let data = img.into_raw(); + let data = Tensor::from_vec(data, (res, res, 3), &Device::Cpu)?.permute((2, 0, 1))?; + let mean = Tensor::new(mean, &Device::Cpu)?.reshape((3, 1, 1))?; + let std = Tensor::new(std, &Device::Cpu)?.reshape((3, 1, 1))?; + (data.to_dtype(candle_core::DType::F32)? / 255.)? + .broadcast_sub(&mean)? + .broadcast_div(&std) +} + +/// Loads an image from disk using the image crate at the requested resolution. +/// This returns a tensor with shape (3, res, res). imagenet normalization is applied. +pub fn load_image>(p: P, res: usize) -> Result { + load_image_with_std_mean(p, res, &IMAGENET_MEAN, &IMAGENET_STD) +} + +/// Loads an image from disk using the image crate, this returns a tensor with shape +/// (3, 224, 224). imagenet normalization is applied. +pub fn load_image224>(p: P) -> Result { + load_image(p, 224) +} + +/// Loads an image from disk using the image crate, this returns a tensor with shape +/// (3, 518, 518). imagenet normalization is applied. +/// The model dinov2 reg4 analyzes images with dimensions 3x518x518 (resulting in 37x37 transformer tokens). +pub fn load_image518>(p: P) -> Result { + load_image(p, 518) +} + +pub const CLASS_COUNT: i64 = 1000; + +pub const CLASSES: [&str; 1000] = [ + "tench, Tinca tinca", + "goldfish, Carassius auratus", + "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", + "tiger shark, Galeocerdo cuvieri", + "hammerhead, hammerhead shark", + "electric ray, crampfish, numbfish, torpedo", + "stingray", + "cock", + "hen", + "ostrich, Struthio camelus", + "brambling, Fringilla montifringilla", + "goldfinch, Carduelis carduelis", + "house finch, linnet, Carpodacus mexicanus", + "junco, snowbird", + "indigo bunting, indigo finch, indigo bird, Passerina cyanea", + "robin, American robin, Turdus migratorius", + "bulbul", + "jay", + "magpie", + "chickadee", + "water ouzel, dipper", + "kite", + "bald eagle, American eagle, Haliaeetus leucocephalus", + "vulture", + "great grey owl, great gray owl, Strix nebulosa", + "European fire salamander, Salamandra salamandra", + "common newt, Triturus vulgaris", + "eft", + "spotted salamander, Ambystoma maculatum", + "axolotl, mud puppy, Ambystoma mexicanum", + "bullfrog, Rana catesbeiana", + "tree frog, tree-frog", + "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", + "loggerhead, loggerhead turtle, Caretta caretta", + "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", + "mud turtle", + "terrapin", + "box turtle, box tortoise", + "banded gecko", + "common iguana, iguana, Iguana iguana", + "American chameleon, anole, Anolis carolinensis", + "whiptail, whiptail lizard", + "agama", + "frilled lizard, Chlamydosaurus kingi", + "alligator lizard", + "Gila monster, Heloderma suspectum", + "green lizard, Lacerta viridis", + "African chameleon, Chamaeleo chamaeleon", + "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", + "African crocodile, Nile crocodile, Crocodylus niloticus", + "American alligator, Alligator mississipiensis", + "triceratops", + "thunder snake, worm snake, Carphophis amoenus", + "ringneck snake, ring-necked snake, ring snake", + "hognose snake, puff adder, sand viper", + "green snake, grass snake", + "king snake, kingsnake", + "garter snake, grass snake", + "water snake", + "vine snake", + "night snake, Hypsiglena torquata", + "boa constrictor, Constrictor constrictor", + "rock python, rock snake, Python sebae", + "Indian cobra, Naja naja", + "green mamba", + "sea snake", + "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", + "diamondback, diamondback rattlesnake, Crotalus adamanteus", + "sidewinder, horned rattlesnake, Crotalus cerastes", + "trilobite", + "harvestman, daddy longlegs, Phalangium opilio", + "scorpion", + "black and gold garden spider, Argiope aurantia", + "barn spider, Araneus cavaticus", + "garden spider, Aranea diademata", + "black widow, Latrodectus mactans", + "tarantula", + "wolf spider, hunting spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse, partridge, Bonasa umbellus", + "prairie chicken, prairie grouse, prairie fowl", + "peacock", + "quail", + "partridge", + "African grey, African gray, Psittacus erithacus", + "macaw", + "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "drake", + "red-breasted merganser, Mergus serrator", + "goose", + "black swan, Cygnus atratus", + "tusker", + "echidna, spiny anteater, anteater", + "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", + "wallaby, brush kangaroo", + "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", + "wombat", + "jellyfish", + "sea anemone, anemone", + "brain coral", + "flatworm, platyhelminth", + "nematode, nematode worm, roundworm", + "conch", + "snail", + "slug", + "sea slug, nudibranch", + "chiton, coat-of-mail shell, sea cradle, polyplacophore", + "chambered nautilus, pearly nautilus, nautilus", + "Dungeness crab, Cancer magister", + "rock crab, Cancer irroratus", + "fiddler crab", + "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", + "American lobster, Northern lobster, Maine lobster, Homarus americanus", + "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", + "crayfish, crawfish, crawdad, crawdaddy", + "hermit crab", + "isopod", + "white stork, Ciconia ciconia", + "black stork, Ciconia nigra", + "spoonbill", + "flamingo", + "little blue heron, Egretta caerulea", + "American egret, great white heron, Egretta albus", + "bittern", + "crane", + "limpkin, Aramus pictus", + "European gallinule, Porphyrio porphyrio", + "American coot, marsh hen, mud hen, water hen, Fulica americana", + "bustard", + "ruddy turnstone, Arenaria interpres", + "red-backed sandpiper, dunlin, Erolia alpina", + "redshank, Tringa totanus", + "dowitcher", + "oystercatcher, oyster catcher", + "pelican", + "king penguin, Aptenodytes patagonica", + "albatross, mollymawk", + "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", + "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", + "dugong, Dugong dugon", + "sea lion", + "Chihuahua", + "Japanese spaniel", + "Maltese dog, Maltese terrier, Maltese", + "Pekinese, Pekingese, Peke", + "Shih-Tzu", + "Blenheim spaniel", + "papillon", + "toy terrier", + "Rhodesian ridgeback", + "Afghan hound, Afghan", + "basset, basset hound", + "beagle", + "bloodhound, sleuthhound", + "bluetick", + "black-and-tan coonhound", + "Walker hound, Walker foxhound", + "English foxhound", + "redbone", + "borzoi, Russian wolfhound", + "Irish wolfhound", + "Italian greyhound", + "whippet", + "Ibizan hound, Ibizan Podenco", + "Norwegian elkhound, elkhound", + "otterhound, otter hound", + "Saluki, gazelle hound", + "Scottish deerhound, deerhound", + "Weimaraner", + "Staffordshire bullterrier, Staffordshire bull terrier", + "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", + "Bedlington terrier", + "Border terrier", + "Kerry blue terrier", + "Irish terrier", + "Norfolk terrier", + "Norwich terrier", + "Yorkshire terrier", + "wire-haired fox terrier", + "Lakeland terrier", + "Sealyham terrier, Sealyham", + "Airedale, Airedale terrier", + "cairn, cairn terrier", + "Australian terrier", + "Dandie Dinmont, Dandie Dinmont terrier", + "Boston bull, Boston terrier", + "miniature schnauzer", + "giant schnauzer", + "standard schnauzer", + "Scotch terrier, Scottish terrier, Scottie", + "Tibetan terrier, chrysanthemum dog", + "silky terrier, Sydney silky", + "soft-coated wheaten terrier", + "West Highland white terrier", + "Lhasa, Lhasa apso", + "flat-coated retriever", + "curly-coated retriever", + "golden retriever", + "Labrador retriever", + "Chesapeake Bay retriever", + "German short-haired pointer", + "vizsla, Hungarian pointer", + "English setter", + "Irish setter, red setter", + "Gordon setter", + "Brittany spaniel", + "clumber, clumber spaniel", + "English springer, English springer spaniel", + "Welsh springer spaniel", + "cocker spaniel, English cocker spaniel, cocker", + "Sussex spaniel", + "Irish water spaniel", + "kuvasz", + "schipperke", + "groenendael", + "malinois", + "briard", + "kelpie", + "komondor", + "Old English sheepdog, bobtail", + "Shetland sheepdog, Shetland sheep dog, Shetland", + "collie", + "Border collie", + "Bouvier des Flandres, Bouviers des Flandres", + "Rottweiler", + "German shepherd, German shepherd dog, German police dog, alsatian", + "Doberman, Doberman pinscher", + "miniature pinscher", + "Greater Swiss Mountain dog", + "Bernese mountain dog", + "Appenzeller", + "EntleBucher", + "boxer", + "bull mastiff", + "Tibetan mastiff", + "French bulldog", + "Great Dane", + "Saint Bernard, St Bernard", + "Eskimo dog, husky", + "malamute, malemute, Alaskan malamute", + "Siberian husky", + "dalmatian, coach dog, carriage dog", + "affenpinscher, monkey pinscher, monkey dog", + "basenji", + "pug, pug-dog", + "Leonberg", + "Newfoundland, Newfoundland dog", + "Great Pyrenees", + "Samoyed, Samoyede", + "Pomeranian", + "chow, chow chow", + "keeshond", + "Brabancon griffon", + "Pembroke, Pembroke Welsh corgi", + "Cardigan, Cardigan Welsh corgi", + "toy poodle", + "miniature poodle", + "standard poodle", + "Mexican hairless", + "timber wolf, grey wolf, gray wolf, Canis lupus", + "white wolf, Arctic wolf, Canis lupus tundrarum", + "red wolf, maned wolf, Canis rufus, Canis niger", + "coyote, prairie wolf, brush wolf, Canis latrans", + "dingo, warrigal, warragal, Canis dingo", + "dhole, Cuon alpinus", + "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", + "hyena, hyaena", + "red fox, Vulpes vulpes", + "kit fox, Vulpes macrotis", + "Arctic fox, white fox, Alopex lagopus", + "grey fox, gray fox, Urocyon cinereoargenteus", + "tabby, tabby cat", + "tiger cat", + "Persian cat", + "Siamese cat, Siamese", + "Egyptian cat", + "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", + "lynx, catamount", + "leopard, Panthera pardus", + "snow leopard, ounce, Panthera uncia", + "jaguar, panther, Panthera onca, Felis onca", + "lion, king of beasts, Panthera leo", + "tiger, Panthera tigris", + "cheetah, chetah, Acinonyx jubatus", + "brown bear, bruin, Ursus arctos", + "American black bear, black bear, Ursus americanus, Euarctos americanus", + "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", + "sloth bear, Melursus ursinus, Ursus ursinus", + "mongoose", + "meerkat, mierkat", + "tiger beetle", + "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", + "ground beetle, carabid beetle", + "long-horned beetle, longicorn, longicorn beetle", + "leaf beetle, chrysomelid", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant, emmet, pismire", + "grasshopper, hopper", + "cricket", + "walking stick, walkingstick, stick insect", + "cockroach, roach", + "mantis, mantid", + "cicada, cicala", + "leafhopper", + "lacewing, lacewing fly", + "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", + "damselfly", + "admiral", + "ringlet, ringlet butterfly", + "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", + "cabbage butterfly", + "sulphur butterfly, sulfur butterfly", + "lycaenid, lycaenid butterfly", + "starfish, sea star", + "sea urchin", + "sea cucumber, holothurian", + "wood rabbit, cottontail, cottontail rabbit", + "hare", + "Angora, Angora rabbit", + "hamster", + "porcupine, hedgehog", + "fox squirrel, eastern fox squirrel, Sciurus niger", + "marmot", + "beaver", + "guinea pig, Cavia cobaya", + "sorrel", + "zebra", + "hog, pig, grunter, squealer, Sus scrofa", + "wild boar, boar, Sus scrofa", + "warthog", + "hippopotamus, hippo, river horse, Hippopotamus amphibius", + "ox", + "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", + "bison", + "ram, tup", + "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", + "ibex, Capra ibex", + "hartebeest", + "impala, Aepyceros melampus", + "gazelle", + "Arabian camel, dromedary, Camelus dromedarius", + "llama", + "weasel", + "mink", + "polecat, fitch, foulmart, foumart, Mustela putorius", + "black-footed ferret, ferret, Mustela nigripes", + "otter", + "skunk, polecat, wood pussy", + "badger", + "armadillo", + "three-toed sloth, ai, Bradypus tridactylus", + "orangutan, orang, orangutang, Pongo pygmaeus", + "gorilla, Gorilla gorilla", + "chimpanzee, chimp, Pan troglodytes", + "gibbon, Hylobates lar", + "siamang, Hylobates syndactylus, Symphalangus syndactylus", + "guenon, guenon monkey", + "patas, hussar monkey, Erythrocebus patas", + "baboon", + "macaque", + "langur", + "colobus, colobus monkey", + "proboscis monkey, Nasalis larvatus", + "marmoset", + "capuchin, ringtail, Cebus capucinus", + "howler monkey, howler", + "titi, titi monkey", + "spider monkey, Ateles geoffroyi", + "squirrel monkey, Saimiri sciureus", + "Madagascar cat, ring-tailed lemur, Lemur catta", + "indri, indris, Indri indri, Indri brevicaudatus", + "Indian elephant, Elephas maximus", + "African elephant, Loxodonta africana", + "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", + "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", + "barracouta, snoek", + "eel", + "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", + "rock beauty, Holocanthus tricolor", + "anemone fish", + "sturgeon", + "gar, garfish, garpike, billfish, Lepisosteus osseus", + "lionfish", + "puffer, pufferfish, blowfish, globefish", + "abacus", + "abaya", + "academic gown, academic robe, judge's robe", + "accordion, piano accordion, squeeze box", + "acoustic guitar", + "aircraft carrier, carrier, flattop, attack aircraft carrier", + "airliner", + "airship, dirigible", + "altar", + "ambulance", + "amphibian, amphibious vehicle", + "analog clock", + "apiary, bee house", + "apron", + "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", + "assault rifle, assault gun", + "backpack, back pack, knapsack, packsack, rucksack, haversack", + "bakery, bakeshop, bakehouse", + "balance beam, beam", + "balloon", + "ballpoint, ballpoint pen, ballpen, Biro", + "Band Aid", + "banjo", + "bannister, banister, balustrade, balusters, handrail", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel, cask", + "barrow, garden cart, lawn cart, wheelbarrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "bathing cap, swimming cap", + "bath towel", + "bathtub, bathing tub, bath, tub", + "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", + "beacon, lighthouse, beacon light, pharos", + "beaker", + "bearskin, busby, shako", + "beer bottle", + "beer glass", + "bell cote, bell cot", + "bib", + "bicycle-built-for-two, tandem bicycle, tandem", + "bikini, two-piece", + "binder, ring-binder", + "binoculars, field glasses, opera glasses", + "birdhouse", + "boathouse", + "bobsled, bobsleigh, bob", + "bolo tie, bolo, bola tie, bola", + "bonnet, poke bonnet", + "bookcase", + "bookshop, bookstore, bookstall", + "bottlecap", + "bow", + "bow tie, bow-tie, bowtie", + "brass, memorial tablet, plaque", + "brassiere, bra, bandeau", + "breakwater, groin, groyne, mole, bulwark, seawall, jetty", + "breastplate, aegis, egis", + "broom", + "bucket, pail", + "buckle", + "bulletproof vest", + "bullet train, bullet", + "butcher shop, meat market", + "cab, hack, taxi, taxicab", + "caldron, cauldron", + "candle, taper, wax light", + "cannon", + "canoe", + "can opener, tin opener", + "cardigan", + "car mirror", + "carousel, carrousel, merry-go-round, roundabout, whirligig", + "carpenter's kit, tool kit", + "carton", + "car wheel", + "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello, violoncello", + "cellular telephone, cellular phone, cellphone, cell, mobile phone", + "chain", + "chainlink fence", + "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", + "chain saw, chainsaw", + "chest", + "chiffonier, commode", + "chime, bell, gong", + "china cabinet, china closet", + "Christmas stocking", + "church, church building", + "cinema, movie theater, movie theatre, movie house, picture palace", + "cleaver, meat cleaver, chopper", + "cliff dwelling", + "cloak", + "clog, geta, patten, sabot", + "cocktail shaker", + "coffee mug", + "coffeepot", + "coil, spiral, volute, whorl, helix", + "combination lock", + "computer keyboard, keypad", + "confectionery, confectionary, candy store", + "container ship, containership, container vessel", + "convertible", + "corkscrew, bottle screw", + "cornet, horn, trumpet, trump", + "cowboy boot", + "cowboy hat, ten-gallon hat", + "cradle", + "crane", + "crash helmet", + "crate", + "crib, cot", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam, dike, dyke", + "desk", + "desktop computer", + "dial telephone, dial phone", + "diaper, nappy, napkin", + "digital clock", + "digital watch", + "dining table, board", + "dishrag, dishcloth", + "dishwasher, dish washer, dishwashing machine", + "disk brake, disc brake", + "dock, dockage, docking facility", + "dogsled, dog sled, dog sleigh", + "dome", + "doormat, welcome mat", + "drilling platform, offshore rig", + "drum, membranophone, tympan", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan, blower", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso maker", + "face powder", + "feather boa, boa", + "file, file cabinet, filing cabinet", + "fireboat", + "fire engine, fire truck", + "fire screen, fireguard", + "flagpole, flagstaff", + "flute, transverse flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster", + "freight car", + "French horn, horn", + "frying pan, frypan, skillet", + "fur coat", + "garbage truck, dustcart", + "gasmask, respirator, gas helmet", + "gas pump, gasoline pump, petrol pump, island dispenser", + "goblet", + "go-kart", + "golf ball", + "golfcart, golf cart", + "gondola", + "gong, tam-tam", + "gown", + "grand piano, grand", + "greenhouse, nursery, glasshouse", + "grille, radiator grille", + "grocery store, grocery, food market, market", + "guillotine", + "hair slide", + "hair spray", + "half track", + "hammer", + "hamper", + "hand blower, blow dryer, blow drier, hair dryer, hair drier", + "hand-held computer, hand-held microcomputer", + "handkerchief, hankie, hanky, hankey", + "hard disc, hard disk, fixed disk", + "harmonica, mouth organ, harp, mouth harp", + "harp", + "harvester, reaper", + "hatchet", + "holster", + "home theater, home theatre", + "honeycomb", + "hook, claw", + "hoopskirt, crinoline", + "horizontal bar, high bar", + "horse cart, horse-cart", + "hourglass", + "iPod", + "iron, smoothing iron", + "jack-o'-lantern", + "jean, blue jean, denim", + "jeep, landrover", + "jersey, T-shirt, tee shirt", + "jigsaw puzzle", + "jinrikisha, ricksha, rickshaw", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat, laboratory coat", + "ladle", + "lampshade, lamp shade", + "laptop, laptop computer", + "lawn mower, mower", + "lens cap, lens cover", + "letter opener, paper knife, paperknife", + "library", + "lifeboat", + "lighter, light, igniter, ignitor", + "limousine, limo", + "liner, ocean liner", + "lipstick, lip rouge", + "Loafer", + "lotion", + "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", + "loupe, jeweler's loupe", + "lumbermill, sawmill", + "magnetic compass", + "mailbag, postbag", + "mailbox, letter box", + "maillot", + "maillot, tank suit", + "manhole cover", + "maraca", + "marimba, xylophone", + "mask", + "matchstick", + "maypole", + "maze, labyrinth", + "measuring cup", + "medicine chest, medicine cabinet", + "megalith, megalithic structure", + "microphone, mike", + "microwave, microwave oven", + "military uniform", + "milk can", + "minibus", + "miniskirt, mini", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home, manufactured home", + "Model T", + "modem", + "monastery", + "monitor", + "moped", + "mortar", + "mortarboard", + "mosque", + "mosquito net", + "motor scooter, scooter", + "mountain bike, all-terrain bike, off-roader", + "mountain tent", + "mouse, computer mouse", + "mousetrap", + "moving van", + "muzzle", + "nail", + "neck brace", + "necklace", + "nipple", + "notebook, notebook computer", + "obelisk", + "oboe, hautboy, hautbois", + "ocarina, sweet potato", + "odometer, hodometer, mileometer, milometer", + "oil filter", + "organ, pipe organ", + "oscilloscope, scope, cathode-ray oscilloscope, CRO", + "overskirt", + "oxcart", + "oxygen mask", + "packet", + "paddle, boat paddle", + "paddlewheel, paddle wheel", + "padlock", + "paintbrush", + "pajama, pyjama, pj's, jammies", + "palace", + "panpipe, pandean pipe, syrinx", + "paper towel", + "parachute, chute", + "parallel bars, bars", + "park bench", + "parking meter", + "passenger car, coach, carriage", + "patio, terrace", + "pay-phone, pay-station", + "pedestal, plinth, footstall", + "pencil box, pencil case", + "pencil sharpener", + "perfume, essence", + "Petri dish", + "photocopier", + "pick, plectrum, plectron", + "pickelhaube", + "picket fence, paling", + "pickup, pickup truck", + "pier", + "piggy bank, penny bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate, pirate ship", + "pitcher, ewer", + "plane, carpenter's plane, woodworking plane", + "planetarium", + "plastic bag", + "plate rack", + "plow, plough", + "plunger, plumber's helper", + "Polaroid camera, Polaroid Land camera", + "pole", + "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", + "poncho", + "pool table, billiard table, snooker table", + "pop bottle, soda bottle", + "pot, flowerpot", + "potter's wheel", + "power drill", + "prayer rug, prayer mat", + "printer", + "prison, prison house", + "projectile, missile", + "projector", + "puck, hockey puck", + "punching bag, punch bag, punching ball, punchball", + "purse", + "quill, quill pen", + "quilt, comforter, comfort, puff", + "racer, race car, racing car", + "racket, racquet", + "radiator", + "radio, wireless", + "radio telescope, radio reflector", + "rain barrel", + "recreational vehicle, RV, R.V.", + "reel", + "reflex camera", + "refrigerator, icebox", + "remote control, remote", + "restaurant, eating house, eating place, eatery", + "revolver, six-gun, six-shooter", + "rifle", + "rocking chair, rocker", + "rotisserie", + "rubber eraser, rubber, pencil eraser", + "rugby ball", + "rule, ruler", + "running shoe", + "safe", + "safety pin", + "saltshaker, salt shaker", + "sandal", + "sarong", + "sax, saxophone", + "scabbard", + "scale, weighing machine", + "school bus", + "schooner", + "scoreboard", + "screen, CRT screen", + "screw", + "screwdriver", + "seat belt, seatbelt", + "sewing machine", + "shield, buckler", + "shoe shop, shoe-shop, shoe store", + "shoji", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "ski mask", + "sleeping bag", + "slide rule, slipstick", + "sliding door", + "slot, one-armed bandit", + "snorkel", + "snowmobile", + "snowplow, snowplough", + "soap dispenser", + "soccer ball", + "sock", + "solar dish, solar collector, solar furnace", + "sombrero", + "soup bowl", + "space bar", + "space heater", + "space shuttle", + "spatula", + "speedboat", + "spider web, spider's web", + "spindle", + "sports car, sport car", + "spotlight, spot", + "stage", + "steam locomotive", + "steel arch bridge", + "steel drum", + "stethoscope", + "stole", + "stone wall", + "stopwatch, stop watch", + "stove", + "strainer", + "streetcar, tram, tramcar, trolley, trolley car", + "stretcher", + "studio couch, day bed", + "stupa, tope", + "submarine, pigboat, sub, U-boat", + "suit, suit of clothes", + "sundial", + "sunglass", + "sunglasses, dark glasses, shades", + "sunscreen, sunblock, sun blocker", + "suspension bridge", + "swab, swob, mop", + "sweatshirt", + "swimming trunks, bathing trunks", + "swing", + "switch, electric switch, electrical switch", + "syringe", + "table lamp", + "tank, army tank, armored combat vehicle, armoured combat vehicle", + "tape player", + "teapot", + "teddy, teddy bear", + "television, television system", + "tennis ball", + "thatch, thatched roof", + "theater curtain, theatre curtain", + "thimble", + "thresher, thrasher, threshing machine", + "throne", + "tile roof", + "toaster", + "tobacco shop, tobacconist shop, tobacconist", + "toilet seat", + "torch", + "totem pole", + "tow truck, tow car, wrecker", + "toyshop", + "tractor", + "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", + "tray", + "trench coat", + "tricycle, trike, velocipede", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus, trolley coach, trackless trolley", + "trombone", + "tub, vat", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle, monocycle", + "upright, upright piano", + "vacuum, vacuum cleaner", + "vase", + "vault", + "velvet", + "vending machine", + "vestment", + "viaduct", + "violin, fiddle", + "volleyball", + "waffle iron", + "wall clock", + "wallet, billfold, notecase, pocketbook", + "wardrobe, closet, press", + "warplane, military plane", + "washbasin, handbasin, washbowl, lavabo, wash-hand basin", + "washer, automatic washer, washing machine", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "wing", + "wok", + "wooden spoon", + "wool, woolen, woollen", + "worm fence, snake fence, snake-rail fence, Virginia fence", + "wreck", + "yawl", + "yurt", + "web site, website, internet site, site", + "comic book", + "crossword puzzle, crossword", + "street sign", + "traffic light, traffic signal, stoplight", + "book jacket, dust cover, dust jacket, dust wrapper", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot, hotpot", + "trifle", + "ice cream, icecream", + "ice lolly, lolly, lollipop, popsicle", + "French loaf", + "bagel, beigel", + "pretzel", + "cheeseburger", + "hotdog, hot dog, red hot", + "mashed potato", + "head cabbage", + "broccoli", + "cauliflower", + "zucchini, courgette", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber, cuke", + "artichoke, globe artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple, ananas", + "banana", + "jackfruit, jak, jack", + "custard apple", + "pomegranate", + "hay", + "carbonara", + "chocolate sauce, chocolate syrup", + "dough", + "meat loaf, meatloaf", + "pizza, pizza pie", + "potpie", + "burrito", + "red wine", + "espresso", + "cup", + "eggnog", + "alp", + "bubble", + "cliff, drop, drop-off", + "coral reef", + "geyser", + "lakeside, lakeshore", + "promontory, headland, head, foreland", + "sandbar, sand bar", + "seashore, coast, seacoast, sea-coast", + "valley, vale", + "volcano", + "ballplayer, baseball player", + "groom, bridegroom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", + "corn", + "acorn", + "hip, rose hip, rosehip", + "buckeye, horse chestnut, conker", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn, carrion fungus", + "earthstar", + "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", + "bolete", + "ear, spike, capitulum", + "toilet tissue, toilet paper, bathroom tissue", +]; diff --git a/crates/utils/src/lib.rs b/crates/utils/src/lib.rs new file mode 100644 index 0000000..3b13714 --- /dev/null +++ b/crates/utils/src/lib.rs @@ -0,0 +1,156 @@ +extern crate candle_core; +extern crate candle_transformers; +extern crate tokenizers; + +pub mod audio; +pub mod bs1770; +pub mod coco_classes; +pub mod imagenet; +pub mod token_output_stream; +pub mod wav; +use candle_core::{Device, Tensor, utils::{cuda_is_available, metal_is_available}}; + + +pub fn device(cpu: bool) -> Result { + if cpu { + Ok(Device::Cpu) + } else if cuda_is_available() { + Ok(Device::new_cuda(0)?) + } else if metal_is_available() { + Ok(Device::new_metal(0)?) + } else { + #[cfg(all(target_os = "macos", target_arch = "aarch64"))] + { + println!( + "Running on CPU, to run on GPU(metal), build this example with `--features metal`" + ); + } + #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] + { + println!("Running on CPU, to run on GPU, build this example with `--features cuda`"); + } + Ok(Device::Cpu) + } +} + +pub fn load_image>( + p: P, + resize_longest: Option, +) -> Result<(Tensor, usize, usize), anyhow::Error> { + let img = image::ImageReader::open(p)? + .decode() + .map_err(candle_core::Error::wrap)?; + let (initial_h, initial_w) = (img.height() as usize, img.width() as usize); + let img = match resize_longest { + None => img, + Some(resize_longest) => { + let (height, width) = (img.height(), img.width()); + let resize_longest = resize_longest as u32; + let (height, width) = if height < width { + let h = (resize_longest * height) / width; + (h, resize_longest) + } else { + let w = (resize_longest * width) / height; + (resize_longest, w) + }; + img.resize_exact(width, height, image::imageops::FilterType::CatmullRom) + } + }; + let (height, width) = (img.height() as usize, img.width() as usize); + let img = img.to_rgb8(); + let data = img.into_raw(); + let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?; + Ok((data, initial_h, initial_w)) +} + +pub fn load_image_and_resize>( + p: P, + width: usize, + height: usize, +) -> candle_core::Result { + let img = image::ImageReader::open(p)? + .decode() + .map_err(candle_core::Error::wrap)? + .resize_to_fill( + width as u32, + height as u32, + image::imageops::FilterType::Triangle, + ); + let img = img.to_rgb8(); + let data = img.into_raw(); + Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1)) +} + +/// Saves an image to disk using the image crate, this expects an input with shape +/// (c, height, width). +pub fn save_image>(img: &Tensor, p: P) -> Result<(), anyhow::Error> { + let p = p.as_ref(); + let (channel, height, width) = img.dims3()?; + if channel != 3 { + anyhow::bail!("save_image expects an input of shape (3, height, width)") + } + let img = img.permute((1, 2, 0))?.flatten_all()?; + let pixels = img.to_vec1::()?; + let image: image::ImageBuffer, Vec> = + match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { + Some(image) => image, + None => anyhow::bail!("error saving image {p:?}"), + }; + image.save(p).map_err(candle_core::Error::wrap)?; + Ok(()) +} + +/// Loads the safetensors files for a model from the hub based on a json index file. +pub fn hub_load_safetensors( + repo: &hf_hub::api::sync::ApiRepo, + json_file: &str, +) -> Result, anyhow::Error> { + let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?; + let json_file = std::fs::File::open(json_file)?; + let json: serde_json::Value = + serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?; + let weight_map = match json.get("weight_map") { + None => anyhow::bail!("no weight map in {json_file:?}"), + Some(serde_json::Value::Object(map)) => map, + Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"), + }; + let mut safetensors_files = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_files.insert(file.to_string()); + } + } + let safetensors_files = safetensors_files + .iter() + .map(|v| { + repo.get(v) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) + }) + .collect::, std::io::Error, >>()?; + Ok(safetensors_files) +} + +pub fn hub_load_local_safetensors>( + path: P, + json_file: &str, +) -> Result, anyhow::Error> { + let path = path.as_ref(); + let jsfile = std::fs::File::open(path.join(json_file))?; + let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?; + let weight_map = match json.get("weight_map") { + None => anyhow::bail!("no weight map in {json_file:?}"), + Some(serde_json::Value::Object(map)) => map, + Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"), + }; + let mut safetensors_files = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_files.insert(file); + } + } + let safetensors_files: Vec<_> = safetensors_files + .into_iter() + .map(|v| path.join(v)) + .collect(); + Ok(safetensors_files) +} diff --git a/crates/utils/src/main.rs b/crates/utils/src/main.rs new file mode 100644 index 0000000..e7a11a9 --- /dev/null +++ b/crates/utils/src/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello, world!"); +} diff --git a/crates/utils/src/token_output_stream.rs b/crates/utils/src/token_output_stream.rs new file mode 100644 index 0000000..856b67f --- /dev/null +++ b/crates/utils/src/token_output_stream.rs @@ -0,0 +1,85 @@ +use candle_core::Result; +use tokenizers::Tokenizer; + +pub struct TokenOutputStream { + tokenizer: tokenizers::Tokenizer, + tokens: Vec, + prev_index: usize, + current_index: usize, +} + +impl TokenOutputStream { + pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { + Self { + tokenizer, + tokens: Vec::new(), + prev_index: 0, + current_index: 0, + } + } + + pub fn into_inner(self) -> tokenizers::Tokenizer { + self.tokenizer + } + + fn decode(&self, tokens: &[u32]) -> Result { + match self.tokenizer.decode(tokens, true) { + Ok(str) => Ok(str), + Err(err) => candle_core::bail!("cannot decode: {err}"), + } + } + + // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 + pub fn next_token(&mut self, token: u32) -> Result> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + self.tokens.push(token); + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { + let text = text.split_at(prev_text.len()); + self.prev_index = self.current_index; + self.current_index = self.tokens.len(); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_rest(&self) -> Result> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() { + let text = text.split_at(prev_text.len()); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_all(&self) -> Result { + self.decode(&self.tokens) + } + + pub fn get_token(&self, token_s: &str) -> Option { + self.tokenizer.get_vocab(true).get(token_s).copied() + } + + pub fn tokenizer(&self) -> &tokenizers::Tokenizer { + &self.tokenizer + } + + pub fn clear(&mut self) { + self.tokens.clear(); + self.prev_index = 0; + self.current_index = 0; + } +} diff --git a/crates/utils/src/wav.rs b/crates/utils/src/wav.rs new file mode 100644 index 0000000..df98aa1 --- /dev/null +++ b/crates/utils/src/wav.rs @@ -0,0 +1,56 @@ +use std::io::prelude::*; + +pub trait Sample { + fn to_i16(&self) -> i16; +} + +impl Sample for f32 { + fn to_i16(&self) -> i16 { + (self.clamp(-1.0, 1.0) * 32767.0) as i16 + } +} + +impl Sample for f64 { + fn to_i16(&self) -> i16 { + (self.clamp(-1.0, 1.0) * 32767.0) as i16 + } +} + +impl Sample for i16 { + fn to_i16(&self) -> i16 { + *self + } +} + +pub fn write_pcm_as_wav( + w: &mut W, + samples: &[S], + sample_rate: u32, +) -> std::io::Result<()> { + let len = 12u32; // header + let len = len + 24u32; // fmt + let len = len + samples.len() as u32 * 2 + 8; // data + let n_channels = 1u16; + let bytes_per_second = sample_rate * 2 * n_channels as u32; + w.write_all(b"RIFF")?; + w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes + w.write_all(b"WAVE")?; + + // Format block + w.write_all(b"fmt ")?; + w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes + w.write_all(&1u16.to_le_bytes())?; // PCM + w.write_all(&n_channels.to_le_bytes())?; // one channel + w.write_all(&sample_rate.to_le_bytes())?; + w.write_all(&bytes_per_second.to_le_bytes())?; + w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample + w.write_all(&16u16.to_le_bytes())?; // bits per sample + + // Data block + w.write_all(b"data")?; + w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?; + for sample in samples.iter() { + w.write_all(&sample.to_i16().to_le_bytes())? + } + Ok(()) +}