9 Commits

Author SHA1 Message Date
geoffsee
7bc9479a11 fix format issues, needs precommit hook 2025-08-31 13:24:51 -04:00
geoffsee
0580dc8c5e move cli into crates and stage for release 2025-08-31 13:23:50 -04:00
geoffsee
9e9aa69769 bump version in Cargo.toml 2025-08-31 11:04:31 -04:00
geoffsee
3eb1a5329b add rust compiler optimizations at workspace level, bump minor version and publish first release 2025-08-31 11:02:58 -04:00
geoffsee
eb1591aa5d fix fmt error 2025-08-31 10:52:48 -04:00
geoffsee
e6c417bd83 align dependencies across inference features 2025-08-31 10:49:04 -04:00
geoffsee
f5d2a85f2e cleanup, add ci 2025-08-31 10:31:20 -04:00
Geoff Seemueller
419e1c2ea7 fix Kubernetes spelling 2025-08-30 08:24:24 -04:00
Geoff Seemueller
06fdfcf898 clarify project intent 2025-08-30 08:23:38 -04:00
55 changed files with 1590 additions and 3352 deletions

49
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,49 @@
version: 2
updates:
# Monitor Rust dependencies in the main crate
- package-ecosystem: "cargo"
directory: "/crates/predict-otron-9000"
schedule:
interval: "weekly"
day: "monday"
time: "09:00"
timezone: "UTC"
# Focus on security updates with higher priority
open-pull-requests-limit: 10
reviewers:
- "security-team"
assignees:
- "maintainer"
labels:
- "dependencies"
- "security"
# Security updates get higher priority
allow:
- dependency-type: "all"
# Group minor and patch updates to reduce noise
# Separate major updates for careful review
ignore:
- dependency-name: "*"
update-types: ["version-update:semver-major"]
commit-message:
prefix: "deps"
include: "scope"
# Monitor security updates more frequently
- package-ecosystem: "cargo"
directory: "/crates/predict-otron-9000"
schedule:
interval: "daily"
# Only security updates in daily checks
allow:
- dependency-type: "direct"
update-types: ["security"]
- dependency-type: "indirect"
update-types: ["security"]
open-pull-requests-limit: 5
labels:
- "security-update"
- "high-priority"
commit-message:
prefix: "security"
include: "scope"

47
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,47 @@
name: CI
on:
push:
pull_request:
jobs:
build:
name: build-and-test
runs-on: ubuntu-latest
strategy:
fail-fast: false
steps:
- name: Checkout
uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Setup Rust
run: rustup update stable && rustup default stable
- name: Install clippy and rustfmt
run: rustup component add clippy rustfmt
- name: Cargo fmt (check)
run: cargo fmt --all -- --check
- name: Clippy
shell: bash
run: cargo clippy --all-targets
- name: Tests
shell: bash
run: cargo test --all
- name: Build Docs
shell: bash
run: |
cargo doc -p predict-otron-9000 --no-deps

235
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,235 @@
name: Release
on:
push:
tags:
- 'v*'
env:
CARGO_TERM_COLOR: always
jobs:
test:
name: Test before release
runs-on: ubuntu-latest
defaults:
run:
working-directory: crates/predict-otron-9000
strategy:
fail-fast: false
steps:
- name: Checkout
uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Setup Rust
run: rustup update stable && rustup default stable
- name: Setup Bun
uses: oven-sh/setup-bun@v2
- name: Install clippy and rustfmt
run: rustup component add clippy rustfmt
- name: Cargo fmt (check)
run: cargo fmt --all -- --check
- name: Clippy
shell: bash
run: cargo clippy --all-targets
- name: Tests
shell: bash
run: cargo test --all
# publish:
# name: Publish to crates.io
# runs-on: ubuntu-latest
# permissions:
# id-token: write # Required for OIDC token exchange https://crates.io/docs/trusted-publishing
# needs: test
# defaults:
# run:
# working-directory: crates/predict-otron-9000
# steps:
# - name: Checkout
# uses: actions/checkout@v4
#
# - uses: actions/cache@v4
# with:
# path: |
# ~/.cargo/bin/
# ~/.cargo/registry/index/
# ~/.cargo/registry/cache/
# ~/.cargo/git/db/
# target/
# key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
#
# - name: Setup Rust
# run: rustup update stable && rustup default stable
#
# - name: Verify tag matches version
# run: |
# TAG_VERSION=${GITHUB_REF#refs/tags/v}
# CARGO_VERSION=$(cargo metadata --no-deps --format-version 1 | jq -r '.packages[0].version')
# if [ "$TAG_VERSION" != "$CARGO_VERSION" ]; then
# echo "Tag version ($TAG_VERSION) does not match Cargo.toml version ($CARGO_VERSION)"
# exit 1
# fi
#
# # See Trusted publishing: https://crates.io/docs/trusted-publishing
# - uses: rust-lang/crates-io-auth-action@v1
# id: auth
#
# - run: cargo publish
# env:
# CARGO_REGISTRY_TOKEN: ${{ steps.auth.outputs.token }}
build-binaries:
name: Build binaries
runs-on: ${{ matrix.os }}
needs: test
strategy:
fail-fast: false
matrix:
include:
- target: x86_64-unknown-linux-gnu
os: ubuntu-latest
name: predict-otron-9000-x86_64-unknown-linux-gnu
- target: x86_64-apple-darwin
os: macos-latest
name: predict-otron-9000-x86_64-apple-darwin
- target: aarch64-apple-darwin
os: macos-latest
name: predict-otron-9000-aarch64-apple-darwin
- target: x86_64-pc-windows-msvc
os: windows-latest
name: predict-otron-9000-x86_64-pc-windows-msvc.exe
steps:
- name: Checkout
uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
key: ${{ runner.os }}-${{ matrix.target }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Setup Rust
run: rustup update stable && rustup default stable
- name: Add target
run: rustup target add ${{ matrix.target }}
- name: Build binary
run: cargo build --release --target ${{ matrix.target }} -p predict-otron-9000 -p cli
env:
CARGO_TERM_COLOR: always
- name: Package binary (Unix)
if: matrix.os != 'windows-latest'
run: |
cd target/${{ matrix.target }}/release
tar czf ../../../${{ matrix.name }}.tar.gz predict-otron-9000 cli
cd ../../../
- name: Package binary (Windows)
if: matrix.os == 'windows-latest'
run: |
cd target/${{ matrix.target }}/release
7z a ../../../${{ matrix.name }}.zip predict-otron-9000.exe cli.exe
cd ../../../
- name: Upload binary artifacts (Unix)
if: matrix.os != 'windows-latest'
uses: actions/upload-artifact@v4
with:
name: ${{ matrix.name }}
path: ${{ matrix.name }}.tar.gz
- name: Upload binary artifacts (Windows)
if: matrix.os == 'windows-latest'
uses: actions/upload-artifact@v4
with:
name: ${{ matrix.name }}
path: ${{ matrix.name }}.zip
release:
name: Create GitHub Release
runs-on: ubuntu-latest
needs: [test, build-binaries]
permissions:
contents: write
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Extract tag name
id: tag
run: echo "tag=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
- name: Generate changelog
id: changelog
run: |
# Get the previous tag
PREV_TAG=$(git describe --tags --abbrev=0 HEAD^ 2>/dev/null || echo "")
# Generate changelog
if [ -n "$PREV_TAG" ]; then
echo "## What's Changed" > changelog.md
echo "" >> changelog.md
git log --pretty=format:"* %s (%h)" ${PREV_TAG}..HEAD >> changelog.md
echo "" >> changelog.md
echo "" >> changelog.md
echo "**Full Changelog**: https://github.com/${{ github.repository }}/compare/${PREV_TAG}...${{ steps.tag.outputs.tag }}" >> changelog.md
else
echo "## What's Changed" > changelog.md
echo "" >> changelog.md
echo "Initial release of predict-otron-9000" >> changelog.md
echo "" >> changelog.md
echo "OpenAI Compatible Inference Server" >> changelog.md
fi
# Set the changelog as output (handle multiline)
echo "changelog<<EOF" >> $GITHUB_OUTPUT
cat changelog.md >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
- name: Download all artifacts
uses: actions/download-artifact@v4
with:
path: artifacts
- name: Create Release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
if [[ "${{ steps.tag.outputs.tag }}" == *"-"* ]]; then
PRERELEASE_FLAG="--prerelease"
else
PRERELEASE_FLAG=""
fi
gh release create "${{ steps.tag.outputs.tag }}" \
--title "Release ${{ steps.tag.outputs.tag }}" \
--notes-file changelog.md \
$PRERELEASE_FLAG \
artifacts/predict-otron-9000-x86_64-unknown-linux-gnu/predict-otron-9000-x86_64-unknown-linux-gnu.tar.gz \
artifacts/predict-otron-9000-x86_64-apple-darwin/predict-otron-9000-x86_64-apple-darwin.tar.gz \
artifacts/predict-otron-9000-aarch64-apple-darwin/predict-otron-9000-aarch64-apple-darwin.tar.gz \
artifacts/predict-otron-9000-x86_64-pc-windows-msvc.exe/predict-otron-9000-x86_64-pc-windows-msvc.exe.zip

3
.gitignore vendored
View File

@@ -76,3 +76,6 @@ venv/
*.bak *.bak
*.backup *.backup
*~ *~
/scripts/cli
!/scripts/cli.ts
/**/.*.bun-build

949
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -6,11 +6,43 @@ members = [
"crates/leptos-app", "crates/leptos-app",
"crates/helm-chart-tool", "crates/helm-chart-tool",
"crates/llama-runner", "crates/llama-runner",
"crates/gemma-runner" "crates/gemma-runner",
"crates/cli"
] ]
default-members = ["crates/predict-otron-9000"] default-members = ["crates/predict-otron-9000"]
resolver = "2" resolver = "2"
[workspace.package]
version = "0.1.2"
# Compiler optimization profiles for the workspace
[profile.release]
opt-level = 3
debug = false
strip = true
lto = "thin"
codegen-units = 1
panic = "abort"
[profile.dev]
opt-level = 0
debug = true
strip = false
overflow-checks = true
# Profile for fast development builds with some optimization
[profile.dev-opt]
inherits = "dev"
opt-level = 1
debug = true
overflow-checks = true
# Profile for benchmarking and profiling
[profile.bench]
opt-level = 3
debug = true
lto = "thin"
[[workspace.metadata.leptos]] [[workspace.metadata.leptos]]
# project name # project name
bin-package = "leptos-app" bin-package = "leptos-app"

View File

@@ -1,11 +1,20 @@
# predict-otron-9000 <h1 align="center">
predict-otron-9000
A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces. </h1>
<p align="center"> <p align="center">
Powerful local AI inference with OpenAI-compatible APIs Powerful local AI inference with OpenAI-compatible APIs
</p> </p>
<br/>
> This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks.
> By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems.
Stability is currently best effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces.
## Project Overview ## Project Overview
The predict-otron-9000 is a flexible AI platform that provides: The predict-otron-9000 is a flexible AI platform that provides:
@@ -24,7 +33,7 @@ The system supports both CPU and GPU acceleration (CUDA/Metal), with intelligent
- **Text Embeddings**: Generate high-quality text embeddings using FastEmbed - **Text Embeddings**: Generate high-quality text embeddings using FastEmbed
- **Text Generation**: Chat completions with OpenAI-compatible API using Gemma and Llama models (various sizes including instruction-tuned variants) - **Text Generation**: Chat completions with OpenAI-compatible API using Gemma and Llama models (various sizes including instruction-tuned variants)
- **Performance Optimized**: Efficient caching and platform-specific optimizations for improved throughput - **Performance Optimized**: Efficient caching and platform-specific optimizations for improved throughput
- **Web Chat Interface**: Leptos-based WebAssembly (WASM) chat interface for browser-based interaction - **Web Chat Interface**: Leptos chat interface
- **Flexible Deployment**: Run as monolithic service or microservices architecture - **Flexible Deployment**: Run as monolithic service or microservices architecture
## Architecture Overview ## Architecture Overview
@@ -50,7 +59,7 @@ crates/
- **Main Server** (port 8080): Orchestrates inference and embeddings services - **Main Server** (port 8080): Orchestrates inference and embeddings services
- **Embeddings Service** (port 8080): Standalone FastEmbed service with OpenAI API compatibility - **Embeddings Service** (port 8080): Standalone FastEmbed service with OpenAI API compatibility
- **Web Frontend** (port 8788): Leptos WASM chat interface served by Trunk - **Web Frontend** (port 8788): cargo leptos SSR app
- **CLI Client**: TypeScript/Bun client for testing and automation - **CLI Client**: TypeScript/Bun client for testing and automation
### Deployment Modes ### Deployment Modes
@@ -278,7 +287,7 @@ cargo test --workspace
**End-to-end test script:** **End-to-end test script:**
```bash ```bash
./test.sh ./smoke_test.sh
``` ```
This script: This script:
@@ -469,7 +478,7 @@ cd crates/leptos-app && ./run.sh &
**Integration test:** **Integration test:**
```bash ```bash
./test.sh ./smoke_test.sh
``` ```
**Cleanup:** **Cleanup:**

22
bun.lock Normal file
View File

@@ -0,0 +1,22 @@
{
"lockfileVersion": 1,
"workspaces": {
"": {
"name": "predict-otron-9000",
},
"crates/cli/package": {
"name": "cli",
"dependencies": {
"install": "^0.13.0",
"openai": "^5.16.0",
},
},
},
"packages": {
"cli": ["cli@workspace:crates/cli/package"],
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
"openai": ["openai@5.16.0", "", { "peerDependencies": { "ws": "^8.18.0", "zod": "^3.23.8" }, "optionalPeers": ["ws", "zod"], "bin": { "openai": "bin/cli" } }, "sha512-hoEH8ZNvg1HXjU9mp88L/ZH8O082Z8r6FHCXGiWAzVRrEv443aI57qhch4snu07yQydj+AUAWLenAiBXhu89Tw=="],
}
}

11
crates/cli/Cargo.toml Normal file
View File

@@ -0,0 +1,11 @@
[package]
name = "cli"
version.workspace = true
edition = "2021"
build = "build.rs"
[[bin]]
name = "cli"
path = "src/main.rs"
[dependencies]

23
crates/cli/README.md Normal file
View File

@@ -0,0 +1,23 @@
# cli
A Rust/Typescript Hybrid
```console
./cli [options] [prompt]
Simple CLI tool for testing the local OpenAI-compatible API server.
Options:
--model <model> Model to use (default: gemma-3-1b-it)
--prompt <prompt> The prompt to send (can also be provided as positional argument)
--list-models List all available models from the server
--help Show this help message
Examples:
./cli "What is the capital of France?"
./cli --model gemma-3-1b-it --prompt "Hello, world!"
./cli --prompt "Who was the 16th president of the United States?"
./cli --list-models
The server must be running at http://localhost:8080
```

204
crates/cli/build.rs Normal file
View File

@@ -0,0 +1,204 @@
use std::env;
use std::fs;
use std::io::{self, BufRead, Write};
use std::path::{Path, PathBuf};
use std::process::{ChildStderr, ChildStdout, Command, Stdio};
use std::thread;
use std::time::{Duration, SystemTime};
mod bun_target;
use bun_target::BunTarget;
fn main() {
println!("cargo:rerun-if-changed=");
if let Err(e) = run_build() {
println!("cargo:warning=build.rs failed: {e}");
std::process::exit(1);
}
}
fn run_build() -> io::Result<()> {
let manifest_dir =
PathBuf::from(env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"));
let package_dir = manifest_dir.join("package");
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set by Cargo"));
let output_path = out_dir.join("client-cli");
let bun_tgt = BunTarget::from_cargo_env()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
// Optional: warn if using a Bun target thats marked unsupported in your chart
if matches!(bun_tgt, BunTarget::WindowsArm64) {
println!(
"cargo:warning=bun-windows-arm64 is marked unsupported in the compatibility chart"
);
}
warn(&format!("Building CLI into: {}", output_path.display()));
// --- bun install (in ./package), keep temps inside OUT_DIR ---
let mut install = Command::new("bun")
.current_dir(&package_dir)
.env("TMPDIR", &out_dir)
.arg("install")
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| io::Error::new(e.kind(), format!("Failed to spawn `bun install`: {e}")))?;
let install_join = stream_child("bun install", install.stdout.take(), install.stderr.take());
let install_status = install.wait()?;
// ensure streams finish
join_streams(install_join);
if !install_status.success() {
let code = install_status.code().unwrap_or(1);
return Err(io::Error::new(
io::ErrorKind::Other,
format!("bun install failed with status {code}"),
));
}
let target = env::var("TARGET").unwrap();
// --- bun build (in ./package), emit to OUT_DIR, keep temps inside OUT_DIR ---
let mut build = Command::new("bun")
.current_dir(&package_dir)
.env("TMPDIR", &out_dir)
.arg("build")
.arg("./cli.ts")
.arg(format!("--target={}", bun_tgt.as_bun_flag()))
.arg("--compile")
.arg("--outfile")
.arg(&output_path)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| io::Error::new(e.kind(), format!("Failed to spawn `bun build`: {e}")))?;
let build_join = stream_child("bun build", build.stdout.take(), build.stderr.take());
let status = build.wait()?;
// ensure streams finish
join_streams(build_join);
if status.success() {
info("bun build succeeded");
} else {
let code = status.code().unwrap_or(1);
warn(&format!("bun build failed with status: {code}"));
return Err(io::Error::new(io::ErrorKind::Other, "bun build failed"));
}
// Ensure the output is executable (after it exists)
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(&output_path)?.permissions();
perms.set_mode(0o755);
fs::set_permissions(&output_path, perms)?;
}
println!("cargo:warning=Built CLI at {}", output_path.display());
println!("cargo:rustc-env=CLIENT_CLI_BIN={}", output_path.display());
// --- Cleanup stray .bun-build temp files (conservative: older than 5 minutes) ---
for dir in [&manifest_dir, &package_dir, &out_dir] {
if let Err(e) = remove_bun_temp_files(dir, Some(Duration::from_secs(5 * 60))) {
println!("cargo:warning=cleanup in {} failed: {e}", dir.display());
}
}
Ok(())
}
// Spawn readers for child's stdout/stderr so we don't deadlock on pipe buffers
fn stream_child(
tag: &str,
stdout: Option<ChildStdout>,
stderr: Option<ChildStderr>,
) -> (
Option<thread::JoinHandle<()>>,
Option<thread::JoinHandle<()>>,
) {
let t1 = stdout.map(|out| {
let tag = tag.to_string();
thread::spawn(move || {
let reader = io::BufReader::new(out);
for line in reader.lines() {
info(&format!("[{tag} stdout] {}", line.unwrap_or_default()));
}
})
});
let t2 = stderr.map(|err| {
let tag = tag.to_string();
thread::spawn(move || {
let reader = io::BufReader::new(err);
for line in reader.lines() {
warn(&format!("[{tag} stderr] {}", line.unwrap_or_default()));
}
})
});
(t1, t2)
}
fn join_streams(
joins: (
Option<thread::JoinHandle<()>>,
Option<thread::JoinHandle<()>>,
),
) {
if let Some(j) = joins.0 {
let _ = j.join();
}
if let Some(j) = joins.1 {
let _ = j.join();
}
}
fn remove_bun_temp_files(dir: &Path, older_than: Option<Duration>) -> io::Result<()> {
let now = SystemTime::now();
for entry in fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if !path.is_file() {
continue;
}
// Files like ".1860e7df40ff1bef-00000000.bun-build"
let name = entry.file_name();
let name = name.to_string_lossy();
let looks_like_bun_temp = name.starts_with('.') && name.ends_with(".bun-build");
if !looks_like_bun_temp {
continue;
}
if let Some(age) = older_than {
if let Ok(meta) = entry.metadata() {
if let Ok(modified) = meta.modified() {
if now.duration_since(modified).unwrap_or_default() < age {
// too new; skip to avoid racing an in-flight builder
continue;
}
}
}
}
match fs::remove_file(&path) {
Ok(_) => println!("cargo:warning=removed stray bun temp {}", path.display()),
Err(e) => println!("cargo:warning=failed to remove {}: {e}", path.display()),
}
}
Ok(())
}
fn warn(msg: &str) {
let _ = writeln!(io::stderr(), "[build.rs] {msg}");
println!("cargo:warning={msg}");
}
fn info(msg: &str) {
let _ = writeln!(io::stderr(), "[build.rs] {msg}");
println!("cargo:warning=INFO|{msg}");
}

131
crates/cli/bun_target.rs Normal file
View File

@@ -0,0 +1,131 @@
use std::env;
use std::fmt;
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub enum BunTarget {
LinuxX64Glibc,
LinuxArm64Glibc,
LinuxX64Musl,
LinuxArm64Musl,
WindowsX64,
WindowsArm64,
MacX64,
MacArm64,
}
impl BunTarget {
pub const fn as_bun_flag(self) -> &'static str {
match self {
BunTarget::LinuxX64Glibc => "bun-linux-x64",
BunTarget::LinuxArm64Glibc => "bun-linux-arm64",
BunTarget::LinuxX64Musl => "bun-linux-x64-musl",
BunTarget::LinuxArm64Musl => "bun-linux-arm64-musl",
BunTarget::WindowsX64 => "bun-windows-x64",
BunTarget::WindowsArm64 => "bun-windows-arm64",
BunTarget::MacX64 => "bun-darwin-x64",
BunTarget::MacArm64 => "bun-darwin-arm64",
}
}
pub const fn rust_triples(self) -> &'static [&'static str] {
match self {
BunTarget::LinuxX64Glibc => {
&["x86_64-unknown-linux-gnu", "x86_64-unknown-linux-gnu.2.17"]
}
BunTarget::LinuxArm64Glibc => &["aarch64-unknown-linux-gnu"],
BunTarget::LinuxX64Musl => &["x86_64-unknown-linux-musl"],
BunTarget::LinuxArm64Musl => &["aarch64-unknown-linux-musl"],
BunTarget::WindowsX64 => &["x86_64-pc-windows-msvc"],
BunTarget::WindowsArm64 => &["aarch64-pc-windows-msvc"], // chart says unsupported; still map
BunTarget::MacX64 => &["x86_64-apple-darwin"],
BunTarget::MacArm64 => &["aarch64-apple-darwin"],
}
}
pub fn from_rust_target(triple: &str) -> Option<Self> {
let norm = triple.trim();
if norm.starts_with("x86_64-") && norm.contains("-linux-") && norm.ends_with("gnu") {
return Some(BunTarget::LinuxX64Glibc);
}
if norm.starts_with("aarch64-") && norm.contains("-linux-") && norm.ends_with("gnu") {
return Some(BunTarget::LinuxArm64Glibc);
}
if norm.starts_with("x86_64-") && norm.contains("-linux-") && norm.ends_with("musl") {
return Some(BunTarget::LinuxX64Musl);
}
if norm.starts_with("aarch64-") && norm.contains("-linux-") && norm.ends_with("musl") {
return Some(BunTarget::LinuxArm64Musl);
}
if norm == "x86_64-pc-windows-msvc" {
return Some(BunTarget::WindowsX64);
}
if norm == "aarch64-pc-windows-msvc" {
return Some(BunTarget::WindowsArm64);
}
if norm == "x86_64-apple-darwin" {
return Some(BunTarget::MacX64);
}
if norm == "aarch64-apple-darwin" {
return Some(BunTarget::MacArm64);
}
for bt in [
BunTarget::LinuxX64Glibc,
BunTarget::LinuxArm64Glibc,
BunTarget::LinuxX64Musl,
BunTarget::LinuxArm64Musl,
BunTarget::WindowsX64,
BunTarget::WindowsArm64,
BunTarget::MacX64,
BunTarget::MacArm64,
] {
for &t in bt.rust_triples() {
if t == norm {
return Some(bt);
}
}
}
None
}
pub fn from_cargo_env() -> Result<Self, BunTargetError> {
if let Ok(triple) = env::var("TARGET") {
if let Some(bt) = Self::from_rust_target(&triple) {
return Ok(bt);
}
return Err(BunTargetError::UnknownTriple(triple));
}
let os = env::var("CARGO_CFG_TARGET_OS").unwrap_or_default();
let arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap_or_default();
let envv = env::var("CARGO_CFG_TARGET_ENV").unwrap_or_default();
let vendor = env::var("CARGO_CFG_TARGET_VENDOR").unwrap_or_else(|_| "unknown".into());
let triple = format!(
"{}-{}-{}-{}",
arch,
vendor,
os,
if envv.is_empty() { "gnu" } else { &envv }
);
if let Some(bt) = Self::from_rust_target(&triple) {
Ok(bt)
} else {
Err(BunTargetError::UnknownTriple(triple))
}
}
}
#[derive(Debug)]
pub enum BunTargetError {
UnknownTriple(String),
}
impl fmt::Display for BunTargetError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BunTargetError::UnknownTriple(t) => write!(f, "unrecognized Rust target triple: {t}"),
}
}
}
impl std::error::Error for BunTargetError {}

View File

@@ -30,24 +30,23 @@ type ChunkStat = {
function printHelp() { function printHelp() {
console.log(` console.log(`
Usage: bun client_cli.ts [options] [prompt] ./cli [options] [prompt]
Simple CLI tool for testing the local OpenAI-compatible API server. Simple CLI tool for testing the local OpenAI-compatible API server.
Options: Options:
--model <model> Model to use (default: ${DEFAULT_MODEL}) --model <model> Model to use (default: gemma-3-1b-it)
--prompt <prompt> The prompt to send (can also be provided as positional argument) --prompt <prompt> The prompt to send (can also be provided as positional argument)
--list-models List all available models from the server --list-models List all available models from the server
--help Show this help message --help Show this help message
Examples: Examples:
./cli.ts "What is the capital of France?" ./cli "What is the capital of France?"
./cli.ts --model gemma-3-1b-it --prompt "Hello, world!" ./cli --model gemma-3-1b-it --prompt "Hello, world!"
./cli.ts --prompt "Who was the 16th president of the United States?" ./cli --prompt "Who was the 16th president of the United States?"
./cli.ts --list-models ./cli --list-models
The server should be running at http://localhost:8080 The server must be running at http://localhost:8080
Start it with: ./run_server.sh
`); `);
} }

View File

@@ -0,0 +1,11 @@
{
"name": "cli",
"main": "cli.ts",
"scripts": {
"build": "bun build cli.ts --compile --outfile cli"
},
"dependencies": {
"install": "^0.13.0",
"openai": "^5.16.0"
}
}

32
crates/cli/src/main.rs Normal file
View File

@@ -0,0 +1,32 @@
use std::{env, fs, io, path::PathBuf, process::Command};
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
fn main() -> io::Result<()> {
// Absolute path provided by build.rs at compile time.
// `include_bytes!` accepts string literals; `env!` expands to a literal at compile time.
const CLIENT_CLI: &[u8] = include_bytes!(env!("CLIENT_CLI_BIN"));
// Write to a temp file
let mut tmp = env::temp_dir();
tmp.push("client-cli-embedded");
fs::write(&tmp, CLIENT_CLI)?;
// Ensure it's executable on Unix
#[cfg(unix)]
{
let mut perms = fs::metadata(&tmp)?.permissions();
perms.set_mode(0o755);
fs::set_permissions(&tmp, perms)?;
}
// Run it
let status = Command::new(&tmp).arg("--version").status()?;
if !status.success() {
return Err(io::Error::new(io::ErrorKind::Other, "client-cli failed"));
}
Ok(())
}

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "embeddings-engine" name = "embeddings-engine"
version = "0.1.0" version.workspace = true
edition = "2024" edition = "2024"
[lib] [lib]

View File

@@ -1,9 +1,5 @@
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput}; use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
use axum::{ use axum::{Json, Router, response::Json as ResponseJson, routing::post};
response::Json as ResponseJson, routing::{post},
Json,
Router,
};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
@@ -15,12 +11,15 @@ static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
let model_start_time = std::time::Instant::now(); let model_start_time = std::time::Instant::now();
let model = TextEmbedding::try_new( let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true) InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
) )
.expect("Failed to initialize persistent embedding model"); .expect("Failed to initialize persistent embedding model");
let model_init_time = model_start_time.elapsed(); let model_init_time = model_start_time.elapsed();
tracing::info!("Persistent embedding model initialized in {:.2?}", model_init_time); tracing::info!(
"Persistent embedding model initialized in {:.2?}",
model_init_time
);
model model
}); });
@@ -37,7 +36,10 @@ pub async fn embeddings_create(
// Access the lazy-initialized persistent model instance // Access the lazy-initialized persistent model instance
// This will only initialize the model on the first request // This will only initialize the model on the first request
let model_access_time = model_start_time.elapsed(); let model_access_time = model_start_time.elapsed();
tracing::debug!("Persistent model access completed in {:.2?}", model_access_time); tracing::debug!(
"Persistent model access completed in {:.2?}",
model_access_time
);
// Phase 2: Process input // Phase 2: Process input
let input_start_time = std::time::Instant::now(); let input_start_time = std::time::Instant::now();
@@ -55,7 +57,10 @@ pub async fn embeddings_create(
}; };
let input_processing_time = input_start_time.elapsed(); let input_processing_time = input_start_time.elapsed();
tracing::debug!("Input processing completed in {:.2?}", input_processing_time); tracing::debug!(
"Input processing completed in {:.2?}",
input_processing_time
);
// Phase 3: Generate embeddings // Phase 3: Generate embeddings
let embedding_start_time = std::time::Instant::now(); let embedding_start_time = std::time::Instant::now();
@@ -65,25 +70,39 @@ pub async fn embeddings_create(
.expect("failed to embed document"); .expect("failed to embed document");
let embedding_generation_time = embedding_start_time.elapsed(); let embedding_generation_time = embedding_start_time.elapsed();
tracing::info!("Embedding generation completed in {:.2?}", embedding_generation_time); tracing::info!(
"Embedding generation completed in {:.2?}",
embedding_generation_time
);
// Memory usage estimation (approximate) // Memory usage estimation (approximate)
let embedding_size_bytes = embeddings.iter() let embedding_size_bytes = embeddings
.iter()
.map(|e| e.len() * std::mem::size_of::<f32>()) .map(|e| e.len() * std::mem::size_of::<f32>())
.sum::<usize>(); .sum::<usize>();
tracing::debug!("Embedding size: {:.2} MB", embedding_size_bytes as f64 / 1024.0 / 1024.0); tracing::debug!(
"Embedding size: {:.2} MB",
embedding_size_bytes as f64 / 1024.0 / 1024.0
);
// Only log detailed embedding information at trace level to reduce log volume // Only log detailed embedding information at trace level to reduce log volume
tracing::trace!("Embeddings length: {}", embeddings.len()); tracing::trace!("Embeddings length: {}", embeddings.len());
tracing::info!("Embedding dimension: {}", embeddings[0].len()); tracing::info!("Embedding dimension: {}", embeddings[0].len());
// Log the first 10 values of the original embedding at trace level // Log the first 10 values of the original embedding at trace level
tracing::trace!("Original embedding preview: {:?}", &embeddings[0][..10.min(embeddings[0].len())]); tracing::trace!(
"Original embedding preview: {:?}",
&embeddings[0][..10.min(embeddings[0].len())]
);
// Check if there are any NaN or zero values in the original embedding // Check if there are any NaN or zero values in the original embedding
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count(); let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count(); let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count); tracing::trace!(
"Original embedding stats: NaN count={}, zero count={}",
nan_count,
zero_count
);
// Phase 4: Post-process embeddings // Phase 4: Post-process embeddings
let postprocessing_start_time = std::time::Instant::now(); let postprocessing_start_time = std::time::Instant::now();
@@ -110,6 +129,8 @@ pub async fn embeddings_create(
// Normalize the random embedding // Normalize the random embedding
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt(); let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
#[allow(clippy::needless_range_loop)]
for i in 0..random_embedding.len() { for i in 0..random_embedding.len() {
random_embedding[i] /= norm; random_embedding[i] /= norm;
} }
@@ -123,7 +144,11 @@ pub async fn embeddings_create(
let target_dimension = 768; let target_dimension = 768;
if padded_embedding.len() < target_dimension { if padded_embedding.len() < target_dimension {
let padding_needed = target_dimension - padded_embedding.len(); let padding_needed = target_dimension - padded_embedding.len();
tracing::trace!("Padding embedding with {} zeros to reach {} dimensions", padding_needed, target_dimension); tracing::trace!(
"Padding embedding with {} zeros to reach {} dimensions",
padding_needed,
target_dimension
);
padded_embedding.extend(vec![0.0; padding_needed]); padded_embedding.extend(vec![0.0; padding_needed]);
} }
@@ -132,12 +157,18 @@ pub async fn embeddings_create(
}; };
let postprocessing_time = postprocessing_start_time.elapsed(); let postprocessing_time = postprocessing_start_time.elapsed();
tracing::debug!("Embedding post-processing completed in {:.2?}", postprocessing_time); tracing::debug!(
"Embedding post-processing completed in {:.2?}",
postprocessing_time
);
tracing::trace!("Final embedding dimension: {}", final_embedding.len()); tracing::trace!("Final embedding dimension: {}", final_embedding.len());
// Log the first 10 values of the final embedding at trace level // Log the first 10 values of the final embedding at trace level
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]); tracing::trace!(
"Final embedding preview: {:?}",
&final_embedding[..10.min(final_embedding.len())]
);
// Phase 5: Prepare response // Phase 5: Prepare response
let response_start_time = std::time::Instant::now(); let response_start_time = std::time::Instant::now();

View File

@@ -1,8 +1,8 @@
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput}; use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
use axum::{ use axum::{
response::Json as ResponseJson, routing::{get, post}, Json, Router,
Json, response::Json as ResponseJson,
Router, routing::{get, post},
}; };
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -13,16 +13,14 @@ use tracing;
const DEFAULT_SERVER_HOST: &str = "127.0.0.1"; const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
const DEFAULT_SERVER_PORT: &str = "8080"; const DEFAULT_SERVER_PORT: &str = "8080";
async fn embeddings_create( async fn embeddings_create(
Json(payload): Json<CreateEmbeddingRequest>, Json(payload): Json<CreateEmbeddingRequest>,
) -> ResponseJson<serde_json::Value> { ) -> ResponseJson<serde_json::Value> {
let model = TextEmbedding::try_new( let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true) InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
) )
.expect("Failed to initialize model"); .expect("Failed to initialize model");
let embedding_input = payload.input; let embedding_input = payload.input;
let texts_from_embedding_input = match embedding_input { let texts_from_embedding_input = match embedding_input {
@@ -45,12 +43,19 @@ async fn embeddings_create(
tracing::info!("Embedding dimension: {}", embeddings[0].len()); tracing::info!("Embedding dimension: {}", embeddings[0].len());
// Log the first 10 values of the original embedding at trace level // Log the first 10 values of the original embedding at trace level
tracing::trace!("Original embedding preview: {:?}", &embeddings[0][..10.min(embeddings[0].len())]); tracing::trace!(
"Original embedding preview: {:?}",
&embeddings[0][..10.min(embeddings[0].len())]
);
// Check if there are any NaN or zero values in the original embedding // Check if there are any NaN or zero values in the original embedding
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count(); let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count(); let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count); tracing::trace!(
"Original embedding stats: NaN count={}, zero count={}",
nan_count,
zero_count
);
// Create the final embedding // Create the final embedding
let final_embedding = { let final_embedding = {
@@ -87,7 +92,11 @@ async fn embeddings_create(
let target_dimension = 768; let target_dimension = 768;
if padded_embedding.len() < target_dimension { if padded_embedding.len() < target_dimension {
let padding_needed = target_dimension - padded_embedding.len(); let padding_needed = target_dimension - padded_embedding.len();
tracing::trace!("Padding embedding with {} zeros to reach {} dimensions", padding_needed, target_dimension); tracing::trace!(
"Padding embedding with {} zeros to reach {} dimensions",
padding_needed,
target_dimension
);
padded_embedding.extend(vec![0.0; padding_needed]); padded_embedding.extend(vec![0.0; padding_needed]);
} }
@@ -98,7 +107,10 @@ async fn embeddings_create(
tracing::trace!("Final embedding dimension: {}", final_embedding.len()); tracing::trace!("Final embedding dimension: {}", final_embedding.len());
// Log the first 10 values of the final embedding at trace level // Log the first 10 values of the final embedding at trace level
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]); tracing::trace!(
"Final embedding preview: {:?}",
&final_embedding[..10.min(final_embedding.len())]
);
// Return a response that matches the OpenAI API format // Return a response that matches the OpenAI API format
let response = serde_json::json!({ let response = serde_json::json!({
@@ -154,8 +166,8 @@ async fn main() {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use axum::body::to_bytes;
use axum::body::Body; use axum::body::Body;
use axum::body::to_bytes;
use axum::http::StatusCode; use axum::http::StatusCode;
use tower::ServiceExt; use tower::ServiceExt;
@@ -168,7 +180,9 @@ mod tests {
let body = CreateEmbeddingRequest { let body = CreateEmbeddingRequest {
model: "nomic-text-embed".to_string(), model: "nomic-text-embed".to_string(),
input: EmbeddingInput::from(vec!["The food was delicious and the waiter...".to_string()]), input: EmbeddingInput::from(vec![
"The food was delicious and the waiter...".to_string(),
]),
encoding_format: None, encoding_format: None,
user: None, user: None,
dimensions: Some(768), dimensions: Some(768),

View File

@@ -1,18 +1,16 @@
[package] [package]
name = "gemma-runner" name = "gemma-runner"
version = "0.1.0" version.workspace = true
edition = "2021" edition = "2021"
[dependencies] [dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git" } candle-core = { git = "https://github.com/huggingface/candle.git" }
candle-nn = { git = "https://github.com/huggingface/candle.git" } candle-nn = { git = "https://github.com/huggingface/candle.git" }
candle-transformers = { git = "https://github.com/huggingface/candle.git" } candle-transformers = { git = "https://github.com/huggingface/candle.git" }
candle-examples = { git = "https://github.com/huggingface/candle.git" } candle-examples = { git = "https://github.com/huggingface/candle.git" }
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
hf-hub = "0.4" hf-hub = "0.4"
tokenizers = "0.21" tokenizers = "0.21"
anyhow = "1.0" anyhow = "1.0"
@@ -22,6 +20,12 @@ tracing = "0.1"
tracing-chrome = "0.7" tracing-chrome = "0.7"
tracing-subscriber = "0.3" tracing-subscriber = "0.3"
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
[features] [features]
default = [] default = []
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]

View File

@@ -4,10 +4,10 @@ extern crate accelerate_src;
extern crate intel_mkl_src; extern crate intel_mkl_src;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use clap::ValueEnum;
use candle_transformers::models::gemma::{Config as Config1, Model as Model1}; use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2}; use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3}; use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use clap::ValueEnum;
// Removed gemma_cli import as it's not needed for the API // Removed gemma_cli import as it's not needed for the API
use candle_core::{utils, DType, Device, Tensor}; use candle_core::{utils, DType, Device, Tensor};
@@ -119,7 +119,12 @@ impl TextGeneration {
/// Stream-only generation: sends freshly generated token strings over `tx`. /// Stream-only generation: sends freshly generated token strings over `tx`.
/// (Does not send the prompt tokens; only newly generated model tokens.) /// (Does not send the prompt tokens; only newly generated model tokens.)
fn run_stream(&mut self, prompt: &str, sample_len: usize, tx: Sender<Result<String>>) -> Result<()> { fn run_stream(
&mut self,
prompt: &str,
sample_len: usize,
tx: Sender<Result<String>>,
) -> Result<()> {
self.tokenizer.clear(); self.tokenizer.clear();
// Encode prompt (context only; do not emit prompt tokens to the stream). // Encode prompt (context only; do not emit prompt tokens to the stream).
@@ -337,7 +342,10 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let model = Model1::new(cfg.use_flash_attn, &config, vb)?; let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
Model::V1(model) Model::V1(model)
} }
WhichModel::BaseV2_2B | WhichModel::InstructV2_2B | WhichModel::BaseV2_9B | WhichModel::InstructV2_9B => { WhichModel::BaseV2_2B
| WhichModel::InstructV2_2B
| WhichModel::BaseV2_9B
| WhichModel::InstructV2_9B => {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(cfg.use_flash_attn, &config, vb)?; let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
Model::V2(model) Model::V2(model)

View File

@@ -1,6 +1,6 @@
use std::io::Write;
use clap::Parser;
use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel}; use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
use clap::Parser;
use std::io::Write;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)] #[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)]

View File

@@ -2,8 +2,8 @@
extern crate accelerate_src; extern crate accelerate_src;
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;
mod gemma_cli;
mod gemma_api; mod gemma_api;
mod gemma_cli;
use anyhow::Error; use anyhow::Error;
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "helm-chart-tool" name = "helm-chart-tool"
version = "0.1.0" version.workspace = true
edition = "2021" edition = "2021"
[[bin]] [[bin]]

View File

@@ -84,7 +84,10 @@ fn main() -> Result<()> {
let services = discover_services(workspace_path)?; let services = discover_services(workspace_path)?;
println!("Found {} services:", services.len()); println!("Found {} services:", services.len());
for service in &services { for service in &services {
println!(" - {}: {} (port {})", service.name, service.image, service.port); println!(
" - {}: {} (port {})",
service.name, service.image, service.port
);
} }
generate_helm_chart(output_path, chart_name, &services)?; generate_helm_chart(output_path, chart_name, &services)?;
@@ -119,13 +122,16 @@ fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> {
let cargo_toml: CargoToml = toml::from_str(&content) let cargo_toml: CargoToml = toml::from_str(&content)
.with_context(|| format!("Failed to parse Cargo.toml at {:?}", path))?; .with_context(|| format!("Failed to parse Cargo.toml at {:?}", path))?;
let package = cargo_toml.package let package = cargo_toml
.package
.ok_or_else(|| anyhow::anyhow!("No package section found in {:?}", path))?; .ok_or_else(|| anyhow::anyhow!("No package section found in {:?}", path))?;
let metadata = package.metadata let metadata = package
.metadata
.ok_or_else(|| anyhow::anyhow!("No metadata section found in {:?}", path))?; .ok_or_else(|| anyhow::anyhow!("No metadata section found in {:?}", path))?;
let kube_metadata = metadata.kube let kube_metadata = metadata
.kube
.ok_or_else(|| anyhow::anyhow!("No kube metadata found in {:?}", path))?; .ok_or_else(|| anyhow::anyhow!("No kube metadata found in {:?}", path))?;
Ok(ServiceInfo { Ok(ServiceInfo {
@@ -136,7 +142,11 @@ fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> {
}) })
} }
fn generate_helm_chart(output_path: &str, chart_name: &str, services: &[ServiceInfo]) -> Result<()> { fn generate_helm_chart(
output_path: &str,
chart_name: &str,
services: &[ServiceInfo],
) -> Result<()> {
let chart_dir = Path::new(output_path); let chart_dir = Path::new(output_path);
let templates_dir = chart_dir.join("templates"); let templates_dir = chart_dir.join("templates");

View File

@@ -1,41 +1,15 @@
[package] [package]
name = "inference-engine" name = "inference-engine"
version = "0.1.0" version.workspace = true
edition = "2021" edition = "2021"
[[bin]]
name="gemma_inference"
path = "src/gemma_inference.rs"
required-features = ["bin"]
[[bin]]
name="llama_inference"
path = "src/llama_inference.rs"
required-features = ["bin"]
[dependencies] [dependencies]
accelerate-src = { version = "0.3.2", optional = true } candle-core = { git = "https://github.com/huggingface/candle.git" }
candle-datasets = { version = "=0.9.1", optional = true } candle-nn = { git = "https://github.com/huggingface/candle.git" }
candle-nn = { version = "=0.9.1" } candle-transformers = { git = "https://github.com/huggingface/candle.git" }
candle-transformers = { version = "=0.9.1" }
candle-flash-attn = { version = "=0.9.1", optional = true } candle-flash-attn = { version = "=0.9.1", optional = true }
candle-onnx = { version = "=0.9.1", optional = true } candle-onnx = { version = "=0.9.1", optional = true }
csv = "1.3.0"
cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false, optional = true }
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"], optional = true }
hf-hub = { version = "0.4.1", features = ["tokio"] }
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"], optional = true }
num-traits = { version = "0.2.15" }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true}
pyo3 = { version = "0.22.0", features = ["auto-initialize", "abi3-py311"], optional = true }
rayon = "1.7.0"
rubato = { version = "0.15.0", optional = true }
safetensors = "0.4.1"
serde = { version = "1.0.171", features = ["derive"] } serde = { version = "1.0.171", features = ["derive"] }
serde_json = "1.0.99" serde_json = "1.0.99"
symphonia = { version = "0.5.3", features = ["all"], optional = true } symphonia = { version = "0.5.3", features = ["all"], optional = true }
@@ -60,19 +34,11 @@ futures-util = "0.3.31"
gemma-runner = { path = "../gemma-runner" } gemma-runner = { path = "../gemma-runner" }
llama-runner = { path = "../llama-runner" } llama-runner = { path = "../llama-runner" }
# --- Add this section for conditional compilation ---
[target.'cfg(target_os = "macos")'.dependencies] [target.'cfg(target_os = "macos")'.dependencies]
# Use CPU backend for macOS to avoid Metal rotary-emb implementation issues candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-core = { version = "=0.9.1", features = ["metal"], optional = false } candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
[target.'cfg(not(target_os = "macos"))'.dependencies]
# For Linux or other non-macOS systems, you likely want the CPU backend or CUDA
# If you're building on Linux with a CUDA-enabled GPU:
candle-core = { version = "=0.9.1", features = ["cuda"], default-features = false } # Or just "cuda" if not using default features
# If you're building on Linux with only CPU:
# candle-core = { version = "=0.9.1", default-features = false } # CPU is often the default, but good to be explicit
# --- End of conditional compilation section ---
[dev-dependencies] [dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }

View File

@@ -1,19 +1,14 @@
// Expose modules for testing and library usage // Expose modules for testing and library usage
pub mod token_output_stream;
pub mod model; pub mod model;
pub mod text_generation;
pub mod utilities_lib;
pub mod openai_types; pub mod openai_types;
// pub mod cli; // pub mod cli;
pub mod server;
pub mod inference; pub mod inference;
pub mod server;
// Re-export key components for easier access // Re-export key components for easier access
pub use model::{Model, Which};
pub use text_generation::TextGeneration;
pub use token_output_stream::TokenOutputStream;
pub use server::{AppState, create_router};
pub use inference::ModelInference; pub use inference::ModelInference;
pub use model::{Model, Which};
pub use server::{create_router, AppState};
use std::env; use std::env;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

View File

@@ -1,8 +1,8 @@
// use candle_core::Tensor; // use candle_core::Tensor;
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1}; use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2}; use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3}; use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] #[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum Which { pub enum Which {
@@ -52,7 +52,11 @@ pub enum Model {
} }
impl Model { impl Model {
pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> { pub fn forward(
&mut self,
input_ids: &candle_core::Tensor,
pos: usize,
) -> candle_core::Result<candle_core::Tensor> {
match self { match self {
Self::V1(m) => m.forward(input_ids, pos), Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos), Self::V2(m) => m.forward(input_ids, pos),
@@ -88,7 +92,13 @@ impl Which {
pub fn is_instruct_model(&self) -> bool { pub fn is_instruct_model(&self) -> bool {
match self { match self {
Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false, Self::Base2B
| Self::Base7B
| Self::CodeBase2B
| Self::CodeBase7B
| Self::BaseV2_2B
| Self::BaseV2_9B
| Self::BaseV3_1B => false,
_ => true, _ => true,
} }
} }

View File

@@ -1,5 +1,6 @@
use either::Either; use either::Either;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap; use std::collections::HashMap;
use utoipa::ToSchema; use utoipa::ToSchema;
@@ -10,7 +11,10 @@ pub struct MessageInnerContent(
); );
impl ToSchema<'_> for MessageInnerContent { impl ToSchema<'_> for MessageInnerContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) { fn schema() -> (
&'static str,
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
) {
( (
"MessageInnerContent", "MessageInnerContent",
utoipa::openapi::RefOr::T(message_inner_content_schema()), utoipa::openapi::RefOr::T(message_inner_content_schema()),
@@ -49,8 +53,14 @@ pub struct MessageContent(
); );
impl ToSchema<'_> for MessageContent { impl ToSchema<'_> for MessageContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) { fn schema() -> (
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema())) &'static str,
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
) {
(
"MessageContent",
utoipa::openapi::RefOr::T(message_content_schema()),
)
} }
} }

View File

@@ -6,19 +6,22 @@ use axum::{
Json, Router, Json, Router,
}; };
use futures_util::stream::{self, Stream}; use futures_util::stream::{self, Stream};
use tokio_stream::wrappers::UnboundedReceiverStream;
use std::convert::Infallible; use std::convert::Infallible;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{Mutex, mpsc}; use tokio::sync::{mpsc, Mutex};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid; use uuid::Uuid;
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage}; use crate::openai_types::{
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest,
ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage,
};
use crate::Which; use crate::Which;
use either::Either; use either::Either;
use serde_json::Value;
use gemma_runner::{run_gemma_api, GemmaInferenceConfig}; use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
use llama_runner::{run_llama_inference, LlamaInferenceConfig}; use llama_runner::{run_llama_inference, LlamaInferenceConfig};
use serde_json::Value;
// ------------------------- // -------------------------
// Shared app state // Shared app state
// ------------------------- // -------------------------
@@ -67,7 +70,10 @@ fn build_gemma_prompt(messages: &[Message]) -> String {
match message.role.as_str() { match message.role.as_str() {
"system" => { "system" => {
if let Some(MessageContent(Either::Left(content))) = &message.content { if let Some(MessageContent(Either::Left(content))) = &message.content {
prompt.push_str(&format!("<start_of_turn>system\n{}<end_of_turn>\n", content)); prompt.push_str(&format!(
"<start_of_turn>system\n{}<end_of_turn>\n",
content
));
} }
} }
"user" => { "user" => {
@@ -97,9 +103,13 @@ pub async fn chat_completions(
Json(request): Json<ChatCompletionRequest>, Json(request): Json<ChatCompletionRequest>,
) -> Result<impl IntoResponse, (StatusCode, String)> { ) -> Result<impl IntoResponse, (StatusCode, String)> {
if !request.stream.unwrap_or(false) { if !request.stream.unwrap_or(false) {
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response()); return Ok(chat_completions_non_streaming_proxy(state, request)
.await
.into_response());
} }
Ok(chat_completions_stream(state, request).await.into_response()) Ok(chat_completions_stream(state, request)
.await
.into_response())
} }
pub async fn chat_completions_non_streaming_proxy( pub async fn chat_completions_non_streaming_proxy(
@@ -136,7 +146,9 @@ pub async fn chat_completions_non_streaming_proxy(
ModelType::Gemma => build_gemma_prompt(&request.messages), ModelType::Gemma => build_gemma_prompt(&request.messages),
ModelType::Llama => { ModelType::Llama => {
// For Llama, just use the last user message for now // For Llama, just use the last user message for now
request.messages.last() request
.messages
.last()
.and_then(|m| m.content.as_ref()) .and_then(|m| m.content.as_ref())
.and_then(|c| match c { .and_then(|c| match c {
MessageContent(Either::Left(text)) => Some(text.clone()), MessageContent(Either::Left(text)) => Some(text.clone()),
@@ -147,7 +159,8 @@ pub async fn chat_completions_non_streaming_proxy(
}; };
// Get streaming receiver based on model type // Get streaming receiver based on model type
let rx = match state.model_type { let rx =
match state.model_type {
ModelType::Gemma => { ModelType::Gemma => {
if let Some(mut config) = state.gemma_config { if let Some(mut config) = state.gemma_config {
config.prompt = prompt.clone(); config.prompt = prompt.clone();
@@ -163,7 +176,7 @@ pub async fn chat_completions_non_streaming_proxy(
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": "Gemma configuration not available" } "error": { "message": "Gemma configuration not available" }
})) })),
)); ));
} }
} }
@@ -182,7 +195,7 @@ pub async fn chat_completions_non_streaming_proxy(
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": "Llama configuration not available" } "error": { "message": "Llama configuration not available" }
})) })),
)); ));
} }
} }
@@ -281,7 +294,9 @@ async fn handle_streaming_request(
ModelType::Gemma => build_gemma_prompt(&request.messages), ModelType::Gemma => build_gemma_prompt(&request.messages),
ModelType::Llama => { ModelType::Llama => {
// For Llama, just use the last user message for now // For Llama, just use the last user message for now
request.messages.last() request
.messages
.last()
.and_then(|m| m.content.as_ref()) .and_then(|m| m.content.as_ref())
.and_then(|c| match c { .and_then(|c| match c {
MessageContent(Either::Left(text)) => Some(text.clone()), MessageContent(Either::Left(text)) => Some(text.clone()),
@@ -303,7 +318,10 @@ async fn handle_streaming_request(
model: model_id.clone(), model: model_id.clone(),
choices: vec![ChatCompletionChunkChoice { choices: vec![ChatCompletionChunkChoice {
index: 0, index: 0,
delta: Delta { role: Some("assistant".to_string()), content: None }, delta: Delta {
role: Some("assistant".to_string()),
content: None,
},
finish_reason: None, finish_reason: None,
}], }],
}; };
@@ -324,7 +342,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) } "error": { "message": format!("Error initializing Gemma model: {}", e) }
})) })),
)); ));
} }
} }
@@ -333,7 +351,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": "Gemma configuration not available" } "error": { "message": "Gemma configuration not available" }
})) })),
)); ));
} }
} }
@@ -348,7 +366,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) } "error": { "message": format!("Error initializing Llama model: {}", e) }
})) })),
)); ));
} }
} }
@@ -357,7 +375,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": "Llama configuration not available" } "error": { "message": "Llama configuration not available" }
})) })),
)); ));
} }
} }
@@ -394,7 +412,11 @@ async fn handle_streaming_request(
if last_token == second_last { if last_token == second_last {
repetition_count += 1; repetition_count += 1;
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count); tracing::warn!(
"Detected repetition pattern: '{}' (count: {})",
last_token,
repetition_count
);
if repetition_count >= MAX_REPETITION_COUNT { if repetition_count >= MAX_REPETITION_COUNT {
tracing::info!("Stopping generation due to excessive repetition"); tracing::info!("Stopping generation due to excessive repetition");
@@ -412,7 +434,10 @@ async fn handle_streaming_request(
model: model_id_clone.clone(), model: model_id_clone.clone(),
choices: vec![ChatCompletionChunkChoice { choices: vec![ChatCompletionChunkChoice {
index: 0, index: 0,
delta: Delta { role: None, content: Some(token) }, delta: Delta {
role: None,
content: Some(token),
},
finish_reason: None, finish_reason: None,
}], }],
}; };
@@ -436,7 +461,10 @@ async fn handle_streaming_request(
model: model_id_clone.clone(), model: model_id_clone.clone(),
choices: vec![ChatCompletionChunkChoice { choices: vec![ChatCompletionChunkChoice {
index: 0, index: 0,
delta: Delta { role: None, content: None }, delta: Delta {
role: None,
content: None,
},
finish_reason: Some("stop".to_string()), finish_reason: Some("stop".to_string()),
}], }],
}; };
@@ -451,8 +479,6 @@ async fn handle_streaming_request(
Ok(Sse::new(stream)) Ok(Sse::new(stream))
} }
// ------------------------- // -------------------------
// Router // Router
// ------------------------- // -------------------------
@@ -647,7 +673,6 @@ pub async fn list_models() -> Json<ModelListResponse> {
}) })
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@@ -681,10 +706,7 @@ mod tests {
let prompt = build_gemma_prompt(&messages); let prompt = build_gemma_prompt(&messages);
let expected = "<start_of_turn>user\nSystem message\n\nKnock knock.<end_of_turn>\n\ let expected = "<start_of_turn>system\nSystem message<end_of_turn>\n<start_of_turn>user\nKnock knock.<end_of_turn>\n<start_of_turn>model\nWho's there?<end_of_turn>\n<start_of_turn>user\nGemma.<end_of_turn>\n<start_of_turn>model\n";
<start_of_turn>model\nWho's there?<end_of_turn>\n\
<start_of_turn>user\nGemma.<end_of_turn>\n\
<start_of_turn>model\n";
assert_eq!(prompt, expected); assert_eq!(prompt, expected);
} }
@@ -698,15 +720,13 @@ mod tests {
#[test] #[test]
fn test_missing_content() { fn test_missing_content() {
let messages = vec![ let messages = vec![Message {
Message {
role: "user".to_string(), role: "user".to_string(),
content: None, content: None,
name: None, name: None,
} }];
];
let prompt = build_gemma_prompt(&messages); let prompt = build_gemma_prompt(&messages);
assert_eq!(prompt, "<start_of_turn>user\n<end_of_turn>\n<start_of_turn>model\n"); assert_eq!(prompt, "<start_of_turn>model\n");
} }
} }

File diff suppressed because it is too large Load Diff

View File

@@ -1,87 +0,0 @@
use candle_core::Result;
/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
/// streaming way rather than having to wait for the full decoding.
pub struct TokenOutputStream {
tokenizer: tokenizers::Tokenizer,
tokens: Vec<u32>,
prev_index: usize,
current_index: usize,
}
impl TokenOutputStream {
pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
Self {
tokenizer,
tokens: Vec::new(),
prev_index: 0,
current_index: 0,
}
}
pub fn into_inner(self) -> tokenizers::Tokenizer {
self.tokenizer
}
fn decode(&self, tokens: &[u32]) -> Result<String> {
match self.tokenizer.decode(tokens, true) {
Ok(str) => Ok(str),
Err(err) => candle_core::bail!("cannot decode: {err}"),
}
}
// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
let prev_text = if self.tokens.is_empty() {
String::new()
} else {
let tokens = &self.tokens[self.prev_index..self.current_index];
self.decode(tokens)?
};
self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() {
// Modified to include all tokens, not just alphanumeric ones
let text = text.split_at(prev_text.len());
self.prev_index = self.current_index;
self.current_index = self.tokens.len();
Ok(Some(text.1.to_string()))
} else {
Ok(None)
}
}
pub fn decode_rest(&self) -> Result<Option<String>> {
let prev_text = if self.tokens.is_empty() {
String::new()
} else {
let tokens = &self.tokens[self.prev_index..self.current_index];
self.decode(tokens)?
};
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() {
let text = text.split_at(prev_text.len());
Ok(Some(text.1.to_string()))
} else {
Ok(None)
}
}
pub fn decode_all(&self) -> Result<String> {
self.decode(&self.tokens)
}
pub fn get_token(&self, token_s: &str) -> Option<u32> {
self.tokenizer.get_vocab(true).get(token_s).copied()
}
pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
&self.tokenizer
}
pub fn clear(&mut self) {
self.tokens.clear();
self.prev_index = 0;
self.current_index = 0;
}
}

View File

@@ -1,167 +0,0 @@
use candle_core::utils::{cuda_is_available, metal_is_available};
use candle_core::{Device, Result, Tensor};
pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!(
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
Ok(Device::Cpu)
}
}
pub fn load_image<P: AsRef<std::path::Path>>(
p: P,
resize_longest: Option<usize>,
) -> Result<(Tensor, usize, usize)> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle_core::Error::wrap)?;
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
let img = match resize_longest {
None => img,
Some(resize_longest) => {
let (height, width) = (img.height(), img.width());
let resize_longest = resize_longest as u32;
let (height, width) = if height < width {
let h = (resize_longest * height) / width;
(h, resize_longest)
} else {
let w = (resize_longest * width) / height;
(resize_longest, w)
};
img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
}
};
let (height, width) = (img.height() as usize, img.width() as usize);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
Ok((data, initial_h, initial_w))
}
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
p: P,
width: usize,
height: usize,
) -> Result<Tensor> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle_core::Error::wrap)?
.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let data = img.into_raw();
Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1))
}
/// Saves an image to disk using the image crate, this expects an input with shape
/// (c, height, width).
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
candle_core::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle_core::bail!("error saving image {p:?}"),
};
image.save(p).map_err(candle_core::Error::wrap)?;
Ok(())
}
pub fn save_image_resize<P: AsRef<std::path::Path>>(
img: &Tensor,
p: P,
h: usize,
w: usize,
) -> Result<()> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
candle_core::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle_core::bail!("error saving image {p:?}"),
};
let image = image::DynamicImage::from(image);
let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
image.save(p).map_err(candle_core::Error::wrap)?;
Ok(())
}
/// Loads the safetensors files for a model from the hub based on a json index file.
pub fn hub_load_safetensors(
repo: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>> {
let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle_core::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| repo.get(v).map_err(candle_core::Error::wrap))
.collect::<Result<Vec<_>>>()?;
Ok(safetensors_files)
}
pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
path: P,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>> {
let path = path.as_ref();
let jsfile = std::fs::File::open(path.join(json_file))?;
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle_core::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file);
}
}
let safetensors_files: Vec<_> = safetensors_files
.into_iter()
.map(|v| path.join(v))
.collect();
Ok(safetensors_files)
}

View File

@@ -9,7 +9,10 @@ mod tests {
// Test a few representative model variants // Test a few representative model variants
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b"); assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it"); assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
assert_eq!(Which::InstructV1_1_2B.to_model_id(), "google/gemma-1.1-2b-it"); assert_eq!(
Which::InstructV1_1_2B.to_model_id(),
"google/gemma-1.1-2b-it"
);
assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b"); assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b");
assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b"); assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b");
assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it"); assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it");

View File

@@ -1,549 +0,0 @@
use anyhow::Result;
use candle_core::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use inference_engine::model::Which;
use inference_engine::text_generation::TextGeneration;
use inference_engine::token_output_stream::TokenOutputStream;
use std::collections::HashMap;
use tokenizers::Tokenizer;
#[cfg(test)]
mod tests {
use super::*;
// Helper function to create a simple tokenizer for testing
fn create_test_tokenizer() -> Result<Tokenizer> {
// Create a simple tokenizer from the pretrained model
// This uses the tokenizer from the Hugging Face hub
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
Ok(tokenizer)
}
// Test the Which enum's to_model_id method
#[test]
fn test_which_model_id() {
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
}
// Test the Which enum's is_instruct_model method
#[test]
fn test_which_is_instruct() {
assert!(!Which::Base2B.is_instruct_model());
assert!(Which::Instruct7B.is_instruct_model());
}
// Test the Which enum's is_v3_model method
#[test]
fn test_which_is_v3() {
assert!(!Which::Base2B.is_v3_model());
assert!(Which::BaseV3_1B.is_v3_model());
}
// Test the TokenOutputStream functionality
#[test]
fn test_token_output_stream() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Test encoding and decoding
let text = "Hello, world!";
let encoded = token_stream.tokenizer().encode(text, true).unwrap();
let token_ids = encoded.get_ids();
// Add tokens one by one
for &token_id in token_ids {
token_stream.next_token(token_id)?;
}
// Decode all and check
let decoded = token_stream.decode_all()?;
assert_eq!(decoded.trim(), text);
Ok(())
}
// Test the LogitsProcessor
#[test]
fn test_logits_processor() -> Result<()> {
// Create a LogitsProcessor with default settings
let seed = 42;
let temp = Some(0.8);
let top_p = Some(0.9);
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
// Create a simple logits tensor
// In a real test, we would create a tensor with known values and verify
// that sampling produces expected results
// For now, we'll just verify that the LogitsProcessor can be created
assert!(true);
Ok(())
}
// Test the TextGeneration constructor
#[test]
fn test_text_generation_constructor() -> Result<()> {
// We can't easily create a Model instance for testing,
// but we can test that the constructor compiles and the types are correct
// In a real test with a mock Model, we would:
// 1. Create a mock model
// 2. Create a tokenizer
// 3. Call TextGeneration::new
// 4. Verify the properties of the created instance
// For now, we'll just verify that the code compiles
assert!(true);
Ok(())
}
// Test apply_cached_repeat_penalty method with no penalty
#[test]
fn test_apply_cached_repeat_penalty_no_penalty() -> Result<()> {
// Create a simple test setup
let device = Device::Cpu;
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 3u32];
// Create a mock TextGeneration instance
// Since we can't easily create a full TextGeneration instance without a model,
// we'll test the logic by creating a simple struct with the necessary fields
struct MockTextGeneration {
repeat_penalty: f32,
repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>,
}
impl MockTextGeneration {
fn apply_cached_repeat_penalty(
&mut self,
logits: Tensor,
tokens: &[u32],
) -> Result<(Tensor, std::time::Duration)> {
let repeat_start = std::time::Instant::now();
// If no penalty, return the original logits
if self.repeat_penalty == 1.0 {
return Ok((logits, repeat_start.elapsed()));
}
// Get the tokens to penalize (the last n tokens)
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
let penalty_tokens = &tokens[start_at..];
// Extract logits to a vector for modification
let mut logits_vec = logits.to_vec1::<f32>()?;
let cache_hits = std::cell::Cell::new(0);
// Apply penalties with caching
for &token_id in penalty_tokens {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
// Check if we've already calculated this token's penalty
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
// Use cached value
logits_vec[token_id] = *penalized_score;
cache_hits.set(cache_hits.get() + 1);
} else {
// Calculate and cache new value
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
let penalized_score = sign * score / self.repeat_penalty;
logits_vec[token_id] = penalized_score;
self.penalty_cache.insert(token_id, penalized_score);
}
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
let result = new_logits.reshape(shape)?;
let elapsed = repeat_start.elapsed();
Ok((result, elapsed))
}
}
let mut mock_gen = MockTextGeneration {
repeat_penalty: 1.0, // No penalty
repeat_last_n: 3,
penalty_cache: HashMap::new(),
};
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?;
// With no penalty, logits should be unchanged
assert_eq!(result_data, logits_data);
Ok(())
}
// Test apply_cached_repeat_penalty method with penalty
#[test]
fn test_apply_cached_repeat_penalty_with_penalty() -> Result<()> {
let device = Device::Cpu;
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 3u32];
struct MockTextGeneration {
repeat_penalty: f32,
repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>,
}
impl MockTextGeneration {
fn apply_cached_repeat_penalty(
&mut self,
logits: Tensor,
tokens: &[u32],
) -> Result<(Tensor, std::time::Duration)> {
let repeat_start = std::time::Instant::now();
if self.repeat_penalty == 1.0 {
return Ok((logits, repeat_start.elapsed()));
}
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
let penalty_tokens = &tokens[start_at..];
let mut logits_vec = logits.to_vec1::<f32>()?;
let cache_hits = std::cell::Cell::new(0);
for &token_id in penalty_tokens {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
logits_vec[token_id] = *penalized_score;
cache_hits.set(cache_hits.get() + 1);
} else {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
let penalized_score = sign * score / self.repeat_penalty;
logits_vec[token_id] = penalized_score;
self.penalty_cache.insert(token_id, penalized_score);
}
}
}
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
let result = new_logits.reshape(shape)?;
let elapsed = repeat_start.elapsed();
Ok((result, elapsed))
}
}
let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0, // Apply penalty
repeat_last_n: 3,
penalty_cache: HashMap::new(),
};
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?;
// Tokens 1, 2, 3 should be penalized (divided by 2.0)
let expected = vec![1.0f32, 1.0, 1.5, 2.0, 5.0]; // [1.0, 2.0/2.0, 3.0/2.0, 4.0/2.0, 5.0]
assert_eq!(result_data, expected);
Ok(())
}
// Test apply_cached_repeat_penalty caching behavior
#[test]
fn test_apply_cached_repeat_penalty_caching() -> Result<()> {
let device = Device::Cpu;
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 1u32, 1u32]; // Repeated token should use cache
struct MockTextGeneration {
repeat_penalty: f32,
repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>,
}
impl MockTextGeneration {
fn apply_cached_repeat_penalty(
&mut self,
logits: Tensor,
tokens: &[u32],
) -> Result<(Tensor, std::time::Duration)> {
let repeat_start = std::time::Instant::now();
if self.repeat_penalty == 1.0 {
return Ok((logits, repeat_start.elapsed()));
}
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
let penalty_tokens = &tokens[start_at..];
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in penalty_tokens {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
logits_vec[token_id] = *penalized_score;
} else {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
let penalized_score = sign * score / self.repeat_penalty;
logits_vec[token_id] = penalized_score;
self.penalty_cache.insert(token_id, penalized_score);
}
}
}
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
let result = new_logits.reshape(shape)?;
let elapsed = repeat_start.elapsed();
Ok((result, elapsed))
}
}
let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0,
repeat_last_n: 3,
penalty_cache: HashMap::new(),
};
// First call should cache the penalty for token 1
let (_result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
// Cache should contain the penalized value for token 1
assert!(mock_gen.penalty_cache.contains_key(&1));
assert_eq!(mock_gen.penalty_cache.get(&1), Some(&1.0)); // 2.0 / 2.0 = 1.0
Ok(())
}
// Test edge case: empty tokens array
#[test]
fn test_apply_cached_repeat_penalty_empty_tokens() -> Result<()> {
let device = Device::Cpu;
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens: Vec<u32> = vec![]; // Empty tokens
struct MockTextGeneration {
repeat_penalty: f32,
repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>,
}
impl MockTextGeneration {
fn apply_cached_repeat_penalty(
&mut self,
logits: Tensor,
tokens: &[u32],
) -> Result<(Tensor, std::time::Duration)> {
let repeat_start = std::time::Instant::now();
if self.repeat_penalty == 1.0 {
return Ok((logits, repeat_start.elapsed()));
}
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
let penalty_tokens = &tokens[start_at..];
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in penalty_tokens {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
logits_vec[token_id] = *penalized_score;
} else {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
let penalized_score = sign * score / self.repeat_penalty;
logits_vec[token_id] = penalized_score;
self.penalty_cache.insert(token_id, penalized_score);
}
}
}
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
let result = new_logits.reshape(shape)?;
let elapsed = repeat_start.elapsed();
Ok((result, elapsed))
}
}
let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0,
repeat_last_n: 3,
penalty_cache: HashMap::new(),
};
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?;
// With empty tokens, logits should be unchanged
assert_eq!(result_data, logits_data);
Ok(())
}
// Test edge case: out-of-bounds token IDs
#[test]
fn test_apply_cached_repeat_penalty_out_of_bounds() -> Result<()> {
let device = Device::Cpu;
let logits_data = vec![1.0f32, 2.0, 3.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 5u32, 10u32]; // Token 5 and 10 are out of bounds
struct MockTextGeneration {
repeat_penalty: f32,
repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>,
}
impl MockTextGeneration {
fn apply_cached_repeat_penalty(
&mut self,
logits: Tensor,
tokens: &[u32],
) -> Result<(Tensor, std::time::Duration)> {
let repeat_start = std::time::Instant::now();
if self.repeat_penalty == 1.0 {
return Ok((logits, repeat_start.elapsed()));
}
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
let penalty_tokens = &tokens[start_at..];
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in penalty_tokens {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
logits_vec[token_id] = *penalized_score;
} else {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
let penalized_score = sign * score / self.repeat_penalty;
logits_vec[token_id] = penalized_score;
self.penalty_cache.insert(token_id, penalized_score);
}
}
}
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
let result = new_logits.reshape(shape)?;
let elapsed = repeat_start.elapsed();
Ok((result, elapsed))
}
}
let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0,
repeat_last_n: 3,
penalty_cache: HashMap::new(),
};
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?;
// Only token 1 should be penalized, out-of-bounds tokens should be ignored
let expected = vec![1.0f32, 1.0, 3.0]; // [1.0, 2.0/2.0, 3.0]
assert_eq!(result_data, expected);
Ok(())
}
// Test the actual apply_cached_repeat_penalty method from TextGeneration
// This test creates a TextGeneration instance with minimal dependencies to test the real method
#[test]
fn test_actual_apply_cached_repeat_penalty_implementation() -> Result<()> {
// Since creating a real TextGeneration instance requires a Model which needs model weights,
// we'll create a test that demonstrates the method is now public and can be accessed.
// The comprehensive functionality testing is already covered by the mock tests above.
// Test data setup
let device = Device::Cpu;
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 3u32];
// Test that we can create the necessary components
let tokenizer = create_test_tokenizer()?;
// The method is now public as confirmed by making it pub fn apply_cached_repeat_penalty
// This test verifies the method signature and that it's accessible from external code
// We could create a TextGeneration instance if we had a way to mock the Model,
// but for now we confirm that the existing mock tests cover the functionality
// and the method is properly exposed as public
println!("apply_cached_repeat_penalty method is now public and accessible for testing");
assert!(true);
Ok(())
}
// Integration test that demonstrates the method usage pattern
#[test]
fn test_apply_cached_repeat_penalty_usage_pattern() -> Result<()> {
// This test demonstrates how the apply_cached_repeat_penalty method would be used
// in practice, even though we can't create a full TextGeneration instance in unit tests
let device = Device::Cpu;
let logits_data = vec![1.5f32, 2.5, 3.5, 4.5, 5.5];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 1u32, 3u32]; // Repeated token 1 to test caching
// Test parameters that would be used with TextGeneration
let repeat_penalty = 1.2f32;
let repeat_last_n = 3usize;
let mut penalty_cache: HashMap<usize, f32> = HashMap::new();
// Simulate the method's logic to verify it works as expected
let start_time = std::time::Instant::now();
if repeat_penalty != 1.0 {
let start_at = tokens.len().saturating_sub(repeat_last_n);
let penalty_tokens = &tokens[start_at..];
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in penalty_tokens {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
if let Some(_cached_score) = penalty_cache.get(&token_id) {
// Cache hit simulation
} else {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
let penalized_score = sign * score / repeat_penalty;
penalty_cache.insert(token_id, penalized_score);
}
}
}
}
let _duration = start_time.elapsed();
// Verify that tokens were processed correctly
assert!(penalty_cache.contains_key(&1)); // Token 1 should be cached
assert!(penalty_cache.contains_key(&2)); // Token 2 should be cached
assert!(penalty_cache.contains_key(&3)); // Token 3 should be cached
println!("Successfully demonstrated apply_cached_repeat_penalty usage pattern");
Ok(())
}
// Note: Testing the actual text generation functionality would require
// integration tests with real models, which is beyond the scope of these unit tests.
// The tests above focus on the components that can be tested in isolation.
}

View File

@@ -1,129 +0,0 @@
use inference_engine::token_output_stream::TokenOutputStream;
use tokenizers::Tokenizer;
use std::path::PathBuf;
use anyhow::Result;
#[cfg(test)]
mod tests {
use super::*;
// Helper function to create a simple tokenizer for testing
fn create_test_tokenizer() -> Result<Tokenizer> {
// Create a simple tokenizer from the pretrained model
// This uses the tokenizer from the Hugging Face hub
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
Ok(tokenizer)
}
#[test]
fn test_new_token_output_stream() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer);
// Check that the token stream was created successfully
assert!(token_stream.tokenizer().get_vocab(true).len() > 0);
Ok(())
}
#[test]
fn test_clear() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Add a token
let token_id = token_stream.get_token("<eos>").unwrap();
token_stream.next_token(token_id)?;
// Clear the stream
token_stream.clear();
// Check that the stream is empty by trying to decode all
let decoded = token_stream.decode_all()?;
assert_eq!(decoded, "");
Ok(())
}
#[test]
fn test_get_token() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer);
// Get a token that should exist
let eos_token = token_stream.get_token("<eos>");
assert!(eos_token.is_some());
// Get a token that shouldn't exist
let nonexistent_token = token_stream.get_token("<this_token_does_not_exist>");
assert!(nonexistent_token.is_none());
Ok(())
}
#[test]
fn test_next_token_and_decode() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Get some tokens
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
let token_ids = hello_tokens.get_ids();
// Add tokens one by one
let mut output = String::new();
for &token_id in token_ids {
if let Some(text) = token_stream.next_token(token_id)? {
output.push_str(&text);
}
}
// Get any remaining text
if let Some(rest) = token_stream.decode_rest()? {
output.push_str(&rest);
}
// Check the output
assert!(!output.is_empty());
assert_eq!(output.trim(), "Hello world");
Ok(())
}
#[test]
fn test_decode_all() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Get some tokens
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
let token_ids = hello_tokens.get_ids();
// Add tokens one by one
for &token_id in token_ids {
token_stream.next_token(token_id)?;
}
// Decode all
let decoded = token_stream.decode_all()?;
// Check the output
assert_eq!(decoded.trim(), "Hello world");
Ok(())
}
#[test]
fn test_into_inner() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer);
// Get the inner tokenizer
let inner_tokenizer = token_stream.into_inner();
// Check that the inner tokenizer works
let encoded = inner_tokenizer.encode("Test", true).unwrap();
assert!(encoded.get_ids().len() > 0);
Ok(())
}
}

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "leptos-app" name = "leptos-app"
version = "0.1.0" version.workspace = true
edition = "2021" edition = "2021"
[lib] [lib]

View File

@@ -5,6 +5,25 @@ use leptos_router::{
StaticSegment, StaticSegment,
}; };
#[cfg(feature = "hydrate")]
use async_openai_wasm::config::OpenAIConfig;
#[cfg(feature = "hydrate")]
use async_openai_wasm::types::{FinishReason, Role};
#[cfg(feature = "hydrate")]
use async_openai_wasm::{
types::{
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs,
Model as OpenAIModel,
},
Client,
};
#[cfg(feature = "hydrate")]
use futures_util::StreamExt;
#[cfg(feature = "hydrate")]
use js_sys::Date;
#[cfg(feature = "hydrate")]
use leptos::task::spawn_local;
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
@@ -12,25 +31,7 @@ use std::collections::VecDeque;
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
use uuid::Uuid; use uuid::Uuid;
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
use js_sys::Date;
#[cfg(feature = "hydrate")]
use web_sys::{HtmlInputElement, KeyboardEvent, SubmitEvent}; use web_sys::{HtmlInputElement, KeyboardEvent, SubmitEvent};
#[cfg(feature = "hydrate")]
use futures_util::StreamExt;
#[cfg(feature = "hydrate")]
use async_openai_wasm::{
types::{
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, Model as OpenAIModel,
},
Client,
};
#[cfg(feature = "hydrate")]
use async_openai_wasm::config::OpenAIConfig;
#[cfg(feature = "hydrate")]
use async_openai_wasm::types::{Role, FinishReason};
#[cfg(feature = "hydrate")]
use leptos::task::spawn_local;
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -43,11 +44,15 @@ pub struct Message {
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageContent(pub either::Either<String, Vec<std::collections::HashMap<String, MessageInnerContent>>>); pub struct MessageContent(
pub either::Either<String, Vec<std::collections::HashMap<String, MessageInnerContent>>>,
);
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageInnerContent(pub either::Either<String, std::collections::HashMap<String, String>>); pub struct MessageInnerContent(
pub either::Either<String, std::collections::HashMap<String, String>>,
);
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -62,7 +67,9 @@ const DEFAULT_MODEL: &str = "default";
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
async fn fetch_available_models() -> Result<Vec<OpenAIModel>, String> { async fn fetch_available_models() -> Result<Vec<OpenAIModel>, String> {
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Starting model fetch from http://localhost:8080/v1"); leptos::logging::log!(
"[DEBUG_LOG] fetch_available_models: Starting model fetch from http://localhost:8080/v1"
);
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string()); let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
let client = Client::with_config(config); let client = Client::with_config(config);
@@ -70,19 +77,30 @@ async fn fetch_available_models() -> Result<Vec<OpenAIModel>, String> {
match client.models().list().await { match client.models().list().await {
Ok(response) => { Ok(response) => {
let model_count = response.data.len(); let model_count = response.data.len();
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Successfully fetched {} models", model_count); leptos::logging::log!(
"[DEBUG_LOG] fetch_available_models: Successfully fetched {} models",
model_count
);
if model_count > 0 { if model_count > 0 {
let model_names: Vec<String> = response.data.iter().map(|m| m.id.clone()).collect(); let model_names: Vec<String> = response.data.iter().map(|m| m.id.clone()).collect();
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Available models: {:?}", model_names); leptos::logging::log!(
"[DEBUG_LOG] fetch_available_models: Available models: {:?}",
model_names
);
} else { } else {
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: No models returned by server"); leptos::logging::log!(
"[DEBUG_LOG] fetch_available_models: No models returned by server"
);
} }
Ok(response.data) Ok(response.data)
}, }
Err(e) => { Err(e) => {
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Failed to fetch models: {:?}", e); leptos::logging::log!(
"[DEBUG_LOG] fetch_available_models: Failed to fetch models: {:?}",
e
);
Err(format!("Failed to fetch models: {}", e)) Err(format!("Failed to fetch models: {}", e))
} }
} }
@@ -335,7 +353,11 @@ fn ChatInterfaceImpl() -> impl IntoView {
} }
} }
Err(e) => { Err(e) => {
leptos::logging::log!("[DEBUG_LOG] send_message: Stream error after {} chunks: {:?}", chunks_received, e); leptos::logging::log!(
"[DEBUG_LOG] send_message: Stream error after {} chunks: {:?}",
chunks_received,
e
);
set_messages.update(|msgs| { set_messages.update(|msgs| {
msgs.push_back(Message { msgs.push_back(Message {
id: Uuid::new_v4().to_string(), id: Uuid::new_v4().to_string(),
@@ -364,7 +386,10 @@ fn ChatInterfaceImpl() -> impl IntoView {
leptos::logging::log!("[DEBUG_LOG] send_message: Stream completed successfully, received {} chunks", chunks_received); leptos::logging::log!("[DEBUG_LOG] send_message: Stream completed successfully, received {} chunks", chunks_received);
} }
Err(e) => { Err(e) => {
leptos::logging::log!("[DEBUG_LOG] send_message: Request failed with error: {:?}", e); leptos::logging::log!(
"[DEBUG_LOG] send_message: Request failed with error: {:?}",
e
);
let error_message = Message { let error_message = Message {
id: Uuid::new_v4().to_string(), id: Uuid::new_v4().to_string(),
role: "system".to_string(), role: "system".to_string(),
@@ -404,7 +429,8 @@ fn ChatInterfaceImpl() -> impl IntoView {
}; };
let messages_list = move || { let messages_list = move || {
messages.get() messages
.get()
.into_iter() .into_iter()
.map(|message| { .map(|message| {
let role_class = match message.role.as_str() { let role_class = match message.role.as_str() {

View File

@@ -10,10 +10,10 @@ pub fn hydrate() {
#[cfg(feature = "ssr")] #[cfg(feature = "ssr")]
pub fn create_leptos_router() -> axum::Router { pub fn create_leptos_router() -> axum::Router {
use crate::app::*;
use axum::Router; use axum::Router;
use leptos::prelude::*; use leptos::prelude::*;
use leptos_axum::{generate_route_list, LeptosRoutes}; use leptos_axum::{generate_route_list, LeptosRoutes};
use crate::app::*;
let conf = get_configuration(None).unwrap(); let conf = get_configuration(None).unwrap();
let leptos_options = conf.leptos_options; let leptos_options = conf.leptos_options;

View File

@@ -1,12 +1,11 @@
#[cfg(feature = "ssr")] #[cfg(feature = "ssr")]
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
use axum::Router; use axum::Router;
use leptos::logging::log; use leptos::logging::log;
use leptos::prelude::*; use leptos::prelude::*;
use leptos_axum::{generate_route_list, LeptosRoutes};
use leptos_app::app::*; use leptos_app::app::*;
use leptos_axum::{generate_route_list, LeptosRoutes};
let conf = get_configuration(None).unwrap(); let conf = get_configuration(None).unwrap();
let addr = conf.leptos_options.site_addr; let addr = conf.leptos_options.site_addr;

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "llama-runner" name = "llama-runner"
version = "0.1.0" version.workspace = true
edition = "2021" edition = "2021"
[dependencies] [dependencies]

View File

@@ -5,4 +5,3 @@ pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
// Re-export constants and types that might be needed // Re-export constants and types that might be needed
pub const EOS_TOKEN: &str = "</s>"; pub const EOS_TOKEN: &str = "</s>";

View File

@@ -1,14 +1,14 @@
use crate::EOS_TOKEN;
use anyhow::{bail, Error as E}; use anyhow::{bail, Error as E};
use candle_core::{utils, DType, Device, Tensor}; use candle_core::{utils, DType, Device, Tensor};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling}; use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama::{Llama, LlamaConfig};
use candle_transformers::models::llama as model; use candle_transformers::models::llama as model;
use candle_transformers::models::llama::{Llama, LlamaConfig};
use clap::ValueEnum;
use hf_hub::api::sync::Api; use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType}; use hf_hub::{Repo, RepoType};
use std::sync::mpsc::{self, Receiver}; use std::sync::mpsc::{self, Receiver};
use clap::ValueEnum;
use crate::{EOS_TOKEN};
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)] #[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
pub enum WhichModel { pub enum WhichModel {
@@ -98,8 +98,6 @@ impl Default for LlamaInferenceConfig {
} }
} }
fn device(cpu: bool) -> anyhow::Result<Device> { fn device(cpu: bool) -> anyhow::Result<Device> {
if cpu { if cpu {
Ok(Device::Cpu) Ok(Device::Cpu)
@@ -112,7 +110,6 @@ fn device(cpu: bool) -> anyhow::Result<Device> {
} }
} }
fn hub_load_safetensors( fn hub_load_safetensors(
api: &hf_hub::api::sync::ApiRepo, api: &hf_hub::api::sync::ApiRepo,
json_file: &str, json_file: &str,
@@ -334,4 +331,3 @@ pub fn run_llama_inference(
Ok(rx) Ok(rx)
} }

View File

@@ -88,7 +88,6 @@ impl Into<LlamaInferenceConfig> for Args {
} }
} }
pub fn run_cli() -> anyhow::Result<()> { pub fn run_cli() -> anyhow::Result<()> {
let args = Args::parse(); let args = Args::parse();
let cfg = args.into(); let cfg = args.into();

View File

@@ -2,8 +2,8 @@
extern crate accelerate_src; extern crate accelerate_src;
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;
mod llama_cli;
mod llama_api; mod llama_api;
mod llama_cli;
use anyhow::Result; use anyhow::Result;
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
@@ -14,7 +14,6 @@ use crate::llama_cli::run_cli;
const EOS_TOKEN: &str = "</s>"; const EOS_TOKEN: &str = "</s>";
fn main() -> Result<()> { fn main() -> Result<()> {
run_cli() run_cli()
} }

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "predict-otron-9000" name = "predict-otron-9000"
version = "0.1.0" version.workspace = true
edition = "2024" edition = "2024"
[[bin]] [[bin]]
@@ -44,4 +44,8 @@ port = 8080
image = "ghcr.io/geoffsee/predict-otron-9000:latest" image = "ghcr.io/geoffsee/predict-otron-9000:latest"
replicas = 1 replicas = 1
port = 8080 port = 8080
env = { SERVER_CONFIG = "" } # SERVER_CONFIG Example: {\"serverMode\":\"HighAvailability\",\"services\":{\"inference_url\":\"http://custom-inference:9000\",\"embeddings_url\":\"http://custom-embeddings:9001\"}}
# you can generate this via node to avoid toil
# const server_config = {serverMode: "HighAvailability", services: {inference_url: "http://custom-inference:9000", embeddings_url: "http://custom-embeddings:9001"} };
# console.log(JSON.stringify(server_config).replace(/"/g, '\\"'));
env = { SERVER_CONFIG = "<your-json-value-here>" }

View File

@@ -1,7 +1,12 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::env; use std::env;
use tracing::info;
#[derive(Debug, Clone, Deserialize, Serialize)] use tracing::log::error;
/// # Generating `SERVER_CONFIG` with Node
// # const server_config = {serverMode: "HighAvailability", services: {inference_url: "http://custom-inference:9000", embeddings_url: "http://custom-embeddings:9001"} };
// # console.log(JSON.stringify(server_config).replace(/"/g, '\\"'));
///
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ServerConfig { pub struct ServerConfig {
#[serde(default = "default_server_host")] #[serde(default = "default_server_host")]
@@ -10,14 +15,16 @@ pub struct ServerConfig {
pub server_port: u16, pub server_port: u16,
pub server_mode: ServerMode, pub server_mode: ServerMode,
#[serde(default)] #[serde(default)]
pub services: Services, pub services: Option<Services>,
} }
fn default_server_host() -> String { fn default_server_host() -> String {
"127.0.0.1".to_string() "127.0.0.1".to_string()
} }
fn default_server_port() -> u16 { 8080 } fn default_server_port() -> u16 {
8080
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "PascalCase")] #[serde(rename_all = "PascalCase")]
@@ -34,17 +41,15 @@ impl Default for ServerMode {
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Services { pub struct Services {
#[serde(default = "inference_service_url")] pub inference_url: Option<String>,
pub inference_url: String, pub embeddings_url: Option<String>,
#[serde(default = "embeddings_service_url")]
pub embeddings_url: String,
} }
impl Default for Services { impl Default for Services {
fn default() -> Self { fn default() -> Self {
Self { Self {
inference_url: inference_service_url(), inference_url: None,
embeddings_url: embeddings_service_url(), embeddings_url: None,
} }
} }
} }
@@ -63,7 +68,7 @@ impl Default for ServerConfig {
server_host: "127.0.0.1".to_string(), server_host: "127.0.0.1".to_string(),
server_port: 8080, server_port: 8080,
server_mode: ServerMode::Standalone, server_mode: ServerMode::Standalone,
services: Services::default(), services: Some(Services::default()),
} }
} }
} }
@@ -73,8 +78,7 @@ impl ServerConfig {
/// Falls back to default (Local mode) if not set or invalid /// Falls back to default (Local mode) if not set or invalid
pub fn from_env() -> Self { pub fn from_env() -> Self {
match env::var("SERVER_CONFIG") { match env::var("SERVER_CONFIG") {
Ok(config_str) => { Ok(config_str) => match serde_json::from_str::<ServerConfig>(&config_str) {
match serde_json::from_str::<ServerConfig>(&config_str) {
Ok(config) => { Ok(config) => {
tracing::info!("Loaded server configuration: {:?}", config); tracing::info!("Loaded server configuration: {:?}", config);
config config
@@ -86,8 +90,7 @@ impl ServerConfig {
); );
ServerConfig::default() ServerConfig::default()
} }
} },
}
Err(_) => { Err(_) => {
tracing::info!("SERVER_CONFIG not set, Standalone mode active"); tracing::info!("SERVER_CONFIG not set, Standalone mode active");
ServerConfig::default() ServerConfig::default()
@@ -96,18 +99,52 @@ impl ServerConfig {
} }
/// Check if the server should run in high availability mode /// Check if the server should run in high availability mode
pub fn is_high_availability(&self) -> bool { pub fn is_high_availability(&self) -> Result<bool, std::io::Error> {
self.server_mode == ServerMode::HighAvailability if self.server_mode == ServerMode::HighAvailability {
let services_well_defined: bool = self.clone().services.is_some();
let inference_url_well_defined: bool =
services_well_defined && self.clone().services.unwrap().inference_url.is_some();
let embeddings_well_defined: bool =
services_well_defined && self.clone().services.unwrap().embeddings_url.is_some();
let is_well_defined_for_ha =
services_well_defined && inference_url_well_defined && embeddings_well_defined;
if !is_well_defined_for_ha {
let config_string = serde_json::to_string_pretty(&self).unwrap();
error!(
"HighAvailability mode configured but services not well defined! \n## Config Used:\n {}",
config_string
);
let err = std::io::Error::new(
std::io::ErrorKind::Other,
"HighAvailability mode configured but services not well defined!",
);
return Err(err);
}
}
Ok(self.server_mode == ServerMode::HighAvailability)
} }
/// Get the inference service URL for proxying /// Get the inference service URL for proxying
pub fn inference_url(&self) -> &str { pub fn inference_url(&self) -> Option<String> {
&self.services.inference_url if self.services.is_some() {
self.services.clone()?.inference_url
} else {
None
}
} }
/// Get the embeddings service URL for proxying /// Get the embeddings service URL for proxying
pub fn embeddings_url(&self) -> &str { pub fn embeddings_url(&self) -> Option<String> {
&self.services.embeddings_url if self.services.is_some() {
self.services.clone()?.embeddings_url
} else {
None
}
} }
} }
@@ -119,7 +156,7 @@ mod tests {
fn test_default_config() { fn test_default_config() {
let config = ServerConfig::default(); let config = ServerConfig::default();
assert_eq!(config.server_mode, ServerMode::Standalone); assert_eq!(config.server_mode, ServerMode::Standalone);
assert!(!config.is_high_availability()); assert!(!config.is_high_availability().unwrap());
} }
#[test] #[test]
@@ -134,23 +171,26 @@ mod tests {
let config: ServerConfig = serde_json::from_str(config_json).unwrap(); let config: ServerConfig = serde_json::from_str(config_json).unwrap();
assert_eq!(config.server_mode, ServerMode::HighAvailability); assert_eq!(config.server_mode, ServerMode::HighAvailability);
assert!(config.is_high_availability()); assert!(config.is_high_availability().unwrap());
assert_eq!(config.inference_url(), "http://inference-service:8080"); assert_eq!(
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080"); config.inference_url().unwrap(),
"http://inference-service:8080"
);
assert_eq!(
config.embeddings_url().unwrap(),
"http://embeddings-service:8080"
);
} }
#[test] #[test]
fn test_local_mode_config() { fn test_local_mode_config() {
let config_json = r#"{ let config_json = r#"{
"serverMode": "Local" "serverMode": "Standalone"
}"#; }"#;
let config: ServerConfig = serde_json::from_str(config_json).unwrap(); let config: ServerConfig = serde_json::from_str(config_json).unwrap();
assert_eq!(config.server_mode, ServerMode::Standalone); assert_eq!(config.server_mode, ServerMode::Standalone);
assert!(!config.is_high_availability()); assert!(!config.is_high_availability().unwrap());
// Should use default URLs
assert_eq!(config.inference_url(), "http://inference-service:8080");
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
} }
#[test] #[test]
@@ -164,17 +204,26 @@ mod tests {
}"#; }"#;
let config: ServerConfig = serde_json::from_str(config_json).unwrap(); let config: ServerConfig = serde_json::from_str(config_json).unwrap();
assert_eq!(config.inference_url(), "http://custom-inference:9000"); assert_eq!(
assert_eq!(config.embeddings_url(), "http://custom-embeddings:9001"); config.inference_url().unwrap(),
"http://custom-inference:9000"
);
assert_eq!(
config.embeddings_url().unwrap(),
"http://custom-embeddings:9001"
);
} }
#[test] #[test]
fn test_minimal_high_availability_config() { fn test_minimal_high_availability_config_error() {
let config_json = r#"{"serverMode": "HighAvailability"}"#; let config_json = r#"{"serverMode": "HighAvailability"}"#;
let config: ServerConfig = serde_json::from_str(config_json).unwrap(); let config: ServerConfig = serde_json::from_str(config_json).unwrap();
assert!(config.is_high_availability());
// Should use default URLs let is_high_availability = config.is_high_availability();
assert_eq!(config.inference_url(), "http://inference-service:8080");
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080"); assert!(is_high_availability.is_err());
// // Should use default URLs
// assert_eq!(config.inference_url().unwrap(), "http://inference-service:8080");
// assert_eq!(config.embeddings_url().unwrap(), "http://embeddings-service:8080");
} }
} }

View File

@@ -1,10 +1,10 @@
use axum::{ use axum::{
Router,
body::Body, body::Body,
extract::{Request, State}, extract::{Request, State},
http::{HeaderMap, Method, StatusCode, Uri}, http::{HeaderMap, Method, StatusCode, Uri},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
routing::{get, post}, routing::{get, post},
Router,
}; };
use reqwest::Client; use reqwest::Client;
use serde_json::Value; use serde_json::Value;
@@ -12,6 +12,120 @@ use std::time::Duration;
use crate::config::ServerConfig; use crate::config::ServerConfig;
/// # Generating `SERVER_CONFIG` for TOML using Node.js
///
/// You can still use the Node.js REPL to build the JSON, but when pasting into
/// a `.toml` file you must follow TOML's string rules. Below are the safest patterns.
///
/// ## 1) Generate the JSON in Node
/// ```bash
/// node
/// ```
/// ```javascript
/// const myobject = {
/// serverMode: "HighAvailability",
/// services: {
/// inference_url: "http://custom-inference:9000",
/// embeddings_url: "http://custom-embeddings:9001"
/// }
/// };
/// const json = JSON.stringify(myobject);
/// json
/// // -> '{"serverMode":"HighAvailability","services":{"inference_url":"http://custom-inference:9000","embeddings_url":"http://custom-embeddings:9001"}}'
/// ```
///
/// ## 2) Put it into `.toml`
///
/// ### Option A (recommended): single-quoted TOML *literal* string
/// Single quotes in TOML mean "no escaping", so your inner double quotes are safe.
/// ```toml
/// SERVER_CONFIG = '{"serverMode":"HighAvailability","services":{"inference_url":"http://custom-inference:9000","embeddings_url":"http://custom-embeddings:9001"}}'
/// ```
///
/// ### Option B: double-quoted TOML string (must escape inner quotes)
/// If you *must* use double quotes in TOML, escape all `"` inside the JSON.
/// You can have Node do this for you:
/// ```javascript
/// // In Node:
/// const jsonForToml = JSON.stringify(myobject).replace(/"/g, '\\"');
/// jsonForToml
/// // -> \"{\\\"serverMode\\\":\\\"HighAvailability\\\",...}\"
/// ```
/// Then paste into TOML:
/// ```toml
/// SERVER_CONFIG = "{\"serverMode\":\"HighAvailability\",\"services\":{\"inference_url\":\"http://custom-inference:9000\",\"embeddings_url\":\"http://custom-embeddings:9001\"}}"
/// ```
///
/// ### Option C: multi-line literal (for pretty JSON)
/// If you want pretty-printed JSON in the file, use TOML's triple single quotes:
/// ```javascript
/// // In Node (pretty with 2 spaces):
/// const pretty = JSON.stringify(myobject, null, 2);
/// ```
/// ```toml
/// SERVER_CONFIG = '''{
/// "serverMode": "HighAvailability",
/// "services": {
/// "inference_url": "http://custom-inference:9000",
/// "embeddings_url": "http://custom-embeddings:9001"
/// }
/// }'''
/// ```
///
/// ## 3) Reading it in Rust
///
/// If `SERVER_CONFIG` is stored as a **string** in TOML (Options A/B/C):
/// ```rust
/// use serde_json::Value;
///
/// // Suppose you've already loaded your .toml into a struct or a toml::Value:
/// // e.g., struct FileCfg { pub SERVER_CONFIG: String }
/// fn parse_server_config(raw: &str) -> anyhow::Result<Value> {
/// let v: Value = serde_json::from_str(raw)?;
/// Ok(v)
/// }
/// ```
///
/// ### Alternative: store it as TOML tables and serialize to JSON at runtime
/// Instead of a JSON string, you can make the TOML first-class tables:
/// ```toml
/// [SERVER_CONFIG]
/// serverMode = "HighAvailability"
///
/// [SERVER_CONFIG.services]
/// inference_url = "http://custom-inference:9000"
/// embeddings_url = "http://custom-embeddings:9001"
/// ```
/// ```rust
/// use serde::{Deserialize, Serialize};
/// use serde_json::Value;
///
/// #[derive(Debug, Serialize, Deserialize)]
/// struct Services {
/// inference_url: String,
/// embeddings_url: String,
/// }
///
/// #[derive(Debug, Serialize, Deserialize)]
/// struct ServerConfig {
/// serverMode: String,
/// services: Services,
/// }
///
/// // After loading the .toml (e.g., via `toml::from_str`):
/// // let cfg: ServerConfig = toml::from_str(toml_str)?;
/// // Convert to JSON if needed:
/// fn to_json(cfg: &ServerConfig) -> serde_json::Result<Value> {
/// Ok(serde_json::to_value(cfg)?)
/// }
/// ```
///
/// ## Gotchas
/// - Prefer **single-quoted** TOML strings for raw JSON to avoid escaping.
/// - If you use **double-quoted** TOML strings, escape every inner `"` in the JSON.
/// - Pretty JSON is fine in TOML using `''' ... '''`, but remember the newlines are part of the string.
/// - If you control the consumer, TOML tables (the alternative above) are more ergonomic than embedding JSON.
/// HTTP client configured for proxying requests /// HTTP client configured for proxying requests
#[derive(Clone)] #[derive(Clone)]
pub struct ProxyClient { pub struct ProxyClient {
@@ -31,7 +145,7 @@ impl ProxyClient {
} }
/// Create a router that proxies requests to external services in HighAvailability mode /// Create a router that proxies requests to external services in HighAvailability mode
pub fn create_proxy_router(config: ServerConfig) -> Router { pub fn create_ha_router(config: ServerConfig) -> Router {
let proxy_client = ProxyClient::new(config.clone()); let proxy_client = ProxyClient::new(config.clone());
Router::new() Router::new()
@@ -47,7 +161,13 @@ async fn proxy_chat_completions(
headers: HeaderMap, headers: HeaderMap,
body: Body, body: Body,
) -> Result<Response, StatusCode> { ) -> Result<Response, StatusCode> {
let target_url = format!("{}/v1/chat/completions", proxy_client.config.inference_url()); let target_url = format!(
"{}/v1/chat/completions",
proxy_client
.config
.inference_url()
.expect("Invalid Configuration")
);
tracing::info!("Proxying chat completions request to: {}", target_url); tracing::info!("Proxying chat completions request to: {}", target_url);
@@ -63,7 +183,9 @@ async fn proxy_chat_completions(
// Check if this is a streaming request // Check if this is a streaming request
let is_streaming = if let Ok(body_str) = String::from_utf8(body_bytes.to_vec()) { let is_streaming = if let Ok(body_str) = String::from_utf8(body_bytes.to_vec()) {
if let Ok(json) = serde_json::from_str::<Value>(&body_str) { if let Ok(json) = serde_json::from_str::<Value>(&body_str) {
json.get("stream").and_then(|v| v.as_bool()).unwrap_or(false) json.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false)
} else { } else {
false false
} }
@@ -72,7 +194,8 @@ async fn proxy_chat_completions(
}; };
// Forward the request // Forward the request
let mut req_builder = proxy_client.client let mut req_builder = proxy_client
.client
.post(&target_url) .post(&target_url)
.body(body_bytes.to_vec()); .body(body_bytes.to_vec());
@@ -85,8 +208,7 @@ async fn proxy_chat_completions(
match req_builder.send().await { match req_builder.send().await {
Ok(response) => { Ok(response) => {
let mut resp_builder = Response::builder() let mut resp_builder = Response::builder().status(response.status());
.status(response.status());
// Forward response headers // Forward response headers
for (name, value) in response.headers().iter() { for (name, value) in response.headers().iter() {
@@ -99,14 +221,12 @@ async fn proxy_chat_completions(
if is_streaming { if is_streaming {
// For streaming, we need to forward the response as-is // For streaming, we need to forward the response as-is
match response.bytes().await { match response.bytes().await {
Ok(body) => { Ok(body) => resp_builder
resp_builder
.header("content-type", "text/plain; charset=utf-8") .header("content-type", "text/plain; charset=utf-8")
.header("cache-control", "no-cache") .header("cache-control", "no-cache")
.header("connection", "keep-alive") .header("connection", "keep-alive")
.body(Body::from(body)) .body(Body::from(body))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
}
Err(e) => { Err(e) => {
tracing::error!("Failed to read streaming response body: {}", e); tracing::error!("Failed to read streaming response body: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR) Err(StatusCode::INTERNAL_SERVER_ERROR)
@@ -115,11 +235,9 @@ async fn proxy_chat_completions(
} else { } else {
// For non-streaming, forward the JSON response // For non-streaming, forward the JSON response
match response.bytes().await { match response.bytes().await {
Ok(body) => { Ok(body) => resp_builder
resp_builder
.body(Body::from(body)) .body(Body::from(body))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
}
Err(e) => { Err(e) => {
tracing::error!("Failed to read response body: {}", e); tracing::error!("Failed to read response body: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR) Err(StatusCode::INTERNAL_SERVER_ERROR)
@@ -139,7 +257,13 @@ async fn proxy_models(
State(proxy_client): State<ProxyClient>, State(proxy_client): State<ProxyClient>,
headers: HeaderMap, headers: HeaderMap,
) -> Result<Response, StatusCode> { ) -> Result<Response, StatusCode> {
let target_url = format!("{}/v1/models", proxy_client.config.inference_url()); let target_url = format!(
"{}/v1/models",
proxy_client
.config
.inference_url()
.expect("Invalid Configuration Detected")
);
tracing::info!("Proxying models request to: {}", target_url); tracing::info!("Proxying models request to: {}", target_url);
@@ -154,8 +278,7 @@ async fn proxy_models(
match req_builder.send().await { match req_builder.send().await {
Ok(response) => { Ok(response) => {
let mut resp_builder = Response::builder() let mut resp_builder = Response::builder().status(response.status());
.status(response.status());
// Forward response headers // Forward response headers
for (name, value) in response.headers().iter() { for (name, value) in response.headers().iter() {
@@ -165,11 +288,9 @@ async fn proxy_models(
} }
match response.bytes().await { match response.bytes().await {
Ok(body) => { Ok(body) => resp_builder
resp_builder
.body(Body::from(body)) .body(Body::from(body))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
}
Err(e) => { Err(e) => {
tracing::error!("Failed to read models response body: {}", e); tracing::error!("Failed to read models response body: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR) Err(StatusCode::INTERNAL_SERVER_ERROR)
@@ -189,7 +310,13 @@ async fn proxy_embeddings(
headers: HeaderMap, headers: HeaderMap,
body: Body, body: Body,
) -> Result<Response, StatusCode> { ) -> Result<Response, StatusCode> {
let target_url = format!("{}/v1/embeddings", proxy_client.config.embeddings_url()); let target_url = format!(
"{}/v1/embeddings",
proxy_client
.config
.embeddings_url()
.expect("Invalid Configuration Detected")
);
tracing::info!("Proxying embeddings request to: {}", target_url); tracing::info!("Proxying embeddings request to: {}", target_url);
@@ -203,7 +330,8 @@ async fn proxy_embeddings(
}; };
// Forward the request // Forward the request
let mut req_builder = proxy_client.client let mut req_builder = proxy_client
.client
.post(&target_url) .post(&target_url)
.body(body_bytes.to_vec()); .body(body_bytes.to_vec());
@@ -216,8 +344,7 @@ async fn proxy_embeddings(
match req_builder.send().await { match req_builder.send().await {
Ok(response) => { Ok(response) => {
let mut resp_builder = Response::builder() let mut resp_builder = Response::builder().status(response.status());
.status(response.status());
// Forward response headers // Forward response headers
for (name, value) in response.headers().iter() { for (name, value) in response.headers().iter() {
@@ -227,11 +354,9 @@ async fn proxy_embeddings(
} }
match response.bytes().await { match response.bytes().await {
Ok(body) => { Ok(body) => resp_builder
resp_builder
.body(Body::from(body)) .body(Body::from(body))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
}
Err(e) => { Err(e) => {
tracing::error!("Failed to read embeddings response body: {}", e); tracing::error!("Failed to read embeddings response body: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR) Err(StatusCode::INTERNAL_SERVER_ERROR)
@@ -290,14 +415,20 @@ mod tests {
server_host: "127.0.0.1".to_string(), server_host: "127.0.0.1".to_string(),
server_port: 8080, server_port: 8080,
server_mode: ServerMode::HighAvailability, server_mode: ServerMode::HighAvailability,
services: Services { services: Some(Services {
inference_url: "http://test-inference:8080".to_string(), inference_url: Some("http://test-inference:8080".to_string()),
embeddings_url: "http://test-embeddings:8080".to_string(), embeddings_url: Some("http://test-embeddings:8080".to_string()),
}, }),
}; };
let proxy_client = ProxyClient::new(config); let proxy_client = ProxyClient::new(config);
assert_eq!(proxy_client.config.inference_url(), "http://test-inference:8080"); assert_eq!(
assert_eq!(proxy_client.config.embeddings_url(), "http://test-embeddings:8080"); proxy_client.config.inference_url().unwrap().as_str(),
"http://test-inference:8080"
);
assert_eq!(
proxy_client.config.embeddings_url().unwrap().as_str(),
"http://test-embeddings:8080"
);
} }
} }

View File

@@ -1,16 +1,19 @@
mod config; mod config;
mod ha_mode;
mod middleware; mod middleware;
mod proxy; mod standalone_mode;
use crate::standalone_mode::create_standalone_router;
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::get; use axum::routing::get;
use axum::{Router, http::Uri, response::Html, serve}; use axum::{Router, http::Uri, response::Html, serve};
use config::ServerConfig; use config::ServerConfig;
use ha_mode::create_ha_router;
use inference_engine::AppState; use inference_engine::AppState;
use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore}; use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
use proxy::create_proxy_router;
use rust_embed::Embed; use rust_embed::Embed;
use std::env; use std::env;
use std::path::Component::ParentDir;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tower_http::classify::ServerErrorsFailureClass::StatusCode; use tower_http::classify::ServerErrorsFailureClass::StatusCode;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
@@ -49,33 +52,19 @@ async fn main() {
let default_host = server_config.server_host.clone(); let default_host = server_config.server_host.clone();
let default_port = server_config.server_port; let default_port = server_config.server_port;
// Create router based on server mode let service_router = match server_config.clone().is_high_availability() {
let service_router = if server_config.clone().is_high_availability() { Ok(is_ha) => {
tracing::info!("Running in HighAvailability mode - proxying to external services"); if is_ha {
tracing::info!(" Inference service URL: {}", server_config.inference_url()); log_config(server_config.clone());
tracing::info!( create_ha_router(server_config.clone())
" Embeddings service URL: {}",
server_config.embeddings_url()
);
// Use proxy router that forwards requests to external services
create_proxy_router(server_config.clone())
} else { } else {
tracing::info!("Running in Standalone mode - using embedded services"); log_config(server_config.clone());
create_standalone_router(server_config)
// Create unified router by merging embeddings and inference routers (existing behavior) }
let embeddings_router = embeddings_engine::create_embeddings_router(); }
Err(error) => {
// Create AppState with correct model configuration panic!("{}", error);
let app_state = AppState::default(); }
// Get the inference router directly from the inference engine
let inference_router = inference_engine::create_router(app_state);
// Merge the local routers
Router::new()
.merge(embeddings_router)
.merge(inference_router)
}; };
// Create CORS layer // Create CORS layer
@@ -124,5 +113,25 @@ async fn main() {
serve(listener, app).await.unwrap(); serve(listener, app).await.unwrap();
} }
fn log_config(config: ServerConfig) {
match config.is_high_availability() {
Ok(is_high) => {
if is_high {
tracing::info!("Running in HighAvailability mode - proxying to external services");
tracing::info!("Inference service URL: {}", config.inference_url().unwrap());
tracing::info!(
"Embeddings service URL: {}",
config.embeddings_url().unwrap()
);
} else {
tracing::info!("Running in Standalone mode");
}
}
Err(error) => {
panic!("{}", error);
}
}
}
// Chat completions handler that properly uses the inference server crate's error handling // Chat completions handler that properly uses the inference server crate's error handling
// This function is no longer needed as we're using the inference_engine router directly // This function is no longer needed as we're using the inference_engine router directly

View File

@@ -2,6 +2,8 @@ use axum::{
extract::MatchedPath, extract::MatchedPath,
http::{Request, Response}, http::{Request, Response},
}; };
use std::fmt;
use std::task::ready;
use std::{ use std::{
future::Future, future::Future,
pin::Pin, pin::Pin,
@@ -12,8 +14,6 @@ use std::{
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tower::{Layer, Service}; use tower::{Layer, Service};
use tracing::{debug, info}; use tracing::{debug, info};
use std::task::ready;
use std::fmt;
/// Performance metrics for a specific endpoint /// Performance metrics for a specific endpoint
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
@@ -56,7 +56,10 @@ impl EndpointMetrics {
pub fn summary(&self) -> String { pub fn summary(&self) -> String {
format!( format!(
"requests: {}, avg: {:.2}ms, min: {}ms, max: {}ms", "requests: {}, avg: {:.2}ms, min: {}ms, max: {}ms",
self.count, self.avg_time_ms(), self.min_time_ms, self.max_time_ms self.count,
self.avg_time_ms(),
self.min_time_ms,
self.max_time_ms
) )
} }
} }
@@ -79,7 +82,9 @@ impl MetricsStore {
/// Record a request's timing information /// Record a request's timing information
pub async fn record(&self, path: String, time_ms: u64) { pub async fn record(&self, path: String, time_ms: u64) {
let mut endpoints = self.endpoints.lock().await; let mut endpoints = self.endpoints.lock().await;
let metrics = endpoints.entry(path).or_insert_with(EndpointMetrics::default); let metrics = endpoints
.entry(path)
.or_insert_with(EndpointMetrics::default);
metrics.add_response_time(time_ms); metrics.add_response_time(time_ms);
} }
@@ -178,7 +183,9 @@ where
let time_ms = time.as_millis() as u64; let time_ms = time.as_millis() as u64;
// Record the timing in our metrics store // Record the timing in our metrics store
metrics_store.record(format!("{} {}", method, path), time_ms).await; metrics_store
.record(format!("{} {}", method, path), time_ms)
.await;
// Log the request timing // Log the request timing
debug!("{} {} {} - {} ms", method, path, status, time_ms); debug!("{} {} {} - {} ms", method, path, status, time_ms);

View File

@@ -1,7 +1,3 @@
pub mod metrics; pub mod metrics;
pub use metrics::{ pub use metrics::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
MetricsStore,
MetricsLoggerFuture,
MetricsLayer,
};

View File

@@ -0,0 +1,19 @@
use crate::config::ServerConfig;
use axum::Router;
use inference_engine::AppState;
pub fn create_standalone_router(server_config: ServerConfig) -> Router {
// Create unified router by merging embeddings and inference routers (existing behavior)
let embeddings_router = embeddings_engine::create_embeddings_router();
// Create AppState with correct model configuration
let app_state = AppState::default();
// Get the inference router directly from the inference engine
let inference_router = inference_engine::create_router(app_state);
// Merge the local routers
Router::new()
.merge(embeddings_router)
.merge(inference_router)
}

View File

@@ -1,8 +1,8 @@
{ {
"dependencies": { "name": "predict-otron-9000",
"openai": "^5.16.0" "workspaces": ["crates/cli/package"],
},
"scripts": { "scripts": {
"cli": "./scripts/cli.ts" "# WORKSPACE ALIASES": "#",
"cli": "bun --filter crates/cli/package"
} }
} }

View File

@@ -1,30 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
PROMPT=${1:-"Say hello in one short sentence."}
MODEL=${2:-"meta-llama/Llama-3.2-1B-Instruct"}
MAX_NEW=${3:-64}
FORCE_CPU=${FORCE_CPU:-0}
# Optional: keep HF cache local to repo if not already set
export HF_HOME=${HF_HOME:-"$PWD/.hf-cache"}
BIN="$(dirname "$0")/../target/release/llama_infer"
if [[ ! -x "$BIN" ]]; then
echo "Building llama-runner (release)..."
cargo build -p llama-runner --release
fi
echo "Running llama inference..." >&2
ARGS=(
--model-id "$MODEL"
--prompt "$PROMPT"
--max-new-tokens "$MAX_NEW"
)
if [[ "$FORCE_CPU" == "1" || "$FORCE_CPU" == "true" ]]; then
ARGS+=( --force-cpu )
fi
"$BIN" "${ARGS[@]}"