mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
reorg + update docs with new paths
This commit is contained in:
11
integration/cli/Cargo.toml
Normal file
11
integration/cli/Cargo.toml
Normal file
@@ -0,0 +1,11 @@
|
||||
[package]
|
||||
name = "cli"
|
||||
version.workspace = true
|
||||
edition = "2021"
|
||||
build = "build.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "cli"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
24
integration/cli/README.md
Normal file
24
integration/cli/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# cli
|
||||
|
||||
A Rust/Typescript Hybrid
|
||||
|
||||
```console
|
||||
bun run cli.ts [options] [prompt]
|
||||
|
||||
Simple CLI tool for testing the local OpenAI-compatible API server.
|
||||
|
||||
Options:
|
||||
--model <model> Model to use (default: gemma-3-1b-it)
|
||||
--prompt <prompt> The prompt to send (can also be provided as positional argument)
|
||||
--list-models List all available models from the server
|
||||
--help Show this help message
|
||||
|
||||
Examples:
|
||||
cd integration/cli/package
|
||||
bun run cli.ts "What is the capital of France?"
|
||||
bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||
bun run cli.ts --prompt "Who was the 16th president of the United States?"
|
||||
bun run cli.ts --list-models
|
||||
|
||||
The server must be running at http://localhost:8080
|
||||
```
|
204
integration/cli/build.rs
Normal file
204
integration/cli/build.rs
Normal file
@@ -0,0 +1,204 @@
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::io::{self, BufRead, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::{ChildStderr, ChildStdout, Command, Stdio};
|
||||
use std::thread;
|
||||
use std::time::{Duration, SystemTime};
|
||||
mod bun_target;
|
||||
use bun_target::BunTarget;
|
||||
|
||||
fn main() {
|
||||
println!("cargo:rerun-if-changed=");
|
||||
|
||||
if let Err(e) = run_build() {
|
||||
println!("cargo:warning=build.rs failed: {e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_build() -> io::Result<()> {
|
||||
let manifest_dir =
|
||||
PathBuf::from(env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"));
|
||||
let package_dir = manifest_dir.join("package");
|
||||
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set by Cargo"));
|
||||
let output_path = out_dir.join("client-cli");
|
||||
|
||||
let bun_tgt = BunTarget::from_cargo_env()
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
|
||||
|
||||
// Optional: warn if using a Bun target that’s marked unsupported in your chart
|
||||
if matches!(bun_tgt, BunTarget::WindowsArm64) {
|
||||
println!(
|
||||
"cargo:warning=bun-windows-arm64 is marked unsupported in the compatibility chart"
|
||||
);
|
||||
}
|
||||
|
||||
warn(&format!("Building CLI into: {}", output_path.display()));
|
||||
|
||||
// --- bun install (in ./package), keep temps inside OUT_DIR ---
|
||||
let mut install = Command::new("bun")
|
||||
.current_dir(&package_dir)
|
||||
.env("TMPDIR", &out_dir)
|
||||
.arg("install")
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| io::Error::new(e.kind(), format!("Failed to spawn `bun install`: {e}")))?;
|
||||
|
||||
let install_join = stream_child("bun install", install.stdout.take(), install.stderr.take());
|
||||
let install_status = install.wait()?;
|
||||
// ensure streams finish
|
||||
join_streams(install_join);
|
||||
|
||||
if !install_status.success() {
|
||||
let code = install_status.code().unwrap_or(1);
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("bun install failed with status {code}"),
|
||||
));
|
||||
}
|
||||
|
||||
let target = env::var("TARGET").unwrap();
|
||||
|
||||
// --- bun build (in ./package), emit to OUT_DIR, keep temps inside OUT_DIR ---
|
||||
let mut build = Command::new("bun")
|
||||
.current_dir(&package_dir)
|
||||
.env("TMPDIR", &out_dir)
|
||||
.arg("build")
|
||||
.arg("./cli.ts")
|
||||
.arg(format!("--target={}", bun_tgt.as_bun_flag()))
|
||||
.arg("--compile")
|
||||
.arg("--outfile")
|
||||
.arg(&output_path)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| io::Error::new(e.kind(), format!("Failed to spawn `bun build`: {e}")))?;
|
||||
|
||||
let build_join = stream_child("bun build", build.stdout.take(), build.stderr.take());
|
||||
let status = build.wait()?;
|
||||
// ensure streams finish
|
||||
join_streams(build_join);
|
||||
|
||||
if status.success() {
|
||||
info("bun build succeeded");
|
||||
} else {
|
||||
let code = status.code().unwrap_or(1);
|
||||
warn(&format!("bun build failed with status: {code}"));
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "bun build failed"));
|
||||
}
|
||||
|
||||
// Ensure the output is executable (after it exists)
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let mut perms = fs::metadata(&output_path)?.permissions();
|
||||
perms.set_mode(0o755);
|
||||
fs::set_permissions(&output_path, perms)?;
|
||||
}
|
||||
|
||||
println!("cargo:warning=Built CLI at {}", output_path.display());
|
||||
println!("cargo:rustc-env=CLIENT_CLI_BIN={}", output_path.display());
|
||||
|
||||
// --- Cleanup stray .bun-build temp files (conservative: older than 5 minutes) ---
|
||||
for dir in [&manifest_dir, &package_dir, &out_dir] {
|
||||
if let Err(e) = remove_bun_temp_files(dir, Some(Duration::from_secs(5 * 60))) {
|
||||
println!("cargo:warning=cleanup in {} failed: {e}", dir.display());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Spawn readers for child's stdout/stderr so we don't deadlock on pipe buffers
|
||||
fn stream_child(
|
||||
tag: &str,
|
||||
stdout: Option<ChildStdout>,
|
||||
stderr: Option<ChildStderr>,
|
||||
) -> (
|
||||
Option<thread::JoinHandle<()>>,
|
||||
Option<thread::JoinHandle<()>>,
|
||||
) {
|
||||
let t1 = stdout.map(|out| {
|
||||
let tag = tag.to_string();
|
||||
thread::spawn(move || {
|
||||
let reader = io::BufReader::new(out);
|
||||
for line in reader.lines() {
|
||||
info(&format!("[{tag} stdout] {}", line.unwrap_or_default()));
|
||||
}
|
||||
})
|
||||
});
|
||||
let t2 = stderr.map(|err| {
|
||||
let tag = tag.to_string();
|
||||
thread::spawn(move || {
|
||||
let reader = io::BufReader::new(err);
|
||||
for line in reader.lines() {
|
||||
warn(&format!("[{tag} stderr] {}", line.unwrap_or_default()));
|
||||
}
|
||||
})
|
||||
});
|
||||
(t1, t2)
|
||||
}
|
||||
|
||||
fn join_streams(
|
||||
joins: (
|
||||
Option<thread::JoinHandle<()>>,
|
||||
Option<thread::JoinHandle<()>>,
|
||||
),
|
||||
) {
|
||||
if let Some(j) = joins.0 {
|
||||
let _ = j.join();
|
||||
}
|
||||
if let Some(j) = joins.1 {
|
||||
let _ = j.join();
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_bun_temp_files(dir: &Path, older_than: Option<Duration>) -> io::Result<()> {
|
||||
let now = SystemTime::now();
|
||||
for entry in fs::read_dir(dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if !path.is_file() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Files like ".1860e7df40ff1bef-00000000.bun-build"
|
||||
let name = entry.file_name();
|
||||
let name = name.to_string_lossy();
|
||||
let looks_like_bun_temp = name.starts_with('.') && name.ends_with(".bun-build");
|
||||
|
||||
if !looks_like_bun_temp {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(age) = older_than {
|
||||
if let Ok(meta) = entry.metadata() {
|
||||
if let Ok(modified) = meta.modified() {
|
||||
if now.duration_since(modified).unwrap_or_default() < age {
|
||||
// too new; skip to avoid racing an in-flight builder
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match fs::remove_file(&path) {
|
||||
Ok(_) => println!("cargo:warning=removed stray bun temp {}", path.display()),
|
||||
Err(e) => println!("cargo:warning=failed to remove {}: {e}", path.display()),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn warn(msg: &str) {
|
||||
let _ = writeln!(io::stderr(), "[build.rs] {msg}");
|
||||
println!("cargo:warning={msg}");
|
||||
}
|
||||
|
||||
fn info(msg: &str) {
|
||||
let _ = writeln!(io::stderr(), "[build.rs] {msg}");
|
||||
println!("cargo:warning=INFO|{msg}");
|
||||
}
|
131
integration/cli/bun_target.rs
Normal file
131
integration/cli/bun_target.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
use std::env;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
|
||||
pub enum BunTarget {
|
||||
LinuxX64Glibc,
|
||||
LinuxArm64Glibc,
|
||||
LinuxX64Musl,
|
||||
LinuxArm64Musl,
|
||||
WindowsX64,
|
||||
WindowsArm64,
|
||||
MacX64,
|
||||
MacArm64,
|
||||
}
|
||||
|
||||
impl BunTarget {
|
||||
pub const fn as_bun_flag(self) -> &'static str {
|
||||
match self {
|
||||
BunTarget::LinuxX64Glibc => "bun-linux-x64",
|
||||
BunTarget::LinuxArm64Glibc => "bun-linux-arm64",
|
||||
BunTarget::LinuxX64Musl => "bun-linux-x64-musl",
|
||||
BunTarget::LinuxArm64Musl => "bun-linux-arm64-musl",
|
||||
BunTarget::WindowsX64 => "bun-windows-x64",
|
||||
BunTarget::WindowsArm64 => "bun-windows-arm64",
|
||||
BunTarget::MacX64 => "bun-darwin-x64",
|
||||
BunTarget::MacArm64 => "bun-darwin-arm64",
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn rust_triples(self) -> &'static [&'static str] {
|
||||
match self {
|
||||
BunTarget::LinuxX64Glibc => {
|
||||
&["x86_64-unknown-linux-gnu", "x86_64-unknown-linux-gnu.2.17"]
|
||||
}
|
||||
BunTarget::LinuxArm64Glibc => &["aarch64-unknown-linux-gnu"],
|
||||
BunTarget::LinuxX64Musl => &["x86_64-unknown-linux-musl"],
|
||||
BunTarget::LinuxArm64Musl => &["aarch64-unknown-linux-musl"],
|
||||
BunTarget::WindowsX64 => &["x86_64-pc-windows-msvc"],
|
||||
BunTarget::WindowsArm64 => &["aarch64-pc-windows-msvc"], // chart says unsupported; still map
|
||||
BunTarget::MacX64 => &["x86_64-apple-darwin"],
|
||||
BunTarget::MacArm64 => &["aarch64-apple-darwin"],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_rust_target(triple: &str) -> Option<Self> {
|
||||
let norm = triple.trim();
|
||||
if norm.starts_with("x86_64-") && norm.contains("-linux-") && norm.ends_with("gnu") {
|
||||
return Some(BunTarget::LinuxX64Glibc);
|
||||
}
|
||||
if norm.starts_with("aarch64-") && norm.contains("-linux-") && norm.ends_with("gnu") {
|
||||
return Some(BunTarget::LinuxArm64Glibc);
|
||||
}
|
||||
if norm.starts_with("x86_64-") && norm.contains("-linux-") && norm.ends_with("musl") {
|
||||
return Some(BunTarget::LinuxX64Musl);
|
||||
}
|
||||
if norm.starts_with("aarch64-") && norm.contains("-linux-") && norm.ends_with("musl") {
|
||||
return Some(BunTarget::LinuxArm64Musl);
|
||||
}
|
||||
if norm == "x86_64-pc-windows-msvc" {
|
||||
return Some(BunTarget::WindowsX64);
|
||||
}
|
||||
if norm == "aarch64-pc-windows-msvc" {
|
||||
return Some(BunTarget::WindowsArm64);
|
||||
}
|
||||
if norm == "x86_64-apple-darwin" {
|
||||
return Some(BunTarget::MacX64);
|
||||
}
|
||||
if norm == "aarch64-apple-darwin" {
|
||||
return Some(BunTarget::MacArm64);
|
||||
}
|
||||
for bt in [
|
||||
BunTarget::LinuxX64Glibc,
|
||||
BunTarget::LinuxArm64Glibc,
|
||||
BunTarget::LinuxX64Musl,
|
||||
BunTarget::LinuxArm64Musl,
|
||||
BunTarget::WindowsX64,
|
||||
BunTarget::WindowsArm64,
|
||||
BunTarget::MacX64,
|
||||
BunTarget::MacArm64,
|
||||
] {
|
||||
for &t in bt.rust_triples() {
|
||||
if t == norm {
|
||||
return Some(bt);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn from_cargo_env() -> Result<Self, BunTargetError> {
|
||||
if let Ok(triple) = env::var("TARGET") {
|
||||
if let Some(bt) = Self::from_rust_target(&triple) {
|
||||
return Ok(bt);
|
||||
}
|
||||
return Err(BunTargetError::UnknownTriple(triple));
|
||||
}
|
||||
|
||||
let os = env::var("CARGO_CFG_TARGET_OS").unwrap_or_default();
|
||||
let arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap_or_default();
|
||||
let envv = env::var("CARGO_CFG_TARGET_ENV").unwrap_or_default();
|
||||
let vendor = env::var("CARGO_CFG_TARGET_VENDOR").unwrap_or_else(|_| "unknown".into());
|
||||
|
||||
let triple = format!(
|
||||
"{}-{}-{}-{}",
|
||||
arch,
|
||||
vendor,
|
||||
os,
|
||||
if envv.is_empty() { "gnu" } else { &envv }
|
||||
);
|
||||
if let Some(bt) = Self::from_rust_target(&triple) {
|
||||
Ok(bt)
|
||||
} else {
|
||||
Err(BunTargetError::UnknownTriple(triple))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum BunTargetError {
|
||||
UnknownTriple(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for BunTargetError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
BunTargetError::UnknownTriple(t) => write!(f, "unrecognized Rust target triple: {t}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for BunTargetError {}
|
17
integration/cli/package/bun.lock
Normal file
17
integration/cli/package/bun.lock
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"lockfileVersion": 1,
|
||||
"workspaces": {
|
||||
"": {
|
||||
"name": "cli",
|
||||
"dependencies": {
|
||||
"install": "^0.13.0",
|
||||
"openai": "^5.16.0",
|
||||
},
|
||||
},
|
||||
},
|
||||
"packages": {
|
||||
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
|
||||
|
||||
"openai": ["openai@5.19.1", "", { "peerDependencies": { "ws": "^8.18.0", "zod": "^3.23.8" }, "optionalPeers": ["ws", "zod"], "bin": { "openai": "bin/cli" } }, "sha512-zSqnUF7oR9ksmpusKkpUgkNrj8Sl57U+OyzO8jzc7LUjTMg4DRfR3uCm+EIMA6iw06sRPNp4t7ojp3sCpEUZRQ=="],
|
||||
}
|
||||
}
|
339
integration/cli/package/cli.ts
Executable file
339
integration/cli/package/cli.ts
Executable file
@@ -0,0 +1,339 @@
|
||||
#!/usr/bin/env bun
|
||||
|
||||
import OpenAI from "openai";
|
||||
import { parseArgs } from "util";
|
||||
|
||||
// =====================
|
||||
// Config
|
||||
// =====================
|
||||
const DEFAULT_MODEL = "gemma-3-1b-it";
|
||||
const DEFAULT_MAX_TOKENS = 256;
|
||||
|
||||
// Toggle this to reduce log overhead during timing runs
|
||||
const PRINT_CHUNK_DEBUG = false;
|
||||
|
||||
// How many rows to show in the timing tables
|
||||
const SHOW_FIRST_N = 3;
|
||||
const SHOW_SLOWEST_N = 3;
|
||||
|
||||
// =====================
|
||||
// Helpers
|
||||
// =====================
|
||||
const now = () => performance.now();
|
||||
|
||||
type ChunkStat = {
|
||||
index: number;
|
||||
tSinceRequestStartMs: number;
|
||||
dtSincePrevMs: number;
|
||||
contentChars: number;
|
||||
};
|
||||
|
||||
function printHelp() {
|
||||
console.log(`
|
||||
./cli [options] [prompt]
|
||||
|
||||
Simple CLI tool for testing the local OpenAI-compatible API server.
|
||||
|
||||
Options:
|
||||
--model <model> Model to use (default: gemma-3-1b-it)
|
||||
--prompt <prompt> The prompt to send (can also be provided as positional argument)
|
||||
--list-models List all available models from the server
|
||||
--help Show this help message
|
||||
|
||||
Examples:
|
||||
./cli "What is the capital of France?"
|
||||
./cli --model gemma-3-1b-it --prompt "Hello, world!"
|
||||
./cli --prompt "Who was the 16th president of the United States?"
|
||||
./cli --list-models
|
||||
|
||||
The server must be running at http://localhost:8080
|
||||
`);
|
||||
}
|
||||
|
||||
const { values, positionals } = parseArgs({
|
||||
args: process.argv.slice(2),
|
||||
options: {
|
||||
model: { type: "string" },
|
||||
prompt: { type: "string" },
|
||||
help: { type: "boolean" },
|
||||
"list-models": { type: "boolean" },
|
||||
},
|
||||
strict: false,
|
||||
allowPositionals: true,
|
||||
});
|
||||
|
||||
async function requestLocalOpenAI(model: string, userPrompt: string) {
|
||||
const openai = new OpenAI({
|
||||
baseURL: "http://localhost:8080/v1",
|
||||
apiKey: "not used",
|
||||
});
|
||||
try {
|
||||
console.log("[DEBUG] Creating chat completion request...");
|
||||
return openai.chat.completions.create({
|
||||
model,
|
||||
max_tokens: DEFAULT_MAX_TOKENS,
|
||||
stream: true,
|
||||
messages: [
|
||||
{
|
||||
role: "system",
|
||||
content: "You are a helpful assistant who responds thoughtfully and concisely.",
|
||||
},
|
||||
{ role: "user", content: userPrompt },
|
||||
],
|
||||
});
|
||||
} catch (e: any) {
|
||||
console.error("[ERROR] Failed to connect to local OpenAI server:", e.message);
|
||||
console.error("[HINT] Make sure the server is running at http://localhost:8080");
|
||||
console.error("[HINT] Start it with: ./run_server.sh");
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
async function listModels() {
|
||||
const openai = new OpenAI({
|
||||
baseURL: "http://localhost:8080/v1",
|
||||
apiKey: "not used",
|
||||
});
|
||||
try {
|
||||
const models = await openai.models.list();
|
||||
console.log(`[INFO] Available models from http://localhost:8080/v1:`);
|
||||
console.log("---");
|
||||
|
||||
if (models.data && models.data.length > 0) {
|
||||
models.data.forEach((model, index) => {
|
||||
console.log(`${index + 1}. ${model.id}`);
|
||||
console.log(` Owner: ${model.owned_by}`);
|
||||
console.log(` Created: ${new Date(model.created * 1000).toISOString()}`);
|
||||
console.log("");
|
||||
});
|
||||
console.log(`Total: ${models.data.length} models available`);
|
||||
} else {
|
||||
console.log("No models found.");
|
||||
}
|
||||
} catch (e: any) {
|
||||
console.error("[ERROR] Failed to fetch models from local OpenAI server:", e.message);
|
||||
console.error("[HINT] Make sure the server is running at http://localhost:8080");
|
||||
console.error("[HINT] Start it with: ./run_server.sh");
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Timing math
|
||||
// =====================
|
||||
function median(nums: number[]) {
|
||||
if (nums.length === 0) return 0;
|
||||
const s = [...nums].sort((a, b) => a - b);
|
||||
const mid = Math.floor(s.length / 2);
|
||||
return s.length % 2 ? s[mid] : (s[mid - 1] + s[mid]) / 2;
|
||||
}
|
||||
|
||||
function quantile(nums: number[], q: number) {
|
||||
if (nums.length === 0) return 0;
|
||||
const s = [...nums].sort((a, b) => a - b);
|
||||
const pos = (s.length - 1) * q;
|
||||
const base = Math.floor(pos);
|
||||
const rest = pos - base;
|
||||
return s[base + 1] !== undefined ? s[base] + rest * (s[base + 1] - s[base]) : s[base];
|
||||
}
|
||||
|
||||
function ms(n: number) {
|
||||
return `${n.toFixed(1)} ms`;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Main
|
||||
// =====================
|
||||
async function main() {
|
||||
const tProgramStart = now();
|
||||
|
||||
if (values.help) {
|
||||
printHelp();
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
if (values["list-models"]) {
|
||||
try {
|
||||
await listModels();
|
||||
process.exit(0);
|
||||
} catch (error: any) {
|
||||
console.error("\n[ERROR] Failed to list models:", error.message);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
const prompt = values.prompt ?? positionals[0];
|
||||
|
||||
if (!prompt) {
|
||||
console.error("[ERROR] No prompt provided!");
|
||||
printHelp();
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
const model = values.model || DEFAULT_MODEL;
|
||||
|
||||
console.log(`[INFO] Using model: ${model}`);
|
||||
console.log(`[INFO] Prompt: ${prompt}`);
|
||||
console.log(`[INFO] Connecting to: http://localhost:8080/v1`);
|
||||
console.log("---");
|
||||
|
||||
const tBeforeRequest = now();
|
||||
|
||||
try {
|
||||
console.log("[DEBUG] Initiating request to OpenAI server...");
|
||||
const response = await requestLocalOpenAI(model, prompt);
|
||||
const tAfterCreate = now();
|
||||
|
||||
// Streaming handling + timing
|
||||
let fullResponse = "";
|
||||
let chunkCount = 0;
|
||||
|
||||
const chunkStats: ChunkStat[] = [];
|
||||
let tFirstChunk: number | null = null;
|
||||
let tPrevChunk: number | null = null;
|
||||
|
||||
console.log("[INFO] Waiting for model to generate response...");
|
||||
let loadingInterval;
|
||||
if (!PRINT_CHUNK_DEBUG) {
|
||||
// Show loading animation only if not in debug mode
|
||||
const loadingChars = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'];
|
||||
let i = 0;
|
||||
process.stdout.write('\r[INFO] Thinking ');
|
||||
loadingInterval = setInterval(() => {
|
||||
process.stdout.write(`\r[INFO] Thinking ${loadingChars[i++ % loadingChars.length]} `);
|
||||
}, 80);
|
||||
} else {
|
||||
console.log("[DEBUG] Starting to receive streaming response...");
|
||||
}
|
||||
|
||||
for await (const chunk of response) {
|
||||
// Clear loading animation on first chunk
|
||||
if (loadingInterval) {
|
||||
clearInterval(loadingInterval);
|
||||
process.stdout.write('\r \r');
|
||||
}
|
||||
const tNow = now();
|
||||
chunkCount++;
|
||||
|
||||
// Extract content (delta) if present
|
||||
const content = chunk.choices?.[0]?.delta?.content ?? "";
|
||||
if (PRINT_CHUNK_DEBUG) {
|
||||
console.log(`[DEBUG] Received chunk #${chunkCount}:`, JSON.stringify(chunk));
|
||||
if (content) console.log(`[DEBUG] Chunk content: "${content}"`);
|
||||
}
|
||||
|
||||
if (content) {
|
||||
process.stdout.write(content);
|
||||
fullResponse += content;
|
||||
}
|
||||
|
||||
if (tFirstChunk === null) tFirstChunk = tNow;
|
||||
|
||||
const dtSincePrev = tPrevChunk === null ? 0 : tNow - tPrevChunk;
|
||||
chunkStats.push({
|
||||
index: chunkCount,
|
||||
tSinceRequestStartMs: tNow - tBeforeRequest,
|
||||
dtSincePrevMs: dtSincePrev,
|
||||
contentChars: content.length,
|
||||
});
|
||||
|
||||
tPrevChunk = tNow;
|
||||
}
|
||||
|
||||
// =========
|
||||
// Summary
|
||||
// =========
|
||||
const tStreamEnd = now();
|
||||
const totalChars = fullResponse.length;
|
||||
|
||||
console.log("\n---");
|
||||
console.log(`[DEBUG] Stream completed after ${chunkCount} chunks`);
|
||||
console.log(`[INFO] Response completed. Total length: ${totalChars} characters`);
|
||||
|
||||
// Build timing metrics
|
||||
const ttfbMs = (tFirstChunk ?? tStreamEnd) - tAfterCreate; // time from create() resolved → first chunk
|
||||
const createOverheadMs = tAfterCreate - tBeforeRequest; // time spent awaiting create() promise
|
||||
const totalSinceRequestMs = tStreamEnd - tBeforeRequest; // from just before create() to last chunk
|
||||
const streamDurationMs =
|
||||
tFirstChunk === null ? 0 : tStreamEnd - tFirstChunk;
|
||||
|
||||
const gaps = chunkStats
|
||||
.map((c) => c.dtSincePrevMs)
|
||||
// ignore the first "gap" which is 0 by construction
|
||||
.slice(1);
|
||||
|
||||
const avgGapMs = gaps.length ? gaps.reduce((a, b) => a + b, 0) / gaps.length : 0;
|
||||
const medGapMs = median(gaps);
|
||||
const p95GapMs = quantile(gaps, 0.95);
|
||||
|
||||
let maxGapMs = 0;
|
||||
let maxGapAtChunk = 0;
|
||||
for (let i = 0; i < gaps.length; i++) {
|
||||
if (gaps[i] > maxGapMs) {
|
||||
maxGapMs = gaps[i];
|
||||
maxGapAtChunk = i + 2; // +1 to move from 0-based, +1 because we sliced starting at second chunk
|
||||
}
|
||||
}
|
||||
|
||||
// Pretty print summary
|
||||
console.log("\n=== Timing Summary ===");
|
||||
console.log(`create() await time: ${ms(createOverheadMs)}`);
|
||||
console.log(`TTFB (to 1st chunk): ${ms(ttfbMs)}`);
|
||||
console.log(`Stream duration: ${ms(streamDurationMs)}`);
|
||||
console.log(`End-to-end (req→last): ${ms(totalSinceRequestMs)}`);
|
||||
console.log(`Chunks: ${chunkCount}`);
|
||||
console.log(`Total content chars: ${totalChars}`);
|
||||
console.log(`Avg chars/chunk: ${(chunkCount ? totalChars / chunkCount : 0).toFixed(1)}`);
|
||||
console.log(`Inter-chunk gap (avg): ${ms(avgGapMs)}`);
|
||||
console.log(`Inter-chunk gap (median): ${ms(medGapMs)}`);
|
||||
console.log(`Inter-chunk gap (p95): ${ms(p95GapMs)}`);
|
||||
if (gaps.length > 0) {
|
||||
console.log(`Largest gap: ${ms(maxGapMs)} (before chunk #${maxGapAtChunk})`);
|
||||
}
|
||||
|
||||
// Small tables: first N and slowest N gaps
|
||||
const firstRows = chunkStats.slice(0, SHOW_FIRST_N).map((c) => ({
|
||||
chunk: c.index,
|
||||
"t since request": `${c.tSinceRequestStartMs.toFixed(1)} ms`,
|
||||
"dt since prev": `${c.dtSincePrevMs.toFixed(1)} ms`,
|
||||
"chars": c.contentChars,
|
||||
}));
|
||||
|
||||
const slowestRows = chunkStats
|
||||
.slice(1) // skip first (no meaningful gap)
|
||||
.sort((a, b) => b.dtSincePrevMs - a.dtSincePrevMs)
|
||||
.slice(0, SHOW_SLOWEST_N)
|
||||
.map((c) => ({
|
||||
chunk: c.index,
|
||||
"t since request": `${c.tSinceRequestStartMs.toFixed(1)} ms`,
|
||||
"dt since prev": `${c.dtSincePrevMs.toFixed(1)} ms`,
|
||||
"chars": c.contentChars,
|
||||
}));
|
||||
|
||||
if (firstRows.length > 0) {
|
||||
console.log("\n--- First chunk timings ---");
|
||||
// @ts-ignore Bun/Node support console.table
|
||||
console.table(firstRows);
|
||||
}
|
||||
|
||||
if (slowestRows.length > 0) {
|
||||
console.log(`\n--- Slowest ${SHOW_SLOWEST_N} gaps ---`);
|
||||
// @ts-ignore
|
||||
console.table(slowestRows);
|
||||
}
|
||||
|
||||
const tProgramEnd = now();
|
||||
console.log("\n=== Program Overhead ===");
|
||||
console.log(`Total program runtime: ${ms(tProgramEnd - tProgramStart)}`);
|
||||
|
||||
} catch (error: any) {
|
||||
console.error("\n[ERROR] Request failed:", error.message);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Run the main function
|
||||
main().catch((error) => {
|
||||
console.error("[FATAL ERROR]:", error);
|
||||
process.exit(1);
|
||||
});
|
11
integration/cli/package/package.json
Normal file
11
integration/cli/package/package.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"name": "cli",
|
||||
"main": "cli.ts",
|
||||
"scripts": {
|
||||
"build": "bun build cli.ts --compile --outfile cli"
|
||||
},
|
||||
"dependencies": {
|
||||
"install": "^0.13.0",
|
||||
"openai": "^5.16.0"
|
||||
}
|
||||
}
|
32
integration/cli/src/main.rs
Normal file
32
integration/cli/src/main.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use std::{env, fs, io, path::PathBuf, process::Command};
|
||||
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
fn main() -> io::Result<()> {
|
||||
// Absolute path provided by build.rs at compile time.
|
||||
// `include_bytes!` accepts string literals; `env!` expands to a literal at compile time.
|
||||
const CLIENT_CLI: &[u8] = include_bytes!(env!("CLIENT_CLI_BIN"));
|
||||
|
||||
// Write to a temp file
|
||||
let mut tmp = env::temp_dir();
|
||||
tmp.push("client-cli-embedded");
|
||||
|
||||
fs::write(&tmp, CLIENT_CLI)?;
|
||||
|
||||
// Ensure it's executable on Unix
|
||||
#[cfg(unix)]
|
||||
{
|
||||
let mut perms = fs::metadata(&tmp)?.permissions();
|
||||
perms.set_mode(0o755);
|
||||
fs::set_permissions(&tmp, perms)?;
|
||||
}
|
||||
|
||||
// Run it
|
||||
let status = Command::new(&tmp).arg("--version").status()?;
|
||||
if !status.success() {
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "client-cli failed"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
32
integration/gemma-runner/Cargo.toml
Normal file
32
integration/gemma-runner/Cargo.toml
Normal file
@@ -0,0 +1,32 @@
|
||||
[package]
|
||||
name = "gemma-runner"
|
||||
version.workspace = true
|
||||
edition = "2021"
|
||||
|
||||
|
||||
|
||||
|
||||
[dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
||||
hf-hub = "0.4"
|
||||
tokenizers = "0.22.0"
|
||||
anyhow = "1.0"
|
||||
clap = { version = "4.0", features = ["derive", "string"] }
|
||||
serde_json = "1.0"
|
||||
tracing = "0.1"
|
||||
tracing-chrome = "0.7"
|
||||
tracing-subscriber = "0.3"
|
||||
utils = {path = "../utils" }
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
|
137
integration/gemma-runner/README.md
Normal file
137
integration/gemma-runner/README.md
Normal file
@@ -0,0 +1,137 @@
|
||||
# Gemma Runner
|
||||
|
||||
Fast Gemma inference with Candle framework in Rust.
|
||||
|
||||
## Features
|
||||
|
||||
- Support for multiple Gemma model versions (v1, v2, v3)
|
||||
- GPU acceleration with CUDA and Metal
|
||||
- Configurable sampling parameters
|
||||
- Multiple model variants including instruct and code models
|
||||
|
||||
## Supported Models
|
||||
|
||||
### Gemma v1
|
||||
- `gemma-2b` - Base 2B model
|
||||
- `gemma-7b` - Base 7B model
|
||||
- `gemma-2b-it` - Instruct 2B model
|
||||
- `gemma-7b-it` - Instruct 7B model
|
||||
- `gemma-1.1-2b-it` - Instruct 2B v1.1 model
|
||||
- `gemma-1.1-7b-it` - Instruct 7B v1.1 model
|
||||
|
||||
### CodeGemma
|
||||
- `codegemma-2b` - Code base 2B model
|
||||
- `codegemma-7b` - Code base 7B model
|
||||
- `codegemma-2b-it` - Code instruct 2B model
|
||||
- `codegemma-7b-it` - Code instruct 7B model
|
||||
|
||||
### Gemma v2
|
||||
- `gemma-2-2b` - Base 2B v2 model (default)
|
||||
- `gemma-2-2b-it` - Instruct 2B v2 model
|
||||
- `gemma-2-9b` - Base 9B v2 model
|
||||
- `gemma-2-9b-it` - Instruct 9B v2 model
|
||||
|
||||
### Gemma v3
|
||||
- `gemma-3-1b` - Base 1B v3 model
|
||||
- `gemma-3-1b-it` - Instruct 1B v3 model
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
cd gemma-runner
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
For GPU support:
|
||||
```bash
|
||||
# CUDA
|
||||
cargo build --release --features cuda
|
||||
|
||||
# Metal (macOS)
|
||||
cargo build --release --features metal
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
# Run with default model (gemma-2-2b)
|
||||
cargo run -- --prompt "The capital of France is"
|
||||
|
||||
# Specify a different model
|
||||
cargo run -- --model gemma-2b-it --prompt "Explain quantum computing"
|
||||
|
||||
# Generate more tokens
|
||||
cargo run -- --model codegemma-2b-it --prompt "Write a Python function to sort a list" --max-tokens 200
|
||||
```
|
||||
|
||||
### Advanced Options
|
||||
|
||||
```bash
|
||||
# Use CPU instead of GPU
|
||||
cargo run -- --cpu --prompt "Hello world"
|
||||
|
||||
# Adjust sampling parameters
|
||||
cargo run -- --temperature 0.8 --top-p 0.9 --prompt "Write a story about"
|
||||
|
||||
# Use custom model from HuggingFace Hub
|
||||
cargo run -- --model-id "google/gemma-2-2b-it" --prompt "What is AI?"
|
||||
|
||||
# Enable tracing for performance analysis
|
||||
cargo run -- --tracing --prompt "Explain machine learning"
|
||||
```
|
||||
|
||||
### Command Line Arguments
|
||||
|
||||
- `--prompt, -p` - The prompt to generate text from (default: "The capital of France is")
|
||||
- `--model, -m` - The model to use (default: "gemma-2-2b")
|
||||
- `--cpu` - Run on CPU rather than GPU
|
||||
- `--temperature, -t` - Sampling temperature (optional)
|
||||
- `--top-p` - Nucleus sampling probability cutoff (optional)
|
||||
- `--seed` - Random seed (default: 299792458)
|
||||
- `--max-tokens, -n` - Maximum tokens to generate (default: 100)
|
||||
- `--model-id` - Custom model ID from HuggingFace Hub
|
||||
- `--revision` - Model revision (default: "main")
|
||||
- `--use-flash-attn` - Use flash attention
|
||||
- `--repeat-penalty` - Repetition penalty (default: 1.1)
|
||||
- `--repeat-last-n` - Context size for repeat penalty (default: 64)
|
||||
- `--dtype` - Data type (f16, bf16, f32)
|
||||
- `--tracing` - Enable performance tracing
|
||||
|
||||
## Examples
|
||||
|
||||
### Text Generation
|
||||
```bash
|
||||
cargo run -- --model gemma-2b-it --prompt "Explain the theory of relativity" --max-tokens 150
|
||||
```
|
||||
|
||||
### Code Generation
|
||||
```bash
|
||||
cargo run -- --model codegemma-7b-it --prompt "Write a Rust function to calculate factorial" --max-tokens 100
|
||||
```
|
||||
|
||||
### Creative Writing
|
||||
```bash
|
||||
cargo run -- --model gemma-7b-it --temperature 0.9 --prompt "Once upon a time in a magical forest" --max-tokens 200
|
||||
```
|
||||
|
||||
### Chat with Gemma 3 (Instruct format)
|
||||
```bash
|
||||
cargo run -- --model gemma-3-1b-it --prompt "How do I learn Rust programming?"
|
||||
```
|
||||
|
||||
## Performance Notes
|
||||
|
||||
- GPU acceleration is automatically detected and used when available
|
||||
- BF16 precision is used on CUDA for better performance
|
||||
- F32 precision is used on CPU
|
||||
- Flash attention can be enabled with `--use-flash-attn` for supported models
|
||||
- Model files are cached locally after first download
|
||||
|
||||
## Requirements
|
||||
|
||||
- Rust 1.70+
|
||||
- CUDA toolkit (for CUDA support)
|
||||
- Metal (automatically available on macOS)
|
||||
- Internet connection for first-time model download
|
398
integration/gemma-runner/src/gemma_api.rs
Normal file
398
integration/gemma-runner/src/gemma_api.rs
Normal file
@@ -0,0 +1,398 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
use clap::ValueEnum;
|
||||
|
||||
// Removed gemma_cli import as it's not needed for the API
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
|
||||
use std::sync::mpsc::{self, Receiver, Sender};
|
||||
use std::thread;
|
||||
use tokenizers::Tokenizer;
|
||||
use utils::hub_load_safetensors;
|
||||
use utils::token_output_stream::TokenOutputStream;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
pub enum WhichModel {
|
||||
#[value(name = "gemma-2b")]
|
||||
Base2B,
|
||||
#[value(name = "gemma-7b")]
|
||||
Base7B,
|
||||
#[value(name = "gemma-2b-it")]
|
||||
Instruct2B,
|
||||
#[value(name = "gemma-7b-it")]
|
||||
Instruct7B,
|
||||
#[value(name = "gemma-1.1-2b-it")]
|
||||
InstructV1_1_2B,
|
||||
#[value(name = "gemma-1.1-7b-it")]
|
||||
InstructV1_1_7B,
|
||||
#[value(name = "codegemma-2b")]
|
||||
CodeBase2B,
|
||||
#[value(name = "codegemma-7b")]
|
||||
CodeBase7B,
|
||||
#[value(name = "codegemma-2b-it")]
|
||||
CodeInstruct2B,
|
||||
#[value(name = "codegemma-7b-it")]
|
||||
CodeInstruct7B,
|
||||
#[value(name = "gemma-2-2b")]
|
||||
BaseV2_2B,
|
||||
#[value(name = "gemma-2-2b-it")]
|
||||
InstructV2_2B,
|
||||
#[value(name = "gemma-2-9b")]
|
||||
BaseV2_9B,
|
||||
#[value(name = "gemma-2-9b-it")]
|
||||
InstructV2_9B,
|
||||
#[value(name = "gemma-3-1b")]
|
||||
BaseV3_1B,
|
||||
#[value(name = "gemma-3-1b-it")]
|
||||
InstructV3_1B,
|
||||
}
|
||||
|
||||
enum Model {
|
||||
V1(Model1),
|
||||
V2(Model2),
|
||||
V3(Model3),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> candle_core::Result<Tensor> {
|
||||
match self {
|
||||
Self::V1(m) => m.forward(input_ids, pos),
|
||||
Self::V2(m) => m.forward(input_ids, pos),
|
||||
Self::V3(m) => m.forward(input_ids, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn device(cpu: bool) -> Result<Device> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
} else if candle_core::utils::cuda_is_available() {
|
||||
Ok(Device::new_cuda(0)?)
|
||||
} else if candle_core::utils::metal_is_available() {
|
||||
Ok(Device::new_metal(0)?)
|
||||
} else {
|
||||
Ok(Device::Cpu)
|
||||
}
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Stream-only generation: sends freshly generated token strings over `tx`.
|
||||
/// (Does not send the prompt tokens; only newly generated model tokens.)
|
||||
fn run_stream(
|
||||
&mut self,
|
||||
prompt: &str,
|
||||
sample_len: usize,
|
||||
tx: Sender<Result<String>>,
|
||||
) -> Result<()> {
|
||||
self.tokenizer.clear();
|
||||
|
||||
// Encode prompt (context only; do not emit prompt tokens to the stream).
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
// Warm the tokenizer's internal state with prompt tokens (so merges are correct),
|
||||
// but do not send them to the receiver.
|
||||
for &t in tokens.iter() {
|
||||
let _ = self.tokenizer.next_token(t)?;
|
||||
}
|
||||
// Make sure stdout isn't holding anything (if caller also prints).
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
eprintln!("Warning: <end_of_turn> token not found, using <eos> as backup");
|
||||
eos_token
|
||||
}
|
||||
};
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
// Best-effort send; ignore if receiver dropped.
|
||||
let _ = tx.send(Ok(t));
|
||||
}
|
||||
}
|
||||
|
||||
let _dt = start_gen.elapsed();
|
||||
|
||||
// Flush any remaining buffered bytes as one final chunk.
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
let _ = tx.send(Ok(rest));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GemmaInferenceConfig {
|
||||
pub tracing: bool,
|
||||
pub prompt: String,
|
||||
pub model: WhichModel,
|
||||
pub cpu: bool,
|
||||
pub dtype: Option<String>,
|
||||
pub model_id: Option<String>,
|
||||
pub revision: String,
|
||||
pub use_flash_attn: bool,
|
||||
pub seed: u64,
|
||||
pub temperature: f64,
|
||||
pub top_p: Option<f64>,
|
||||
pub repeat_penalty: f32,
|
||||
pub repeat_last_n: usize,
|
||||
pub max_tokens: usize,
|
||||
}
|
||||
|
||||
impl Default for GemmaInferenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tracing: false,
|
||||
prompt: "Hello".to_string(),
|
||||
model: WhichModel::InstructV2_2B,
|
||||
cpu: false,
|
||||
dtype: None,
|
||||
model_id: None,
|
||||
revision: "main".to_string(),
|
||||
use_flash_attn: false,
|
||||
seed: 299792458,
|
||||
temperature: 0.8,
|
||||
top_p: None,
|
||||
repeat_penalty: 1.1,
|
||||
repeat_last_n: 128,
|
||||
max_tokens: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Removed From<Args> implementation as Args is not available and not needed for API usage
|
||||
|
||||
/// Builds the model and returns a channel that streams generated token strings.
|
||||
/// If model setup fails, the `Result` is returned immediately.
|
||||
pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String>>> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let _guard = if cfg.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle_core::utils::with_avx(),
|
||||
candle_core::utils::with_neon(),
|
||||
candle_core::utils::with_simd128(),
|
||||
candle_core::utils::with_f16c()
|
||||
);
|
||||
|
||||
let device = device(cfg.cpu)?;
|
||||
println!("Device: {:?}", device);
|
||||
|
||||
let dtype = match cfg.dtype.as_deref() {
|
||||
Some("f16") => DType::F16,
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => anyhow::bail!("Unsupported dtype {dtype}"),
|
||||
None => {
|
||||
if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F16
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("Using dtype: {:?}", dtype);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
|
||||
let model_id = cfg.model_id.unwrap_or_else(|| {
|
||||
match cfg.model {
|
||||
WhichModel::Base2B => "google/gemma-2b",
|
||||
WhichModel::Base7B => "google/gemma-7b",
|
||||
WhichModel::Instruct2B => "google/gemma-2b-it",
|
||||
WhichModel::Instruct7B => "google/gemma-7b-it",
|
||||
WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it",
|
||||
WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it",
|
||||
WhichModel::CodeBase2B => "google/codegemma-2b",
|
||||
WhichModel::CodeBase7B => "google/codegemma-7b",
|
||||
WhichModel::CodeInstruct2B => "google/codegemma-2b-it",
|
||||
WhichModel::CodeInstruct7B => "google/codegemma-7b-it",
|
||||
WhichModel::BaseV2_2B => "google/gemma-2-2b",
|
||||
WhichModel::InstructV2_2B => "google/gemma-2-2b-it",
|
||||
WhichModel::BaseV2_9B => "google/gemma-2-9b",
|
||||
WhichModel::InstructV2_9B => "google/gemma-2-9b-it",
|
||||
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
|
||||
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
|
||||
}
|
||||
.to_string()
|
||||
});
|
||||
|
||||
println!("Loading model: {}", &model_id);
|
||||
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, cfg.revision));
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let filenames = match cfg.model {
|
||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
||||
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("Retrieved files in {:?}", start.elapsed());
|
||||
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
|
||||
let model: Model = match cfg.model {
|
||||
WhichModel::Base2B
|
||||
| WhichModel::Base7B
|
||||
| WhichModel::Instruct2B
|
||||
| WhichModel::Instruct7B
|
||||
| WhichModel::InstructV1_1_2B
|
||||
| WhichModel::InstructV1_1_7B
|
||||
| WhichModel::CodeBase2B
|
||||
| WhichModel::CodeBase7B
|
||||
| WhichModel::CodeInstruct2B
|
||||
| WhichModel::CodeInstruct7B => {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V1(model)
|
||||
}
|
||||
WhichModel::BaseV2_2B
|
||||
| WhichModel::InstructV2_2B
|
||||
| WhichModel::BaseV2_9B
|
||||
| WhichModel::InstructV2_9B => {
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V2(model)
|
||||
}
|
||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => {
|
||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V3(model)
|
||||
}
|
||||
};
|
||||
println!("Loaded model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
cfg.seed,
|
||||
cfg.temperature.into(),
|
||||
cfg.top_p,
|
||||
cfg.repeat_penalty,
|
||||
cfg.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
|
||||
let prompt = match cfg.model {
|
||||
WhichModel::InstructV3_1B => {
|
||||
format!(
|
||||
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
|
||||
cfg.prompt
|
||||
)
|
||||
}
|
||||
_ => cfg.prompt,
|
||||
};
|
||||
|
||||
println!("Starting inference...");
|
||||
|
||||
// Create the channel after successful setup.
|
||||
let (tx, rx) = mpsc::channel::<Result<String>>();
|
||||
|
||||
// Spawn generation thread; send tokens to the channel.
|
||||
thread::spawn(move || {
|
||||
// If generation fails, forward the error once.
|
||||
if let Err(e) = pipeline.run_stream(&prompt, cfg.max_tokens, tx.clone()) {
|
||||
let _ = tx.send(Err(e));
|
||||
}
|
||||
// Channel closes when tx is dropped.
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
}
|
97
integration/gemma-runner/src/gemma_cli.rs
Normal file
97
integration/gemma-runner/src/gemma_cli.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
|
||||
use clap::Parser;
|
||||
use std::io::Write;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)]
|
||||
pub struct Args {
|
||||
/// The prompt to generate text from
|
||||
#[arg(short, long, default_value = "The capital of France is")]
|
||||
pub(crate) prompt: String,
|
||||
|
||||
/// The model to use
|
||||
#[arg(short, long, default_value = "gemma-2-2b")]
|
||||
pub(crate) model: WhichModel,
|
||||
|
||||
/// Run on CPU rather than GPU
|
||||
#[arg(long)]
|
||||
pub(crate) cpu: bool,
|
||||
|
||||
/// The temperature used to generate samples
|
||||
#[arg(short, long)]
|
||||
pub(crate) temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff
|
||||
#[arg(long)]
|
||||
pub(crate) top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
pub(crate) seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens)
|
||||
#[arg(short = 'n', long, default_value_t = 100)]
|
||||
pub(crate) max_tokens: usize,
|
||||
|
||||
/// Use different dtype than default
|
||||
#[arg(long)]
|
||||
pub(crate) dtype: Option<String>,
|
||||
|
||||
/// Custom model ID from HuggingFace Hub
|
||||
#[arg(long)]
|
||||
pub(crate) model_id: Option<String>,
|
||||
|
||||
/// Model revision
|
||||
#[arg(long, default_value = "main")]
|
||||
pub(crate) revision: String,
|
||||
|
||||
/// Use flash attention
|
||||
#[arg(long)]
|
||||
pub(crate) use_flash_attn: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
pub(crate) repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty
|
||||
#[arg(long, default_value_t = 64)]
|
||||
pub(crate) repeat_last_n: usize,
|
||||
|
||||
/// Enable tracing
|
||||
#[arg(long)]
|
||||
pub(crate) tracing: bool,
|
||||
}
|
||||
|
||||
pub fn run_cli() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let cfg = GemmaInferenceConfig {
|
||||
tracing: args.tracing,
|
||||
prompt: args.prompt,
|
||||
model: args.model,
|
||||
cpu: args.cpu,
|
||||
dtype: args.dtype,
|
||||
model_id: args.model_id,
|
||||
revision: args.revision,
|
||||
use_flash_attn: args.use_flash_attn,
|
||||
seed: args.seed,
|
||||
temperature: args.temperature.unwrap_or(0.8),
|
||||
top_p: args.top_p,
|
||||
repeat_penalty: args.repeat_penalty,
|
||||
repeat_last_n: args.repeat_last_n,
|
||||
max_tokens: args.max_tokens,
|
||||
};
|
||||
let rx = run_gemma_api(cfg)?;
|
||||
for msg in rx {
|
||||
match msg {
|
||||
Ok(tok) => {
|
||||
print!("{tok}");
|
||||
let _ = std::io::stdout().flush(); // <- force it out now
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("generation error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
3
integration/gemma-runner/src/lib.rs
Normal file
3
integration/gemma-runner/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod gemma_api;
|
||||
|
||||
pub use gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
|
17
integration/gemma-runner/src/main.rs
Normal file
17
integration/gemma-runner/src/main.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
mod gemma_api;
|
||||
mod gemma_cli;
|
||||
|
||||
use anyhow::Error;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use crate::gemma_cli::run_cli;
|
||||
use std::io::Write;
|
||||
|
||||
/// just a placeholder, not used for anything
|
||||
fn main() -> std::result::Result<(), Error> {
|
||||
run_cli()
|
||||
}
|
16
integration/helm-chart-tool/Cargo.toml
Normal file
16
integration/helm-chart-tool/Cargo.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "helm-chart-tool"
|
||||
version.workspace = true
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
name = "helm-chart-tool"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
toml = "0.8"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
anyhow = "1.0"
|
||||
clap = { version = "4.0", features = ["derive"] }
|
||||
walkdir = "2.0"
|
218
integration/helm-chart-tool/README.md
Normal file
218
integration/helm-chart-tool/README.md
Normal file
@@ -0,0 +1,218 @@
|
||||
# Helm Chart Tool
|
||||
|
||||
A Rust-based tool that automatically generates Helm charts from Cargo.toml metadata in Rust workspace projects.
|
||||
|
||||
## Overview
|
||||
|
||||
This tool scans a Rust workspace for crates containing Docker/Kubernetes metadata in their `Cargo.toml` files and generates a complete, production-ready Helm chart with deployments, services, ingress, and configuration templates.
|
||||
|
||||
## Features
|
||||
|
||||
- **Automatic Service Discovery**: Scans all `Cargo.toml` files in a workspace to find services with Kubernetes metadata
|
||||
- **Complete Helm Chart Generation**: Creates Chart.yaml, values.yaml, deployment templates, service templates, ingress template, and helper templates
|
||||
- **Metadata Extraction**: Uses `[package.metadata.kube]` sections from Cargo.toml files to extract:
|
||||
- Docker image names
|
||||
- Service ports
|
||||
- Replica counts
|
||||
- Service names
|
||||
- **Production Ready**: Generated charts include health checks, resource limits, node selectors, affinity rules, and tolerations
|
||||
- **Helm Best Practices**: Follows Helm chart conventions and passes `helm lint` validation
|
||||
|
||||
## Installation
|
||||
|
||||
Build the tool from source:
|
||||
|
||||
```bash
|
||||
cd helm-chart-tool
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
The binary will be available at `target/release/helm-chart-tool`.
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
./target/release/helm-chart-tool --workspace /path/to/rust/workspace --output ./my-helm-chart
|
||||
```
|
||||
|
||||
### Command Line Options
|
||||
|
||||
- `--workspace, -w PATH`: Path to the workspace root (default: `.`)
|
||||
- `--output, -o PATH`: Output directory for the Helm chart (default: `./helm-chart`)
|
||||
- `--name, -n NAME`: Name of the Helm chart (default: `predict-otron-9000`)
|
||||
|
||||
### Example
|
||||
|
||||
```bash
|
||||
# Generate chart from current workspace
|
||||
./target/release/helm-chart-tool
|
||||
|
||||
# Generate chart from specific workspace with custom output
|
||||
./target/release/helm-chart-tool -w /path/to/my/workspace -o ./charts/my-app -n my-application
|
||||
```
|
||||
|
||||
## Cargo.toml Metadata Format
|
||||
|
||||
The tool expects crates to have Kubernetes metadata in their `Cargo.toml` files:
|
||||
|
||||
```toml
|
||||
[package]
|
||||
name = "my-service"
|
||||
version = "0.1.0"
|
||||
|
||||
# Required: Kubernetes metadata
|
||||
[package.metadata.kube]
|
||||
image = "ghcr.io/myorg/my-service:latest"
|
||||
replicas = 1
|
||||
port = 8080
|
||||
|
||||
# Optional: Docker Compose metadata (currently not used but parsed)
|
||||
[package.metadata.compose]
|
||||
image = "ghcr.io/myorg/my-service:latest"
|
||||
port = 8080
|
||||
```
|
||||
|
||||
### Required Fields
|
||||
|
||||
- `image`: Full Docker image name including registry and tag
|
||||
- `port`: Port number the service listens on
|
||||
- `replicas`: Number of replicas to deploy (optional, defaults to 1)
|
||||
|
||||
## Generated Chart Structure
|
||||
|
||||
The tool generates a complete Helm chart with the following structure:
|
||||
|
||||
```
|
||||
helm-chart/
|
||||
├── Chart.yaml # Chart metadata
|
||||
├── values.yaml # Default configuration values
|
||||
├── .helmignore # Files to ignore when packaging
|
||||
└── templates/
|
||||
├── _helpers.tpl # Template helper functions
|
||||
├── ingress.yaml # Ingress configuration (optional)
|
||||
├── {service}-deployment.yaml # Deployment for each service
|
||||
└── {service}-service.yaml # Service for each service
|
||||
```
|
||||
|
||||
### Generated Files
|
||||
|
||||
#### Chart.yaml
|
||||
- Standard Helm v2 chart metadata
|
||||
- Includes keywords for AI/ML applications
|
||||
- Maintainer information
|
||||
|
||||
#### values.yaml
|
||||
- Individual service configurations
|
||||
- Resource limits and requests
|
||||
- Service types and ports
|
||||
- Node selectors, affinity, and tolerations
|
||||
- Global settings and ingress configuration
|
||||
|
||||
#### Deployment Templates
|
||||
- Kubernetes Deployment manifests
|
||||
- Health checks (liveness and readiness probes)
|
||||
- Resource management
|
||||
- Container port configuration from metadata
|
||||
- Support for node selectors, affinity, and tolerations
|
||||
|
||||
#### Service Templates
|
||||
- Kubernetes Service manifests
|
||||
- ClusterIP services by default
|
||||
- Port mapping from metadata
|
||||
|
||||
#### Ingress Template
|
||||
- Optional ingress configuration
|
||||
- Disabled by default
|
||||
- Configurable through values.yaml
|
||||
|
||||
## Example Output
|
||||
|
||||
When run against the predict-otron-9000 workspace, the tool generates:
|
||||
|
||||
```bash
|
||||
$ ./target/release/helm-chart-tool --workspace .. --output ../generated-helm-chart
|
||||
Parsing workspace at: ..
|
||||
Output directory: ../generated-helm-chart
|
||||
Chart name: predict-otron-9000
|
||||
Found 4 services:
|
||||
- chat-ui: ghcr.io/geoffsee/chat-ui:latest (port 8788)
|
||||
- inference-engine: ghcr.io/geoffsee/inference-service:latest (port 8080)
|
||||
- embeddings-engine: ghcr.io/geoffsee/embeddings-service:latest (port 8080)
|
||||
- predict-otron-9000: ghcr.io/geoffsee/predict-otron-9000:latest (port 8080)
|
||||
Helm chart generated successfully!
|
||||
```
|
||||
|
||||
## Validation
|
||||
|
||||
The generated charts pass Helm validation:
|
||||
|
||||
```bash
|
||||
$ helm lint generated-helm-chart
|
||||
==> Linting generated-helm-chart
|
||||
[INFO] Chart.yaml: icon is recommended
|
||||
1 chart(s) linted, 0 chart(s) failed
|
||||
```
|
||||
|
||||
## Deployment
|
||||
|
||||
Deploy the generated chart:
|
||||
|
||||
```bash
|
||||
# Install the chart
|
||||
helm install my-release ./generated-helm-chart
|
||||
|
||||
# Upgrade the chart
|
||||
helm upgrade my-release ./generated-helm-chart
|
||||
|
||||
# Uninstall the chart
|
||||
helm uninstall my-release
|
||||
```
|
||||
|
||||
### Customization
|
||||
|
||||
Customize the deployment by modifying `values.yaml`:
|
||||
|
||||
```yaml
|
||||
# Enable ingress
|
||||
ingress:
|
||||
enabled: true
|
||||
className: "nginx"
|
||||
hosts:
|
||||
- host: my-app.example.com
|
||||
|
||||
# Adjust resources for a specific service
|
||||
predict_otron_9000:
|
||||
replicas: 3
|
||||
resources:
|
||||
limits:
|
||||
memory: "4Gi"
|
||||
cpu: "2000m"
|
||||
requests:
|
||||
memory: "2Gi"
|
||||
cpu: "1000m"
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- Rust 2021+ (for building the tool)
|
||||
- Helm 3.x (for deploying the generated charts)
|
||||
- Kubernetes cluster (for deployment)
|
||||
|
||||
## Limitations
|
||||
|
||||
- Currently assumes all services need health checks on `/health` endpoint
|
||||
- Resource limits are hardcoded defaults (can be overridden in values.yaml)
|
||||
- Ingress configuration is basic (can be customized through values.yaml)
|
||||
|
||||
## Contributing
|
||||
|
||||
1. Add new features to the tool
|
||||
2. Test with various Cargo.toml metadata configurations
|
||||
3. Validate generated charts with `helm lint`
|
||||
4. Ensure charts deploy successfully to test clusters
|
||||
|
||||
## License
|
||||
|
||||
This tool is part of the predict-otron-9000 project and follows the same license terms.
|
525
integration/helm-chart-tool/src/main.rs
Normal file
525
integration/helm-chart-tool/src/main.rs
Normal file
@@ -0,0 +1,525 @@
|
||||
use anyhow::{Context, Result};
|
||||
use clap::{Arg, Command};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use walkdir::WalkDir;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CargoToml {
|
||||
package: Option<Package>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Package {
|
||||
name: String,
|
||||
metadata: Option<Metadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Metadata {
|
||||
kube: Option<KubeMetadata>,
|
||||
compose: Option<ComposeMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct KubeMetadata {
|
||||
image: String,
|
||||
replicas: Option<u32>,
|
||||
port: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ComposeMetadata {
|
||||
image: Option<String>,
|
||||
port: Option<u16>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ServiceInfo {
|
||||
name: String,
|
||||
image: String,
|
||||
port: u16,
|
||||
replicas: u32,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let matches = Command::new("helm-chart-tool")
|
||||
.about("Generate Helm charts from Cargo.toml metadata")
|
||||
.arg(
|
||||
Arg::new("workspace")
|
||||
.short('w')
|
||||
.long("workspace")
|
||||
.value_name("PATH")
|
||||
.help("Path to the workspace root")
|
||||
.default_value("."),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("output")
|
||||
.short('o')
|
||||
.long("output")
|
||||
.value_name("PATH")
|
||||
.help("Output directory for the Helm chart")
|
||||
.default_value("./helm-chart"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("chart-name")
|
||||
.short('n')
|
||||
.long("name")
|
||||
.value_name("NAME")
|
||||
.help("Name of the Helm chart")
|
||||
.default_value("predict-otron-9000"),
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
let workspace_path = matches.get_one::<String>("workspace").unwrap();
|
||||
let output_path = matches.get_one::<String>("output").unwrap();
|
||||
let chart_name = matches.get_one::<String>("chart-name").unwrap();
|
||||
|
||||
println!("Parsing workspace at: {}", workspace_path);
|
||||
println!("Output directory: {}", output_path);
|
||||
println!("Chart name: {}", chart_name);
|
||||
|
||||
let services = discover_services(workspace_path)?;
|
||||
println!("Found {} services:", services.len());
|
||||
for service in &services {
|
||||
println!(
|
||||
" - {}: {} (port {})",
|
||||
service.name, service.image, service.port
|
||||
);
|
||||
}
|
||||
|
||||
generate_helm_chart(output_path, chart_name, &services)?;
|
||||
println!("Helm chart generated successfully!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
|
||||
let workspace_root = Path::new(workspace_path);
|
||||
let mut services = Vec::new();
|
||||
|
||||
// Find all Cargo.toml files in the workspace
|
||||
for entry in WalkDir::new(workspace_root)
|
||||
.into_iter()
|
||||
.filter_map(|e| e.ok())
|
||||
{
|
||||
if entry.file_name() == "Cargo.toml" && entry.path() != workspace_root.join("../../../Cargo.toml") {
|
||||
if let Ok(service_info) = parse_cargo_toml(entry.path()) {
|
||||
services.push(service_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(services)
|
||||
}
|
||||
|
||||
fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> {
|
||||
let content = fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read Cargo.toml at {:?}", path))?;
|
||||
|
||||
let cargo_toml: CargoToml = toml::from_str(&content)
|
||||
.with_context(|| format!("Failed to parse Cargo.toml at {:?}", path))?;
|
||||
|
||||
let package = cargo_toml
|
||||
.package
|
||||
.ok_or_else(|| anyhow::anyhow!("No package section found in {:?}", path))?;
|
||||
|
||||
let metadata = package
|
||||
.metadata
|
||||
.ok_or_else(|| anyhow::anyhow!("No metadata section found in {:?}", path))?;
|
||||
|
||||
let kube_metadata = metadata
|
||||
.kube
|
||||
.ok_or_else(|| anyhow::anyhow!("No kube metadata found in {:?}", path))?;
|
||||
|
||||
Ok(ServiceInfo {
|
||||
name: package.name,
|
||||
image: kube_metadata.image,
|
||||
port: kube_metadata.port,
|
||||
replicas: kube_metadata.replicas.unwrap_or(1),
|
||||
})
|
||||
}
|
||||
|
||||
fn generate_helm_chart(
|
||||
output_path: &str,
|
||||
chart_name: &str,
|
||||
services: &[ServiceInfo],
|
||||
) -> Result<()> {
|
||||
let chart_dir = Path::new(output_path);
|
||||
let templates_dir = chart_dir.join("templates");
|
||||
|
||||
// Create directories
|
||||
fs::create_dir_all(&templates_dir)?;
|
||||
|
||||
// Generate Chart.yaml
|
||||
generate_chart_yaml(chart_dir, chart_name)?;
|
||||
|
||||
// Generate values.yaml
|
||||
generate_values_yaml(chart_dir, services)?;
|
||||
|
||||
// Generate templates for each service
|
||||
for service in services {
|
||||
generate_deployment_template(&templates_dir, service)?;
|
||||
generate_service_template(&templates_dir, service)?;
|
||||
}
|
||||
|
||||
// Generate ingress template
|
||||
generate_ingress_template(&templates_dir, services)?;
|
||||
|
||||
// Generate helper templates
|
||||
generate_helpers_template(&templates_dir)?;
|
||||
|
||||
// Generate .helmignore
|
||||
generate_helmignore(chart_dir)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_chart_yaml(chart_dir: &Path, chart_name: &str) -> Result<()> {
|
||||
let chart_yaml = format!(
|
||||
r#"apiVersion: v2
|
||||
name: {}
|
||||
description: A Helm chart for the predict-otron-9000 AI platform
|
||||
type: application
|
||||
version: 0.1.0
|
||||
appVersion: "0.1.0"
|
||||
keywords:
|
||||
- ai
|
||||
- llm
|
||||
- inference
|
||||
- embeddings
|
||||
- chat
|
||||
maintainers:
|
||||
- name: predict-otron-9000-team
|
||||
"#,
|
||||
chart_name
|
||||
);
|
||||
|
||||
fs::write(chart_dir.join("Chart.yaml"), chart_yaml)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_values_yaml(chart_dir: &Path, services: &[ServiceInfo]) -> Result<()> {
|
||||
let mut values = String::from(
|
||||
r#"# Default values for predict-otron-9000
|
||||
# This is a YAML-formatted file.
|
||||
|
||||
global:
|
||||
imagePullPolicy: IfNotPresent
|
||||
serviceType: ClusterIP
|
||||
|
||||
# Ingress configuration
|
||||
ingress:
|
||||
enabled: false
|
||||
className: ""
|
||||
annotations: {}
|
||||
hosts:
|
||||
- host: predict-otron-9000.local
|
||||
paths:
|
||||
- path: /
|
||||
pathType: Prefix
|
||||
backend:
|
||||
service:
|
||||
name: predict-otron-9000
|
||||
port:
|
||||
number: 8080
|
||||
tls: []
|
||||
|
||||
"#,
|
||||
);
|
||||
|
||||
for service in services {
|
||||
let service_config = format!(
|
||||
r#"{}:
|
||||
image:
|
||||
repository: {}
|
||||
tag: "latest"
|
||||
pullPolicy: IfNotPresent
|
||||
replicas: {}
|
||||
service:
|
||||
type: ClusterIP
|
||||
port: {}
|
||||
resources:
|
||||
limits:
|
||||
memory: "1Gi"
|
||||
cpu: "1000m"
|
||||
requests:
|
||||
memory: "512Mi"
|
||||
cpu: "250m"
|
||||
nodeSelector: {{}}
|
||||
tolerations: []
|
||||
affinity: {{}}
|
||||
|
||||
"#,
|
||||
service.name.replace("-", "_"),
|
||||
service.image.split(':').next().unwrap_or(&service.image),
|
||||
service.replicas,
|
||||
service.port
|
||||
);
|
||||
values.push_str(&service_config);
|
||||
}
|
||||
|
||||
fs::write(chart_dir.join("values.yaml"), values)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_deployment_template(templates_dir: &Path, service: &ServiceInfo) -> Result<()> {
|
||||
let service_name_underscore = service.name.replace("-", "_");
|
||||
let deployment_template = format!(
|
||||
r#"apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{{{ include "predict-otron-9000.fullname" . }}}}-{}
|
||||
labels:
|
||||
{{{{- include "predict-otron-9000.labels" . | nindent 4 }}}}
|
||||
app.kubernetes.io/component: {}
|
||||
spec:
|
||||
replicas: {{{{ .Values.{}.replicas }}}}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{{{- include "predict-otron-9000.selectorLabels" . | nindent 6 }}}}
|
||||
app.kubernetes.io/component: {}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{{{- include "predict-otron-9000.selectorLabels" . | nindent 8 }}}}
|
||||
app.kubernetes.io/component: {}
|
||||
spec:
|
||||
containers:
|
||||
- name: {}
|
||||
image: "{{{{ .Values.{}.image.repository }}}}:{{{{ .Values.{}.image.tag }}}}"
|
||||
imagePullPolicy: {{{{ .Values.{}.image.pullPolicy }}}}
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: {}
|
||||
protocol: TCP
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 30
|
||||
periodSeconds: 10
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 5
|
||||
periodSeconds: 5
|
||||
resources:
|
||||
{{{{- toYaml .Values.{}.resources | nindent 12 }}}}
|
||||
{{{{- with .Values.{}.nodeSelector }}}}
|
||||
nodeSelector:
|
||||
{{{{- toYaml . | nindent 8 }}}}
|
||||
{{{{- end }}}}
|
||||
{{{{- with .Values.{}.affinity }}}}
|
||||
affinity:
|
||||
{{{{- toYaml . | nindent 8 }}}}
|
||||
{{{{- end }}}}
|
||||
{{{{- with .Values.{}.tolerations }}}}
|
||||
tolerations:
|
||||
{{{{- toYaml . | nindent 8 }}}}
|
||||
{{{{- end }}}}
|
||||
"#,
|
||||
service.name,
|
||||
service.name,
|
||||
service_name_underscore,
|
||||
service.name,
|
||||
service.name,
|
||||
service.name,
|
||||
service_name_underscore,
|
||||
service_name_underscore,
|
||||
service_name_underscore,
|
||||
service.port,
|
||||
service_name_underscore,
|
||||
service_name_underscore,
|
||||
service_name_underscore,
|
||||
service_name_underscore
|
||||
);
|
||||
|
||||
let filename = format!("{}-deployment.yaml", service.name);
|
||||
fs::write(templates_dir.join(filename), deployment_template)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_service_template(templates_dir: &Path, service: &ServiceInfo) -> Result<()> {
|
||||
let service_template = format!(
|
||||
r#"apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{{{ include "predict-otron-9000.fullname" . }}}}-{}
|
||||
labels:
|
||||
{{{{- include "predict-otron-9000.labels" . | nindent 4 }}}}
|
||||
app.kubernetes.io/component: {}
|
||||
spec:
|
||||
type: {{{{ .Values.{}.service.type }}}}
|
||||
ports:
|
||||
- port: {{{{ .Values.{}.service.port }}}}
|
||||
targetPort: http
|
||||
protocol: TCP
|
||||
name: http
|
||||
selector:
|
||||
{{{{- include "predict-otron-9000.selectorLabels" . | nindent 4 }}}}
|
||||
app.kubernetes.io/component: {}
|
||||
"#,
|
||||
service.name,
|
||||
service.name,
|
||||
service.name.replace("-", "_"),
|
||||
service.name.replace("-", "_"),
|
||||
service.name
|
||||
);
|
||||
|
||||
let filename = format!("{}-service.yaml", service.name);
|
||||
fs::write(templates_dir.join(filename), service_template)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_ingress_template(templates_dir: &Path, services: &[ServiceInfo]) -> Result<()> {
|
||||
let ingress_template = r#"{{- if .Values.ingress.enabled -}}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ include "predict-otron-9000.fullname" . }}
|
||||
labels:
|
||||
{{- include "predict-otron-9000.labels" . | nindent 4 }}
|
||||
{{- with .Values.ingress.annotations }}
|
||||
annotations:
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
{{- if .Values.ingress.className }}
|
||||
ingressClassName: {{ .Values.ingress.className }}
|
||||
{{- end }}
|
||||
{{- if .Values.ingress.tls }}
|
||||
tls:
|
||||
{{- range .Values.ingress.tls }}
|
||||
- hosts:
|
||||
{{- range .hosts }}
|
||||
- {{ . | quote }}
|
||||
{{- end }}
|
||||
secretName: {{ .secretName }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
rules:
|
||||
{{- range .Values.ingress.hosts }}
|
||||
- host: {{ .host | quote }}
|
||||
http:
|
||||
paths:
|
||||
{{- range .paths }}
|
||||
- path: {{ .path }}
|
||||
{{- if .pathType }}
|
||||
pathType: {{ .pathType }}
|
||||
{{- end }}
|
||||
backend:
|
||||
service:
|
||||
name: {{ include "predict-otron-9000.fullname" $ }}-{{ .backend.service.name }}
|
||||
port:
|
||||
number: {{ .backend.service.port.number }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
"#;
|
||||
|
||||
fs::write(templates_dir.join("ingress.yaml"), ingress_template)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_helpers_template(templates_dir: &Path) -> Result<()> {
|
||||
let helpers_template = r#"{{/*
|
||||
Expand the name of the chart.
|
||||
*/}}
|
||||
{{- define "predict-otron-9000.name" -}}
|
||||
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Create a default fully qualified app name.
|
||||
We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
|
||||
If release name contains chart name it will be used as a full name.
|
||||
*/}}
|
||||
{{- define "predict-otron-9000.fullname" -}}
|
||||
{{- if .Values.fullnameOverride }}
|
||||
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
|
||||
{{- else }}
|
||||
{{- $name := default .Chart.Name .Values.nameOverride }}
|
||||
{{- if contains $name .Release.Name }}
|
||||
{{- .Release.Name | trunc 63 | trimSuffix "-" }}
|
||||
{{- else }}
|
||||
{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Create chart name and version as used by the chart label.
|
||||
*/}}
|
||||
{{- define "predict-otron-9000.chart" -}}
|
||||
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Common labels
|
||||
*/}}
|
||||
{{- define "predict-otron-9000.labels" -}}
|
||||
helm.sh/chart: {{ include "predict-otron-9000.chart" . }}
|
||||
{{ include "predict-otron-9000.selectorLabels" . }}
|
||||
{{- if .Chart.AppVersion }}
|
||||
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
|
||||
{{- end }}
|
||||
app.kubernetes.io/managed-by: {{ .Release.Service }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Selector labels
|
||||
*/}}
|
||||
{{- define "predict-otron-9000.selectorLabels" -}}
|
||||
app.kubernetes.io/name: {{ include "predict-otron-9000.name" . }}
|
||||
app.kubernetes.io/instance: {{ .Release.Name }}
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Create the name of the service account to use
|
||||
*/}}
|
||||
{{- define "predict-otron-9000.serviceAccountName" -}}
|
||||
{{- if .Values.serviceAccount.create }}
|
||||
{{- default (include "predict-otron-9000.fullname" .) .Values.serviceAccount.name }}
|
||||
{{- else }}
|
||||
{{- default "default" .Values.serviceAccount.name }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
"#;
|
||||
|
||||
fs::write(templates_dir.join("_helpers.tpl"), helpers_template)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_helmignore(chart_dir: &Path) -> Result<()> {
|
||||
let helmignore_content = r#"# Patterns to ignore when building packages.
|
||||
# This supports shell glob matching, relative path matching, and
|
||||
# negation (prefixed with !). Only one pattern per line.
|
||||
.DS_Store
|
||||
# Common VCS dirs
|
||||
.git/
|
||||
.gitignore
|
||||
.bzr/
|
||||
.bzrignore
|
||||
.hg/
|
||||
.hgignore
|
||||
.svn/
|
||||
# Common backup files
|
||||
*.swp
|
||||
*.bak
|
||||
*.tmp
|
||||
*.orig
|
||||
*~
|
||||
# Various IDEs
|
||||
.project
|
||||
.idea/
|
||||
*.tmproj
|
||||
.vscode/
|
||||
"#;
|
||||
|
||||
fs::write(chart_dir.join(".helmignore"), helmignore_content)?;
|
||||
Ok(())
|
||||
}
|
24
integration/llama-runner/Cargo.toml
Normal file
24
integration/llama-runner/Cargo.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[package]
|
||||
name = "llama-runner"
|
||||
version.workspace = true
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git"}
|
||||
hf-hub = "0.3"
|
||||
tokenizers = "0.20"
|
||||
anyhow = "1.0"
|
||||
clap = { version = "4.0", features = ["derive", "string"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
|
188
integration/llama-runner/README.md
Normal file
188
integration/llama-runner/README.md
Normal file
@@ -0,0 +1,188 @@
|
||||
# Llama Runner
|
||||
|
||||
A fast Rust implementation for running Llama and other language models using the Candle deep learning framework. Built on the official Candle examples with optimizations for speed and usability.
|
||||
|
||||
## Features
|
||||
|
||||
- 🚀 **High Performance**: Metal GPU acceleration on macOS, CUDA support on Linux/Windows
|
||||
- 🤖 **Multiple Models**: Supports Llama 3.2, SmolLM2, TinyLlama, and more
|
||||
- ⚡ **Fast Inference**: Optimized with F16 precision and KV caching
|
||||
- 🎯 **Advanced Sampling**: Top-k, top-p, temperature, and repeat penalty controls
|
||||
- 📊 **Performance Metrics**: Real-time tokens/second reporting
|
||||
- 🔧 **Easy CLI**: Simple command-line interface with sensible defaults
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Size | Command | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| SmolLM2-135M | 135M | `smollm2-135m` | Tiny, fast model for testing |
|
||||
| SmolLM2-360M | 360M | `smollm2-360m` | Small, efficient model |
|
||||
| SmolLM2-1.7B | 1.7B | `smollm2-1.7b` | Balanced performance/speed |
|
||||
| Llama-3.2-1B | 1B | `llama-3.2-1b` | Meta's compact model |
|
||||
| Llama-3.2-3B | 3B | `llama-3.2-3b` | Larger Llama model |
|
||||
| TinyLlama-1.1B | 1.1B | `tinyllama-1.1b-chat` | Chat-optimized small model |
|
||||
|
||||
Add `-instruct` suffix for instruction-tuned variants (e.g., `smollm2-135m-instruct`).
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone <repository-url>
|
||||
cd llama-runner
|
||||
|
||||
# Build with GPU acceleration (recommended)
|
||||
cargo build --release --features metal # macOS
|
||||
cargo build --release --features cuda # Linux/Windows with NVIDIA GPU
|
||||
|
||||
# CPU-only build
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Fast inference with GPU acceleration
|
||||
cargo run --features metal -- --prompt "What is quantum computing?"
|
||||
|
||||
# Specify a model and parameters
|
||||
cargo run --features metal -- \
|
||||
--prompt "Write a short story about space exploration" \
|
||||
--model smollm2-360m \
|
||||
--max-tokens 100 \
|
||||
--temperature 0.8
|
||||
|
||||
# Use CPU (slower but works everywhere)
|
||||
cargo run -- --prompt "Hello, world!" --model smollm2-135m --cpu
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Text Generation
|
||||
```bash
|
||||
# Simple completion
|
||||
cargo run --features metal -- --prompt "The capital of France is"
|
||||
|
||||
# Creative writing with higher temperature
|
||||
cargo run --features metal -- \
|
||||
--prompt "Once upon a time" \
|
||||
--temperature 1.0 \
|
||||
--max-tokens 200
|
||||
```
|
||||
|
||||
### Advanced Sampling
|
||||
```bash
|
||||
# Top-k and top-p sampling
|
||||
cargo run --features metal -- \
|
||||
--prompt "Explain artificial intelligence" \
|
||||
--top-k 40 \
|
||||
--top-p 0.9 \
|
||||
--temperature 0.7
|
||||
|
||||
# Reduce repetition
|
||||
cargo run --features metal -- \
|
||||
--prompt "List the benefits of renewable energy" \
|
||||
--repeat-penalty 1.2 \
|
||||
--repeat-last-n 64
|
||||
```
|
||||
|
||||
### Different Models
|
||||
```bash
|
||||
# Ultra-fast with tiny model
|
||||
cargo run --features metal -- \
|
||||
--prompt "Quick test" \
|
||||
--model smollm2-135m
|
||||
|
||||
# Better quality with larger model
|
||||
cargo run --features metal -- \
|
||||
--prompt "Explain quantum physics" \
|
||||
--model llama-3.2-1b \
|
||||
--max-tokens 150
|
||||
```
|
||||
|
||||
## Command-Line Options
|
||||
|
||||
| Option | Short | Default | Description |
|
||||
|--------|-------|---------|-------------|
|
||||
| `--prompt` | `-p` | "The capital of France is" | Input prompt |
|
||||
| `--model` | `-m` | `smollm2-135m` | Model to use |
|
||||
| `--max-tokens` | `-n` | 100 | Maximum tokens to generate |
|
||||
| `--temperature` | `-t` | 0.8 | Sampling temperature (0.0 = deterministic) |
|
||||
| `--top-k` | | None | Top-k sampling |
|
||||
| `--top-p` | | None | Top-p (nucleus) sampling |
|
||||
| `--seed` | | 299792458 | Random seed for reproducibility |
|
||||
| `--repeat-penalty` | | 1.1 | Repetition penalty (1.0 = no penalty) |
|
||||
| `--repeat-last-n` | | 128 | Context window for repeat penalty |
|
||||
| `--cpu` | | false | Force CPU usage |
|
||||
| `--dtype` | | f16 | Data type: f16, bf16, f32 |
|
||||
| `--no-kv-cache` | | false | Disable key-value caching |
|
||||
|
||||
## Performance
|
||||
|
||||
Typical performance on Apple M2 with Metal acceleration:
|
||||
|
||||
| Model | Size | Speed | Memory |
|
||||
|-------|------|-------|--------|
|
||||
| SmolLM2-135M | 135M | ~100 tok/s | ~500MB |
|
||||
| SmolLM2-360M | 360M | ~80 tok/s | ~1GB |
|
||||
| SmolLM2-1.7B | 1.7B | ~50 tok/s | ~3GB |
|
||||
| Llama-3.2-1B | 1B | ~40 tok/s | ~2GB |
|
||||
|
||||
## Requirements
|
||||
|
||||
- **Rust**: 1.70+ (latest stable recommended)
|
||||
- **Memory**: 2-8GB RAM depending on model size
|
||||
- **Storage**: 1-10GB for model weights
|
||||
- **Network**: Internet connection for first-time model download
|
||||
- **GPU** (optional): Metal on macOS, CUDA on Linux/Windows
|
||||
|
||||
## GPU Support
|
||||
|
||||
### macOS (Metal)
|
||||
```bash
|
||||
cargo run --features metal -- [options]
|
||||
```
|
||||
|
||||
### Linux/Windows (CUDA)
|
||||
```bash
|
||||
cargo run --features cuda -- [options]
|
||||
```
|
||||
|
||||
### CPU Only
|
||||
```bash
|
||||
cargo run -- --cpu [options]
|
||||
```
|
||||
|
||||
## Model Downloads
|
||||
|
||||
Models are automatically downloaded from HuggingFace Hub on first use and cached locally. Download times:
|
||||
|
||||
- SmolLM2-135M: ~1 minute
|
||||
- SmolLM2-360M: ~2 minutes
|
||||
- Llama-3.2-1B: ~5 minutes
|
||||
- Larger models: 10+ minutes
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Slow Performance
|
||||
- Use `--features metal` on macOS or `--features cuda` on Linux/Windows
|
||||
- Try smaller models like `smollm2-135m` for faster inference
|
||||
- Ensure sufficient RAM for your chosen model
|
||||
|
||||
### Out of Memory
|
||||
- Use `--cpu` to use system RAM instead of GPU memory
|
||||
- Try smaller models or reduce `--max-tokens`
|
||||
- Use `--dtype f32` if f16 causes issues
|
||||
|
||||
### Model Download Issues
|
||||
- Check internet connection
|
||||
- Some models may require HuggingFace Hub authentication
|
||||
- Verify sufficient disk space in `~/.cache/huggingface/`
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions welcome! This project is based on the [Candle](https://github.com/huggingface/candle) framework by HuggingFace.
|
||||
|
||||
## License
|
||||
|
||||
MIT License - see LICENSE file for details.
|
7
integration/llama-runner/src/lib.rs
Normal file
7
integration/llama-runner/src/lib.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
pub mod llama_api;
|
||||
|
||||
use clap::ValueEnum;
|
||||
pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
||||
|
||||
// Re-export constants and types that might be needed
|
||||
pub const EOS_TOKEN: &str = "</s>";
|
333
integration/llama-runner/src/llama_api.rs
Normal file
333
integration/llama-runner/src/llama_api.rs
Normal file
@@ -0,0 +1,333 @@
|
||||
use crate::EOS_TOKEN;
|
||||
use anyhow::{bail, Error as E};
|
||||
use candle_core::{utils, DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use candle_transformers::models::llama as model;
|
||||
use candle_transformers::models::llama::{Llama, LlamaConfig};
|
||||
use clap::ValueEnum;
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use std::sync::mpsc::{self, Receiver};
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
|
||||
pub enum WhichModel {
|
||||
#[value(name = "llama-3.2-1b")]
|
||||
#[default]
|
||||
Llama32_1B,
|
||||
#[value(name = "llama-3.2-1b-instruct")]
|
||||
Llama32_1BInstruct,
|
||||
#[value(name = "llama-3.2-3b")]
|
||||
Llama32_3B,
|
||||
#[value(name = "llama-3.2-3b-instruct")]
|
||||
Llama32_3BInstruct,
|
||||
#[value(name = "smollm2-135m")]
|
||||
SmolLM2_135M,
|
||||
#[value(name = "smollm2-135m-instruct")]
|
||||
SmolLM2_135MInstruct,
|
||||
#[value(name = "smollm2-360m")]
|
||||
SmolLM2_360M,
|
||||
#[value(name = "smollm2-360m-instruct")]
|
||||
SmolLM2_360MInstruct,
|
||||
#[value(name = "smollm2-1.7b")]
|
||||
SmolLM2_1_7B,
|
||||
#[value(name = "smollm2-1.7b-instruct")]
|
||||
SmolLM2_1_7BInstruct,
|
||||
#[value(name = "tinyllama-1.1b-chat")]
|
||||
TinyLlama1_1BChat,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LlamaInferenceConfig {
|
||||
pub prompt: String,
|
||||
|
||||
pub model: WhichModel,
|
||||
pub cpu: bool,
|
||||
pub temperature: f64,
|
||||
pub top_p: Option<f64>,
|
||||
pub top_k: Option<usize>,
|
||||
pub seed: u64,
|
||||
pub max_tokens: usize,
|
||||
pub no_kv_cache: bool,
|
||||
pub dtype: Option<String>,
|
||||
pub model_id: Option<String>,
|
||||
pub revision: Option<String>,
|
||||
pub use_flash_attn: bool,
|
||||
pub repeat_penalty: f32,
|
||||
pub repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl Default for LlamaInferenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
// Leave prompt empty by default; let call sites set it.
|
||||
prompt: String::new(),
|
||||
|
||||
// Keep your existing model choice; swap at call-site if needed.
|
||||
model: WhichModel::Llama32_1BInstruct,
|
||||
|
||||
// Prefer GPU if available.
|
||||
cpu: false,
|
||||
|
||||
// Sampling: balanced + stable
|
||||
temperature: 0.7,
|
||||
top_p: Some(0.95),
|
||||
top_k: Some(50),
|
||||
|
||||
// Reproducible by default; override for variability.
|
||||
seed: 42,
|
||||
|
||||
// Don’t run unbounded generations.
|
||||
max_tokens: 512,
|
||||
|
||||
// Performance flags
|
||||
no_kv_cache: false, // keep cache ON for speed
|
||||
use_flash_attn: false, // great speed boost if supported
|
||||
|
||||
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
|
||||
dtype: Some("bf16".to_string()),
|
||||
|
||||
// Optional model source pinning (None = app defaults)
|
||||
model_id: None,
|
||||
revision: None,
|
||||
|
||||
// Anti-repeat heuristics
|
||||
repeat_penalty: 1.15,
|
||||
repeat_last_n: 128,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn device(cpu: bool) -> anyhow::Result<Device> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
} else if utils::cuda_is_available() {
|
||||
Ok(Device::new_cuda(0)?)
|
||||
} else if utils::metal_is_available() {
|
||||
Ok(Device::new_metal(0)?)
|
||||
} else {
|
||||
Ok(Device::Cpu)
|
||||
}
|
||||
}
|
||||
|
||||
fn hub_load_safetensors(
|
||||
api: &hf_hub::api::sync::ApiRepo,
|
||||
json_file: &str,
|
||||
) -> anyhow::Result<Vec<std::path::PathBuf>> {
|
||||
let json_file = api.get(json_file)?;
|
||||
let json_file = std::fs::File::open(json_file)?;
|
||||
let json: serde_json::Value = serde_json::from_reader(&json_file)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
None => bail!("no weight map in {json_file:?}"),
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
Some(_) => bail!("weight map in {json_file:?} is not a map"),
|
||||
};
|
||||
let mut safetensors_files = std::collections::HashSet::new();
|
||||
for value in weight_map.values() {
|
||||
if let Some(file) = value.as_str() {
|
||||
safetensors_files.insert(file.to_string());
|
||||
}
|
||||
}
|
||||
let safetensors_files = safetensors_files
|
||||
.iter()
|
||||
.map(|v| api.get(v))
|
||||
.collect::<anyhow::Result<Vec<_>, _>>()?;
|
||||
Ok(safetensors_files)
|
||||
}
|
||||
|
||||
pub fn run_llama_inference(
|
||||
cfg: LlamaInferenceConfig,
|
||||
) -> anyhow::Result<Receiver<anyhow::Result<String>>, anyhow::Error> {
|
||||
// ---- Device & dtype -----------------------------------------------------
|
||||
let device = device(cfg.cpu)?;
|
||||
println!("Device: {:?}", device);
|
||||
|
||||
let dtype = match cfg.dtype.as_deref() {
|
||||
Some("f16") => DType::F16,
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
println!("Using dtype: {:?}", dtype);
|
||||
|
||||
// ---- Load model & tokenizer --------------------------------------------
|
||||
let (llama, tokenizer, mut cache) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = cfg.model_id.clone().unwrap_or_else(|| {
|
||||
match cfg.model {
|
||||
WhichModel::Llama32_1B => "meta-llama/Llama-3.2-1B",
|
||||
WhichModel::Llama32_1BInstruct => "meta-llama/Llama-3.2-1B-Instruct",
|
||||
WhichModel::Llama32_3B => "meta-llama/Llama-3.2-3B",
|
||||
WhichModel::Llama32_3BInstruct => "meta-llama/Llama-3.2-3B-Instruct",
|
||||
WhichModel::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
|
||||
WhichModel::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||
WhichModel::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
|
||||
WhichModel::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
|
||||
WhichModel::SmolLM2_1_7B => "HuggingFaceTB/SmolLM2-1.7B",
|
||||
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
}
|
||||
.to_string()
|
||||
});
|
||||
println!("Loading model: {}", model_id);
|
||||
let revision = cfg.revision.clone().unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let config_filename = api.get("config.json")?;
|
||||
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let config = config.into_config(cfg.use_flash_attn);
|
||||
|
||||
let filenames = match cfg.model {
|
||||
WhichModel::Llama32_3B | WhichModel::Llama32_3BInstruct => {
|
||||
hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
}
|
||||
_ => vec![api.get("model.safetensors")?],
|
||||
};
|
||||
|
||||
let cache = model::Cache::new(!cfg.no_kv_cache, dtype, &config, &device)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let llama = Llama::load(vb, &config)?;
|
||||
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
(llama, tokenizer, cache)
|
||||
};
|
||||
|
||||
// ---- Prepare prompt & sampler ------------------------------------------
|
||||
let eos_token_id = tokenizer
|
||||
.token_to_id(EOS_TOKEN)
|
||||
.map(model::LlamaEosToks::Single);
|
||||
|
||||
let mut tokens = tokenizer
|
||||
.encode(cfg.prompt.as_str(), true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
println!("Starting inference...");
|
||||
|
||||
let mut logits_processor = {
|
||||
let temperature = cfg.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (cfg.top_k, cfg.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(cfg.seed, sampling)
|
||||
};
|
||||
|
||||
// Channel for streaming decoded fragments to the caller.
|
||||
let (tx, rx) = mpsc::channel::<anyhow::Result<String>>();
|
||||
|
||||
// ---- Spawn generation thread -------------------------------------------
|
||||
std::thread::spawn(move || {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0usize;
|
||||
let mut token_generated = 0usize;
|
||||
|
||||
for index in 0..cfg.max_tokens {
|
||||
// Use KV-cache for single-token step after the first pass.
|
||||
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
|
||||
(1, index_pos)
|
||||
} else {
|
||||
(tokens.len(), 0)
|
||||
};
|
||||
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = match Tensor::new(ctxt, &device).and_then(|t| t.unsqueeze(0)) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let logits = match llama.forward(&input, context_index, &mut cache) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
let logits = match logits.squeeze(0) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let logits = if cfg.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(cfg.repeat_last_n);
|
||||
match candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
cfg.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = match logits_processor.sample(&logits) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(e.into()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
token_generated += 1;
|
||||
tokens.push(next_token);
|
||||
|
||||
// Early stop on EOS.
|
||||
let stop = match eos_token_id {
|
||||
Some(model::LlamaEosToks::Single(eos_tok_id)) => next_token == eos_tok_id,
|
||||
Some(model::LlamaEosToks::Multiple(ref eos_ids)) => eos_ids.contains(&next_token),
|
||||
None => false,
|
||||
};
|
||||
if stop {
|
||||
break;
|
||||
}
|
||||
|
||||
// Decode this token's text and stream it out.
|
||||
match tokenizer.decode(&[next_token], false) {
|
||||
Ok(text) => {
|
||||
if !text.is_empty() {
|
||||
// Best-effort send; if receiver is gone, just stop.
|
||||
if tx.send(Ok(text)).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(anyhow::anyhow!("{}", e)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Optional: final stats as a debug line (not sent through the stream).
|
||||
let dt = start_gen.elapsed();
|
||||
eprintln!(
|
||||
"[llama-runner] {} tokens generated ({:.2} tokens/s)",
|
||||
token_generated,
|
||||
token_generated as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
// Dropping tx closes the stream.
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
}
|
108
integration/llama-runner/src/llama_cli.rs
Normal file
108
integration/llama-runner/src/llama_cli.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use crate::llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
||||
use clap::Parser;
|
||||
use std::io::Write;
|
||||
|
||||
#[derive(Parser, Debug, Default)]
|
||||
#[command(author, version, about = "Fast Llama inference with Candle", long_about = None)]
|
||||
struct Args {
|
||||
/// The prompt to generate text from
|
||||
#[arg(short, long, default_value = "The capital of France is")]
|
||||
prompt: String,
|
||||
|
||||
/// The model to use
|
||||
#[arg(short, long, default_value = "llama-3.2-1b-instruct")]
|
||||
model: WhichModel,
|
||||
|
||||
/// Run on CPU rather than GPU
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The temperature used to generate samples
|
||||
#[arg(short, long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens)
|
||||
#[arg(short = 'n', long, default_value_t = 100)]
|
||||
max_tokens: usize,
|
||||
|
||||
/// Disable the key-value cache
|
||||
#[arg(long)]
|
||||
no_kv_cache: bool,
|
||||
|
||||
/// Use different dtype than f16
|
||||
#[arg(long)]
|
||||
dtype: Option<String>,
|
||||
|
||||
/// Custom model ID from HuggingFace Hub
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
/// Model revision
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// Use flash attention
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty
|
||||
#[arg(long, default_value_t = 128)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl Into<LlamaInferenceConfig> for Args {
|
||||
fn into(self) -> LlamaInferenceConfig {
|
||||
LlamaInferenceConfig {
|
||||
prompt: self.prompt,
|
||||
model: self.model,
|
||||
cpu: self.cpu,
|
||||
temperature: self.temperature,
|
||||
top_p: self.top_p,
|
||||
top_k: self.top_k,
|
||||
seed: self.seed,
|
||||
max_tokens: self.max_tokens,
|
||||
no_kv_cache: self.no_kv_cache,
|
||||
dtype: self.dtype,
|
||||
model_id: self.model_id,
|
||||
revision: self.revision,
|
||||
use_flash_attn: self.use_flash_attn,
|
||||
repeat_penalty: self.repeat_penalty,
|
||||
repeat_last_n: self.repeat_last_n,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_cli() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let cfg = args.into();
|
||||
let rx = run_llama_inference(cfg)?;
|
||||
for msg in rx {
|
||||
match msg {
|
||||
Ok(tok) => {
|
||||
print!("{tok}");
|
||||
let _ = std::io::stdout().flush(); // <- force it out now
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("generation error: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
19
integration/llama-runner/src/main.rs
Normal file
19
integration/llama-runner/src/main.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
mod llama_api;
|
||||
mod llama_cli;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use std::io::Write;
|
||||
|
||||
use crate::llama_cli::run_cli;
|
||||
|
||||
const EOS_TOKEN: &str = "</s>";
|
||||
|
||||
fn main() -> Result<()> {
|
||||
run_cli()
|
||||
}
|
88
integration/utils/Cargo.toml
Normal file
88
integration/utils/Cargo.toml
Normal file
@@ -0,0 +1,88 @@
|
||||
[package]
|
||||
name = "utils"
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = {version = "0.3.2", optional = true }
|
||||
candle-nn = {version = "0.9.1" }
|
||||
candle-transformers = {version = "0.9.1" }
|
||||
|
||||
candle-flash-attn = {version = "0.9.1", optional = true }
|
||||
candle-onnx = {version = "0.9.1", optional = true }
|
||||
candle-core="0.9.1"
|
||||
csv = "1.3.0"
|
||||
anyhow = "1.0.99"
|
||||
cudarc = {version = "0.17.3", optional = true }
|
||||
half = {version = "2.6.0", optional = true }
|
||||
hf-hub = {version = "0.4.3", features = ["tokio"] }
|
||||
image = {version = "0.25.6" }
|
||||
intel-mkl-src = {version = "0.8.1", optional = true }
|
||||
num-traits = {version = "0.2.19" }
|
||||
palette = { version = "0.7.6", optional = true }
|
||||
enterpolation = { version = "0.2.1", optional = true }
|
||||
pyo3 = { version = "0.22.0", features = [
|
||||
"auto-initialize",
|
||||
"abi3-py311",
|
||||
], optional = true }
|
||||
rayon = {version = "1.11.0" }
|
||||
rubato = { version = "0.15.0", optional = true }
|
||||
safetensors = {version = "0.6.2" }
|
||||
serde = {version = "1.0.219" }
|
||||
serde_json = {version = "1.0.143" }
|
||||
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
||||
tokenizers = {version = "0.22.0", features = ["onig"] }
|
||||
cpal = { version = "0.15.2", optional = true }
|
||||
pdf2image = { version = "0.1.2", optional = true }
|
||||
tekken-rs = { version = "0.1.1", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = {version = "1.0.99" }
|
||||
byteorder = {version = "1.5.0" }
|
||||
clap = {version = "4.5.46" }
|
||||
imageproc = {version = "0.25.0" }
|
||||
memmap2 = {version = "0.9.8" }
|
||||
rand = {version = "0.9.2" }
|
||||
ab_glyph = {version = "0.2.31" }
|
||||
tracing = {version = "0.1.41" }
|
||||
tracing-chrome = {version = "0.7.2" }
|
||||
tracing-subscriber = {version = "0.3.20" }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.43.0"
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = {version = "1.0.99" }
|
||||
bindgen_cuda = { version = "0.1.1", optional = true }
|
||||
#
|
||||
[features]
|
||||
default = []
|
||||
accelerate = [
|
||||
"dep:accelerate-src",
|
||||
"candle-core/accelerate",
|
||||
"candle-nn/accelerate",
|
||||
"candle-transformers/accelerate",
|
||||
]
|
||||
cuda = [
|
||||
"candle-core/cuda",
|
||||
"candle-nn/cuda",
|
||||
"candle-transformers/cuda",
|
||||
"dep:bindgen_cuda",
|
||||
]
|
||||
cudnn = ["candle-core/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = [
|
||||
"dep:intel-mkl-src",
|
||||
"candle-core/mkl",
|
||||
"candle-nn/mkl",
|
||||
"candle-transformers/mkl",
|
||||
]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
onnx = ["candle-onnx"]
|
||||
metal = ["candle-core/metal", "candle-nn/metal"]
|
||||
microphone = ["cpal", "rubato"]
|
||||
encodec = ["cpal", "symphonia", "rubato"]
|
||||
mimi = ["cpal", "symphonia", "rubato"]
|
||||
snac = ["cpal", "symphonia", "rubato"]
|
||||
depth_anything_v2 = ["palette", "enterpolation"]
|
||||
tekken = ["tekken-rs"]
|
138
integration/utils/src/audio.rs
Normal file
138
integration/utils/src/audio.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
use candle_core::{Result, Tensor};
|
||||
|
||||
// https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/data/audio_utils.py#L57
|
||||
pub fn normalize_loudness(
|
||||
wav: &Tensor,
|
||||
sample_rate: u32,
|
||||
loudness_compressor: bool,
|
||||
) -> Result<Tensor> {
|
||||
let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::<f32>()?;
|
||||
if energy < 2e-3 {
|
||||
return Ok(wav.clone());
|
||||
}
|
||||
let wav_array = wav.to_vec1::<f32>()?;
|
||||
let mut meter = crate::bs1770::ChannelLoudnessMeter::new(sample_rate);
|
||||
meter.push(wav_array.into_iter());
|
||||
let power = meter.as_100ms_windows();
|
||||
let loudness = match crate::bs1770::gated_mean(power) {
|
||||
None => return Ok(wav.clone()),
|
||||
Some(gp) => gp.loudness_lkfs() as f64,
|
||||
};
|
||||
let delta_loudness = -14. - loudness;
|
||||
let gain = 10f64.powf(delta_loudness / 20.);
|
||||
let wav = (wav * gain)?;
|
||||
if loudness_compressor {
|
||||
wav.tanh()
|
||||
} else {
|
||||
Ok(wav)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "symphonia")]
|
||||
pub fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
|
||||
use symphonia::core::conv::FromSample;
|
||||
|
||||
fn conv<T>(
|
||||
samples: &mut Vec<f32>,
|
||||
data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>,
|
||||
) where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
// Open the media source.
|
||||
let src = std::fs::File::open(path).map_err(candle::Error::wrap)?;
|
||||
|
||||
// Create the media source stream.
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
|
||||
// Create a probe hint using the file's extension. [Optional]
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
|
||||
// Use the default options for metadata and format readers.
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
|
||||
// Probe the media source.
|
||||
let probed = symphonia::default::get_probe()
|
||||
.format(&hint, mss, &fmt_opts, &meta_opts)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
// Get the instantiated format reader.
|
||||
let mut format = probed.format;
|
||||
|
||||
// Find the first audio track with a known (decodeable) codec.
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
|
||||
.ok_or_else(|| candle::Error::Msg("no supported audio tracks".to_string()))?;
|
||||
|
||||
// Use the default options for the decoder.
|
||||
let dec_opts: DecoderOptions = Default::default();
|
||||
|
||||
// Create a decoder for the track.
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &dec_opts)
|
||||
.map_err(|_| candle::Error::Msg("unsupported codec".to_string()))?;
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
// The decode loop.
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
// Consume any new metadata that has been read since the last packet.
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
|
||||
// If the packet does not belong to the selected track, skip over it.
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet).map_err(candle::Error::wrap)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
||||
|
||||
#[cfg(feature = "rubato")]
|
||||
pub fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let mut pcm_out =
|
||||
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
|
||||
|
||||
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let mut output_buffer = resampler.output_buffer_allocate(true);
|
||||
let mut pos_in = 0;
|
||||
while pos_in + resampler.input_frames_next() < pcm_in.len() {
|
||||
let (in_len, out_len) = resampler
|
||||
.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
pos_in += in_len;
|
||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||
}
|
||||
|
||||
if pos_in < pcm_in.len() {
|
||||
let (_in_len, out_len) = resampler
|
||||
.process_partial_into_buffer(Some(&[&pcm_in[pos_in..]]), &mut output_buffer, None)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||
}
|
||||
|
||||
Ok(pcm_out)
|
||||
}
|
506
integration/utils/src/bs1770.rs
Normal file
506
integration/utils/src/bs1770.rs
Normal file
@@ -0,0 +1,506 @@
|
||||
// Copied from https://github.com/ruuda/bs1770/blob/master/src/lib.rs
|
||||
// BS1770 -- Loudness analysis library conforming to ITU-R BS.1770
|
||||
// Copyright 2020 Ruud van Asseldonk
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// A copy of the License has been included in the root of the repository.
|
||||
|
||||
//! Loudness analysis conforming to [ITU-R BS.1770-4][bs17704].
|
||||
//!
|
||||
//! This library offers the building blocks to perform BS.1770 loudness
|
||||
//! measurements, but you need to put the pieces together yourself.
|
||||
//!
|
||||
//! [bs17704]: https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
|
||||
//!
|
||||
//! # Stereo integrated loudness example
|
||||
//!
|
||||
//! ```ignore
|
||||
//! # fn load_stereo_audio() -> [Vec<i16>; 2] {
|
||||
//! # [vec![0; 48_000], vec![0; 48_000]]
|
||||
//! # }
|
||||
//! #
|
||||
//! let sample_rate_hz = 44_100;
|
||||
//! let bits_per_sample = 16;
|
||||
//! let channel_samples: [Vec<i16>; 2] = load_stereo_audio();
|
||||
//!
|
||||
//! // When converting integer samples to float, note that the maximum amplitude
|
||||
//! // is `1 << (bits_per_sample - 1)`, one bit is the sign bit.
|
||||
//! let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
|
||||
//!
|
||||
//! let channel_power: Vec<_> = channel_samples.iter().map(|samples| {
|
||||
//! let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate_hz);
|
||||
//! meter.push(samples.iter().map(|&s| s as f32 * normalizer));
|
||||
//! meter.into_100ms_windows()
|
||||
//! }).collect();
|
||||
//!
|
||||
//! let stereo_power = bs1770::reduce_stereo(
|
||||
//! channel_power[0].as_ref(),
|
||||
//! channel_power[1].as_ref(),
|
||||
//! );
|
||||
//!
|
||||
//! let gated_power = bs1770::gated_mean(
|
||||
//! stereo_power.as_ref()
|
||||
//! ).unwrap_or(bs1770::Power(0.0));
|
||||
//! println!("Integrated loudness: {:.1} LUFS", gated_power.loudness_lkfs());
|
||||
//! ```
|
||||
|
||||
use std::f32;
|
||||
|
||||
/// Coefficients for a 2nd-degree infinite impulse response filter.
|
||||
///
|
||||
/// Coefficient a0 is implicitly 1.0.
|
||||
#[derive(Clone)]
|
||||
struct Filter {
|
||||
a1: f32,
|
||||
a2: f32,
|
||||
b0: f32,
|
||||
b1: f32,
|
||||
b2: f32,
|
||||
|
||||
// The past two input and output samples.
|
||||
x1: f32,
|
||||
x2: f32,
|
||||
y1: f32,
|
||||
y2: f32,
|
||||
}
|
||||
|
||||
impl Filter {
|
||||
/// Stage 1 of th BS.1770-4 pre-filter.
|
||||
pub fn high_shelf(sample_rate_hz: f32) -> Filter {
|
||||
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
|
||||
let gain_db = 3.999_843_8;
|
||||
let q = 0.707_175_25;
|
||||
let center_hz = 1_681.974_5;
|
||||
|
||||
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L134-L143.
|
||||
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
|
||||
let vh = 10.0_f32.powf(gain_db / 20.0);
|
||||
let vb = vh.powf(0.499_666_78);
|
||||
let a0 = 1.0 + k / q + k * k;
|
||||
Filter {
|
||||
b0: (vh + vb * k / q + k * k) / a0,
|
||||
b1: 2.0 * (k * k - vh) / a0,
|
||||
b2: (vh - vb * k / q + k * k) / a0,
|
||||
a1: 2.0 * (k * k - 1.0) / a0,
|
||||
a2: (1.0 - k / q + k * k) / a0,
|
||||
|
||||
x1: 0.0,
|
||||
x2: 0.0,
|
||||
y1: 0.0,
|
||||
y2: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Stage 2 of th BS.1770-4 pre-filter.
|
||||
pub fn high_pass(sample_rate_hz: f32) -> Filter {
|
||||
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
|
||||
let q = 0.500_327_05;
|
||||
let center_hz = 38.135_47;
|
||||
|
||||
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L145-L151
|
||||
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
|
||||
Filter {
|
||||
a1: 2.0 * (k * k - 1.0) / (1.0 + k / q + k * k),
|
||||
a2: (1.0 - k / q + k * k) / (1.0 + k / q + k * k),
|
||||
b0: 1.0,
|
||||
b1: -2.0,
|
||||
b2: 1.0,
|
||||
|
||||
x1: 0.0,
|
||||
x2: 0.0,
|
||||
y1: 0.0,
|
||||
y2: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed the next input sample, get the next output sample.
|
||||
#[inline(always)]
|
||||
pub fn apply(&mut self, x0: f32) -> f32 {
|
||||
let y0 = 0.0 + self.b0 * x0 + self.b1 * self.x1 + self.b2 * self.x2
|
||||
- self.a1 * self.y1
|
||||
- self.a2 * self.y2;
|
||||
|
||||
self.x2 = self.x1;
|
||||
self.x1 = x0;
|
||||
self.y2 = self.y1;
|
||||
self.y1 = y0;
|
||||
|
||||
y0
|
||||
}
|
||||
}
|
||||
|
||||
/// Compensated sum, for summing many values of different orders of magnitude
|
||||
/// accurately.
|
||||
#[derive(Copy, Clone, PartialEq)]
|
||||
struct Sum {
|
||||
sum: f32,
|
||||
residue: f32,
|
||||
}
|
||||
|
||||
impl Sum {
|
||||
#[inline(always)]
|
||||
fn zero() -> Sum {
|
||||
Sum {
|
||||
sum: 0.0,
|
||||
residue: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn add(&mut self, x: f32) {
|
||||
let sum = self.sum + (self.residue + x);
|
||||
self.residue = (self.residue + x) - (sum - self.sum);
|
||||
self.sum = sum;
|
||||
}
|
||||
}
|
||||
|
||||
/// The mean of the squares of the K-weighted samples in a window of time.
|
||||
///
|
||||
/// K-weighted power is equivalent to K-weighted loudness, the only difference
|
||||
/// is one of scale: power is quadratic in sample amplitudes, whereas loudness
|
||||
/// units are logarithmic. `loudness_lkfs` and `from_lkfs` convert between power,
|
||||
/// and K-weighted Loudness Units relative to nominal Full Scale (LKFS).
|
||||
///
|
||||
/// The term “LKFS” (Loudness Units, K-Weighted, relative to nominal Full Scale)
|
||||
/// is used in BS.1770-4 to emphasize K-weighting, but the term is otherwise
|
||||
/// interchangeable with the more widespread term “LUFS” (Loudness Units,
|
||||
/// relative to Full Scale). Loudness units are related to decibels in the
|
||||
/// following sense: boosting a signal that has a loudness of
|
||||
/// -<var>L<sub>K</sub></var> LUFS by <var>L<sub>K</sub></var> dB (by
|
||||
/// multiplying the amplitude by 10<sup><var>L<sub>K</sub></var>/20</sup>) will
|
||||
/// bring the loudness to 0 LUFS.
|
||||
///
|
||||
/// K-weighting refers to a high-shelf and high-pass filter that model the
|
||||
/// effect that humans perceive a certain amount of power in low frequencies to
|
||||
/// be less loud than the same amount of power in higher frequencies. In this
|
||||
/// library the `Power` type is used exclusively to refer to power after applying K-weighting.
|
||||
///
|
||||
/// The nominal “full scale” is the range [-1.0, 1.0]. Because the power is the
|
||||
/// mean square of the samples, if no input samples exceeded the full scale, the
|
||||
/// power will be in the range [0.0, 1.0]. However, the power delivered by
|
||||
/// multiple channels, which is a weighted sum over individual channel powers,
|
||||
/// can exceed this range, because the weighted sum is not normalized.
|
||||
#[derive(Copy, Clone, PartialEq, PartialOrd)]
|
||||
pub struct Power(pub f32);
|
||||
|
||||
impl Power {
|
||||
/// Convert Loudness Units relative to Full Scale into a squared sample amplitude.
|
||||
///
|
||||
/// This is the inverse of `loudness_lkfs`.
|
||||
pub fn from_lkfs(lkfs: f32) -> Power {
|
||||
// The inverse of the formula below.
|
||||
Power(10.0_f32.powf((lkfs + 0.691) * 0.1))
|
||||
}
|
||||
|
||||
/// Return the loudness of this window in Loudness Units, K-weighted, relative to Full Scale.
|
||||
///
|
||||
/// This is the inverse of `from_lkfs`.
|
||||
pub fn loudness_lkfs(&self) -> f32 {
|
||||
// Equation 2 (p.5) of BS.1770-4.
|
||||
-0.691 + 10.0 * self.0.log10()
|
||||
}
|
||||
}
|
||||
|
||||
/// A `T` value for non-overlapping windows of audio, 100ms in length.
|
||||
///
|
||||
/// The `ChannelLoudnessMeter` applies K-weighting and then produces the power
|
||||
/// for non-overlapping windows of 100ms duration.
|
||||
///
|
||||
/// These non-overlapping 100ms windows can later be combined into overlapping
|
||||
/// windows of 400ms, spaced 100ms apart, to compute instantaneous loudness or
|
||||
/// to perform a gated measurement, or they can be combined into even larger
|
||||
/// windows for a momentary loudness measurement.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Windows100ms<T> {
|
||||
pub inner: T,
|
||||
}
|
||||
|
||||
impl<T> Windows100ms<T> {
|
||||
/// Wrap a new empty vector.
|
||||
pub fn new() -> Windows100ms<Vec<T>> {
|
||||
Windows100ms { inner: Vec::new() }
|
||||
}
|
||||
|
||||
/// Apply `as_ref` to the inner value.
|
||||
pub fn as_ref(&self) -> Windows100ms<&[Power]>
|
||||
where
|
||||
T: AsRef<[Power]>,
|
||||
{
|
||||
Windows100ms {
|
||||
inner: self.inner.as_ref(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply `as_mut` to the inner value.
|
||||
pub fn as_mut(&mut self) -> Windows100ms<&mut [Power]>
|
||||
where
|
||||
T: AsMut<[Power]>,
|
||||
{
|
||||
Windows100ms {
|
||||
inner: self.inner.as_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
/// Apply `len` to the inner value.
|
||||
pub fn len(&self) -> usize
|
||||
where
|
||||
T: AsRef<[Power]>,
|
||||
{
|
||||
self.inner.as_ref().len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Measures K-weighted power of non-overlapping 100ms windows of a single channel of audio.
|
||||
///
|
||||
/// # Output
|
||||
///
|
||||
/// The output of the meter is an intermediate result in the form of power for
|
||||
/// 100ms non-overlapping windows. The windows need to be processed further to
|
||||
/// get one of the instantaneous, momentary, and integrated loudness
|
||||
/// measurements defined in BS.1770.
|
||||
///
|
||||
/// The windows can also be inspected directly; the data is meaningful
|
||||
/// on its own (the K-weighted power delivered in that window of time), but it
|
||||
/// is not something that BS.1770 defines a term for.
|
||||
///
|
||||
/// # Multichannel audio
|
||||
///
|
||||
/// To perform a loudness measurement of multichannel audio, construct a
|
||||
/// `ChannelLoudnessMeter` per channel, and later combine the measured power
|
||||
/// with e.g. `reduce_stereo`.
|
||||
///
|
||||
/// # Instantaneous loudness
|
||||
///
|
||||
/// The instantaneous loudness is the power over a 400ms window, so you can
|
||||
/// average four 100ms windows. No special functionality is implemented to help
|
||||
/// with that at this time. ([Pull requests would be accepted.][contribute])
|
||||
///
|
||||
/// # Momentary loudness
|
||||
///
|
||||
/// The momentary loudness is the power over a 3-second window, so you can
|
||||
/// average thirty 100ms windows. No special functionality is implemented to
|
||||
/// help with that at this time. ([Pull requests would be accepted.][contribute])
|
||||
///
|
||||
/// # Integrated loudness
|
||||
///
|
||||
/// Use `gated_mean` to perform an integrated loudness measurement:
|
||||
///
|
||||
/// ```ignore
|
||||
/// # use std::iter;
|
||||
/// # use bs1770::{ChannelLoudnessMeter, gated_mean};
|
||||
/// # let sample_rate_hz = 44_100;
|
||||
/// # let samples_per_100ms = sample_rate_hz / 10;
|
||||
/// # let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
|
||||
/// # meter.push((0..44_100).map(|i| (i as f32 * 0.01).sin()));
|
||||
/// let integrated_loudness_lkfs = gated_mean(meter.as_100ms_windows())
|
||||
/// .unwrap_or(bs1770::Power(0.0))
|
||||
/// .loudness_lkfs();
|
||||
/// ```
|
||||
///
|
||||
/// [contribute]: https://github.com/ruuda/bs1770/blob/master/CONTRIBUTING.md
|
||||
#[derive(Clone)]
|
||||
pub struct ChannelLoudnessMeter {
|
||||
/// The number of samples that fit in 100ms of audio.
|
||||
samples_per_100ms: u32,
|
||||
|
||||
/// Stage 1 filter (head effects, high shelf).
|
||||
filter_stage1: Filter,
|
||||
|
||||
/// Stage 2 filter (high-pass).
|
||||
filter_stage2: Filter,
|
||||
|
||||
/// Sum of the squares over non-overlapping windows of 100ms.
|
||||
windows: Windows100ms<Vec<Power>>,
|
||||
|
||||
/// The number of samples in the current unfinished window.
|
||||
count: u32,
|
||||
|
||||
/// The sum of the squares of the samples in the current unfinished window.
|
||||
square_sum: Sum,
|
||||
}
|
||||
|
||||
impl ChannelLoudnessMeter {
|
||||
/// Construct a new loudness meter for the given sample rate.
|
||||
pub fn new(sample_rate_hz: u32) -> ChannelLoudnessMeter {
|
||||
ChannelLoudnessMeter {
|
||||
samples_per_100ms: sample_rate_hz / 10,
|
||||
filter_stage1: Filter::high_shelf(sample_rate_hz as f32),
|
||||
filter_stage2: Filter::high_pass(sample_rate_hz as f32),
|
||||
windows: Windows100ms::new(),
|
||||
count: 0,
|
||||
square_sum: Sum::zero(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed input samples for loudness analysis.
|
||||
///
|
||||
/// # Full scale
|
||||
///
|
||||
/// Full scale for the input samples is the interval [-1.0, 1.0]. If your
|
||||
/// input consists of signed integer samples, you can convert as follows:
|
||||
///
|
||||
/// ```ignore
|
||||
/// # let mut meter = bs1770::ChannelLoudnessMeter::new(44_100);
|
||||
/// # let bits_per_sample = 16_usize;
|
||||
/// # let samples = &[0_i16];
|
||||
/// // Note that the maximum amplitude is `1 << (bits_per_sample - 1)`,
|
||||
/// // one bit is the sign bit.
|
||||
/// let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
|
||||
/// meter.push(samples.iter().map(|&s| s as f32 * normalizer));
|
||||
/// ```
|
||||
///
|
||||
/// # Repeated calls
|
||||
///
|
||||
/// You can call `push` multiple times to feed multiple batches of samples.
|
||||
/// This is equivalent to feeding a single chained iterator. The leftover of
|
||||
/// samples that did not fill a full 100ms window is not discarded:
|
||||
///
|
||||
/// ```ignore
|
||||
/// # use std::iter;
|
||||
/// # use bs1770::ChannelLoudnessMeter;
|
||||
/// let sample_rate_hz = 44_100;
|
||||
/// let samples_per_100ms = sample_rate_hz / 10;
|
||||
/// let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
|
||||
///
|
||||
/// meter.push(iter::repeat(0.0).take(samples_per_100ms as usize - 1));
|
||||
/// assert_eq!(meter.as_100ms_windows().len(), 0);
|
||||
///
|
||||
/// meter.push(iter::once(0.0));
|
||||
/// assert_eq!(meter.as_100ms_windows().len(), 1);
|
||||
/// ```
|
||||
pub fn push<I: Iterator<Item = f32>>(&mut self, samples: I) {
|
||||
let normalizer = 1.0 / self.samples_per_100ms as f32;
|
||||
|
||||
// LLVM, if you could go ahead and inline those apply calls, and then
|
||||
// unroll and vectorize the loop, that'd be terrific.
|
||||
for x in samples {
|
||||
let y = self.filter_stage1.apply(x);
|
||||
let z = self.filter_stage2.apply(y);
|
||||
|
||||
self.square_sum.add(z * z);
|
||||
self.count += 1;
|
||||
|
||||
// TODO: Should this branch be marked cold?
|
||||
if self.count == self.samples_per_100ms {
|
||||
let mean_squares = Power(self.square_sum.sum * normalizer);
|
||||
self.windows.inner.push(mean_squares);
|
||||
// We intentionally do not reset the residue. That way, leftover
|
||||
// energy from this window is not lost, so for the file overall,
|
||||
// the sum remains more accurate.
|
||||
self.square_sum.sum = 0.0;
|
||||
self.count = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a reference to the 100ms windows analyzed so far.
|
||||
pub fn as_100ms_windows(&self) -> Windows100ms<&[Power]> {
|
||||
self.windows.as_ref()
|
||||
}
|
||||
|
||||
/// Return all 100ms windows analyzed so far.
|
||||
pub fn into_100ms_windows(self) -> Windows100ms<Vec<Power>> {
|
||||
self.windows
|
||||
}
|
||||
}
|
||||
|
||||
/// Combine power for multiple channels by taking a weighted sum.
|
||||
///
|
||||
/// Note that BS.1770-4 defines power for a multi-channel signal as a weighted
|
||||
/// sum over channels which is not normalized. This means that a stereo signal
|
||||
/// is inherently louder than a mono signal. For a mono signal played back on
|
||||
/// stereo speakers, you should therefore still apply `reduce_stereo`, passing
|
||||
/// in the same signal for both channels.
|
||||
pub fn reduce_stereo(
|
||||
left: Windows100ms<&[Power]>,
|
||||
right: Windows100ms<&[Power]>,
|
||||
) -> Windows100ms<Vec<Power>> {
|
||||
assert_eq!(
|
||||
left.len(),
|
||||
right.len(),
|
||||
"Channels must have the same length."
|
||||
);
|
||||
let mut result = Vec::with_capacity(left.len());
|
||||
for (l, r) in left.inner.iter().zip(right.inner) {
|
||||
result.push(Power(l.0 + r.0));
|
||||
}
|
||||
Windows100ms { inner: result }
|
||||
}
|
||||
|
||||
/// In-place version of `reduce_stereo` that stores the result in the former left channel.
|
||||
pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows100ms<&[Power]>) {
|
||||
assert_eq!(
|
||||
left.len(),
|
||||
right.len(),
|
||||
"Channels must have the same length."
|
||||
);
|
||||
for (l, r) in left.inner.iter_mut().zip(right.inner) {
|
||||
l.0 += r.0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
|
||||
///
|
||||
/// The integrated loudness measurement is not just the average power over the
|
||||
/// entire signal. BS.1770-4 defines two stages of gating that exclude
|
||||
/// parts of the signal, to ensure that silent parts do not contribute to the
|
||||
/// loudness measurement. This function performs that gating, and returns the
|
||||
/// average power over the windows that were not excluded.
|
||||
///
|
||||
/// The result of this function is the integrated loudness measurement.
|
||||
///
|
||||
/// When no signal remains after applying the gate, this function returns
|
||||
/// `None`. In particular, this happens when all of the signal is softer than
|
||||
/// -70 LKFS, including a signal that consists of pure silence.
|
||||
pub fn gated_mean(windows_100ms: Windows100ms<&[Power]>) -> Option<Power> {
|
||||
let mut gating_blocks = Vec::with_capacity(windows_100ms.len());
|
||||
|
||||
// Stage 1: an absolute threshold of -70 LKFS. (Equation 6, p.6.)
|
||||
let absolute_threshold = Power::from_lkfs(-70.0);
|
||||
|
||||
// Iterate over all 400ms windows.
|
||||
for window in windows_100ms.inner.windows(4) {
|
||||
// Note that the sum over channels has already been performed at this point.
|
||||
let gating_block_power = Power(0.25 * window.iter().map(|mean| mean.0).sum::<f32>());
|
||||
|
||||
if gating_block_power > absolute_threshold {
|
||||
gating_blocks.push(gating_block_power);
|
||||
}
|
||||
}
|
||||
|
||||
if gating_blocks.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Compute the loudness after applying the absolute gate, in order to
|
||||
// determine the threshold for the relative gate.
|
||||
let mut sum_power = Sum::zero();
|
||||
for &gating_block_power in &gating_blocks {
|
||||
sum_power.add(gating_block_power.0);
|
||||
}
|
||||
let absolute_gated_power = Power(sum_power.sum / (gating_blocks.len() as f32));
|
||||
|
||||
// Stage 2: Apply the relative gate.
|
||||
let relative_threshold = Power::from_lkfs(absolute_gated_power.loudness_lkfs() - 10.0);
|
||||
let mut sum_power = Sum::zero();
|
||||
let mut n_blocks = 0_usize;
|
||||
for &gating_block_power in &gating_blocks {
|
||||
if gating_block_power > relative_threshold {
|
||||
sum_power.add(gating_block_power.0);
|
||||
n_blocks += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if n_blocks == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let relative_gated_power = Power(sum_power.sum / n_blocks as f32);
|
||||
Some(relative_gated_power)
|
||||
}
|
82
integration/utils/src/coco_classes.rs
Normal file
82
integration/utils/src/coco_classes.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
pub const NAMES: [&str; 80] = [
|
||||
"person",
|
||||
"bicycle",
|
||||
"car",
|
||||
"motorbike",
|
||||
"aeroplane",
|
||||
"bus",
|
||||
"train",
|
||||
"truck",
|
||||
"boat",
|
||||
"traffic light",
|
||||
"fire hydrant",
|
||||
"stop sign",
|
||||
"parking meter",
|
||||
"bench",
|
||||
"bird",
|
||||
"cat",
|
||||
"dog",
|
||||
"horse",
|
||||
"sheep",
|
||||
"cow",
|
||||
"elephant",
|
||||
"bear",
|
||||
"zebra",
|
||||
"giraffe",
|
||||
"backpack",
|
||||
"umbrella",
|
||||
"handbag",
|
||||
"tie",
|
||||
"suitcase",
|
||||
"frisbee",
|
||||
"skis",
|
||||
"snowboard",
|
||||
"sports ball",
|
||||
"kite",
|
||||
"baseball bat",
|
||||
"baseball glove",
|
||||
"skateboard",
|
||||
"surfboard",
|
||||
"tennis racket",
|
||||
"bottle",
|
||||
"wine glass",
|
||||
"cup",
|
||||
"fork",
|
||||
"knife",
|
||||
"spoon",
|
||||
"bowl",
|
||||
"banana",
|
||||
"apple",
|
||||
"sandwich",
|
||||
"orange",
|
||||
"broccoli",
|
||||
"carrot",
|
||||
"hot dog",
|
||||
"pizza",
|
||||
"donut",
|
||||
"cake",
|
||||
"chair",
|
||||
"sofa",
|
||||
"pottedplant",
|
||||
"bed",
|
||||
"diningtable",
|
||||
"toilet",
|
||||
"tvmonitor",
|
||||
"laptop",
|
||||
"mouse",
|
||||
"remote",
|
||||
"keyboard",
|
||||
"cell phone",
|
||||
"microwave",
|
||||
"oven",
|
||||
"toaster",
|
||||
"sink",
|
||||
"refrigerator",
|
||||
"book",
|
||||
"clock",
|
||||
"vase",
|
||||
"scissors",
|
||||
"teddy bear",
|
||||
"hair drier",
|
||||
"toothbrush",
|
||||
];
|
1056
integration/utils/src/imagenet.rs
Normal file
1056
integration/utils/src/imagenet.rs
Normal file
File diff suppressed because it is too large
Load Diff
156
integration/utils/src/lib.rs
Normal file
156
integration/utils/src/lib.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
extern crate candle_core;
|
||||
extern crate candle_transformers;
|
||||
extern crate tokenizers;
|
||||
|
||||
pub mod audio;
|
||||
pub mod bs1770;
|
||||
pub mod coco_classes;
|
||||
pub mod imagenet;
|
||||
pub mod token_output_stream;
|
||||
pub mod wav;
|
||||
use candle_core::{Device, Tensor, utils::{cuda_is_available, metal_is_available}};
|
||||
|
||||
|
||||
pub fn device(cpu: bool) -> Result<Device, anyhow::Error> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
} else if cuda_is_available() {
|
||||
Ok(Device::new_cuda(0)?)
|
||||
} else if metal_is_available() {
|
||||
Ok(Device::new_metal(0)?)
|
||||
} else {
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
{
|
||||
println!(
|
||||
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
|
||||
);
|
||||
}
|
||||
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||
{
|
||||
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
||||
}
|
||||
Ok(Device::Cpu)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(
|
||||
p: P,
|
||||
resize_longest: Option<usize>,
|
||||
) -> Result<(Tensor, usize, usize), anyhow::Error> {
|
||||
let img = image::ImageReader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle_core::Error::wrap)?;
|
||||
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
|
||||
let img = match resize_longest {
|
||||
None => img,
|
||||
Some(resize_longest) => {
|
||||
let (height, width) = (img.height(), img.width());
|
||||
let resize_longest = resize_longest as u32;
|
||||
let (height, width) = if height < width {
|
||||
let h = (resize_longest * height) / width;
|
||||
(h, resize_longest)
|
||||
} else {
|
||||
let w = (resize_longest * width) / height;
|
||||
(resize_longest, w)
|
||||
};
|
||||
img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
|
||||
}
|
||||
};
|
||||
let (height, width) = (img.height() as usize, img.width() as usize);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
Ok((data, initial_h, initial_w))
|
||||
}
|
||||
|
||||
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
||||
p: P,
|
||||
width: usize,
|
||||
height: usize,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let img = image::ImageReader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle_core::Error::wrap)?
|
||||
.resize_to_fill(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::Triangle,
|
||||
);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1))
|
||||
}
|
||||
|
||||
/// Saves an image to disk using the image crate, this expects an input with shape
|
||||
/// (c, height, width).
|
||||
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<(), anyhow::Error> {
|
||||
let p = p.as_ref();
|
||||
let (channel, height, width) = img.dims3()?;
|
||||
if channel != 3 {
|
||||
anyhow::bail!("save_image expects an input of shape (3, height, width)")
|
||||
}
|
||||
let img = img.permute((1, 2, 0))?.flatten_all()?;
|
||||
let pixels = img.to_vec1::<u8>()?;
|
||||
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
||||
Some(image) => image,
|
||||
None => anyhow::bail!("error saving image {p:?}"),
|
||||
};
|
||||
image.save(p).map_err(candle_core::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Loads the safetensors files for a model from the hub based on a json index file.
|
||||
pub fn hub_load_safetensors(
|
||||
repo: &hf_hub::api::sync::ApiRepo,
|
||||
json_file: &str,
|
||||
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
|
||||
let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?;
|
||||
let json_file = std::fs::File::open(json_file)?;
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
None => anyhow::bail!("no weight map in {json_file:?}"),
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"),
|
||||
};
|
||||
let mut safetensors_files = std::collections::HashSet::new();
|
||||
for value in weight_map.values() {
|
||||
if let Some(file) = value.as_str() {
|
||||
safetensors_files.insert(file.to_string());
|
||||
}
|
||||
}
|
||||
let safetensors_files = safetensors_files
|
||||
.iter()
|
||||
.map(|v| {
|
||||
repo.get(v)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
|
||||
})
|
||||
.collect::<Result<Vec<_>, std::io::Error, >>()?;
|
||||
Ok(safetensors_files)
|
||||
}
|
||||
|
||||
pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
||||
path: P,
|
||||
json_file: &str,
|
||||
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
|
||||
let path = path.as_ref();
|
||||
let jsfile = std::fs::File::open(path.join(json_file))?;
|
||||
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
None => anyhow::bail!("no weight map in {json_file:?}"),
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"),
|
||||
};
|
||||
let mut safetensors_files = std::collections::HashSet::new();
|
||||
for value in weight_map.values() {
|
||||
if let Some(file) = value.as_str() {
|
||||
safetensors_files.insert(file);
|
||||
}
|
||||
}
|
||||
let safetensors_files: Vec<_> = safetensors_files
|
||||
.into_iter()
|
||||
.map(|v| path.join(v))
|
||||
.collect();
|
||||
Ok(safetensors_files)
|
||||
}
|
3
integration/utils/src/main.rs
Normal file
3
integration/utils/src/main.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
fn main() {
|
||||
println!("Hello, world!");
|
||||
}
|
85
integration/utils/src/token_output_stream.rs
Normal file
85
integration/utils/src/token_output_stream.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
use candle_core::Result;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
pub struct TokenOutputStream {
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
tokens: Vec<u32>,
|
||||
prev_index: usize,
|
||||
current_index: usize,
|
||||
}
|
||||
|
||||
impl TokenOutputStream {
|
||||
pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
|
||||
Self {
|
||||
tokenizer,
|
||||
tokens: Vec::new(),
|
||||
prev_index: 0,
|
||||
current_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> tokenizers::Tokenizer {
|
||||
self.tokenizer
|
||||
}
|
||||
|
||||
fn decode(&self, tokens: &[u32]) -> Result<String> {
|
||||
match self.tokenizer.decode(tokens, true) {
|
||||
Ok(str) => Ok(str),
|
||||
Err(err) => candle_core::bail!("cannot decode: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
|
||||
pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
|
||||
let prev_text = if self.tokens.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let tokens = &self.tokens[self.prev_index..self.current_index];
|
||||
self.decode(tokens)?
|
||||
};
|
||||
self.tokens.push(token);
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
|
||||
let text = text.split_at(prev_text.len());
|
||||
self.prev_index = self.current_index;
|
||||
self.current_index = self.tokens.len();
|
||||
Ok(Some(text.1.to_string()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_rest(&self) -> Result<Option<String>> {
|
||||
let prev_text = if self.tokens.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let tokens = &self.tokens[self.prev_index..self.current_index];
|
||||
self.decode(tokens)?
|
||||
};
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() {
|
||||
let text = text.split_at(prev_text.len());
|
||||
Ok(Some(text.1.to_string()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_all(&self) -> Result<String> {
|
||||
self.decode(&self.tokens)
|
||||
}
|
||||
|
||||
pub fn get_token(&self, token_s: &str) -> Option<u32> {
|
||||
self.tokenizer.get_vocab(true).get(token_s).copied()
|
||||
}
|
||||
|
||||
pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
|
||||
&self.tokenizer
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.tokens.clear();
|
||||
self.prev_index = 0;
|
||||
self.current_index = 0;
|
||||
}
|
||||
}
|
56
integration/utils/src/wav.rs
Normal file
56
integration/utils/src/wav.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use std::io::prelude::*;
|
||||
|
||||
pub trait Sample {
|
||||
fn to_i16(&self) -> i16;
|
||||
}
|
||||
|
||||
impl Sample for f32 {
|
||||
fn to_i16(&self) -> i16 {
|
||||
(self.clamp(-1.0, 1.0) * 32767.0) as i16
|
||||
}
|
||||
}
|
||||
|
||||
impl Sample for f64 {
|
||||
fn to_i16(&self) -> i16 {
|
||||
(self.clamp(-1.0, 1.0) * 32767.0) as i16
|
||||
}
|
||||
}
|
||||
|
||||
impl Sample for i16 {
|
||||
fn to_i16(&self) -> i16 {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_pcm_as_wav<W: Write, S: Sample>(
|
||||
w: &mut W,
|
||||
samples: &[S],
|
||||
sample_rate: u32,
|
||||
) -> std::io::Result<()> {
|
||||
let len = 12u32; // header
|
||||
let len = len + 24u32; // fmt
|
||||
let len = len + samples.len() as u32 * 2 + 8; // data
|
||||
let n_channels = 1u16;
|
||||
let bytes_per_second = sample_rate * 2 * n_channels as u32;
|
||||
w.write_all(b"RIFF")?;
|
||||
w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
|
||||
w.write_all(b"WAVE")?;
|
||||
|
||||
// Format block
|
||||
w.write_all(b"fmt ")?;
|
||||
w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes
|
||||
w.write_all(&1u16.to_le_bytes())?; // PCM
|
||||
w.write_all(&n_channels.to_le_bytes())?; // one channel
|
||||
w.write_all(&sample_rate.to_le_bytes())?;
|
||||
w.write_all(&bytes_per_second.to_le_bytes())?;
|
||||
w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample
|
||||
w.write_all(&16u16.to_le_bytes())?; // bits per sample
|
||||
|
||||
// Data block
|
||||
w.write_all(b"data")?;
|
||||
w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
|
||||
for sample in samples.iter() {
|
||||
w.write_all(&sample.to_i16().to_le_bytes())?
|
||||
}
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user