reorg + update docs with new paths

This commit is contained in:
geoffsee
2025-09-04 12:27:13 -04:00
parent 400c70f17d
commit ff55d882c7
43 changed files with 493 additions and 182 deletions

View File

@@ -1,11 +0,0 @@
[package]
name = "cli"
version.workspace = true
edition = "2021"
build = "build.rs"
[[bin]]
name = "cli"
path = "src/main.rs"
[dependencies]

View File

@@ -1,24 +0,0 @@
# cli
A Rust/Typescript Hybrid
```console
bun run cli.ts [options] [prompt]
Simple CLI tool for testing the local OpenAI-compatible API server.
Options:
--model <model> Model to use (default: gemma-3-1b-it)
--prompt <prompt> The prompt to send (can also be provided as positional argument)
--list-models List all available models from the server
--help Show this help message
Examples:
cd crates/cli/package
bun run cli.ts "What is the capital of France?"
bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
bun run cli.ts --prompt "Who was the 16th president of the United States?"
bun run cli.ts --list-models
The server must be running at http://localhost:8080
```

View File

@@ -1,204 +0,0 @@
use std::env;
use std::fs;
use std::io::{self, BufRead, Write};
use std::path::{Path, PathBuf};
use std::process::{ChildStderr, ChildStdout, Command, Stdio};
use std::thread;
use std::time::{Duration, SystemTime};
mod bun_target;
use bun_target::BunTarget;
fn main() {
println!("cargo:rerun-if-changed=");
if let Err(e) = run_build() {
println!("cargo:warning=build.rs failed: {e}");
std::process::exit(1);
}
}
fn run_build() -> io::Result<()> {
let manifest_dir =
PathBuf::from(env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"));
let package_dir = manifest_dir.join("package");
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set by Cargo"));
let output_path = out_dir.join("client-cli");
let bun_tgt = BunTarget::from_cargo_env()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
// Optional: warn if using a Bun target thats marked unsupported in your chart
if matches!(bun_tgt, BunTarget::WindowsArm64) {
println!(
"cargo:warning=bun-windows-arm64 is marked unsupported in the compatibility chart"
);
}
warn(&format!("Building CLI into: {}", output_path.display()));
// --- bun install (in ./package), keep temps inside OUT_DIR ---
let mut install = Command::new("bun")
.current_dir(&package_dir)
.env("TMPDIR", &out_dir)
.arg("install")
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| io::Error::new(e.kind(), format!("Failed to spawn `bun install`: {e}")))?;
let install_join = stream_child("bun install", install.stdout.take(), install.stderr.take());
let install_status = install.wait()?;
// ensure streams finish
join_streams(install_join);
if !install_status.success() {
let code = install_status.code().unwrap_or(1);
return Err(io::Error::new(
io::ErrorKind::Other,
format!("bun install failed with status {code}"),
));
}
let target = env::var("TARGET").unwrap();
// --- bun build (in ./package), emit to OUT_DIR, keep temps inside OUT_DIR ---
let mut build = Command::new("bun")
.current_dir(&package_dir)
.env("TMPDIR", &out_dir)
.arg("build")
.arg("./cli.ts")
.arg(format!("--target={}", bun_tgt.as_bun_flag()))
.arg("--compile")
.arg("--outfile")
.arg(&output_path)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| io::Error::new(e.kind(), format!("Failed to spawn `bun build`: {e}")))?;
let build_join = stream_child("bun build", build.stdout.take(), build.stderr.take());
let status = build.wait()?;
// ensure streams finish
join_streams(build_join);
if status.success() {
info("bun build succeeded");
} else {
let code = status.code().unwrap_or(1);
warn(&format!("bun build failed with status: {code}"));
return Err(io::Error::new(io::ErrorKind::Other, "bun build failed"));
}
// Ensure the output is executable (after it exists)
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(&output_path)?.permissions();
perms.set_mode(0o755);
fs::set_permissions(&output_path, perms)?;
}
println!("cargo:warning=Built CLI at {}", output_path.display());
println!("cargo:rustc-env=CLIENT_CLI_BIN={}", output_path.display());
// --- Cleanup stray .bun-build temp files (conservative: older than 5 minutes) ---
for dir in [&manifest_dir, &package_dir, &out_dir] {
if let Err(e) = remove_bun_temp_files(dir, Some(Duration::from_secs(5 * 60))) {
println!("cargo:warning=cleanup in {} failed: {e}", dir.display());
}
}
Ok(())
}
// Spawn readers for child's stdout/stderr so we don't deadlock on pipe buffers
fn stream_child(
tag: &str,
stdout: Option<ChildStdout>,
stderr: Option<ChildStderr>,
) -> (
Option<thread::JoinHandle<()>>,
Option<thread::JoinHandle<()>>,
) {
let t1 = stdout.map(|out| {
let tag = tag.to_string();
thread::spawn(move || {
let reader = io::BufReader::new(out);
for line in reader.lines() {
info(&format!("[{tag} stdout] {}", line.unwrap_or_default()));
}
})
});
let t2 = stderr.map(|err| {
let tag = tag.to_string();
thread::spawn(move || {
let reader = io::BufReader::new(err);
for line in reader.lines() {
warn(&format!("[{tag} stderr] {}", line.unwrap_or_default()));
}
})
});
(t1, t2)
}
fn join_streams(
joins: (
Option<thread::JoinHandle<()>>,
Option<thread::JoinHandle<()>>,
),
) {
if let Some(j) = joins.0 {
let _ = j.join();
}
if let Some(j) = joins.1 {
let _ = j.join();
}
}
fn remove_bun_temp_files(dir: &Path, older_than: Option<Duration>) -> io::Result<()> {
let now = SystemTime::now();
for entry in fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if !path.is_file() {
continue;
}
// Files like ".1860e7df40ff1bef-00000000.bun-build"
let name = entry.file_name();
let name = name.to_string_lossy();
let looks_like_bun_temp = name.starts_with('.') && name.ends_with(".bun-build");
if !looks_like_bun_temp {
continue;
}
if let Some(age) = older_than {
if let Ok(meta) = entry.metadata() {
if let Ok(modified) = meta.modified() {
if now.duration_since(modified).unwrap_or_default() < age {
// too new; skip to avoid racing an in-flight builder
continue;
}
}
}
}
match fs::remove_file(&path) {
Ok(_) => println!("cargo:warning=removed stray bun temp {}", path.display()),
Err(e) => println!("cargo:warning=failed to remove {}: {e}", path.display()),
}
}
Ok(())
}
fn warn(msg: &str) {
let _ = writeln!(io::stderr(), "[build.rs] {msg}");
println!("cargo:warning={msg}");
}
fn info(msg: &str) {
let _ = writeln!(io::stderr(), "[build.rs] {msg}");
println!("cargo:warning=INFO|{msg}");
}

View File

@@ -1,131 +0,0 @@
use std::env;
use std::fmt;
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub enum BunTarget {
LinuxX64Glibc,
LinuxArm64Glibc,
LinuxX64Musl,
LinuxArm64Musl,
WindowsX64,
WindowsArm64,
MacX64,
MacArm64,
}
impl BunTarget {
pub const fn as_bun_flag(self) -> &'static str {
match self {
BunTarget::LinuxX64Glibc => "bun-linux-x64",
BunTarget::LinuxArm64Glibc => "bun-linux-arm64",
BunTarget::LinuxX64Musl => "bun-linux-x64-musl",
BunTarget::LinuxArm64Musl => "bun-linux-arm64-musl",
BunTarget::WindowsX64 => "bun-windows-x64",
BunTarget::WindowsArm64 => "bun-windows-arm64",
BunTarget::MacX64 => "bun-darwin-x64",
BunTarget::MacArm64 => "bun-darwin-arm64",
}
}
pub const fn rust_triples(self) -> &'static [&'static str] {
match self {
BunTarget::LinuxX64Glibc => {
&["x86_64-unknown-linux-gnu", "x86_64-unknown-linux-gnu.2.17"]
}
BunTarget::LinuxArm64Glibc => &["aarch64-unknown-linux-gnu"],
BunTarget::LinuxX64Musl => &["x86_64-unknown-linux-musl"],
BunTarget::LinuxArm64Musl => &["aarch64-unknown-linux-musl"],
BunTarget::WindowsX64 => &["x86_64-pc-windows-msvc"],
BunTarget::WindowsArm64 => &["aarch64-pc-windows-msvc"], // chart says unsupported; still map
BunTarget::MacX64 => &["x86_64-apple-darwin"],
BunTarget::MacArm64 => &["aarch64-apple-darwin"],
}
}
pub fn from_rust_target(triple: &str) -> Option<Self> {
let norm = triple.trim();
if norm.starts_with("x86_64-") && norm.contains("-linux-") && norm.ends_with("gnu") {
return Some(BunTarget::LinuxX64Glibc);
}
if norm.starts_with("aarch64-") && norm.contains("-linux-") && norm.ends_with("gnu") {
return Some(BunTarget::LinuxArm64Glibc);
}
if norm.starts_with("x86_64-") && norm.contains("-linux-") && norm.ends_with("musl") {
return Some(BunTarget::LinuxX64Musl);
}
if norm.starts_with("aarch64-") && norm.contains("-linux-") && norm.ends_with("musl") {
return Some(BunTarget::LinuxArm64Musl);
}
if norm == "x86_64-pc-windows-msvc" {
return Some(BunTarget::WindowsX64);
}
if norm == "aarch64-pc-windows-msvc" {
return Some(BunTarget::WindowsArm64);
}
if norm == "x86_64-apple-darwin" {
return Some(BunTarget::MacX64);
}
if norm == "aarch64-apple-darwin" {
return Some(BunTarget::MacArm64);
}
for bt in [
BunTarget::LinuxX64Glibc,
BunTarget::LinuxArm64Glibc,
BunTarget::LinuxX64Musl,
BunTarget::LinuxArm64Musl,
BunTarget::WindowsX64,
BunTarget::WindowsArm64,
BunTarget::MacX64,
BunTarget::MacArm64,
] {
for &t in bt.rust_triples() {
if t == norm {
return Some(bt);
}
}
}
None
}
pub fn from_cargo_env() -> Result<Self, BunTargetError> {
if let Ok(triple) = env::var("TARGET") {
if let Some(bt) = Self::from_rust_target(&triple) {
return Ok(bt);
}
return Err(BunTargetError::UnknownTriple(triple));
}
let os = env::var("CARGO_CFG_TARGET_OS").unwrap_or_default();
let arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap_or_default();
let envv = env::var("CARGO_CFG_TARGET_ENV").unwrap_or_default();
let vendor = env::var("CARGO_CFG_TARGET_VENDOR").unwrap_or_else(|_| "unknown".into());
let triple = format!(
"{}-{}-{}-{}",
arch,
vendor,
os,
if envv.is_empty() { "gnu" } else { &envv }
);
if let Some(bt) = Self::from_rust_target(&triple) {
Ok(bt)
} else {
Err(BunTargetError::UnknownTriple(triple))
}
}
}
#[derive(Debug)]
pub enum BunTargetError {
UnknownTriple(String),
}
impl fmt::Display for BunTargetError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BunTargetError::UnknownTriple(t) => write!(f, "unrecognized Rust target triple: {t}"),
}
}
}
impl std::error::Error for BunTargetError {}

View File

@@ -1,339 +0,0 @@
#!/usr/bin/env bun
import OpenAI from "openai";
import { parseArgs } from "util";
// =====================
// Config
// =====================
const DEFAULT_MODEL = "gemma-3-1b-it";
const DEFAULT_MAX_TOKENS = 256;
// Toggle this to reduce log overhead during timing runs
const PRINT_CHUNK_DEBUG = false;
// How many rows to show in the timing tables
const SHOW_FIRST_N = 3;
const SHOW_SLOWEST_N = 3;
// =====================
// Helpers
// =====================
const now = () => performance.now();
type ChunkStat = {
index: number;
tSinceRequestStartMs: number;
dtSincePrevMs: number;
contentChars: number;
};
function printHelp() {
console.log(`
./cli [options] [prompt]
Simple CLI tool for testing the local OpenAI-compatible API server.
Options:
--model <model> Model to use (default: gemma-3-1b-it)
--prompt <prompt> The prompt to send (can also be provided as positional argument)
--list-models List all available models from the server
--help Show this help message
Examples:
./cli "What is the capital of France?"
./cli --model gemma-3-1b-it --prompt "Hello, world!"
./cli --prompt "Who was the 16th president of the United States?"
./cli --list-models
The server must be running at http://localhost:8080
`);
}
const { values, positionals } = parseArgs({
args: process.argv.slice(2),
options: {
model: { type: "string" },
prompt: { type: "string" },
help: { type: "boolean" },
"list-models": { type: "boolean" },
},
strict: false,
allowPositionals: true,
});
async function requestLocalOpenAI(model: string, userPrompt: string) {
const openai = new OpenAI({
baseURL: "http://localhost:8080/v1",
apiKey: "not used",
});
try {
console.log("[DEBUG] Creating chat completion request...");
return openai.chat.completions.create({
model,
max_tokens: DEFAULT_MAX_TOKENS,
stream: true,
messages: [
{
role: "system",
content: "You are a helpful assistant who responds thoughtfully and concisely.",
},
{ role: "user", content: userPrompt },
],
});
} catch (e: any) {
console.error("[ERROR] Failed to connect to local OpenAI server:", e.message);
console.error("[HINT] Make sure the server is running at http://localhost:8080");
console.error("[HINT] Start it with: ./run_server.sh");
throw e;
}
}
async function listModels() {
const openai = new OpenAI({
baseURL: "http://localhost:8080/v1",
apiKey: "not used",
});
try {
const models = await openai.models.list();
console.log(`[INFO] Available models from http://localhost:8080/v1:`);
console.log("---");
if (models.data && models.data.length > 0) {
models.data.forEach((model, index) => {
console.log(`${index + 1}. ${model.id}`);
console.log(` Owner: ${model.owned_by}`);
console.log(` Created: ${new Date(model.created * 1000).toISOString()}`);
console.log("");
});
console.log(`Total: ${models.data.length} models available`);
} else {
console.log("No models found.");
}
} catch (e: any) {
console.error("[ERROR] Failed to fetch models from local OpenAI server:", e.message);
console.error("[HINT] Make sure the server is running at http://localhost:8080");
console.error("[HINT] Start it with: ./run_server.sh");
throw e;
}
}
// =====================
// Timing math
// =====================
function median(nums: number[]) {
if (nums.length === 0) return 0;
const s = [...nums].sort((a, b) => a - b);
const mid = Math.floor(s.length / 2);
return s.length % 2 ? s[mid] : (s[mid - 1] + s[mid]) / 2;
}
function quantile(nums: number[], q: number) {
if (nums.length === 0) return 0;
const s = [...nums].sort((a, b) => a - b);
const pos = (s.length - 1) * q;
const base = Math.floor(pos);
const rest = pos - base;
return s[base + 1] !== undefined ? s[base] + rest * (s[base + 1] - s[base]) : s[base];
}
function ms(n: number) {
return `${n.toFixed(1)} ms`;
}
// =====================
// Main
// =====================
async function main() {
const tProgramStart = now();
if (values.help) {
printHelp();
process.exit(0);
}
if (values["list-models"]) {
try {
await listModels();
process.exit(0);
} catch (error: any) {
console.error("\n[ERROR] Failed to list models:", error.message);
process.exit(1);
}
}
const prompt = values.prompt ?? positionals[0];
if (!prompt) {
console.error("[ERROR] No prompt provided!");
printHelp();
process.exit(1);
}
const model = values.model || DEFAULT_MODEL;
console.log(`[INFO] Using model: ${model}`);
console.log(`[INFO] Prompt: ${prompt}`);
console.log(`[INFO] Connecting to: http://localhost:8080/v1`);
console.log("---");
const tBeforeRequest = now();
try {
console.log("[DEBUG] Initiating request to OpenAI server...");
const response = await requestLocalOpenAI(model, prompt);
const tAfterCreate = now();
// Streaming handling + timing
let fullResponse = "";
let chunkCount = 0;
const chunkStats: ChunkStat[] = [];
let tFirstChunk: number | null = null;
let tPrevChunk: number | null = null;
console.log("[INFO] Waiting for model to generate response...");
let loadingInterval;
if (!PRINT_CHUNK_DEBUG) {
// Show loading animation only if not in debug mode
const loadingChars = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'];
let i = 0;
process.stdout.write('\r[INFO] Thinking ');
loadingInterval = setInterval(() => {
process.stdout.write(`\r[INFO] Thinking ${loadingChars[i++ % loadingChars.length]} `);
}, 80);
} else {
console.log("[DEBUG] Starting to receive streaming response...");
}
for await (const chunk of response) {
// Clear loading animation on first chunk
if (loadingInterval) {
clearInterval(loadingInterval);
process.stdout.write('\r \r');
}
const tNow = now();
chunkCount++;
// Extract content (delta) if present
const content = chunk.choices?.[0]?.delta?.content ?? "";
if (PRINT_CHUNK_DEBUG) {
console.log(`[DEBUG] Received chunk #${chunkCount}:`, JSON.stringify(chunk));
if (content) console.log(`[DEBUG] Chunk content: "${content}"`);
}
if (content) {
process.stdout.write(content);
fullResponse += content;
}
if (tFirstChunk === null) tFirstChunk = tNow;
const dtSincePrev = tPrevChunk === null ? 0 : tNow - tPrevChunk;
chunkStats.push({
index: chunkCount,
tSinceRequestStartMs: tNow - tBeforeRequest,
dtSincePrevMs: dtSincePrev,
contentChars: content.length,
});
tPrevChunk = tNow;
}
// =========
// Summary
// =========
const tStreamEnd = now();
const totalChars = fullResponse.length;
console.log("\n---");
console.log(`[DEBUG] Stream completed after ${chunkCount} chunks`);
console.log(`[INFO] Response completed. Total length: ${totalChars} characters`);
// Build timing metrics
const ttfbMs = (tFirstChunk ?? tStreamEnd) - tAfterCreate; // time from create() resolved → first chunk
const createOverheadMs = tAfterCreate - tBeforeRequest; // time spent awaiting create() promise
const totalSinceRequestMs = tStreamEnd - tBeforeRequest; // from just before create() to last chunk
const streamDurationMs =
tFirstChunk === null ? 0 : tStreamEnd - tFirstChunk;
const gaps = chunkStats
.map((c) => c.dtSincePrevMs)
// ignore the first "gap" which is 0 by construction
.slice(1);
const avgGapMs = gaps.length ? gaps.reduce((a, b) => a + b, 0) / gaps.length : 0;
const medGapMs = median(gaps);
const p95GapMs = quantile(gaps, 0.95);
let maxGapMs = 0;
let maxGapAtChunk = 0;
for (let i = 0; i < gaps.length; i++) {
if (gaps[i] > maxGapMs) {
maxGapMs = gaps[i];
maxGapAtChunk = i + 2; // +1 to move from 0-based, +1 because we sliced starting at second chunk
}
}
// Pretty print summary
console.log("\n=== Timing Summary ===");
console.log(`create() await time: ${ms(createOverheadMs)}`);
console.log(`TTFB (to 1st chunk): ${ms(ttfbMs)}`);
console.log(`Stream duration: ${ms(streamDurationMs)}`);
console.log(`End-to-end (req→last): ${ms(totalSinceRequestMs)}`);
console.log(`Chunks: ${chunkCount}`);
console.log(`Total content chars: ${totalChars}`);
console.log(`Avg chars/chunk: ${(chunkCount ? totalChars / chunkCount : 0).toFixed(1)}`);
console.log(`Inter-chunk gap (avg): ${ms(avgGapMs)}`);
console.log(`Inter-chunk gap (median): ${ms(medGapMs)}`);
console.log(`Inter-chunk gap (p95): ${ms(p95GapMs)}`);
if (gaps.length > 0) {
console.log(`Largest gap: ${ms(maxGapMs)} (before chunk #${maxGapAtChunk})`);
}
// Small tables: first N and slowest N gaps
const firstRows = chunkStats.slice(0, SHOW_FIRST_N).map((c) => ({
chunk: c.index,
"t since request": `${c.tSinceRequestStartMs.toFixed(1)} ms`,
"dt since prev": `${c.dtSincePrevMs.toFixed(1)} ms`,
"chars": c.contentChars,
}));
const slowestRows = chunkStats
.slice(1) // skip first (no meaningful gap)
.sort((a, b) => b.dtSincePrevMs - a.dtSincePrevMs)
.slice(0, SHOW_SLOWEST_N)
.map((c) => ({
chunk: c.index,
"t since request": `${c.tSinceRequestStartMs.toFixed(1)} ms`,
"dt since prev": `${c.dtSincePrevMs.toFixed(1)} ms`,
"chars": c.contentChars,
}));
if (firstRows.length > 0) {
console.log("\n--- First chunk timings ---");
// @ts-ignore Bun/Node support console.table
console.table(firstRows);
}
if (slowestRows.length > 0) {
console.log(`\n--- Slowest ${SHOW_SLOWEST_N} gaps ---`);
// @ts-ignore
console.table(slowestRows);
}
const tProgramEnd = now();
console.log("\n=== Program Overhead ===");
console.log(`Total program runtime: ${ms(tProgramEnd - tProgramStart)}`);
} catch (error: any) {
console.error("\n[ERROR] Request failed:", error.message);
process.exit(1);
}
}
// Run the main function
main().catch((error) => {
console.error("[FATAL ERROR]:", error);
process.exit(1);
});

View File

@@ -1,11 +0,0 @@
{
"name": "cli",
"main": "cli.ts",
"scripts": {
"build": "bun build cli.ts --compile --outfile cli"
},
"dependencies": {
"install": "^0.13.0",
"openai": "^5.16.0"
}
}

View File

@@ -1,32 +0,0 @@
use std::{env, fs, io, path::PathBuf, process::Command};
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
fn main() -> io::Result<()> {
// Absolute path provided by build.rs at compile time.
// `include_bytes!` accepts string literals; `env!` expands to a literal at compile time.
const CLIENT_CLI: &[u8] = include_bytes!(env!("CLIENT_CLI_BIN"));
// Write to a temp file
let mut tmp = env::temp_dir();
tmp.push("client-cli-embedded");
fs::write(&tmp, CLIENT_CLI)?;
// Ensure it's executable on Unix
#[cfg(unix)]
{
let mut perms = fs::metadata(&tmp)?.permissions();
perms.set_mode(0o755);
fs::set_permissions(&tmp, perms)?;
}
// Run it
let status = Command::new(&tmp).arg("--version").status()?;
if !status.success() {
return Err(io::Error::new(io::ErrorKind::Other, "client-cli failed"));
}
Ok(())
}

View File

@@ -1,43 +1,183 @@
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
use axum::{Json, Router, response::Json as ResponseJson, routing::post};
use axum::{Json, Router, response::Json as ResponseJson, routing::{get, post}, http::StatusCode};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use once_cell::sync::Lazy;
use serde::Serialize;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tower_http::trace::TraceLayer;
use tracing;
// Persistent model instance (singleton pattern)
static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
tracing::info!("Initializing persistent embedding model (singleton)");
// Cache for multiple embedding models
static MODEL_CACHE: Lazy<RwLock<HashMap<EmbeddingModel, Arc<TextEmbedding>>>> = Lazy::new(|| {
RwLock::new(HashMap::new())
});
#[derive(Serialize)]
pub struct ModelInfo {
pub id: String,
pub object: String,
pub owned_by: String,
pub description: String,
pub dimensions: usize,
}
#[derive(Serialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelInfo>,
}
// Function to convert model name strings to EmbeddingModel enum variants
fn parse_embedding_model(model_name: &str) -> Result<EmbeddingModel, String> {
match model_name {
// Sentence Transformers models
"sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => Ok(EmbeddingModel::AllMiniLML6V2),
"sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => Ok(EmbeddingModel::AllMiniLML6V2Q),
"sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => Ok(EmbeddingModel::AllMiniLML12V2),
"sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => Ok(EmbeddingModel::AllMiniLML12V2Q),
// BGE models
"BAAI/bge-base-en-v1.5" | "bge-base-en-v1.5" => Ok(EmbeddingModel::BGEBaseENV15),
"BAAI/bge-base-en-v1.5-q" | "bge-base-en-v1.5-q" => Ok(EmbeddingModel::BGEBaseENV15Q),
"BAAI/bge-large-en-v1.5" | "bge-large-en-v1.5" => Ok(EmbeddingModel::BGELargeENV15),
"BAAI/bge-large-en-v1.5-q" | "bge-large-en-v1.5-q" => Ok(EmbeddingModel::BGELargeENV15Q),
"BAAI/bge-small-en-v1.5" | "bge-small-en-v1.5" => Ok(EmbeddingModel::BGESmallENV15),
"BAAI/bge-small-en-v1.5-q" | "bge-small-en-v1.5-q" => Ok(EmbeddingModel::BGESmallENV15Q),
"BAAI/bge-small-zh-v1.5" | "bge-small-zh-v1.5" => Ok(EmbeddingModel::BGESmallZHV15),
"BAAI/bge-large-zh-v1.5" | "bge-large-zh-v1.5" => Ok(EmbeddingModel::BGELargeZHV15),
// Nomic models
"nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => Ok(EmbeddingModel::NomicEmbedTextV1),
"nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => Ok(EmbeddingModel::NomicEmbedTextV15),
"nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => Ok(EmbeddingModel::NomicEmbedTextV15Q),
// Paraphrase models
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2),
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q" | "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q),
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2" | "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2),
// ModernBert
"lightonai/modernbert-embed-large" | "modernbert-embed-large" => Ok(EmbeddingModel::ModernBertEmbedLarge),
// Multilingual E5 models
"intfloat/multilingual-e5-small" | "multilingual-e5-small" => Ok(EmbeddingModel::MultilingualE5Small),
"intfloat/multilingual-e5-base" | "multilingual-e5-base" => Ok(EmbeddingModel::MultilingualE5Base),
"intfloat/multilingual-e5-large" | "multilingual-e5-large" => Ok(EmbeddingModel::MultilingualE5Large),
// Mixedbread models
"mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => Ok(EmbeddingModel::MxbaiEmbedLargeV1),
"mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => Ok(EmbeddingModel::MxbaiEmbedLargeV1Q),
// GTE models
"Alibaba-NLP/gte-base-en-v1.5" | "gte-base-en-v1.5" => Ok(EmbeddingModel::GTEBaseENV15),
"Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => Ok(EmbeddingModel::GTEBaseENV15Q),
"Alibaba-NLP/gte-large-en-v1.5" | "gte-large-en-v1.5" => Ok(EmbeddingModel::GTELargeENV15),
"Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => Ok(EmbeddingModel::GTELargeENV15Q),
// CLIP model
"Qdrant/clip-ViT-B-32-text" | "clip-vit-b-32" => Ok(EmbeddingModel::ClipVitB32),
// Jina model
"jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode),
_ => Err(format!("Unsupported embedding model: {}", model_name)),
}
}
// Function to get model dimensions
fn get_model_dimensions(model: &EmbeddingModel) -> usize {
match model {
EmbeddingModel::AllMiniLML6V2 | EmbeddingModel::AllMiniLML6V2Q => 384,
EmbeddingModel::AllMiniLML12V2 | EmbeddingModel::AllMiniLML12V2Q => 384,
EmbeddingModel::BGEBaseENV15 | EmbeddingModel::BGEBaseENV15Q => 768,
EmbeddingModel::BGELargeENV15 | EmbeddingModel::BGELargeENV15Q => 1024,
EmbeddingModel::BGESmallENV15 | EmbeddingModel::BGESmallENV15Q => 384,
EmbeddingModel::BGESmallZHV15 => 512,
EmbeddingModel::BGELargeZHV15 => 1024,
EmbeddingModel::NomicEmbedTextV1 | EmbeddingModel::NomicEmbedTextV15 | EmbeddingModel::NomicEmbedTextV15Q => 768,
EmbeddingModel::ParaphraseMLMiniLML12V2 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => 384,
EmbeddingModel::ParaphraseMLMpnetBaseV2 => 768,
EmbeddingModel::ModernBertEmbedLarge => 1024,
EmbeddingModel::MultilingualE5Small => 384,
EmbeddingModel::MultilingualE5Base => 768,
EmbeddingModel::MultilingualE5Large => 1024,
EmbeddingModel::MxbaiEmbedLargeV1 | EmbeddingModel::MxbaiEmbedLargeV1Q => 1024,
EmbeddingModel::GTEBaseENV15 | EmbeddingModel::GTEBaseENV15Q => 768,
EmbeddingModel::GTELargeENV15 | EmbeddingModel::GTELargeENV15Q => 1024,
EmbeddingModel::ClipVitB32 => 512,
EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
}
}
// Function to get or create a model from cache
fn get_or_create_model(embedding_model: EmbeddingModel) -> Result<Arc<TextEmbedding>, String> {
// First try to get from cache (read lock)
{
let cache = MODEL_CACHE.read().map_err(|e| format!("Failed to acquire read lock: {}", e))?;
if let Some(model) = cache.get(&embedding_model) {
tracing::debug!("Using cached model: {:?}", embedding_model);
return Ok(Arc::clone(model));
}
}
// Model not in cache, create it (write lock)
let mut cache = MODEL_CACHE.write().map_err(|e| format!("Failed to acquire write lock: {}", e))?;
// Double-check after acquiring write lock
if let Some(model) = cache.get(&embedding_model) {
tracing::debug!("Using cached model (double-check): {:?}", embedding_model);
return Ok(Arc::clone(model));
}
tracing::info!("Initializing new embedding model: {:?}", embedding_model);
let model_start_time = std::time::Instant::now();
let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
InitOptions::new(embedding_model.clone()).with_show_download_progress(true),
)
.expect("Failed to initialize persistent embedding model");
.map_err(|e| format!("Failed to initialize model {:?}: {}", embedding_model, e))?;
let model_init_time = model_start_time.elapsed();
tracing::info!(
"Persistent embedding model initialized in {:.2?}",
"Embedding model {:?} initialized in {:.2?}",
embedding_model,
model_init_time
);
model
});
let model_arc = Arc::new(model);
cache.insert(embedding_model.clone(), Arc::clone(&model_arc));
Ok(model_arc)
}
pub async fn embeddings_create(
Json(payload): Json<CreateEmbeddingRequest>,
) -> ResponseJson<serde_json::Value> {
) -> Result<ResponseJson<serde_json::Value>, (StatusCode, String)> {
// Start timing the entire process
let start_time = std::time::Instant::now();
// Phase 1: Access persistent model instance
// Phase 1: Parse and get the embedding model
let model_start_time = std::time::Instant::now();
// Access the lazy-initialized persistent model instance
// This will only initialize the model on the first request
let embedding_model = match parse_embedding_model(&payload.model) {
Ok(model) => model,
Err(e) => {
tracing::error!("Invalid model requested: {}", e);
return Err((StatusCode::BAD_REQUEST, format!("Invalid model: {}", e)));
}
};
let model = match get_or_create_model(embedding_model.clone()) {
Ok(model) => model,
Err(e) => {
tracing::error!("Failed to get/create model: {}", e);
return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("Model initialization failed: {}", e)));
}
};
let model_access_time = model_start_time.elapsed();
tracing::debug!(
"Persistent model access completed in {:.2?}",
"Model access/creation completed in {:.2?}",
model_access_time
);
@@ -65,9 +205,12 @@ pub async fn embeddings_create(
// Phase 3: Generate embeddings
let embedding_start_time = std::time::Instant::now();
let embeddings = EMBEDDING_MODEL
let embeddings = model
.embed(texts_from_embedding_input, None)
.expect("failed to embed document");
.map_err(|e| {
tracing::error!("Failed to generate embeddings: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, format!("Embedding generation failed: {}", e))
})?;
let embedding_generation_time = embedding_start_time.elapsed();
tracing::info!(
@@ -117,8 +260,9 @@ pub async fn embeddings_create(
// Generate a random non-zero embedding
use rand::Rng;
let mut rng = rand::thread_rng();
let mut random_embedding = Vec::with_capacity(768);
for _ in 0..768 {
let expected_dimensions = get_model_dimensions(&embedding_model);
let mut random_embedding = Vec::with_capacity(expected_dimensions);
for _ in 0..expected_dimensions {
// Generate random values between -1.0 and 1.0, excluding 0
let mut val = 0.0;
while val == 0.0 {
@@ -138,18 +282,19 @@ pub async fn embeddings_create(
random_embedding
} else {
// Check if dimensions parameter is provided and pad the embeddings if necessary
let mut padded_embedding = embeddings[0].clone();
let padded_embedding = embeddings[0].clone();
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
let target_dimension = 768;
if padded_embedding.len() < target_dimension {
let padding_needed = target_dimension - padded_embedding.len();
tracing::trace!(
"Padding embedding with {} zeros to reach {} dimensions",
padding_needed,
target_dimension
// Use the actual model dimensions instead of hardcoded 768
let actual_dimensions = padded_embedding.len();
let expected_dimensions = get_model_dimensions(&embedding_model);
if actual_dimensions != expected_dimensions {
tracing::warn!(
"Model {:?} produced {} dimensions but expected {}",
embedding_model,
actual_dimensions,
expected_dimensions
);
padded_embedding.extend(vec![0.0; padding_needed]);
}
padded_embedding
@@ -203,11 +348,232 @@ pub async fn embeddings_create(
postprocessing_time
);
ResponseJson(response)
Ok(ResponseJson(response))
}
pub async fn models_list() -> ResponseJson<ModelsResponse> {
let models = vec![
ModelInfo {
id: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence Transformer model, MiniLM-L6-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/all-MiniLM-L6-v2-q".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Quantized Sentence Transformer model, MiniLM-L6-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/all-MiniLM-L12-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence Transformer model, MiniLM-L12-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/all-MiniLM-L12-v2-q".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Quantized Sentence Transformer model, MiniLM-L12-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "BAAI/bge-base-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the base English model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "BAAI/bge-base-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "Quantized v1.5 release of the base English model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "BAAI/bge-large-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the large English model".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "BAAI/bge-large-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "Quantized v1.5 release of the large English model".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "BAAI/bge-small-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the fast and default English model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "BAAI/bge-small-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "Quantized v1.5 release of the fast and default English model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "BAAI/bge-small-zh-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the small Chinese model".to_string(),
dimensions: 512,
},
ModelInfo {
id: "BAAI/bge-large-zh-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the large Chinese model".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "nomic-ai/nomic-embed-text-v1".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "8192 context length english model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "nomic-ai/nomic-embed-text-v1.5".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "v1.5 release of the 8192 context length english model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "nomic-ai/nomic-embed-text-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "Quantized v1.5 release of the 8192 context length english model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Multi-lingual model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Quantized Multi-lingual model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence-transformers model for tasks like clustering or semantic search".to_string(),
dimensions: 768,
},
ModelInfo {
id: "lightonai/modernbert-embed-large".to_string(),
object: "model".to_string(),
owned_by: "lightonai".to_string(),
description: "Large model of ModernBert Text Embeddings".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "intfloat/multilingual-e5-small".to_string(),
object: "model".to_string(),
owned_by: "intfloat".to_string(),
description: "Small model of multilingual E5 Text Embeddings".to_string(),
dimensions: 384,
},
ModelInfo {
id: "intfloat/multilingual-e5-base".to_string(),
object: "model".to_string(),
owned_by: "intfloat".to_string(),
description: "Base model of multilingual E5 Text Embeddings".to_string(),
dimensions: 768,
},
ModelInfo {
id: "intfloat/multilingual-e5-large".to_string(),
object: "model".to_string(),
owned_by: "intfloat".to_string(),
description: "Large model of multilingual E5 Text Embeddings".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "mixedbread-ai/mxbai-embed-large-v1".to_string(),
object: "model".to_string(),
owned_by: "mixedbread-ai".to_string(),
description: "Large English embedding model from MixedBreed.ai".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "mixedbread-ai/mxbai-embed-large-v1-q".to_string(),
object: "model".to_string(),
owned_by: "mixedbread-ai".to_string(),
description: "Quantized Large English embedding model from MixedBreed.ai".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "Alibaba-NLP/gte-base-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Base multilingual embedding model from Alibaba".to_string(),
dimensions: 768,
},
ModelInfo {
id: "Alibaba-NLP/gte-base-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Quantized Base multilingual embedding model from Alibaba".to_string(),
dimensions: 768,
},
ModelInfo {
id: "Alibaba-NLP/gte-large-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Large multilingual embedding model from Alibaba".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "Alibaba-NLP/gte-large-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Quantized Large multilingual embedding model from Alibaba".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "Qdrant/clip-ViT-B-32-text".to_string(),
object: "model".to_string(),
owned_by: "Qdrant".to_string(),
description: "CLIP text encoder based on ViT-B/32".to_string(),
dimensions: 512,
},
ModelInfo {
id: "jinaai/jina-embeddings-v2-base-code".to_string(),
object: "model".to_string(),
owned_by: "jinaai".to_string(),
description: "Jina embeddings v2 base code".to_string(),
dimensions: 768,
},
];
ResponseJson(ModelsResponse {
object: "list".to_string(),
data: models,
})
}
pub fn create_embeddings_router() -> Router {
Router::new()
.route("/v1/embeddings", post(embeddings_create))
// .route("/v1/models", get(models_list))
.layer(TraceLayer::new_for_http())
}

View File

@@ -4,8 +4,6 @@ use axum::{
response::Json as ResponseJson,
routing::{get, post},
};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use serde::{Deserialize, Serialize};
use std::env;
use tower_http::trace::TraceLayer;
use tracing;
@@ -13,127 +11,30 @@ use tracing;
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
const DEFAULT_SERVER_PORT: &str = "8080";
use embeddings_engine;
async fn embeddings_create(
Json(payload): Json<CreateEmbeddingRequest>,
) -> ResponseJson<serde_json::Value> {
let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
)
.expect("Failed to initialize model");
let embedding_input = payload.input;
let texts_from_embedding_input = match embedding_input {
EmbeddingInput::String(text) => vec![text],
EmbeddingInput::StringArray(texts) => texts,
EmbeddingInput::IntegerArray(_) => {
panic!("Integer array input not supported for text embeddings");
) -> Result<ResponseJson<serde_json::Value>, axum::response::Response> {
match embeddings_engine::embeddings_create(Json(payload)).await {
Ok(response) => Ok(response),
Err((status_code, message)) => {
Err(axum::response::Response::builder()
.status(status_code)
.body(axum::body::Body::from(message))
.unwrap())
}
EmbeddingInput::ArrayOfIntegerArray(_) => {
panic!("Array of integer arrays not supported for text embeddings");
}
};
}
}
let embeddings = model
.embed(texts_from_embedding_input, None)
.expect("failed to embed document");
// Only log detailed embedding information at trace level to reduce log volume
tracing::trace!("Embeddings length: {}", embeddings.len());
tracing::info!("Embedding dimension: {}", embeddings[0].len());
// Log the first 10 values of the original embedding at trace level
tracing::trace!(
"Original embedding preview: {:?}",
&embeddings[0][..10.min(embeddings[0].len())]
);
// Check if there are any NaN or zero values in the original embedding
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
tracing::trace!(
"Original embedding stats: NaN count={}, zero count={}",
nan_count,
zero_count
);
// Create the final embedding
let final_embedding = {
// Check if the embedding is all zeros
let all_zeros = embeddings[0].iter().all(|&x| x == 0.0);
if all_zeros {
tracing::warn!("Embedding is all zeros. Generating random non-zero embedding.");
// Generate a random non-zero embedding
use rand::Rng;
let mut rng = rand::thread_rng();
let mut random_embedding = Vec::with_capacity(768);
for _ in 0..768 {
// Generate random values between -1.0 and 1.0, excluding 0
let mut val = 0.0;
while val == 0.0 {
val = rng.gen_range(-1.0..1.0);
}
random_embedding.push(val);
}
// Normalize the random embedding
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
for i in 0..random_embedding.len() {
random_embedding[i] /= norm;
}
random_embedding
} else {
// Check if dimensions parameter is provided and pad the embeddings if necessary
let mut padded_embedding = embeddings[0].clone();
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
let target_dimension = 768;
if padded_embedding.len() < target_dimension {
let padding_needed = target_dimension - padded_embedding.len();
tracing::trace!(
"Padding embedding with {} zeros to reach {} dimensions",
padding_needed,
target_dimension
);
padded_embedding.extend(vec![0.0; padding_needed]);
}
padded_embedding
}
};
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
// Log the first 10 values of the final embedding at trace level
tracing::trace!(
"Final embedding preview: {:?}",
&final_embedding[..10.min(final_embedding.len())]
);
// Return a response that matches the OpenAI API format
let response = serde_json::json!({
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": final_embedding
}
],
"model": payload.model,
"usage": {
"prompt_tokens": 0,
"total_tokens": 0
}
});
ResponseJson(response)
async fn models_list() -> ResponseJson<embeddings_engine::ModelsResponse> {
embeddings_engine::models_list().await
}
fn create_app() -> Router {
Router::new()
.route("/v1/embeddings", post(embeddings_create))
.route("/v1/models", get(models_list))
.layer(TraceLayer::new_for_http())
}
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

View File

@@ -1,32 +0,0 @@
[package]
name = "gemma-runner"
version.workspace = true
edition = "2021"
[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git" }
candle-nn = { git = "https://github.com/huggingface/candle.git" }
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
hf-hub = "0.4"
tokenizers = "0.22.0"
anyhow = "1.0"
clap = { version = "4.0", features = ["derive", "string"] }
serde_json = "1.0"
tracing = "0.1"
tracing-chrome = "0.7"
tracing-subscriber = "0.3"
utils = {path = "../utils"}
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
[features]
default = []
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]

View File

@@ -1,137 +0,0 @@
# Gemma Runner
Fast Gemma inference with Candle framework in Rust.
## Features
- Support for multiple Gemma model versions (v1, v2, v3)
- GPU acceleration with CUDA and Metal
- Configurable sampling parameters
- Multiple model variants including instruct and code models
## Supported Models
### Gemma v1
- `gemma-2b` - Base 2B model
- `gemma-7b` - Base 7B model
- `gemma-2b-it` - Instruct 2B model
- `gemma-7b-it` - Instruct 7B model
- `gemma-1.1-2b-it` - Instruct 2B v1.1 model
- `gemma-1.1-7b-it` - Instruct 7B v1.1 model
### CodeGemma
- `codegemma-2b` - Code base 2B model
- `codegemma-7b` - Code base 7B model
- `codegemma-2b-it` - Code instruct 2B model
- `codegemma-7b-it` - Code instruct 7B model
### Gemma v2
- `gemma-2-2b` - Base 2B v2 model (default)
- `gemma-2-2b-it` - Instruct 2B v2 model
- `gemma-2-9b` - Base 9B v2 model
- `gemma-2-9b-it` - Instruct 9B v2 model
### Gemma v3
- `gemma-3-1b` - Base 1B v3 model
- `gemma-3-1b-it` - Instruct 1B v3 model
## Installation
```bash
cd gemma-runner
cargo build --release
```
For GPU support:
```bash
# CUDA
cargo build --release --features cuda
# Metal (macOS)
cargo build --release --features metal
```
## Usage
### Basic Usage
```bash
# Run with default model (gemma-2-2b)
cargo run -- --prompt "The capital of France is"
# Specify a different model
cargo run -- --model gemma-2b-it --prompt "Explain quantum computing"
# Generate more tokens
cargo run -- --model codegemma-2b-it --prompt "Write a Python function to sort a list" --max-tokens 200
```
### Advanced Options
```bash
# Use CPU instead of GPU
cargo run -- --cpu --prompt "Hello world"
# Adjust sampling parameters
cargo run -- --temperature 0.8 --top-p 0.9 --prompt "Write a story about"
# Use custom model from HuggingFace Hub
cargo run -- --model-id "google/gemma-2-2b-it" --prompt "What is AI?"
# Enable tracing for performance analysis
cargo run -- --tracing --prompt "Explain machine learning"
```
### Command Line Arguments
- `--prompt, -p` - The prompt to generate text from (default: "The capital of France is")
- `--model, -m` - The model to use (default: "gemma-2-2b")
- `--cpu` - Run on CPU rather than GPU
- `--temperature, -t` - Sampling temperature (optional)
- `--top-p` - Nucleus sampling probability cutoff (optional)
- `--seed` - Random seed (default: 299792458)
- `--max-tokens, -n` - Maximum tokens to generate (default: 100)
- `--model-id` - Custom model ID from HuggingFace Hub
- `--revision` - Model revision (default: "main")
- `--use-flash-attn` - Use flash attention
- `--repeat-penalty` - Repetition penalty (default: 1.1)
- `--repeat-last-n` - Context size for repeat penalty (default: 64)
- `--dtype` - Data type (f16, bf16, f32)
- `--tracing` - Enable performance tracing
## Examples
### Text Generation
```bash
cargo run -- --model gemma-2b-it --prompt "Explain the theory of relativity" --max-tokens 150
```
### Code Generation
```bash
cargo run -- --model codegemma-7b-it --prompt "Write a Rust function to calculate factorial" --max-tokens 100
```
### Creative Writing
```bash
cargo run -- --model gemma-7b-it --temperature 0.9 --prompt "Once upon a time in a magical forest" --max-tokens 200
```
### Chat with Gemma 3 (Instruct format)
```bash
cargo run -- --model gemma-3-1b-it --prompt "How do I learn Rust programming?"
```
## Performance Notes
- GPU acceleration is automatically detected and used when available
- BF16 precision is used on CUDA for better performance
- F32 precision is used on CPU
- Flash attention can be enabled with `--use-flash-attn` for supported models
- Model files are cached locally after first download
## Requirements
- Rust 1.70+
- CUDA toolkit (for CUDA support)
- Metal (automatically available on macOS)
- Internet connection for first-time model download

View File

@@ -1,398 +0,0 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use clap::ValueEnum;
// Removed gemma_cli import as it's not needed for the API
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
use std::sync::mpsc::{self, Receiver, Sender};
use std::thread;
use tokenizers::Tokenizer;
use utils::hub_load_safetensors;
use utils::token_output_stream::TokenOutputStream;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
pub enum WhichModel {
#[value(name = "gemma-2b")]
Base2B,
#[value(name = "gemma-7b")]
Base7B,
#[value(name = "gemma-2b-it")]
Instruct2B,
#[value(name = "gemma-7b-it")]
Instruct7B,
#[value(name = "gemma-1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "gemma-1.1-7b-it")]
InstructV1_1_7B,
#[value(name = "codegemma-2b")]
CodeBase2B,
#[value(name = "codegemma-7b")]
CodeBase7B,
#[value(name = "codegemma-2b-it")]
CodeInstruct2B,
#[value(name = "codegemma-7b-it")]
CodeInstruct7B,
#[value(name = "gemma-2-2b")]
BaseV2_2B,
#[value(name = "gemma-2-2b-it")]
InstructV2_2B,
#[value(name = "gemma-2-9b")]
BaseV2_9B,
#[value(name = "gemma-2-9b-it")]
InstructV2_9B,
#[value(name = "gemma-3-1b")]
BaseV3_1B,
#[value(name = "gemma-3-1b-it")]
InstructV3_1B,
}
enum Model {
V1(Model1),
V2(Model2),
V3(Model3),
}
impl Model {
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> candle_core::Result<Tensor> {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
Self::V3(m) => m.forward(input_ids, pos),
}
}
}
pub struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if candle_core::utils::cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if candle_core::utils::metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
Ok(Device::Cpu)
}
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: tokenizers::Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
/// Stream-only generation: sends freshly generated token strings over `tx`.
/// (Does not send the prompt tokens; only newly generated model tokens.)
fn run_stream(
&mut self,
prompt: &str,
sample_len: usize,
tx: Sender<Result<String>>,
) -> Result<()> {
self.tokenizer.clear();
// Encode prompt (context only; do not emit prompt tokens to the stream).
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
// Warm the tokenizer's internal state with prompt tokens (so merges are correct),
// but do not send them to the receiver.
for &t in tokens.iter() {
let _ = self.tokenizer.next_token(t)?;
}
// Make sure stdout isn't holding anything (if caller also prints).
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
eprintln!("Warning: <end_of_turn> token not found, using <eos> as backup");
eos_token
}
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
// Best-effort send; ignore if receiver dropped.
let _ = tx.send(Ok(t));
}
}
let _dt = start_gen.elapsed();
// Flush any remaining buffered bytes as one final chunk.
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
let _ = tx.send(Ok(rest));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct GemmaInferenceConfig {
pub tracing: bool,
pub prompt: String,
pub model: WhichModel,
pub cpu: bool,
pub dtype: Option<String>,
pub model_id: Option<String>,
pub revision: String,
pub use_flash_attn: bool,
pub seed: u64,
pub temperature: f64,
pub top_p: Option<f64>,
pub repeat_penalty: f32,
pub repeat_last_n: usize,
pub max_tokens: usize,
}
impl Default for GemmaInferenceConfig {
fn default() -> Self {
Self {
tracing: false,
prompt: "Hello".to_string(),
model: WhichModel::InstructV2_2B,
cpu: false,
dtype: None,
model_id: None,
revision: "main".to_string(),
use_flash_attn: false,
seed: 299792458,
temperature: 0.8,
top_p: None,
repeat_penalty: 1.1,
repeat_last_n: 128,
max_tokens: 100,
}
}
}
// Removed From<Args> implementation as Args is not available and not needed for API usage
/// Builds the model and returns a channel that streams generated token strings.
/// If model setup fails, the `Result` is returned immediately.
pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String>>> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let _guard = if cfg.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle_core::utils::with_avx(),
candle_core::utils::with_neon(),
candle_core::utils::with_simd128(),
candle_core::utils::with_f16c()
);
let device = device(cfg.cpu)?;
println!("Device: {:?}", device);
let dtype = match cfg.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => anyhow::bail!("Unsupported dtype {dtype}"),
None => {
if device.is_cuda() {
DType::BF16
} else {
DType::F16
}
}
};
println!("Using dtype: {:?}", dtype);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = cfg.model_id.unwrap_or_else(|| {
match cfg.model {
WhichModel::Base2B => "google/gemma-2b",
WhichModel::Base7B => "google/gemma-7b",
WhichModel::Instruct2B => "google/gemma-2b-it",
WhichModel::Instruct7B => "google/gemma-7b-it",
WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it",
WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it",
WhichModel::CodeBase2B => "google/codegemma-2b",
WhichModel::CodeBase7B => "google/codegemma-7b",
WhichModel::CodeInstruct2B => "google/codegemma-2b-it",
WhichModel::CodeInstruct7B => "google/codegemma-7b-it",
WhichModel::BaseV2_2B => "google/gemma-2-2b",
WhichModel::InstructV2_2B => "google/gemma-2-2b-it",
WhichModel::BaseV2_9B => "google/gemma-2-9b",
WhichModel::InstructV2_9B => "google/gemma-2-9b-it",
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
}
.to_string()
});
println!("Loading model: {}", &model_id);
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, cfg.revision));
let tokenizer_filename = repo.get("tokenizer.json")?;
let config_filename = repo.get("config.json")?;
let filenames = match cfg.model {
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("Retrieved files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model: Model = match cfg.model {
WhichModel::Base2B
| WhichModel::Base7B
| WhichModel::Instruct2B
| WhichModel::Instruct7B
| WhichModel::InstructV1_1_2B
| WhichModel::InstructV1_1_7B
| WhichModel::CodeBase2B
| WhichModel::CodeBase7B
| WhichModel::CodeInstruct2B
| WhichModel::CodeInstruct7B => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
Model::V1(model)
}
WhichModel::BaseV2_2B
| WhichModel::InstructV2_2B
| WhichModel::BaseV2_9B
| WhichModel::InstructV2_9B => {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
Model::V2(model)
}
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
Model::V3(model)
}
};
println!("Loaded model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
cfg.seed,
cfg.temperature.into(),
cfg.top_p,
cfg.repeat_penalty,
cfg.repeat_last_n,
&device,
);
let prompt = match cfg.model {
WhichModel::InstructV3_1B => {
format!(
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
cfg.prompt
)
}
_ => cfg.prompt,
};
println!("Starting inference...");
// Create the channel after successful setup.
let (tx, rx) = mpsc::channel::<Result<String>>();
// Spawn generation thread; send tokens to the channel.
thread::spawn(move || {
// If generation fails, forward the error once.
if let Err(e) = pipeline.run_stream(&prompt, cfg.max_tokens, tx.clone()) {
let _ = tx.send(Err(e));
}
// Channel closes when tx is dropped.
});
Ok(rx)
}

View File

@@ -1,97 +0,0 @@
use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
use clap::Parser;
use std::io::Write;
#[derive(Parser, Debug)]
#[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)]
pub struct Args {
/// The prompt to generate text from
#[arg(short, long, default_value = "The capital of France is")]
pub(crate) prompt: String,
/// The model to use
#[arg(short, long, default_value = "gemma-2-2b")]
pub(crate) model: WhichModel,
/// Run on CPU rather than GPU
#[arg(long)]
pub(crate) cpu: bool,
/// The temperature used to generate samples
#[arg(short, long)]
pub(crate) temperature: Option<f64>,
/// Nucleus sampling probability cutoff
#[arg(long)]
pub(crate) top_p: Option<f64>,
/// The seed to use when generating random samples
#[arg(long, default_value_t = 299792458)]
pub(crate) seed: u64,
/// The length of the sample to generate (in tokens)
#[arg(short = 'n', long, default_value_t = 100)]
pub(crate) max_tokens: usize,
/// Use different dtype than default
#[arg(long)]
pub(crate) dtype: Option<String>,
/// Custom model ID from HuggingFace Hub
#[arg(long)]
pub(crate) model_id: Option<String>,
/// Model revision
#[arg(long, default_value = "main")]
pub(crate) revision: String,
/// Use flash attention
#[arg(long)]
pub(crate) use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty
#[arg(long, default_value_t = 1.1)]
pub(crate) repeat_penalty: f32,
/// The context size to consider for the repeat penalty
#[arg(long, default_value_t = 64)]
pub(crate) repeat_last_n: usize,
/// Enable tracing
#[arg(long)]
pub(crate) tracing: bool,
}
pub fn run_cli() -> anyhow::Result<()> {
let args = Args::parse();
let cfg = GemmaInferenceConfig {
tracing: args.tracing,
prompt: args.prompt,
model: args.model,
cpu: args.cpu,
dtype: args.dtype,
model_id: args.model_id,
revision: args.revision,
use_flash_attn: args.use_flash_attn,
seed: args.seed,
temperature: args.temperature.unwrap_or(0.8),
top_p: args.top_p,
repeat_penalty: args.repeat_penalty,
repeat_last_n: args.repeat_last_n,
max_tokens: args.max_tokens,
};
let rx = run_gemma_api(cfg)?;
for msg in rx {
match msg {
Ok(tok) => {
print!("{tok}");
let _ = std::io::stdout().flush(); // <- force it out now
}
Err(e) => {
eprintln!("generation error: {e}");
break;
}
}
}
Ok(())
}

View File

@@ -1,3 +0,0 @@
pub mod gemma_api;
pub use gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};

View File

@@ -1,17 +0,0 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod gemma_api;
mod gemma_cli;
use anyhow::Error;
use clap::{Parser, ValueEnum};
use crate::gemma_cli::run_cli;
use std::io::Write;
/// just a placeholder, not used for anything
fn main() -> std::result::Result<(), Error> {
run_cli()
}

View File

@@ -1,16 +0,0 @@
[package]
name = "helm-chart-tool"
version.workspace = true
edition = "2021"
[[bin]]
name = "helm-chart-tool"
path = "src/main.rs"
[dependencies]
toml = "0.8"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
anyhow = "1.0"
clap = { version = "4.0", features = ["derive"] }
walkdir = "2.0"

View File

@@ -1,218 +0,0 @@
# Helm Chart Tool
A Rust-based tool that automatically generates Helm charts from Cargo.toml metadata in Rust workspace projects.
## Overview
This tool scans a Rust workspace for crates containing Docker/Kubernetes metadata in their `Cargo.toml` files and generates a complete, production-ready Helm chart with deployments, services, ingress, and configuration templates.
## Features
- **Automatic Service Discovery**: Scans all `Cargo.toml` files in a workspace to find services with Kubernetes metadata
- **Complete Helm Chart Generation**: Creates Chart.yaml, values.yaml, deployment templates, service templates, ingress template, and helper templates
- **Metadata Extraction**: Uses `[package.metadata.kube]` sections from Cargo.toml files to extract:
- Docker image names
- Service ports
- Replica counts
- Service names
- **Production Ready**: Generated charts include health checks, resource limits, node selectors, affinity rules, and tolerations
- **Helm Best Practices**: Follows Helm chart conventions and passes `helm lint` validation
## Installation
Build the tool from source:
```bash
cd helm-chart-tool
cargo build --release
```
The binary will be available at `target/release/helm-chart-tool`.
## Usage
### Basic Usage
```bash
./target/release/helm-chart-tool --workspace /path/to/rust/workspace --output ./my-helm-chart
```
### Command Line Options
- `--workspace, -w PATH`: Path to the workspace root (default: `.`)
- `--output, -o PATH`: Output directory for the Helm chart (default: `./helm-chart`)
- `--name, -n NAME`: Name of the Helm chart (default: `predict-otron-9000`)
### Example
```bash
# Generate chart from current workspace
./target/release/helm-chart-tool
# Generate chart from specific workspace with custom output
./target/release/helm-chart-tool -w /path/to/my/workspace -o ./charts/my-app -n my-application
```
## Cargo.toml Metadata Format
The tool expects crates to have Kubernetes metadata in their `Cargo.toml` files:
```toml
[package]
name = "my-service"
version = "0.1.0"
# Required: Kubernetes metadata
[package.metadata.kube]
image = "ghcr.io/myorg/my-service:latest"
replicas = 1
port = 8080
# Optional: Docker Compose metadata (currently not used but parsed)
[package.metadata.compose]
image = "ghcr.io/myorg/my-service:latest"
port = 8080
```
### Required Fields
- `image`: Full Docker image name including registry and tag
- `port`: Port number the service listens on
- `replicas`: Number of replicas to deploy (optional, defaults to 1)
## Generated Chart Structure
The tool generates a complete Helm chart with the following structure:
```
helm-chart/
├── Chart.yaml # Chart metadata
├── values.yaml # Default configuration values
├── .helmignore # Files to ignore when packaging
└── templates/
├── _helpers.tpl # Template helper functions
├── ingress.yaml # Ingress configuration (optional)
├── {service}-deployment.yaml # Deployment for each service
└── {service}-service.yaml # Service for each service
```
### Generated Files
#### Chart.yaml
- Standard Helm v2 chart metadata
- Includes keywords for AI/ML applications
- Maintainer information
#### values.yaml
- Individual service configurations
- Resource limits and requests
- Service types and ports
- Node selectors, affinity, and tolerations
- Global settings and ingress configuration
#### Deployment Templates
- Kubernetes Deployment manifests
- Health checks (liveness and readiness probes)
- Resource management
- Container port configuration from metadata
- Support for node selectors, affinity, and tolerations
#### Service Templates
- Kubernetes Service manifests
- ClusterIP services by default
- Port mapping from metadata
#### Ingress Template
- Optional ingress configuration
- Disabled by default
- Configurable through values.yaml
## Example Output
When run against the predict-otron-9000 workspace, the tool generates:
```bash
$ ./target/release/helm-chart-tool --workspace .. --output ../generated-helm-chart
Parsing workspace at: ..
Output directory: ../generated-helm-chart
Chart name: predict-otron-9000
Found 4 services:
- chat-ui: ghcr.io/geoffsee/chat-ui:latest (port 8788)
- inference-engine: ghcr.io/geoffsee/inference-service:latest (port 8080)
- embeddings-engine: ghcr.io/geoffsee/embeddings-service:latest (port 8080)
- predict-otron-9000: ghcr.io/geoffsee/predict-otron-9000:latest (port 8080)
Helm chart generated successfully!
```
## Validation
The generated charts pass Helm validation:
```bash
$ helm lint generated-helm-chart
==> Linting generated-helm-chart
[INFO] Chart.yaml: icon is recommended
1 chart(s) linted, 0 chart(s) failed
```
## Deployment
Deploy the generated chart:
```bash
# Install the chart
helm install my-release ./generated-helm-chart
# Upgrade the chart
helm upgrade my-release ./generated-helm-chart
# Uninstall the chart
helm uninstall my-release
```
### Customization
Customize the deployment by modifying `values.yaml`:
```yaml
# Enable ingress
ingress:
enabled: true
className: "nginx"
hosts:
- host: my-app.example.com
# Adjust resources for a specific service
predict_otron_9000:
replicas: 3
resources:
limits:
memory: "4Gi"
cpu: "2000m"
requests:
memory: "2Gi"
cpu: "1000m"
```
## Requirements
- Rust 2021+ (for building the tool)
- Helm 3.x (for deploying the generated charts)
- Kubernetes cluster (for deployment)
## Limitations
- Currently assumes all services need health checks on `/health` endpoint
- Resource limits are hardcoded defaults (can be overridden in values.yaml)
- Ingress configuration is basic (can be customized through values.yaml)
## Contributing
1. Add new features to the tool
2. Test with various Cargo.toml metadata configurations
3. Validate generated charts with `helm lint`
4. Ensure charts deploy successfully to test clusters
## License
This tool is part of the predict-otron-9000 project and follows the same license terms.

View File

@@ -1,525 +0,0 @@
use anyhow::{Context, Result};
use clap::{Arg, Command};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use walkdir::WalkDir;
#[derive(Debug, Deserialize)]
struct CargoToml {
package: Option<Package>,
}
#[derive(Debug, Deserialize)]
struct Package {
name: String,
metadata: Option<Metadata>,
}
#[derive(Debug, Deserialize)]
struct Metadata {
kube: Option<KubeMetadata>,
compose: Option<ComposeMetadata>,
}
#[derive(Debug, Deserialize)]
struct KubeMetadata {
image: String,
replicas: Option<u32>,
port: u16,
}
#[derive(Debug, Deserialize)]
struct ComposeMetadata {
image: Option<String>,
port: Option<u16>,
}
#[derive(Debug, Clone)]
struct ServiceInfo {
name: String,
image: String,
port: u16,
replicas: u32,
}
fn main() -> Result<()> {
let matches = Command::new("helm-chart-tool")
.about("Generate Helm charts from Cargo.toml metadata")
.arg(
Arg::new("workspace")
.short('w')
.long("workspace")
.value_name("PATH")
.help("Path to the workspace root")
.default_value("."),
)
.arg(
Arg::new("output")
.short('o')
.long("output")
.value_name("PATH")
.help("Output directory for the Helm chart")
.default_value("./helm-chart"),
)
.arg(
Arg::new("chart-name")
.short('n')
.long("name")
.value_name("NAME")
.help("Name of the Helm chart")
.default_value("predict-otron-9000"),
)
.get_matches();
let workspace_path = matches.get_one::<String>("workspace").unwrap();
let output_path = matches.get_one::<String>("output").unwrap();
let chart_name = matches.get_one::<String>("chart-name").unwrap();
println!("Parsing workspace at: {}", workspace_path);
println!("Output directory: {}", output_path);
println!("Chart name: {}", chart_name);
let services = discover_services(workspace_path)?;
println!("Found {} services:", services.len());
for service in &services {
println!(
" - {}: {} (port {})",
service.name, service.image, service.port
);
}
generate_helm_chart(output_path, chart_name, &services)?;
println!("Helm chart generated successfully!");
Ok(())
}
fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
let workspace_root = Path::new(workspace_path);
let mut services = Vec::new();
// Find all Cargo.toml files in the workspace
for entry in WalkDir::new(workspace_root)
.into_iter()
.filter_map(|e| e.ok())
{
if entry.file_name() == "Cargo.toml" && entry.path() != workspace_root.join("Cargo.toml") {
if let Ok(service_info) = parse_cargo_toml(entry.path()) {
services.push(service_info);
}
}
}
Ok(services)
}
fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> {
let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read Cargo.toml at {:?}", path))?;
let cargo_toml: CargoToml = toml::from_str(&content)
.with_context(|| format!("Failed to parse Cargo.toml at {:?}", path))?;
let package = cargo_toml
.package
.ok_or_else(|| anyhow::anyhow!("No package section found in {:?}", path))?;
let metadata = package
.metadata
.ok_or_else(|| anyhow::anyhow!("No metadata section found in {:?}", path))?;
let kube_metadata = metadata
.kube
.ok_or_else(|| anyhow::anyhow!("No kube metadata found in {:?}", path))?;
Ok(ServiceInfo {
name: package.name,
image: kube_metadata.image,
port: kube_metadata.port,
replicas: kube_metadata.replicas.unwrap_or(1),
})
}
fn generate_helm_chart(
output_path: &str,
chart_name: &str,
services: &[ServiceInfo],
) -> Result<()> {
let chart_dir = Path::new(output_path);
let templates_dir = chart_dir.join("templates");
// Create directories
fs::create_dir_all(&templates_dir)?;
// Generate Chart.yaml
generate_chart_yaml(chart_dir, chart_name)?;
// Generate values.yaml
generate_values_yaml(chart_dir, services)?;
// Generate templates for each service
for service in services {
generate_deployment_template(&templates_dir, service)?;
generate_service_template(&templates_dir, service)?;
}
// Generate ingress template
generate_ingress_template(&templates_dir, services)?;
// Generate helper templates
generate_helpers_template(&templates_dir)?;
// Generate .helmignore
generate_helmignore(chart_dir)?;
Ok(())
}
fn generate_chart_yaml(chart_dir: &Path, chart_name: &str) -> Result<()> {
let chart_yaml = format!(
r#"apiVersion: v2
name: {}
description: A Helm chart for the predict-otron-9000 AI platform
type: application
version: 0.1.0
appVersion: "0.1.0"
keywords:
- ai
- llm
- inference
- embeddings
- chat
maintainers:
- name: predict-otron-9000-team
"#,
chart_name
);
fs::write(chart_dir.join("Chart.yaml"), chart_yaml)?;
Ok(())
}
fn generate_values_yaml(chart_dir: &Path, services: &[ServiceInfo]) -> Result<()> {
let mut values = String::from(
r#"# Default values for predict-otron-9000
# This is a YAML-formatted file.
global:
imagePullPolicy: IfNotPresent
serviceType: ClusterIP
# Ingress configuration
ingress:
enabled: false
className: ""
annotations: {}
hosts:
- host: predict-otron-9000.local
paths:
- path: /
pathType: Prefix
backend:
service:
name: predict-otron-9000
port:
number: 8080
tls: []
"#,
);
for service in services {
let service_config = format!(
r#"{}:
image:
repository: {}
tag: "latest"
pullPolicy: IfNotPresent
replicas: {}
service:
type: ClusterIP
port: {}
resources:
limits:
memory: "1Gi"
cpu: "1000m"
requests:
memory: "512Mi"
cpu: "250m"
nodeSelector: {{}}
tolerations: []
affinity: {{}}
"#,
service.name.replace("-", "_"),
service.image.split(':').next().unwrap_or(&service.image),
service.replicas,
service.port
);
values.push_str(&service_config);
}
fs::write(chart_dir.join("values.yaml"), values)?;
Ok(())
}
fn generate_deployment_template(templates_dir: &Path, service: &ServiceInfo) -> Result<()> {
let service_name_underscore = service.name.replace("-", "_");
let deployment_template = format!(
r#"apiVersion: apps/v1
kind: Deployment
metadata:
name: {{{{ include "predict-otron-9000.fullname" . }}}}-{}
labels:
{{{{- include "predict-otron-9000.labels" . | nindent 4 }}}}
app.kubernetes.io/component: {}
spec:
replicas: {{{{ .Values.{}.replicas }}}}
selector:
matchLabels:
{{{{- include "predict-otron-9000.selectorLabels" . | nindent 6 }}}}
app.kubernetes.io/component: {}
template:
metadata:
labels:
{{{{- include "predict-otron-9000.selectorLabels" . | nindent 8 }}}}
app.kubernetes.io/component: {}
spec:
containers:
- name: {}
image: "{{{{ .Values.{}.image.repository }}}}:{{{{ .Values.{}.image.tag }}}}"
imagePullPolicy: {{{{ .Values.{}.image.pullPolicy }}}}
ports:
- name: http
containerPort: {}
protocol: TCP
livenessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 30
periodSeconds: 10
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 5
periodSeconds: 5
resources:
{{{{- toYaml .Values.{}.resources | nindent 12 }}}}
{{{{- with .Values.{}.nodeSelector }}}}
nodeSelector:
{{{{- toYaml . | nindent 8 }}}}
{{{{- end }}}}
{{{{- with .Values.{}.affinity }}}}
affinity:
{{{{- toYaml . | nindent 8 }}}}
{{{{- end }}}}
{{{{- with .Values.{}.tolerations }}}}
tolerations:
{{{{- toYaml . | nindent 8 }}}}
{{{{- end }}}}
"#,
service.name,
service.name,
service_name_underscore,
service.name,
service.name,
service.name,
service_name_underscore,
service_name_underscore,
service_name_underscore,
service.port,
service_name_underscore,
service_name_underscore,
service_name_underscore,
service_name_underscore
);
let filename = format!("{}-deployment.yaml", service.name);
fs::write(templates_dir.join(filename), deployment_template)?;
Ok(())
}
fn generate_service_template(templates_dir: &Path, service: &ServiceInfo) -> Result<()> {
let service_template = format!(
r#"apiVersion: v1
kind: Service
metadata:
name: {{{{ include "predict-otron-9000.fullname" . }}}}-{}
labels:
{{{{- include "predict-otron-9000.labels" . | nindent 4 }}}}
app.kubernetes.io/component: {}
spec:
type: {{{{ .Values.{}.service.type }}}}
ports:
- port: {{{{ .Values.{}.service.port }}}}
targetPort: http
protocol: TCP
name: http
selector:
{{{{- include "predict-otron-9000.selectorLabels" . | nindent 4 }}}}
app.kubernetes.io/component: {}
"#,
service.name,
service.name,
service.name.replace("-", "_"),
service.name.replace("-", "_"),
service.name
);
let filename = format!("{}-service.yaml", service.name);
fs::write(templates_dir.join(filename), service_template)?;
Ok(())
}
fn generate_ingress_template(templates_dir: &Path, services: &[ServiceInfo]) -> Result<()> {
let ingress_template = r#"{{- if .Values.ingress.enabled -}}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: {{ include "predict-otron-9000.fullname" . }}
labels:
{{- include "predict-otron-9000.labels" . | nindent 4 }}
{{- with .Values.ingress.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
{{- if .Values.ingress.className }}
ingressClassName: {{ .Values.ingress.className }}
{{- end }}
{{- if .Values.ingress.tls }}
tls:
{{- range .Values.ingress.tls }}
- hosts:
{{- range .hosts }}
- {{ . | quote }}
{{- end }}
secretName: {{ .secretName }}
{{- end }}
{{- end }}
rules:
{{- range .Values.ingress.hosts }}
- host: {{ .host | quote }}
http:
paths:
{{- range .paths }}
- path: {{ .path }}
{{- if .pathType }}
pathType: {{ .pathType }}
{{- end }}
backend:
service:
name: {{ include "predict-otron-9000.fullname" $ }}-{{ .backend.service.name }}
port:
number: {{ .backend.service.port.number }}
{{- end }}
{{- end }}
{{- end }}
"#;
fs::write(templates_dir.join("ingress.yaml"), ingress_template)?;
Ok(())
}
fn generate_helpers_template(templates_dir: &Path) -> Result<()> {
let helpers_template = r#"{{/*
Expand the name of the chart.
*/}}
{{- define "predict-otron-9000.name" -}}
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Create a default fully qualified app name.
We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
If release name contains chart name it will be used as a full name.
*/}}
{{- define "predict-otron-9000.fullname" -}}
{{- if .Values.fullnameOverride }}
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
{{- else }}
{{- $name := default .Chart.Name .Values.nameOverride }}
{{- if contains $name .Release.Name }}
{{- .Release.Name | trunc 63 | trimSuffix "-" }}
{{- else }}
{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
{{- end }}
{{- end }}
{{- end }}
{{/*
Create chart name and version as used by the chart label.
*/}}
{{- define "predict-otron-9000.chart" -}}
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Common labels
*/}}
{{- define "predict-otron-9000.labels" -}}
helm.sh/chart: {{ include "predict-otron-9000.chart" . }}
{{ include "predict-otron-9000.selectorLabels" . }}
{{- if .Chart.AppVersion }}
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
{{- end }}
app.kubernetes.io/managed-by: {{ .Release.Service }}
{{- end }}
{{/*
Selector labels
*/}}
{{- define "predict-otron-9000.selectorLabels" -}}
app.kubernetes.io/name: {{ include "predict-otron-9000.name" . }}
app.kubernetes.io/instance: {{ .Release.Name }}
{{- end }}
{{/*
Create the name of the service account to use
*/}}
{{- define "predict-otron-9000.serviceAccountName" -}}
{{- if .Values.serviceAccount.create }}
{{- default (include "predict-otron-9000.fullname" .) .Values.serviceAccount.name }}
{{- else }}
{{- default "default" .Values.serviceAccount.name }}
{{- end }}
{{- end }}
"#;
fs::write(templates_dir.join("_helpers.tpl"), helpers_template)?;
Ok(())
}
fn generate_helmignore(chart_dir: &Path) -> Result<()> {
let helmignore_content = r#"# Patterns to ignore when building packages.
# This supports shell glob matching, relative path matching, and
# negation (prefixed with !). Only one pattern per line.
.DS_Store
# Common VCS dirs
.git/
.gitignore
.bzr/
.bzrignore
.hg/
.hgignore
.svn/
# Common backup files
*.swp
*.bak
*.tmp
*.orig
*~
# Various IDEs
.project
.idea/
*.tmproj
.vscode/
"#;
fs::write(chart_dir.join(".helmignore"), helmignore_content)?;
Ok(())
}

View File

@@ -31,8 +31,9 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
uuid = { version = "1.7.0", features = ["v4"] }
reborrow = "0.5.5"
futures-util = "0.3.31"
gemma-runner = { path = "../gemma-runner", features = ["metal"] }
llama-runner = { path = "../llama-runner", features = ["metal"]}
gemma-runner = { path = "../../integration/gemma-runner", features = ["metal"] }
llama-runner = { path = "../../integration/llama-runner", features = ["metal"]}
embeddings-engine = { path = "../embeddings-engine" }
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }

View File

@@ -19,6 +19,7 @@ use crate::openai_types::{
};
use crate::Which;
use either::Either;
use embeddings_engine::models_list;
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
use serde_json::Value;
@@ -530,7 +531,9 @@ pub async fn list_models() -> Json<ModelListResponse> {
Which::Llama32_3BInstruct,
];
let models: Vec<Model> = which_variants.into_iter().map(|which| {
let mut models: Vec<Model> = which_variants.into_iter().map(|which| {
let meta = which.meta();
let model_id = match which {
Which::Base2B => "gemma-2b",
@@ -566,11 +569,25 @@ pub async fn list_models() -> Json<ModelListResponse> {
Model {
id: model_id.to_string(),
object: "model".to_string(),
created: 1686935002, // Using same timestamp as OpenAI example
created: 1686935002,
owned_by: owned_by.to_string(),
}
}).collect();
// Get embeddings models and convert them to inference Model format
let embeddings_response = models_list().await;
let embeddings_models: Vec<Model> = embeddings_response.0.data.into_iter().map(|embedding_model| {
Model {
id: embedding_model.id,
object: embedding_model.object,
created: 1686935002,
owned_by: format!("{} - {}", embedding_model.owned_by, embedding_model.description),
}
}).collect();
// Add embeddings models to the main models list
models.extend(embeddings_models);
Json(ModelListResponse {
object: "list".to_string(),
data: models,

View File

@@ -1,24 +0,0 @@
[package]
name = "llama-runner"
version.workspace = true
edition = "2021"
[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git" }
candle-nn = { git = "https://github.com/huggingface/candle.git" }
candle-transformers = { git = "https://github.com/huggingface/candle.git"}
hf-hub = "0.3"
tokenizers = "0.20"
anyhow = "1.0"
clap = { version = "4.0", features = ["derive", "string"] }
serde_json = "1.0"
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
[features]
default = []
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]

View File

@@ -1,188 +0,0 @@
# Llama Runner
A fast Rust implementation for running Llama and other language models using the Candle deep learning framework. Built on the official Candle examples with optimizations for speed and usability.
## Features
- 🚀 **High Performance**: Metal GPU acceleration on macOS, CUDA support on Linux/Windows
- 🤖 **Multiple Models**: Supports Llama 3.2, SmolLM2, TinyLlama, and more
-**Fast Inference**: Optimized with F16 precision and KV caching
- 🎯 **Advanced Sampling**: Top-k, top-p, temperature, and repeat penalty controls
- 📊 **Performance Metrics**: Real-time tokens/second reporting
- 🔧 **Easy CLI**: Simple command-line interface with sensible defaults
## Supported Models
| Model | Size | Command | Description |
|-------|------|---------|-------------|
| SmolLM2-135M | 135M | `smollm2-135m` | Tiny, fast model for testing |
| SmolLM2-360M | 360M | `smollm2-360m` | Small, efficient model |
| SmolLM2-1.7B | 1.7B | `smollm2-1.7b` | Balanced performance/speed |
| Llama-3.2-1B | 1B | `llama-3.2-1b` | Meta's compact model |
| Llama-3.2-3B | 3B | `llama-3.2-3b` | Larger Llama model |
| TinyLlama-1.1B | 1.1B | `tinyllama-1.1b-chat` | Chat-optimized small model |
Add `-instruct` suffix for instruction-tuned variants (e.g., `smollm2-135m-instruct`).
## Installation
```bash
# Clone the repository
git clone <repository-url>
cd llama-runner
# Build with GPU acceleration (recommended)
cargo build --release --features metal # macOS
cargo build --release --features cuda # Linux/Windows with NVIDIA GPU
# CPU-only build
cargo build --release
```
## Quick Start
```bash
# Fast inference with GPU acceleration
cargo run --features metal -- --prompt "What is quantum computing?"
# Specify a model and parameters
cargo run --features metal -- \
--prompt "Write a short story about space exploration" \
--model smollm2-360m \
--max-tokens 100 \
--temperature 0.8
# Use CPU (slower but works everywhere)
cargo run -- --prompt "Hello, world!" --model smollm2-135m --cpu
```
## Usage Examples
### Basic Text Generation
```bash
# Simple completion
cargo run --features metal -- --prompt "The capital of France is"
# Creative writing with higher temperature
cargo run --features metal -- \
--prompt "Once upon a time" \
--temperature 1.0 \
--max-tokens 200
```
### Advanced Sampling
```bash
# Top-k and top-p sampling
cargo run --features metal -- \
--prompt "Explain artificial intelligence" \
--top-k 40 \
--top-p 0.9 \
--temperature 0.7
# Reduce repetition
cargo run --features metal -- \
--prompt "List the benefits of renewable energy" \
--repeat-penalty 1.2 \
--repeat-last-n 64
```
### Different Models
```bash
# Ultra-fast with tiny model
cargo run --features metal -- \
--prompt "Quick test" \
--model smollm2-135m
# Better quality with larger model
cargo run --features metal -- \
--prompt "Explain quantum physics" \
--model llama-3.2-1b \
--max-tokens 150
```
## Command-Line Options
| Option | Short | Default | Description |
|--------|-------|---------|-------------|
| `--prompt` | `-p` | "The capital of France is" | Input prompt |
| `--model` | `-m` | `smollm2-135m` | Model to use |
| `--max-tokens` | `-n` | 100 | Maximum tokens to generate |
| `--temperature` | `-t` | 0.8 | Sampling temperature (0.0 = deterministic) |
| `--top-k` | | None | Top-k sampling |
| `--top-p` | | None | Top-p (nucleus) sampling |
| `--seed` | | 299792458 | Random seed for reproducibility |
| `--repeat-penalty` | | 1.1 | Repetition penalty (1.0 = no penalty) |
| `--repeat-last-n` | | 128 | Context window for repeat penalty |
| `--cpu` | | false | Force CPU usage |
| `--dtype` | | f16 | Data type: f16, bf16, f32 |
| `--no-kv-cache` | | false | Disable key-value caching |
## Performance
Typical performance on Apple M2 with Metal acceleration:
| Model | Size | Speed | Memory |
|-------|------|-------|--------|
| SmolLM2-135M | 135M | ~100 tok/s | ~500MB |
| SmolLM2-360M | 360M | ~80 tok/s | ~1GB |
| SmolLM2-1.7B | 1.7B | ~50 tok/s | ~3GB |
| Llama-3.2-1B | 1B | ~40 tok/s | ~2GB |
## Requirements
- **Rust**: 1.70+ (latest stable recommended)
- **Memory**: 2-8GB RAM depending on model size
- **Storage**: 1-10GB for model weights
- **Network**: Internet connection for first-time model download
- **GPU** (optional): Metal on macOS, CUDA on Linux/Windows
## GPU Support
### macOS (Metal)
```bash
cargo run --features metal -- [options]
```
### Linux/Windows (CUDA)
```bash
cargo run --features cuda -- [options]
```
### CPU Only
```bash
cargo run -- --cpu [options]
```
## Model Downloads
Models are automatically downloaded from HuggingFace Hub on first use and cached locally. Download times:
- SmolLM2-135M: ~1 minute
- SmolLM2-360M: ~2 minutes
- Llama-3.2-1B: ~5 minutes
- Larger models: 10+ minutes
## Troubleshooting
### Slow Performance
- Use `--features metal` on macOS or `--features cuda` on Linux/Windows
- Try smaller models like `smollm2-135m` for faster inference
- Ensure sufficient RAM for your chosen model
### Out of Memory
- Use `--cpu` to use system RAM instead of GPU memory
- Try smaller models or reduce `--max-tokens`
- Use `--dtype f32` if f16 causes issues
### Model Download Issues
- Check internet connection
- Some models may require HuggingFace Hub authentication
- Verify sufficient disk space in `~/.cache/huggingface/`
## Contributing
Contributions welcome! This project is based on the [Candle](https://github.com/huggingface/candle) framework by HuggingFace.
## License
MIT License - see LICENSE file for details.

View File

@@ -1,7 +0,0 @@
pub mod llama_api;
use clap::ValueEnum;
pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
// Re-export constants and types that might be needed
pub const EOS_TOKEN: &str = "</s>";

View File

@@ -1,333 +0,0 @@
use crate::EOS_TOKEN;
use anyhow::{bail, Error as E};
use candle_core::{utils, DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama as model;
use candle_transformers::models::llama::{Llama, LlamaConfig};
use clap::ValueEnum;
use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType};
use std::sync::mpsc::{self, Receiver};
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
pub enum WhichModel {
#[value(name = "llama-3.2-1b")]
#[default]
Llama32_1B,
#[value(name = "llama-3.2-1b-instruct")]
Llama32_1BInstruct,
#[value(name = "llama-3.2-3b")]
Llama32_3B,
#[value(name = "llama-3.2-3b-instruct")]
Llama32_3BInstruct,
#[value(name = "smollm2-135m")]
SmolLM2_135M,
#[value(name = "smollm2-135m-instruct")]
SmolLM2_135MInstruct,
#[value(name = "smollm2-360m")]
SmolLM2_360M,
#[value(name = "smollm2-360m-instruct")]
SmolLM2_360MInstruct,
#[value(name = "smollm2-1.7b")]
SmolLM2_1_7B,
#[value(name = "smollm2-1.7b-instruct")]
SmolLM2_1_7BInstruct,
#[value(name = "tinyllama-1.1b-chat")]
TinyLlama1_1BChat,
}
#[derive(Debug, Clone)]
pub struct LlamaInferenceConfig {
pub prompt: String,
pub model: WhichModel,
pub cpu: bool,
pub temperature: f64,
pub top_p: Option<f64>,
pub top_k: Option<usize>,
pub seed: u64,
pub max_tokens: usize,
pub no_kv_cache: bool,
pub dtype: Option<String>,
pub model_id: Option<String>,
pub revision: Option<String>,
pub use_flash_attn: bool,
pub repeat_penalty: f32,
pub repeat_last_n: usize,
}
impl Default for LlamaInferenceConfig {
fn default() -> Self {
Self {
// Leave prompt empty by default; let call sites set it.
prompt: String::new(),
// Keep your existing model choice; swap at call-site if needed.
model: WhichModel::Llama32_1BInstruct,
// Prefer GPU if available.
cpu: false,
// Sampling: balanced + stable
temperature: 0.7,
top_p: Some(0.95),
top_k: Some(50),
// Reproducible by default; override for variability.
seed: 42,
// Dont run unbounded generations.
max_tokens: 512,
// Performance flags
no_kv_cache: false, // keep cache ON for speed
use_flash_attn: false, // great speed boost if supported
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
dtype: Some("bf16".to_string()),
// Optional model source pinning (None = app defaults)
model_id: None,
revision: None,
// Anti-repeat heuristics
repeat_penalty: 1.15,
repeat_last_n: 128,
}
}
}
fn device(cpu: bool) -> anyhow::Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if utils::cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if utils::metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
Ok(Device::Cpu)
}
}
fn hub_load_safetensors(
api: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> anyhow::Result<Vec<std::path::PathBuf>> {
let json_file = api.get(json_file)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value = serde_json::from_reader(&json_file)?;
let weight_map = match json.get("weight_map") {
None => bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| api.get(v))
.collect::<anyhow::Result<Vec<_>, _>>()?;
Ok(safetensors_files)
}
pub fn run_llama_inference(
cfg: LlamaInferenceConfig,
) -> anyhow::Result<Receiver<anyhow::Result<String>>, anyhow::Error> {
// ---- Device & dtype -----------------------------------------------------
let device = device(cfg.cpu)?;
println!("Device: {:?}", device);
let dtype = match cfg.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
println!("Using dtype: {:?}", dtype);
// ---- Load model & tokenizer --------------------------------------------
let (llama, tokenizer, mut cache) = {
let api = Api::new()?;
let model_id = cfg.model_id.clone().unwrap_or_else(|| {
match cfg.model {
WhichModel::Llama32_1B => "meta-llama/Llama-3.2-1B",
WhichModel::Llama32_1BInstruct => "meta-llama/Llama-3.2-1B-Instruct",
WhichModel::Llama32_3B => "meta-llama/Llama-3.2-3B",
WhichModel::Llama32_3BInstruct => "meta-llama/Llama-3.2-3B-Instruct",
WhichModel::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
WhichModel::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
WhichModel::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
WhichModel::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
WhichModel::SmolLM2_1_7B => "HuggingFaceTB/SmolLM2-1.7B",
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
.to_string()
});
println!("Loading model: {}", model_id);
let revision = cfg.revision.clone().unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = api.get("tokenizer.json")?;
let config_filename = api.get("config.json")?;
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(cfg.use_flash_attn);
let filenames = match cfg.model {
WhichModel::Llama32_3B | WhichModel::Llama32_3BInstruct => {
hub_load_safetensors(&api, "model.safetensors.index.json")?
}
_ => vec![api.get("model.safetensors")?],
};
let cache = model::Cache::new(!cfg.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let llama = Llama::load(vb, &config)?;
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
(llama, tokenizer, cache)
};
// ---- Prepare prompt & sampler ------------------------------------------
let eos_token_id = tokenizer
.token_to_id(EOS_TOKEN)
.map(model::LlamaEosToks::Single);
let mut tokens = tokenizer
.encode(cfg.prompt.as_str(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
println!("Starting inference...");
let mut logits_processor = {
let temperature = cfg.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (cfg.top_k, cfg.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(cfg.seed, sampling)
};
// Channel for streaming decoded fragments to the caller.
let (tx, rx) = mpsc::channel::<anyhow::Result<String>>();
// ---- Spawn generation thread -------------------------------------------
std::thread::spawn(move || {
let start_gen = std::time::Instant::now();
let mut index_pos = 0usize;
let mut token_generated = 0usize;
for index in 0..cfg.max_tokens {
// Use KV-cache for single-token step after the first pass.
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = match Tensor::new(ctxt, &device).and_then(|t| t.unsqueeze(0)) {
Ok(t) => t,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
let logits = match llama.forward(&input, context_index, &mut cache) {
Ok(l) => l,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
let logits = match logits.squeeze(0) {
Ok(l) => l,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
let logits = if cfg.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(cfg.repeat_last_n);
match candle_transformers::utils::apply_repeat_penalty(
&logits,
cfg.repeat_penalty,
&tokens[start_at..],
) {
Ok(l) => l,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
}
};
index_pos += ctxt.len();
let next_token = match logits_processor.sample(&logits) {
Ok(t) => t,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
token_generated += 1;
tokens.push(next_token);
// Early stop on EOS.
let stop = match eos_token_id {
Some(model::LlamaEosToks::Single(eos_tok_id)) => next_token == eos_tok_id,
Some(model::LlamaEosToks::Multiple(ref eos_ids)) => eos_ids.contains(&next_token),
None => false,
};
if stop {
break;
}
// Decode this token's text and stream it out.
match tokenizer.decode(&[next_token], false) {
Ok(text) => {
if !text.is_empty() {
// Best-effort send; if receiver is gone, just stop.
if tx.send(Ok(text)).is_err() {
break;
}
}
}
Err(e) => {
let _ = tx.send(Err(anyhow::anyhow!("{}", e)));
break;
}
}
}
// Optional: final stats as a debug line (not sent through the stream).
let dt = start_gen.elapsed();
eprintln!(
"[llama-runner] {} tokens generated ({:.2} tokens/s)",
token_generated,
token_generated as f64 / dt.as_secs_f64(),
);
// Dropping tx closes the stream.
});
Ok(rx)
}

View File

@@ -1,108 +0,0 @@
use crate::llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
use clap::Parser;
use std::io::Write;
#[derive(Parser, Debug, Default)]
#[command(author, version, about = "Fast Llama inference with Candle", long_about = None)]
struct Args {
/// The prompt to generate text from
#[arg(short, long, default_value = "The capital of France is")]
prompt: String,
/// The model to use
#[arg(short, long, default_value = "llama-3.2-1b-instruct")]
model: WhichModel,
/// Run on CPU rather than GPU
#[arg(long)]
cpu: bool,
/// The temperature used to generate samples
#[arg(short, long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens)
#[arg(short = 'n', long, default_value_t = 100)]
max_tokens: usize,
/// Disable the key-value cache
#[arg(long)]
no_kv_cache: bool,
/// Use different dtype than f16
#[arg(long)]
dtype: Option<String>,
/// Custom model ID from HuggingFace Hub
#[arg(long)]
model_id: Option<String>,
/// Model revision
#[arg(long)]
revision: Option<String>,
/// Use flash attention
#[arg(long)]
use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty
#[arg(long, default_value_t = 128)]
repeat_last_n: usize,
}
impl Into<LlamaInferenceConfig> for Args {
fn into(self) -> LlamaInferenceConfig {
LlamaInferenceConfig {
prompt: self.prompt,
model: self.model,
cpu: self.cpu,
temperature: self.temperature,
top_p: self.top_p,
top_k: self.top_k,
seed: self.seed,
max_tokens: self.max_tokens,
no_kv_cache: self.no_kv_cache,
dtype: self.dtype,
model_id: self.model_id,
revision: self.revision,
use_flash_attn: self.use_flash_attn,
repeat_penalty: self.repeat_penalty,
repeat_last_n: self.repeat_last_n,
}
}
}
pub fn run_cli() -> anyhow::Result<()> {
let args = Args::parse();
let cfg = args.into();
let rx = run_llama_inference(cfg)?;
for msg in rx {
match msg {
Ok(tok) => {
print!("{tok}");
let _ = std::io::stdout().flush(); // <- force it out now
}
Err(e) => {
eprintln!("generation error: {e}");
break;
}
}
}
Ok(())
}

View File

@@ -1,19 +0,0 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod llama_api;
mod llama_cli;
use anyhow::Result;
use clap::{Parser, ValueEnum};
use std::io::Write;
use crate::llama_cli::run_cli;
const EOS_TOKEN: &str = "</s>";
fn main() -> Result<()> {
run_cli()
}

View File

@@ -144,6 +144,7 @@ async fn main() {
tracing::info!("Available endpoints:");
tracing::info!(" GET / - Leptos chat web application");
tracing::info!(" GET /health - Health check");
tracing::info!(" POST /v1/models - List Models");
tracing::info!(" POST /v1/embeddings - Text embeddings API");
tracing::info!(" POST /v1/chat/completions - Chat completions API");

View File

@@ -1,88 +0,0 @@
[package]
name = "utils"
[lib]
path = "src/lib.rs"
[dependencies]
accelerate-src = {version = "0.3.2", optional = true }
candle-nn = {version = "0.9.1" }
candle-transformers = {version = "0.9.1" }
candle-flash-attn = {version = "0.9.1", optional = true }
candle-onnx = {version = "0.9.1", optional = true }
candle-core="0.9.1"
csv = "1.3.0"
anyhow = "1.0.99"
cudarc = {version = "0.17.3", optional = true }
half = {version = "2.6.0", optional = true }
hf-hub = {version = "0.4.3", features = ["tokio"] }
image = {version = "0.25.6" }
intel-mkl-src = {version = "0.8.1", optional = true }
num-traits = {version = "0.2.19" }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true }
pyo3 = { version = "0.22.0", features = [
"auto-initialize",
"abi3-py311",
], optional = true }
rayon = {version = "1.11.0" }
rubato = { version = "0.15.0", optional = true }
safetensors = {version = "0.6.2" }
serde = {version = "1.0.219" }
serde_json = {version = "1.0.143" }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = {version = "0.22.0", features = ["onig"] }
cpal = { version = "0.15.2", optional = true }
pdf2image = { version = "0.1.2", optional = true }
tekken-rs = { version = "0.1.1", optional = true }
[dev-dependencies]
anyhow = {version = "1.0.99" }
byteorder = {version = "1.5.0" }
clap = {version = "4.5.46" }
imageproc = {version = "0.25.0" }
memmap2 = {version = "0.9.8" }
rand = {version = "0.9.2" }
ab_glyph = {version = "0.2.31" }
tracing = {version = "0.1.41" }
tracing-chrome = {version = "0.7.2" }
tracing-subscriber = {version = "0.3.20" }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.43.0"
[build-dependencies]
anyhow = {version = "1.0.99" }
bindgen_cuda = { version = "0.1.1", optional = true }
#
[features]
default = []
accelerate = [
"dep:accelerate-src",
"candle-core/accelerate",
"candle-nn/accelerate",
"candle-transformers/accelerate",
]
cuda = [
"candle-core/cuda",
"candle-nn/cuda",
"candle-transformers/cuda",
"dep:bindgen_cuda",
]
cudnn = ["candle-core/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = [
"dep:intel-mkl-src",
"candle-core/mkl",
"candle-nn/mkl",
"candle-transformers/mkl",
]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle-core/metal", "candle-nn/metal"]
microphone = ["cpal", "rubato"]
encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
snac = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
tekken = ["tekken-rs"]

View File

@@ -1,138 +0,0 @@
use candle_core::{Result, Tensor};
// https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/data/audio_utils.py#L57
pub fn normalize_loudness(
wav: &Tensor,
sample_rate: u32,
loudness_compressor: bool,
) -> Result<Tensor> {
let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::<f32>()?;
if energy < 2e-3 {
return Ok(wav.clone());
}
let wav_array = wav.to_vec1::<f32>()?;
let mut meter = crate::bs1770::ChannelLoudnessMeter::new(sample_rate);
meter.push(wav_array.into_iter());
let power = meter.as_100ms_windows();
let loudness = match crate::bs1770::gated_mean(power) {
None => return Ok(wav.clone()),
Some(gp) => gp.loudness_lkfs() as f64,
};
let delta_loudness = -14. - loudness;
let gain = 10f64.powf(delta_loudness / 20.);
let wav = (wav * gain)?;
if loudness_compressor {
wav.tanh()
} else {
Ok(wav)
}
}
#[cfg(feature = "symphonia")]
pub fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
use symphonia::core::conv::FromSample;
fn conv<T>(
samples: &mut Vec<f32>,
data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>,
) where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
// Open the media source.
let src = std::fs::File::open(path).map_err(candle::Error::wrap)?;
// Create the media source stream.
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
// Create a probe hint using the file's extension. [Optional]
let hint = symphonia::core::probe::Hint::new();
// Use the default options for metadata and format readers.
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
// Probe the media source.
let probed = symphonia::default::get_probe()
.format(&hint, mss, &fmt_opts, &meta_opts)
.map_err(candle::Error::wrap)?;
// Get the instantiated format reader.
let mut format = probed.format;
// Find the first audio track with a known (decodeable) codec.
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
.ok_or_else(|| candle::Error::Msg("no supported audio tracks".to_string()))?;
// Use the default options for the decoder.
let dec_opts: DecoderOptions = Default::default();
// Create a decoder for the track.
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &dec_opts)
.map_err(|_| candle::Error::Msg("unsupported codec".to_string()))?;
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
// The decode loop.
while let Ok(packet) = format.next_packet() {
// Consume any new metadata that has been read since the last packet.
while !format.metadata().is_latest() {
format.metadata().pop();
}
// If the packet does not belong to the selected track, skip over it.
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet).map_err(candle::Error::wrap)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
#[cfg(feature = "rubato")]
pub fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
use rubato::Resampler;
let mut pcm_out =
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)
.map_err(candle::Error::wrap)?;
let mut output_buffer = resampler.output_buffer_allocate(true);
let mut pos_in = 0;
while pos_in + resampler.input_frames_next() < pcm_in.len() {
let (in_len, out_len) = resampler
.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)
.map_err(candle::Error::wrap)?;
pos_in += in_len;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
if pos_in < pcm_in.len() {
let (_in_len, out_len) = resampler
.process_partial_into_buffer(Some(&[&pcm_in[pos_in..]]), &mut output_buffer, None)
.map_err(candle::Error::wrap)?;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
Ok(pcm_out)
}

View File

@@ -1,506 +0,0 @@
// Copied from https://github.com/ruuda/bs1770/blob/master/src/lib.rs
// BS1770 -- Loudness analysis library conforming to ITU-R BS.1770
// Copyright 2020 Ruud van Asseldonk
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// A copy of the License has been included in the root of the repository.
//! Loudness analysis conforming to [ITU-R BS.1770-4][bs17704].
//!
//! This library offers the building blocks to perform BS.1770 loudness
//! measurements, but you need to put the pieces together yourself.
//!
//! [bs17704]: https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
//!
//! # Stereo integrated loudness example
//!
//! ```ignore
//! # fn load_stereo_audio() -> [Vec<i16>; 2] {
//! # [vec![0; 48_000], vec![0; 48_000]]
//! # }
//! #
//! let sample_rate_hz = 44_100;
//! let bits_per_sample = 16;
//! let channel_samples: [Vec<i16>; 2] = load_stereo_audio();
//!
//! // When converting integer samples to float, note that the maximum amplitude
//! // is `1 << (bits_per_sample - 1)`, one bit is the sign bit.
//! let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
//!
//! let channel_power: Vec<_> = channel_samples.iter().map(|samples| {
//! let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate_hz);
//! meter.push(samples.iter().map(|&s| s as f32 * normalizer));
//! meter.into_100ms_windows()
//! }).collect();
//!
//! let stereo_power = bs1770::reduce_stereo(
//! channel_power[0].as_ref(),
//! channel_power[1].as_ref(),
//! );
//!
//! let gated_power = bs1770::gated_mean(
//! stereo_power.as_ref()
//! ).unwrap_or(bs1770::Power(0.0));
//! println!("Integrated loudness: {:.1} LUFS", gated_power.loudness_lkfs());
//! ```
use std::f32;
/// Coefficients for a 2nd-degree infinite impulse response filter.
///
/// Coefficient a0 is implicitly 1.0.
#[derive(Clone)]
struct Filter {
a1: f32,
a2: f32,
b0: f32,
b1: f32,
b2: f32,
// The past two input and output samples.
x1: f32,
x2: f32,
y1: f32,
y2: f32,
}
impl Filter {
/// Stage 1 of th BS.1770-4 pre-filter.
pub fn high_shelf(sample_rate_hz: f32) -> Filter {
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
let gain_db = 3.999_843_8;
let q = 0.707_175_25;
let center_hz = 1_681.974_5;
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L134-L143.
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
let vh = 10.0_f32.powf(gain_db / 20.0);
let vb = vh.powf(0.499_666_78);
let a0 = 1.0 + k / q + k * k;
Filter {
b0: (vh + vb * k / q + k * k) / a0,
b1: 2.0 * (k * k - vh) / a0,
b2: (vh - vb * k / q + k * k) / a0,
a1: 2.0 * (k * k - 1.0) / a0,
a2: (1.0 - k / q + k * k) / a0,
x1: 0.0,
x2: 0.0,
y1: 0.0,
y2: 0.0,
}
}
/// Stage 2 of th BS.1770-4 pre-filter.
pub fn high_pass(sample_rate_hz: f32) -> Filter {
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
let q = 0.500_327_05;
let center_hz = 38.135_47;
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L145-L151
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
Filter {
a1: 2.0 * (k * k - 1.0) / (1.0 + k / q + k * k),
a2: (1.0 - k / q + k * k) / (1.0 + k / q + k * k),
b0: 1.0,
b1: -2.0,
b2: 1.0,
x1: 0.0,
x2: 0.0,
y1: 0.0,
y2: 0.0,
}
}
/// Feed the next input sample, get the next output sample.
#[inline(always)]
pub fn apply(&mut self, x0: f32) -> f32 {
let y0 = 0.0 + self.b0 * x0 + self.b1 * self.x1 + self.b2 * self.x2
- self.a1 * self.y1
- self.a2 * self.y2;
self.x2 = self.x1;
self.x1 = x0;
self.y2 = self.y1;
self.y1 = y0;
y0
}
}
/// Compensated sum, for summing many values of different orders of magnitude
/// accurately.
#[derive(Copy, Clone, PartialEq)]
struct Sum {
sum: f32,
residue: f32,
}
impl Sum {
#[inline(always)]
fn zero() -> Sum {
Sum {
sum: 0.0,
residue: 0.0,
}
}
#[inline(always)]
fn add(&mut self, x: f32) {
let sum = self.sum + (self.residue + x);
self.residue = (self.residue + x) - (sum - self.sum);
self.sum = sum;
}
}
/// The mean of the squares of the K-weighted samples in a window of time.
///
/// K-weighted power is equivalent to K-weighted loudness, the only difference
/// is one of scale: power is quadratic in sample amplitudes, whereas loudness
/// units are logarithmic. `loudness_lkfs` and `from_lkfs` convert between power,
/// and K-weighted Loudness Units relative to nominal Full Scale (LKFS).
///
/// The term “LKFS” (Loudness Units, K-Weighted, relative to nominal Full Scale)
/// is used in BS.1770-4 to emphasize K-weighting, but the term is otherwise
/// interchangeable with the more widespread term “LUFS” (Loudness Units,
/// relative to Full Scale). Loudness units are related to decibels in the
/// following sense: boosting a signal that has a loudness of
/// -<var>L<sub>K</sub></var> LUFS by <var>L<sub>K</sub></var> dB (by
/// multiplying the amplitude by 10<sup><var>L<sub>K</sub></var>/20</sup>) will
/// bring the loudness to 0 LUFS.
///
/// K-weighting refers to a high-shelf and high-pass filter that model the
/// effect that humans perceive a certain amount of power in low frequencies to
/// be less loud than the same amount of power in higher frequencies. In this
/// library the `Power` type is used exclusively to refer to power after applying K-weighting.
///
/// The nominal “full scale” is the range [-1.0, 1.0]. Because the power is the
/// mean square of the samples, if no input samples exceeded the full scale, the
/// power will be in the range [0.0, 1.0]. However, the power delivered by
/// multiple channels, which is a weighted sum over individual channel powers,
/// can exceed this range, because the weighted sum is not normalized.
#[derive(Copy, Clone, PartialEq, PartialOrd)]
pub struct Power(pub f32);
impl Power {
/// Convert Loudness Units relative to Full Scale into a squared sample amplitude.
///
/// This is the inverse of `loudness_lkfs`.
pub fn from_lkfs(lkfs: f32) -> Power {
// The inverse of the formula below.
Power(10.0_f32.powf((lkfs + 0.691) * 0.1))
}
/// Return the loudness of this window in Loudness Units, K-weighted, relative to Full Scale.
///
/// This is the inverse of `from_lkfs`.
pub fn loudness_lkfs(&self) -> f32 {
// Equation 2 (p.5) of BS.1770-4.
-0.691 + 10.0 * self.0.log10()
}
}
/// A `T` value for non-overlapping windows of audio, 100ms in length.
///
/// The `ChannelLoudnessMeter` applies K-weighting and then produces the power
/// for non-overlapping windows of 100ms duration.
///
/// These non-overlapping 100ms windows can later be combined into overlapping
/// windows of 400ms, spaced 100ms apart, to compute instantaneous loudness or
/// to perform a gated measurement, or they can be combined into even larger
/// windows for a momentary loudness measurement.
#[derive(Copy, Clone, Debug)]
pub struct Windows100ms<T> {
pub inner: T,
}
impl<T> Windows100ms<T> {
/// Wrap a new empty vector.
pub fn new() -> Windows100ms<Vec<T>> {
Windows100ms { inner: Vec::new() }
}
/// Apply `as_ref` to the inner value.
pub fn as_ref(&self) -> Windows100ms<&[Power]>
where
T: AsRef<[Power]>,
{
Windows100ms {
inner: self.inner.as_ref(),
}
}
/// Apply `as_mut` to the inner value.
pub fn as_mut(&mut self) -> Windows100ms<&mut [Power]>
where
T: AsMut<[Power]>,
{
Windows100ms {
inner: self.inner.as_mut(),
}
}
#[allow(clippy::len_without_is_empty)]
/// Apply `len` to the inner value.
pub fn len(&self) -> usize
where
T: AsRef<[Power]>,
{
self.inner.as_ref().len()
}
}
/// Measures K-weighted power of non-overlapping 100ms windows of a single channel of audio.
///
/// # Output
///
/// The output of the meter is an intermediate result in the form of power for
/// 100ms non-overlapping windows. The windows need to be processed further to
/// get one of the instantaneous, momentary, and integrated loudness
/// measurements defined in BS.1770.
///
/// The windows can also be inspected directly; the data is meaningful
/// on its own (the K-weighted power delivered in that window of time), but it
/// is not something that BS.1770 defines a term for.
///
/// # Multichannel audio
///
/// To perform a loudness measurement of multichannel audio, construct a
/// `ChannelLoudnessMeter` per channel, and later combine the measured power
/// with e.g. `reduce_stereo`.
///
/// # Instantaneous loudness
///
/// The instantaneous loudness is the power over a 400ms window, so you can
/// average four 100ms windows. No special functionality is implemented to help
/// with that at this time. ([Pull requests would be accepted.][contribute])
///
/// # Momentary loudness
///
/// The momentary loudness is the power over a 3-second window, so you can
/// average thirty 100ms windows. No special functionality is implemented to
/// help with that at this time. ([Pull requests would be accepted.][contribute])
///
/// # Integrated loudness
///
/// Use `gated_mean` to perform an integrated loudness measurement:
///
/// ```ignore
/// # use std::iter;
/// # use bs1770::{ChannelLoudnessMeter, gated_mean};
/// # let sample_rate_hz = 44_100;
/// # let samples_per_100ms = sample_rate_hz / 10;
/// # let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
/// # meter.push((0..44_100).map(|i| (i as f32 * 0.01).sin()));
/// let integrated_loudness_lkfs = gated_mean(meter.as_100ms_windows())
/// .unwrap_or(bs1770::Power(0.0))
/// .loudness_lkfs();
/// ```
///
/// [contribute]: https://github.com/ruuda/bs1770/blob/master/CONTRIBUTING.md
#[derive(Clone)]
pub struct ChannelLoudnessMeter {
/// The number of samples that fit in 100ms of audio.
samples_per_100ms: u32,
/// Stage 1 filter (head effects, high shelf).
filter_stage1: Filter,
/// Stage 2 filter (high-pass).
filter_stage2: Filter,
/// Sum of the squares over non-overlapping windows of 100ms.
windows: Windows100ms<Vec<Power>>,
/// The number of samples in the current unfinished window.
count: u32,
/// The sum of the squares of the samples in the current unfinished window.
square_sum: Sum,
}
impl ChannelLoudnessMeter {
/// Construct a new loudness meter for the given sample rate.
pub fn new(sample_rate_hz: u32) -> ChannelLoudnessMeter {
ChannelLoudnessMeter {
samples_per_100ms: sample_rate_hz / 10,
filter_stage1: Filter::high_shelf(sample_rate_hz as f32),
filter_stage2: Filter::high_pass(sample_rate_hz as f32),
windows: Windows100ms::new(),
count: 0,
square_sum: Sum::zero(),
}
}
/// Feed input samples for loudness analysis.
///
/// # Full scale
///
/// Full scale for the input samples is the interval [-1.0, 1.0]. If your
/// input consists of signed integer samples, you can convert as follows:
///
/// ```ignore
/// # let mut meter = bs1770::ChannelLoudnessMeter::new(44_100);
/// # let bits_per_sample = 16_usize;
/// # let samples = &[0_i16];
/// // Note that the maximum amplitude is `1 << (bits_per_sample - 1)`,
/// // one bit is the sign bit.
/// let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
/// meter.push(samples.iter().map(|&s| s as f32 * normalizer));
/// ```
///
/// # Repeated calls
///
/// You can call `push` multiple times to feed multiple batches of samples.
/// This is equivalent to feeding a single chained iterator. The leftover of
/// samples that did not fill a full 100ms window is not discarded:
///
/// ```ignore
/// # use std::iter;
/// # use bs1770::ChannelLoudnessMeter;
/// let sample_rate_hz = 44_100;
/// let samples_per_100ms = sample_rate_hz / 10;
/// let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
///
/// meter.push(iter::repeat(0.0).take(samples_per_100ms as usize - 1));
/// assert_eq!(meter.as_100ms_windows().len(), 0);
///
/// meter.push(iter::once(0.0));
/// assert_eq!(meter.as_100ms_windows().len(), 1);
/// ```
pub fn push<I: Iterator<Item = f32>>(&mut self, samples: I) {
let normalizer = 1.0 / self.samples_per_100ms as f32;
// LLVM, if you could go ahead and inline those apply calls, and then
// unroll and vectorize the loop, that'd be terrific.
for x in samples {
let y = self.filter_stage1.apply(x);
let z = self.filter_stage2.apply(y);
self.square_sum.add(z * z);
self.count += 1;
// TODO: Should this branch be marked cold?
if self.count == self.samples_per_100ms {
let mean_squares = Power(self.square_sum.sum * normalizer);
self.windows.inner.push(mean_squares);
// We intentionally do not reset the residue. That way, leftover
// energy from this window is not lost, so for the file overall,
// the sum remains more accurate.
self.square_sum.sum = 0.0;
self.count = 0;
}
}
}
/// Return a reference to the 100ms windows analyzed so far.
pub fn as_100ms_windows(&self) -> Windows100ms<&[Power]> {
self.windows.as_ref()
}
/// Return all 100ms windows analyzed so far.
pub fn into_100ms_windows(self) -> Windows100ms<Vec<Power>> {
self.windows
}
}
/// Combine power for multiple channels by taking a weighted sum.
///
/// Note that BS.1770-4 defines power for a multi-channel signal as a weighted
/// sum over channels which is not normalized. This means that a stereo signal
/// is inherently louder than a mono signal. For a mono signal played back on
/// stereo speakers, you should therefore still apply `reduce_stereo`, passing
/// in the same signal for both channels.
pub fn reduce_stereo(
left: Windows100ms<&[Power]>,
right: Windows100ms<&[Power]>,
) -> Windows100ms<Vec<Power>> {
assert_eq!(
left.len(),
right.len(),
"Channels must have the same length."
);
let mut result = Vec::with_capacity(left.len());
for (l, r) in left.inner.iter().zip(right.inner) {
result.push(Power(l.0 + r.0));
}
Windows100ms { inner: result }
}
/// In-place version of `reduce_stereo` that stores the result in the former left channel.
pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows100ms<&[Power]>) {
assert_eq!(
left.len(),
right.len(),
"Channels must have the same length."
);
for (l, r) in left.inner.iter_mut().zip(right.inner) {
l.0 += r.0;
}
}
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
///
/// The integrated loudness measurement is not just the average power over the
/// entire signal. BS.1770-4 defines two stages of gating that exclude
/// parts of the signal, to ensure that silent parts do not contribute to the
/// loudness measurement. This function performs that gating, and returns the
/// average power over the windows that were not excluded.
///
/// The result of this function is the integrated loudness measurement.
///
/// When no signal remains after applying the gate, this function returns
/// `None`. In particular, this happens when all of the signal is softer than
/// -70 LKFS, including a signal that consists of pure silence.
pub fn gated_mean(windows_100ms: Windows100ms<&[Power]>) -> Option<Power> {
let mut gating_blocks = Vec::with_capacity(windows_100ms.len());
// Stage 1: an absolute threshold of -70 LKFS. (Equation 6, p.6.)
let absolute_threshold = Power::from_lkfs(-70.0);
// Iterate over all 400ms windows.
for window in windows_100ms.inner.windows(4) {
// Note that the sum over channels has already been performed at this point.
let gating_block_power = Power(0.25 * window.iter().map(|mean| mean.0).sum::<f32>());
if gating_block_power > absolute_threshold {
gating_blocks.push(gating_block_power);
}
}
if gating_blocks.is_empty() {
return None;
}
// Compute the loudness after applying the absolute gate, in order to
// determine the threshold for the relative gate.
let mut sum_power = Sum::zero();
for &gating_block_power in &gating_blocks {
sum_power.add(gating_block_power.0);
}
let absolute_gated_power = Power(sum_power.sum / (gating_blocks.len() as f32));
// Stage 2: Apply the relative gate.
let relative_threshold = Power::from_lkfs(absolute_gated_power.loudness_lkfs() - 10.0);
let mut sum_power = Sum::zero();
let mut n_blocks = 0_usize;
for &gating_block_power in &gating_blocks {
if gating_block_power > relative_threshold {
sum_power.add(gating_block_power.0);
n_blocks += 1;
}
}
if n_blocks == 0 {
return None;
}
let relative_gated_power = Power(sum_power.sum / n_blocks as f32);
Some(relative_gated_power)
}

View File

@@ -1,82 +0,0 @@
pub const NAMES: [&str; 80] = [
"person",
"bicycle",
"car",
"motorbike",
"aeroplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"sofa",
"pottedplant",
"bed",
"diningtable",
"toilet",
"tvmonitor",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
];

File diff suppressed because it is too large Load Diff

View File

@@ -1,156 +0,0 @@
extern crate candle_core;
extern crate candle_transformers;
extern crate tokenizers;
pub mod audio;
pub mod bs1770;
pub mod coco_classes;
pub mod imagenet;
pub mod token_output_stream;
pub mod wav;
use candle_core::{Device, Tensor, utils::{cuda_is_available, metal_is_available}};
pub fn device(cpu: bool) -> Result<Device, anyhow::Error> {
if cpu {
Ok(Device::Cpu)
} else if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!(
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
Ok(Device::Cpu)
}
}
pub fn load_image<P: AsRef<std::path::Path>>(
p: P,
resize_longest: Option<usize>,
) -> Result<(Tensor, usize, usize), anyhow::Error> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle_core::Error::wrap)?;
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
let img = match resize_longest {
None => img,
Some(resize_longest) => {
let (height, width) = (img.height(), img.width());
let resize_longest = resize_longest as u32;
let (height, width) = if height < width {
let h = (resize_longest * height) / width;
(h, resize_longest)
} else {
let w = (resize_longest * width) / height;
(resize_longest, w)
};
img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
}
};
let (height, width) = (img.height() as usize, img.width() as usize);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
Ok((data, initial_h, initial_w))
}
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
p: P,
width: usize,
height: usize,
) -> candle_core::Result<Tensor> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle_core::Error::wrap)?
.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let data = img.into_raw();
Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1))
}
/// Saves an image to disk using the image crate, this expects an input with shape
/// (c, height, width).
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<(), anyhow::Error> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
anyhow::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => anyhow::bail!("error saving image {p:?}"),
};
image.save(p).map_err(candle_core::Error::wrap)?;
Ok(())
}
/// Loads the safetensors files for a model from the hub based on a json index file.
pub fn hub_load_safetensors(
repo: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => anyhow::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| {
repo.get(v)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
})
.collect::<Result<Vec<_>, std::io::Error, >>()?;
Ok(safetensors_files)
}
pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
path: P,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
let path = path.as_ref();
let jsfile = std::fs::File::open(path.join(json_file))?;
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => anyhow::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file);
}
}
let safetensors_files: Vec<_> = safetensors_files
.into_iter()
.map(|v| path.join(v))
.collect();
Ok(safetensors_files)
}

View File

@@ -1,3 +0,0 @@
fn main() {
println!("Hello, world!");
}

View File

@@ -1,85 +0,0 @@
use candle_core::Result;
use tokenizers::Tokenizer;
pub struct TokenOutputStream {
tokenizer: tokenizers::Tokenizer,
tokens: Vec<u32>,
prev_index: usize,
current_index: usize,
}
impl TokenOutputStream {
pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
Self {
tokenizer,
tokens: Vec::new(),
prev_index: 0,
current_index: 0,
}
}
pub fn into_inner(self) -> tokenizers::Tokenizer {
self.tokenizer
}
fn decode(&self, tokens: &[u32]) -> Result<String> {
match self.tokenizer.decode(tokens, true) {
Ok(str) => Ok(str),
Err(err) => candle_core::bail!("cannot decode: {err}"),
}
}
// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
let prev_text = if self.tokens.is_empty() {
String::new()
} else {
let tokens = &self.tokens[self.prev_index..self.current_index];
self.decode(tokens)?
};
self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
let text = text.split_at(prev_text.len());
self.prev_index = self.current_index;
self.current_index = self.tokens.len();
Ok(Some(text.1.to_string()))
} else {
Ok(None)
}
}
pub fn decode_rest(&self) -> Result<Option<String>> {
let prev_text = if self.tokens.is_empty() {
String::new()
} else {
let tokens = &self.tokens[self.prev_index..self.current_index];
self.decode(tokens)?
};
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() {
let text = text.split_at(prev_text.len());
Ok(Some(text.1.to_string()))
} else {
Ok(None)
}
}
pub fn decode_all(&self) -> Result<String> {
self.decode(&self.tokens)
}
pub fn get_token(&self, token_s: &str) -> Option<u32> {
self.tokenizer.get_vocab(true).get(token_s).copied()
}
pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
&self.tokenizer
}
pub fn clear(&mut self) {
self.tokens.clear();
self.prev_index = 0;
self.current_index = 0;
}
}

View File

@@ -1,56 +0,0 @@
use std::io::prelude::*;
pub trait Sample {
fn to_i16(&self) -> i16;
}
impl Sample for f32 {
fn to_i16(&self) -> i16 {
(self.clamp(-1.0, 1.0) * 32767.0) as i16
}
}
impl Sample for f64 {
fn to_i16(&self) -> i16 {
(self.clamp(-1.0, 1.0) * 32767.0) as i16
}
}
impl Sample for i16 {
fn to_i16(&self) -> i16 {
*self
}
}
pub fn write_pcm_as_wav<W: Write, S: Sample>(
w: &mut W,
samples: &[S],
sample_rate: u32,
) -> std::io::Result<()> {
let len = 12u32; // header
let len = len + 24u32; // fmt
let len = len + samples.len() as u32 * 2 + 8; // data
let n_channels = 1u16;
let bytes_per_second = sample_rate * 2 * n_channels as u32;
w.write_all(b"RIFF")?;
w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
w.write_all(b"WAVE")?;
// Format block
w.write_all(b"fmt ")?;
w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes
w.write_all(&1u16.to_le_bytes())?; // PCM
w.write_all(&n_channels.to_le_bytes())?; // one channel
w.write_all(&sample_rate.to_le_bytes())?;
w.write_all(&bytes_per_second.to_le_bytes())?;
w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample
w.write_all(&16u16.to_le_bytes())?; // bits per sample
// Data block
w.write_all(b"data")?;
w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
for sample in samples.iter() {
w.write_all(&sample.to_i16().to_le_bytes())?
}
Ok(())
}