32 Commits

Author SHA1 Message Date
geoffsee
4380ac69d3 v0.1.5 already exists 2025-09-04 15:09:30 -04:00
geoffsee
e6f3351ebb minor version 2025-09-04 15:08:43 -04:00
geoffsee
3992532f15 fmt and clippy 2025-09-04 15:07:49 -04:00
geoffsee
3ecdd9ffa0 update deployment tooling to remove dependencies on unused metadata 2025-09-04 15:03:17 -04:00
geoffsee
296d4dbe7e add root dockerfile that contains binaries for all services 2025-09-04 14:54:20 -04:00
geoffsee
fb5098eba6 fix clippy errors 2025-09-04 13:53:00 -04:00
geoffsee
c1c583faab run cargo fmt 2025-09-04 13:45:25 -04:00
geoffsee
1e02b12cda fixes issue with model selection 2025-09-04 13:42:30 -04:00
geoffsee
ff55d882c7 reorg + update docs with new paths 2025-09-04 12:40:59 -04:00
geoffsee
400c70f17d streaming implementaion re-added to UI 2025-09-02 14:45:16 -04:00
geoffsee
bcbc6c4693 fix invalid endpoint in curl_stream_script.sh 2025-09-02 13:58:34 -04:00
geoffsee
21f20470de patch version 2025-09-01 22:55:59 -04:00
geoffsee
2deecb5e51 chat client only displays available models 2025-09-01 22:29:54 -04:00
geoffsee
545e0c9831 make wasm32 availble for all builds in ci 2025-08-31 20:22:12 -04:00
geoffsee
eca61c51ad add build step to ci 2025-08-31 20:08:54 -04:00
geoffsee
d1a7d5b28e fix format error 2025-08-31 19:59:09 -04:00
geoffsee
8d2b85b0b9 update docs 2025-08-31 19:27:15 -04:00
geoffsee
4570780666 release 0.1.3 2025-08-31 18:55:37 -04:00
geoffsee
44e4f9e5e1 put proof in the pudding 2025-08-31 18:54:20 -04:00
geoffsee
64daa77c6b leptos chat ui renders 2025-08-31 18:50:25 -04:00
geoffsee
2b4a8a9df8 chat-ui not functional yet but builds 2025-08-31 18:18:56 -04:00
geoffsee
38d51722f2 Update configuration loading with Cargo.toml path and clean up .gitignore
---

This commit message concisely communicates the key changes:

1. The code now builds an absolute path to the `Cargo.toml` file, enhancing clarity in configuration loading.
2. The addition of `PathBuf` usage improves type safety.
3. The removal of unnecessary entries from `.gitignore` helps maintain a clean project structure.

These updates reflect improvements in both functionality and project organization.
2025-08-31 14:06:44 -04:00
geoffsee
7bc9479a11 fix format issues, needs precommit hook 2025-08-31 13:24:51 -04:00
geoffsee
0580dc8c5e move cli into crates and stage for release 2025-08-31 13:23:50 -04:00
geoffsee
9e9aa69769 bump version in Cargo.toml 2025-08-31 11:04:31 -04:00
geoffsee
3eb1a5329b add rust compiler optimizations at workspace level, bump minor version and publish first release 2025-08-31 11:02:58 -04:00
geoffsee
eb1591aa5d fix fmt error 2025-08-31 10:52:48 -04:00
geoffsee
e6c417bd83 align dependencies across inference features 2025-08-31 10:49:04 -04:00
geoffsee
f5d2a85f2e cleanup, add ci 2025-08-31 10:31:20 -04:00
Geoff Seemueller
419e1c2ea7 fix Kubernetes spelling 2025-08-30 08:24:24 -04:00
Geoff Seemueller
06fdfcf898 clarify project intent 2025-08-30 08:23:38 -04:00
geoffsee
315ef17605 supports small llama and gemma models
Refactor inference

dedicated crates for llama and gemma inferencing, not integrated
2025-08-29 20:00:41 -04:00
97 changed files with 7845 additions and 5160 deletions

35
.dockerignore Normal file
View File

@@ -0,0 +1,35 @@
# Git
.git
.gitignore
# Rust
target/
# Documentation
README.md
*.md
# IDE
.vscode/
.idea/
*.swp
*.swo
# OS
.DS_Store
Thumbs.db
# Logs
*.log
# Environment
.env
.env.local
# Dependencies
node_modules/
# Build artifacts
dist/
build/
.fastembed_cache

49
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,49 @@
version: 2
updates:
# Monitor Rust dependencies in the main crate
- package-ecosystem: "cargo"
directory: "/crates/predict-otron-9000"
schedule:
interval: "weekly"
day: "monday"
time: "09:00"
timezone: "UTC"
# Focus on security updates with higher priority
open-pull-requests-limit: 10
reviewers:
- "security-team"
assignees:
- "maintainer"
labels:
- "dependencies"
- "security"
# Security updates get higher priority
allow:
- dependency-type: "all"
# Group minor and patch updates to reduce noise
# Separate major updates for careful review
ignore:
- dependency-name: "*"
update-types: ["version-update:semver-major"]
commit-message:
prefix: "deps"
include: "scope"
# Monitor security updates more frequently
- package-ecosystem: "cargo"
directory: "/crates/predict-otron-9000"
schedule:
interval: "daily"
# Only security updates in daily checks
allow:
- dependency-type: "direct"
update-types: ["security"]
- dependency-type: "indirect"
update-types: ["security"]
open-pull-requests-limit: 5
labels:
- "security-update"
- "high-priority"
commit-message:
prefix: "security"
include: "scope"

56
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,56 @@
name: CI
on:
push:
pull_request:
jobs:
build:
name: build-and-test
runs-on: ubuntu-latest
strategy:
fail-fast: false
steps:
- name: Checkout
uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Setup Rust
run: rustup update stable && rustup default stable && rustup target add wasm32-unknown-unknown
- name: Setup Bun
uses: oven-sh/setup-bun@v2
- name: Build
run: |
cargo install --locked cargo-leptos
cd crates/chat-ui && cargo leptos build --release
cargo build --release -p predict-otron-9000 -p cli
- name: Install clippy and rustfmt
run: rustup component add clippy rustfmt
- name: Cargo fmt (check)
run: cargo fmt --all -- --check
- name: Clippy
shell: bash
run: cargo clippy --all
- name: Tests
shell: bash
run: cargo test --all
- name: Build Docs
shell: bash
run: |
cargo doc -p predict-otron-9000 --no-deps

46
.github/workflows/docker.yml vendored Normal file
View File

@@ -0,0 +1,46 @@
name: Build and Push Docker Image
on:
tags:
- 'v*'
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
jobs:
build-and-push:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Log in to Container Registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=sha
- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

240
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,240 @@
name: Release
on:
push:
tags:
- 'v*'
env:
CARGO_TERM_COLOR: always
jobs:
test:
name: Test before release
runs-on: ubuntu-latest
defaults:
run:
working-directory: crates/predict-otron-9000
strategy:
fail-fast: false
steps:
- name: Checkout
uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Setup Rust
run: rustup update stable && rustup default stable && rustup target add wasm32-unknown-unknown
- name: Setup Bun
uses: oven-sh/setup-bun@v2
- name: Install clippy and rustfmt
run: rustup component add clippy rustfmt
- name: Cargo fmt (check)
run: cargo fmt --all -- --check
- name: Clippy
shell: bash
run: cargo clippy --all
- name: Tests
shell: bash
run: cargo test --all
# publish:
# name: Publish to crates.io
# runs-on: ubuntu-latest
# permissions:
# id-token: write # Required for OIDC token exchange https://crates.io/docs/trusted-publishing
# needs: test
# defaults:
# run:
# working-directory: crates/predict-otron-9000
# steps:
# - name: Checkout
# uses: actions/checkout@v4
#
# - uses: actions/cache@v4
# with:
# path: |
# ~/.cargo/bin/
# ~/.cargo/registry/index/
# ~/.cargo/registry/cache/
# ~/.cargo/git/db/
# target/
# key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
#
# - name: Setup Rust
# run: rustup update stable && rustup default stable
#
# - name: Verify tag matches version
# run: |
# TAG_VERSION=${GITHUB_REF#refs/tags/v}
# CARGO_VERSION=$(cargo metadata --no-deps --format-version 1 | jq -r '.packages[0].version')
# if [ "$TAG_VERSION" != "$CARGO_VERSION" ]; then
# echo "Tag version ($TAG_VERSION) does not match Cargo.toml version ($CARGO_VERSION)"
# exit 1
# fi
#
# # See Trusted publishing: https://crates.io/docs/trusted-publishing
# - uses: rust-lang/crates-io-auth-action@v1
# id: auth
#
# - run: cargo publish
# env:
# CARGO_REGISTRY_TOKEN: ${{ steps.auth.outputs.token }}
build-binaries:
name: Build binaries
runs-on: ${{ matrix.os }}
needs: test
strategy:
fail-fast: false
matrix:
include:
- target: x86_64-unknown-linux-gnu
os: ubuntu-latest
name: predict-otron-9000-x86_64-unknown-linux-gnu
- target: x86_64-apple-darwin
os: macos-latest
name: predict-otron-9000-x86_64-apple-darwin
- target: aarch64-apple-darwin
os: macos-latest
name: predict-otron-9000-aarch64-apple-darwin
- target: x86_64-pc-windows-msvc
os: windows-latest
name: predict-otron-9000-x86_64-pc-windows-msvc.exe
steps:
- name: Checkout
uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
key: ${{ runner.os }}-${{ matrix.target }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Setup Rust
run: rustup update stable && rustup default stable && rustup target add wasm32-unknown-unknown
- name: Add target
run: rustup target add ${{ matrix.target }}
- name: Build UI
run: cargo install --locked cargo-leptos && cd crates/chat-ui && cargo leptos build --release
env:
CARGO_TERM_COLOR: always
- name: Build Binary
run: cargo build --release --target ${{ matrix.target }} -p predict-otron-9000 -p cli
env:
CARGO_TERM_COLOR: always
- name: Package binary (Unix)
if: matrix.os != 'windows-latest'
run: |
cd target/${{ matrix.target }}/release
tar czf ../../../${{ matrix.name }}.tar.gz predict-otron-9000 cli
cd ../../../
- name: Package binary (Windows)
if: matrix.os == 'windows-latest'
run: |
cd target/${{ matrix.target }}/release
7z a ../../../${{ matrix.name }}.zip predict-otron-9000.exe cli.exe
cd ../../../
- name: Upload binary artifacts (Unix)
if: matrix.os != 'windows-latest'
uses: actions/upload-artifact@v4
with:
name: ${{ matrix.name }}
path: ${{ matrix.name }}.tar.gz
- name: Upload binary artifacts (Windows)
if: matrix.os == 'windows-latest'
uses: actions/upload-artifact@v4
with:
name: ${{ matrix.name }}
path: ${{ matrix.name }}.zip
release:
name: Create GitHub Release
runs-on: ubuntu-latest
needs: [test, build-binaries]
permissions:
contents: write
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Extract tag name
id: tag
run: echo "tag=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
- name: Generate changelog
id: changelog
run: |
# Get the previous tag
PREV_TAG=$(git describe --tags --abbrev=0 HEAD^ 2>/dev/null || echo "")
# Generate changelog
if [ -n "$PREV_TAG" ]; then
echo "## What's Changed" > changelog.md
echo "" >> changelog.md
git log --pretty=format:"* %s (%h)" ${PREV_TAG}..HEAD >> changelog.md
echo "" >> changelog.md
echo "" >> changelog.md
echo "**Full Changelog**: https://github.com/${{ github.repository }}/compare/${PREV_TAG}...${{ steps.tag.outputs.tag }}" >> changelog.md
else
echo "## What's Changed" > changelog.md
echo "" >> changelog.md
echo "Initial release of predict-otron-9000" >> changelog.md
echo "" >> changelog.md
echo "OpenAI Compatible Inference Server" >> changelog.md
fi
# Set the changelog as output (handle multiline)
echo "changelog<<EOF" >> $GITHUB_OUTPUT
cat changelog.md >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
- name: Download all artifacts
uses: actions/download-artifact@v4
with:
path: artifacts
- name: Create Release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
if [[ "${{ steps.tag.outputs.tag }}" == *"-"* ]]; then
PRERELEASE_FLAG="--prerelease"
else
PRERELEASE_FLAG=""
fi
gh release create "${{ steps.tag.outputs.tag }}" \
--title "Release ${{ steps.tag.outputs.tag }}" \
--notes-file changelog.md \
$PRERELEASE_FLAG \
artifacts/predict-otron-9000-x86_64-unknown-linux-gnu/predict-otron-9000-x86_64-unknown-linux-gnu.tar.gz \
artifacts/predict-otron-9000-x86_64-apple-darwin/predict-otron-9000-x86_64-apple-darwin.tar.gz \
artifacts/predict-otron-9000-aarch64-apple-darwin/predict-otron-9000-aarch64-apple-darwin.tar.gz \
artifacts/predict-otron-9000-x86_64-pc-windows-msvc.exe/predict-otron-9000-x86_64-pc-windows-msvc.exe.zip

6
.gitignore vendored
View File

@@ -23,7 +23,6 @@ package-lock.json
# Web frontend build outputs
dist/
.trunk/
# ML model and embedding caches
.fastembed_cache/
@@ -75,4 +74,7 @@ venv/
# Backup files
*.bak
*.backup
*~
!/scripts/cli.ts
/**/.*.bun-build
/AGENTS.md
.claude

1107
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -3,13 +3,42 @@ members = [
"crates/predict-otron-9000",
"crates/inference-engine",
"crates/embeddings-engine",
"crates/leptos-app",
"crates/helm-chart-tool"
]
"integration/helm-chart-tool",
"integration/llama-runner",
"integration/gemma-runner",
"integration/cli",
"crates/chat-ui"
, "integration/utils"]
default-members = ["crates/predict-otron-9000"]
resolver = "2"
[[workspace.metadata.leptos]]
# project name
bin-package = "leptos-app"
lib-package = "leptos-app"
[workspace.package]
version = "0.1.6"
# Compiler optimization profiles for the workspace
[profile.release]
opt-level = 3
debug = false
strip = true
lto = "thin"
codegen-units = 1
panic = "abort"
[profile.dev]
opt-level = 0
debug = true
strip = false
overflow-checks = true
# Profile for fast development builds with some optimization
[profile.dev-opt]
inherits = "dev"
opt-level = 1
debug = true
overflow-checks = true
# Profile for benchmarking and profiling
[profile.bench]
opt-level = 3
debug = true
lto = "thin"

50
Dockerfile Normal file
View File

@@ -0,0 +1,50 @@
# Multi-stage build for predict-otron-9000 workspace
FROM rust:1 AS builder
# Install system dependencies
RUN apt-get update && apt-get install -y \
pkg-config \
libssl-dev \
&& rm -rf /var/lib/apt/lists/*
# Create app directory
WORKDIR /app
# Copy workspace files
COPY Cargo.toml Cargo.lock ./
COPY crates/ ./crates/
COPY integration/ ./integration/
# Build all 3 main server binaries in release mode
RUN cargo build --release -p predict-otron-9000 --bin predict-otron-9000 --no-default-features -p embeddings-engine --bin embeddings-engine -p inference-engine --bin inference-engine
# Runtime stage
FROM debian:bookworm-slim AS runtime
# Install runtime dependencies
RUN apt-get update && apt-get install -y \
ca-certificates \
libssl3 \
&& rm -rf /var/lib/apt/lists/*
# Create app user
RUN useradd -r -s /bin/false -m -d /app appuser
# Set working directory
WORKDIR /app
# Copy binaries from builder stage
COPY --from=builder /app/target/release/predict-otron-9000 ./bin/
COPY --from=builder /app/target/release/embeddings-engine ./bin/
COPY --from=builder /app/target/release/inference-engine ./bin/
# Make binaries executable and change ownership
RUN chmod +x ./bin/* && chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Expose ports (adjust as needed based on your services)
EXPOSE 8080 8081 8082
# Default command (can be overridden)
CMD ["./bin/predict-otron-9000"]

View File

@@ -1,16 +1,32 @@
# predict-otron-9000
<h1 align="center">
predict-otron-9000
</h1>
<p align="center">
AI inference Server with OpenAI-compatible API (Limited Features)
</p>
<p align="center">
<img src="https://github.com/geoffsee/predict-otron-9001/blob/master/predict-otron-9000.png?raw=true" width="90%" />
</p>
<br/>
> This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks.
> By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems.
Stability is currently best-effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces.
<p align="center">
Powerful local AI inference with OpenAI-compatible APIs
</p>
~~~shell
./scripts/run.sh
~~~
## Project Overview
The predict-otron-9000 is a flexible AI platform that provides:
- **Local LLM Inference**: Run Gemma models locally with CPU or GPU acceleration
- **Local LLM Inference**: Run Gemma and Llama models locally with CPU or GPU acceleration
- **Embeddings Generation**: Create text embeddings with FastEmbed
- **Web Interface**: Interact with models through a Leptos WASM chat interface
- **TypeScript CLI**: Command-line client for testing and automation
@@ -22,31 +38,39 @@ The system supports both CPU and GPU acceleration (CUDA/Metal), with intelligent
- **OpenAI Compatible**: API endpoints match OpenAI's format for easy integration
- **Text Embeddings**: Generate high-quality text embeddings using FastEmbed
- **Text Generation**: Chat completions with OpenAI-compatible API using Gemma models (1B, 2B, 7B variants including instruction-tuned models)
- **Text Generation**: Chat completions with OpenAI-compatible API using Gemma and Llama models (various sizes including instruction-tuned variants)
- **Performance Optimized**: Efficient caching and platform-specific optimizations for improved throughput
- **Web Chat Interface**: Leptos-based WebAssembly (WASM) chat interface for browser-based interaction
- **Web Chat Interface**: Leptos chat interface
- **Flexible Deployment**: Run as monolithic service or microservices architecture
## Architecture Overview
### Workspace Structure
The project uses a 4-crate Rust workspace plus TypeScript components:
The project uses a 9-crate Rust workspace plus TypeScript components:
```
crates/
├── predict-otron-9000/ # Main orchestration server (Rust 2024)
├── inference-engine/ # Gemma inference via Candle (Rust 2021)
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
└── leptos-app/ # WASM web frontend (Rust 2021)
cli.ts # TypeScript/Bun CLI client
├── inference-engine/ # Multi-model inference orchestrator (Rust 2021)
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
└── chat-ui/ # WASM web frontend (Rust 2021)
integration/
├── cli/ # CLI client crate (Rust 2024)
│ └── package/
│ └── cli.ts # TypeScript/Bun CLI client
├── gemma-runner/ # Gemma model inference via Candle (Rust 2021)
├── llama-runner/ # Llama model inference via Candle (Rust 2021)
├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024)
└── utils/ # Shared utilities (Rust 2021)
```
### Service Architecture
- **Main Server** (port 8080): Orchestrates inference and embeddings services
- **Embeddings Service** (port 8080): Standalone FastEmbed service with OpenAI API compatibility
- **Web Frontend** (port 8788): Leptos WASM chat interface served by Trunk
- **Web Frontend** (port 8788): chat-ui WASM app
- **CLI Client**: TypeScript/Bun client for testing and automation
### Deployment Modes
@@ -72,11 +96,6 @@ The architecture supports multiple deployment patterns:
- **Bun**: Required for TypeScript CLI client: `curl -fsSL https://bun.sh/install | bash`
- **Node.js**: Alternative to Bun, supports OpenAI SDK v5.16.0+
#### WASM Frontend Toolchain
- **Trunk**: Required for Leptos frontend builds: `cargo install trunk`
- **wasm-pack**: `cargo install wasm-pack`
- **WASM target**: `rustup target add wasm32-unknown-unknown`
#### ML Framework Dependencies
- **Candle**: Version 0.9.1 with conditional compilation:
- macOS: Metal support with CPU fallback for stability
@@ -121,11 +140,6 @@ cargo build --bin cli --package inference-engine --release
cargo build --bin embeddings-engine --release
```
**Web Frontend:**
```bash
cd crates/leptos-app
trunk build --release
```
### Running Services
@@ -139,26 +153,26 @@ trunk build --release
#### Web Frontend (Port 8788)
```bash
cd crates/leptos-app
cd crates/chat-ui
./run.sh
```
- Serves Leptos WASM frontend on port 8788
- Serves chat-ui WASM frontend on port 8788
- Sets required RUSTFLAGS for WebAssembly getrandom support
- Auto-reloads during development
#### TypeScript CLI Client
```bash
# List available models
bun run cli.ts --list-models
cd integration/cli/package && bun run cli.ts --list-models
# Chat completion
bun run cli.ts "What is the capital of France?"
cd integration/cli/package && bun run cli.ts "What is the capital of France?"
# With specific model
bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
cd integration/cli/package && bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
# Show help
bun run cli.ts --help
cd integration/cli/package && bun run cli.ts --help
```
## API Usage
@@ -274,7 +288,7 @@ cargo test --workspace
**End-to-end test script:**
```bash
./test.sh
./scripts/smoke_test.sh
```
This script:
@@ -363,7 +377,7 @@ All services include Docker metadata in `Cargo.toml`:
- Port: 8080
**Web Frontend:**
- Image: `ghcr.io/geoffsee/leptos-app:latest`
- Image: `ghcr.io/geoffsee/chat-ui:latest`
- Port: 8788
**Docker Compose:**
@@ -422,8 +436,7 @@ For Kubernetes deployment details, see the [ARCHITECTURE.md](docs/ARCHITECTURE.m
**Symptom:** WASM compilation failures
**Solution:**
1. Install required targets: `rustup target add wasm32-unknown-unknown`
2. Install trunk: `cargo install trunk`
3. Check RUSTFLAGS in leptos-app/run.sh
2. Check RUSTFLAGS in chat-ui/run.sh
### Network/Timeout Issues
**Symptom:** First-time model downloads timing out
@@ -454,24 +467,23 @@ curl -s http://localhost:8080/v1/models | jq
**CLI client test:**
```bash
bun run cli.ts "What is 2+2?"
cd integration/cli/package && bun run cli.ts "What is 2+2?"
```
**Web frontend:**
```bash
cd crates/leptos-app && ./run.sh &
cd crates/chat-ui && ./run.sh &
# Navigate to http://localhost:8788
```
**Integration test:**
```bash
./test.sh
./scripts/smoke_test.sh
```
**Cleanup:**
```bash
pkill -f "predict-otron-9000"
pkill -f "trunk"
```
For networked tests and full functionality, ensure Hugging Face authentication is configured as described above.
@@ -493,4 +505,4 @@ For networked tests and full functionality, ensure Hugging Face authentication i
4. Ensure all tests pass: `cargo test`
5. Submit a pull request
_Warning: Do NOT use this in production unless you are cool like that._
_Warning: Do NOT use this in production unless you are cool like that._

22
bun.lock Normal file
View File

@@ -0,0 +1,22 @@
{
"lockfileVersion": 1,
"workspaces": {
"": {
"name": "predict-otron-9000",
},
"integration/cli/package": {
"name": "cli",
"dependencies": {
"install": "^0.13.0",
"openai": "^5.16.0",
},
},
},
"packages": {
"cli": ["cli@workspace:integration/cli/package"],
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
"openai": ["openai@5.16.0", "", { "peerDependencies": { "ws": "^8.18.0", "zod": "^3.23.8" }, "optionalPeers": ["ws", "zod"], "bin": { "openai": "bin/cli" } }, "sha512-hoEH8ZNvg1HXjU9mp88L/ZH8O082Z8r6FHCXGiWAzVRrEv443aI57qhch4snu07yQydj+AUAWLenAiBXhu89Tw=="],
}
}

View File

@@ -1,8 +1,9 @@
[package]
name = "leptos-app"
name = "chat-ui"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["cdylib", "rlib"]
@@ -15,45 +16,33 @@ leptos_axum = { version = "0.8.0", optional = true }
leptos_meta = { version = "0.8.0" }
tokio = { version = "1", features = ["rt-multi-thread"], optional = true }
wasm-bindgen = { version = "=0.2.100", optional = true }
# Chat interface dependencies
wasm-bindgen-futures = "0.4"
js-sys = "0.3"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
async-openai-wasm = { version = "0.29", default-features = false }
futures-util = "0.3"
js-sys = { version = "0.3", optional = true }
either = { version = "1.9", features = ["serde"] }
web-sys = { version = "0.3", optional = true, features = [
"console",
"Window",
"Document",
"Element",
"HtmlElement",
"HtmlInputElement",
"HtmlSelectElement",
"HtmlTextAreaElement",
"Event",
"EventTarget",
"KeyboardEvent",
reqwest = { version = "0.12", features = ["json"] }
web-sys = { version = "0.3", features = [
"console",
"EventSource",
"MessageEvent",
"Window",
"Request",
"RequestInit",
"Response",
"Headers",
"ReadableStream",
"ReadableStreamDefaultReader",
"TextDecoder",
"TextDecoderOptions",
"HtmlInputElement"
] }
[dependencies.uuid]
version = "1.0"
features = [
"v4",
"fast-rng",
"macro-diagnostics",
"js",
]
gloo-net = { version = "0.6", features = ["http"] }
[features]
hydrate = [
"leptos/hydrate",
"dep:console_error_panic_hook",
"dep:wasm-bindgen",
"dep:js-sys",
"dep:web-sys",
]
ssr = [
"dep:axum",
@@ -73,8 +62,9 @@ codegen-units = 1
panic = "abort"
[package.metadata.leptos]
name = "chat-ui"
# The name used by wasm-bindgen/cargo-leptos for the JS/WASM bundle. Defaults to the crate name
output-name = "leptos-app"
output-name = "chat-ui"
# The site root folder is where cargo-leptos generate all output. WARNING: all content of this folder will be erased on a rebuild. Use it in your server setup.
site-root = "target/site"
@@ -84,7 +74,7 @@ site-root = "target/site"
site-pkg-dir = "pkg"
# [Optional] The source CSS file. If it ends with .sass or .scss then it will be compiled by dart-sass into CSS. The CSS is optimized by Lightning CSS before being written to <site-root>/<site-pkg>/app.css
style-file = "style/main.scss"
style-file = "./style/main.scss"
# Assets source dir. All files found here will be copied and synchronized to site-root.
# The assets-dir cannot have a sub directory with the same name/path as site-pkg-dir.
#
@@ -132,4 +122,8 @@ lib-default-features = false
# The profile to use for the lib target when compiling for release
#
# Optional. Defaults to "release".
lib-profile-release = "wasm-release"
lib-profile-release = "release"
[[bin]]
name = "chat-ui"
path = "src/main.rs"

41
crates/chat-ui/README.md Normal file
View File

@@ -0,0 +1,41 @@
# chat-ui
A WASM-based web chat interface for the predict-otron-9000 AI platform.
## Overview
The chat-ui provides a real-time web interface for interacting with language models through the predict-otron-9000 server. Built with Leptos and compiled to WebAssembly, it offers a modern chat experience with streaming response support.
## Features
- Real-time chat interface with the inference server
- Streaming response support
- Conversation history
- Responsive web design
- WebAssembly-powered for optimal performance
## Building and Running
### Prerequisites
- Rust toolchain with WASM target: `rustup target add wasm32-unknown-unknown`
- The predict-otron-9000 server must be running on port 8080
### Development Server
```bash
cd crates/chat-ui
./run.sh
```
This starts the development server on port 8788 with auto-reload capabilities.
### Usage
1. Start the predict-otron-9000 server: `./scripts/run.sh`
2. Start the chat-ui: `cd crates/chat-ui && ./run.sh`
3. Navigate to `http://localhost:8788`
4. Start chatting with your AI models!
## Technical Details
- Built with Leptos framework
- Compiled to WebAssembly for browser execution
- Communicates with predict-otron-9000 API via HTTP
- Sets required RUSTFLAGS for WebAssembly getrandom support

View File

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 15 KiB

617
crates/chat-ui/src/app.rs Normal file
View File

@@ -0,0 +1,617 @@
#[cfg(feature = "ssr")]
use axum::Router;
#[cfg(feature = "ssr")]
use leptos::prelude::LeptosOptions;
#[cfg(feature = "ssr")]
use leptos_axum::{generate_route_list, LeptosRoutes};
pub struct AppConfig {
pub config: ConfFile,
pub address: String,
}
impl Default for AppConfig {
fn default() -> Self {
let conf = get_configuration(Some(concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml")))
.expect("failed to read config");
let addr = conf.leptos_options.site_addr;
AppConfig {
config: conf, // or whichever field/string representation you need
address: addr.to_string(),
}
}
}
/// Build the Axum router for this app, including routes, fallback, and state.
/// Call this from another crate (or your bin) when running the server.
#[cfg(feature = "ssr")]
pub fn create_router(leptos_options: LeptosOptions) -> Router {
// Generate the list of routes in your Leptos App
let routes = generate_route_list(App);
Router::new()
.leptos_routes(&leptos_options, routes, {
let leptos_options = leptos_options.clone();
move || shell(leptos_options.clone())
})
.fallback(leptos_axum::file_and_error_handler(shell))
.with_state(leptos_options)
}
use gloo_net::http::Request;
use leptos::prelude::*;
use leptos_meta::{provide_meta_context, MetaTags, Stylesheet, Title};
use leptos_router::{
components::{Route, Router, Routes},
StaticSegment,
};
use serde::{Deserialize, Serialize};
use web_sys::console;
// Data structures for OpenAI-compatible API
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
#[derive(Debug, Serialize)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub max_tokens: Option<u32>,
pub stream: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct ChatChoice {
pub message: ChatMessage,
pub index: u32,
pub finish_reason: Option<String>,
}
// Streaming response structures
#[derive(Debug, Deserialize)]
pub struct StreamDelta {
pub role: Option<String>,
pub content: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct StreamChoice {
pub index: u32,
pub delta: StreamDelta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct StreamChatResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<StreamChoice>,
}
#[derive(Debug, Deserialize)]
pub struct ChatResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatChoice>,
}
// Data structures for models API
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub object: String,
pub created: u64,
pub owned_by: String,
}
#[derive(Debug, Deserialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelInfo>,
}
// API client function to fetch available models
pub async fn fetch_models() -> Result<Vec<ModelInfo>, String> {
let response = Request::get("/v1/models")
.send()
.await
.map_err(|e| format!("Failed to fetch models: {:?}", e))?;
if response.ok() {
let models_response: ModelsResponse = response
.json()
.await
.map_err(|e| format!("Failed to parse models response: {:?}", e))?;
Ok(models_response.data)
} else {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
Err(format!("Failed to fetch models {}: {}", status, error_text))
}
}
// API client function to send chat completion requests
pub async fn send_chat_completion(
messages: Vec<ChatMessage>,
model: String,
) -> Result<String, String> {
let request = ChatRequest {
model,
messages,
max_tokens: Some(1024),
stream: Some(false),
};
let response = Request::post("/v1/chat/completions")
.header("Content-Type", "application/json")
.json(&request)
.map_err(|e| format!("Failed to create request: {:?}", e))?
.send()
.await
.map_err(|e| format!("Failed to send request: {:?}", e))?;
if response.ok() {
let chat_response: ChatResponse = response
.json()
.await
.map_err(|e| format!("Failed to parse response: {:?}", e))?;
if let Some(choice) = chat_response.choices.first() {
Ok(choice.message.content.clone())
} else {
Err("No response choices available".to_string())
}
} else {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
Err(format!("Server error {}: {}", status, error_text))
}
}
// Streaming chat completion using EventSource
#[cfg(target_arch = "wasm32")]
pub fn send_chat_completion_stream(
messages: Vec<ChatMessage>,
model: String,
on_chunk: impl Fn(String) + 'static,
on_complete: impl Fn() + 'static,
on_error: impl Fn(String) + 'static,
) {
use wasm_bindgen::prelude::*;
use wasm_bindgen::JsCast;
let request = ChatRequest {
model,
messages,
max_tokens: Some(1024),
stream: Some(true),
};
// We need to send a POST request but EventSource only supports GET
// So we'll use fetch with a readable stream instead
let window = web_sys::window().unwrap();
let request_json = serde_json::to_string(&request).unwrap();
let opts = web_sys::RequestInit::new();
opts.set_method("POST");
opts.set_body(&JsValue::from_str(&request_json));
let headers = web_sys::Headers::new().unwrap();
headers.set("Content-Type", "application/json").unwrap();
headers.set("Accept", "text/event-stream").unwrap();
opts.set_headers(&headers);
let request = web_sys::Request::new_with_str_and_init("/v1/chat/completions", &opts).unwrap();
let promise = window.fetch_with_request(&request);
wasm_bindgen_futures::spawn_local(async move {
match wasm_bindgen_futures::JsFuture::from(promise).await {
Ok(resp_value) => {
let resp: web_sys::Response = resp_value.dyn_into().unwrap();
if !resp.ok() {
on_error(format!("Server error: {}", resp.status()));
return;
}
let body = resp.body();
if body.is_none() {
on_error("No response body".to_string());
return;
}
let reader = body
.unwrap()
.get_reader()
.dyn_into::<web_sys::ReadableStreamDefaultReader>()
.unwrap();
let decoder = web_sys::TextDecoder::new().unwrap();
let mut buffer = String::new();
loop {
match wasm_bindgen_futures::JsFuture::from(reader.read()).await {
Ok(result) => {
let done = js_sys::Reflect::get(&result, &JsValue::from_str("done"))
.unwrap()
.as_bool()
.unwrap_or(false);
if done {
break;
}
let value =
js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap();
let array = js_sys::Uint8Array::new(&value);
let mut bytes = vec![0; array.length() as usize];
array.copy_to(&mut bytes);
let text = decoder.decode_with_u8_array(&bytes).unwrap();
buffer.push_str(&text);
// Process complete SSE events from buffer
while let Some(event_end) = buffer.find("\n\n") {
let event = buffer[..event_end].to_string();
buffer = buffer[event_end + 2..].to_string();
// Parse SSE event
for line in event.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
on_complete();
return;
}
// Parse JSON chunk
if let Ok(chunk) =
serde_json::from_str::<StreamChatResponse>(data)
{
if let Some(choice) = chunk.choices.first() {
if let Some(content) = &choice.delta.content {
on_chunk(content.clone());
}
}
}
}
}
}
}
Err(e) => {
on_error(format!("Read error: {:?}", e));
break;
}
}
}
on_complete();
}
Err(e) => {
on_error(format!("Fetch error: {:?}", e));
}
}
});
}
pub fn shell(options: LeptosOptions) -> impl IntoView {
view! {
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<AutoReload options=options.clone() />
<HydrationScripts options/>
<MetaTags/>
</head>
<body>
<App/>
</body>
</html>
}
}
#[component]
pub fn App() -> impl IntoView {
// Provides context that manages stylesheets, titles, meta tags, etc.
provide_meta_context();
view! {
// injects a stylesheet into the document <head>
// id=leptos means cargo-leptos will hot-reload this stylesheet
<Stylesheet id="leptos" href="/pkg/chat-ui.css"/>
// sets the document title
<Title text="Predict-Otron-9000 Chat"/>
// content for this welcome page
<Router>
<main>
<Routes fallback=|| "Page not found.".into_view()>
<Route path=StaticSegment("") view=ChatPage/>
</Routes>
</main>
</Router>
}
}
/// Renders the chat interface page
#[component]
fn ChatPage() -> impl IntoView {
// State for conversation messages
let messages = RwSignal::new(Vec::<ChatMessage>::new());
// State for current user input
let input_text = RwSignal::new(String::new());
// State for loading indicator
let is_loading = RwSignal::new(false);
// State for error messages
let error_message = RwSignal::new(Option::<String>::None);
// State for available models and selected model
let available_models = RwSignal::new(Vec::<ModelInfo>::new());
let selected_model = RwSignal::new(String::from("")); // Default model
// State for streaming response
let streaming_content = RwSignal::new(String::new());
let is_streaming = RwSignal::new(false);
// State for streaming mode toggle
let use_streaming = RwSignal::new(true); // Default to streaming
// Client-side only: Fetch models on component mount
#[cfg(target_arch = "wasm32")]
{
use leptos::task::spawn_local;
spawn_local(async move {
match fetch_models().await {
Ok(models) => {
available_models.set(models);
selected_model.set(String::from("gemma-3-1b-it"));
}
Err(error) => {
console::log_1(&format!("Failed to fetch models: {}", error).into());
error_message.set(Some(format!("Failed to load models: {}", error)));
}
}
});
}
// Shared logic for sending a message
let send_message_logic = move || {
let user_input = input_text.get();
if user_input.trim().is_empty() {
return;
}
// Add user message to conversation
let user_message = ChatMessage {
role: "user".to_string(),
content: user_input.clone(),
};
messages.update(|msgs| msgs.push(user_message.clone()));
input_text.set(String::new());
is_loading.set(true);
error_message.set(None);
// Client-side only: Send chat completion request
#[cfg(target_arch = "wasm32")]
{
use leptos::task::spawn_local;
// Prepare messages for API call
let current_messages = messages.get();
let current_model = selected_model.get();
let should_stream = use_streaming.get();
if should_stream {
// Clear streaming content and set streaming flag
streaming_content.set(String::new());
is_streaming.set(true);
// Use streaming API
send_chat_completion_stream(
current_messages,
current_model,
move |chunk| {
// Append chunk to streaming content
streaming_content.update(|content| content.push_str(&chunk));
},
move || {
// On complete, move streaming content to messages
let final_content = streaming_content.get();
if !final_content.is_empty() {
let assistant_message = ChatMessage {
role: "assistant".to_string(),
content: final_content,
};
messages.update(|msgs| msgs.push(assistant_message));
}
streaming_content.set(String::new());
is_streaming.set(false);
is_loading.set(false);
},
move |error| {
console::log_1(&format!("Streaming Error: {}", error).into());
error_message.set(Some(error));
is_streaming.set(false);
is_loading.set(false);
streaming_content.set(String::new());
},
);
} else {
// Use non-streaming API
spawn_local(async move {
match send_chat_completion(current_messages, current_model).await {
Ok(response_content) => {
let assistant_message = ChatMessage {
role: "assistant".to_string(),
content: response_content,
};
messages.update(|msgs| msgs.push(assistant_message));
is_loading.set(false);
}
Err(error) => {
console::log_1(&format!("API Error: {}", error).into());
error_message.set(Some(error));
is_loading.set(false);
}
}
});
}
}
};
// Button click handler
let on_button_click = {
let send_logic = send_message_logic.clone();
move |_: web_sys::MouseEvent| {
send_logic();
}
};
// Handle enter key press in input field
let on_key_down = move |ev: web_sys::KeyboardEvent| {
if ev.key() == "Enter" && !ev.shift_key() {
ev.prevent_default();
send_message_logic();
}
};
view! {
<div class="chat-container">
<div class="chat-header">
<h1>"Predict-Otron-9000 Chat"</h1>
<div class="model-selector">
<label for="model-select">"Model:"</label>
<select
id="model-select"
prop:value=move || selected_model.get()
on:change=move |ev| {
let new_model = event_target_value(&ev);
selected_model.set(new_model);
}
>
<For
each=move || available_models.get().into_iter()
key=|model| model.id.clone()
children=move |model| {
view! {
<option value=model.id.clone()>
{format!("{} ({})", model.id, model.owned_by)}
</option>
}
}
/>
</select>
<div class="streaming-toggle">
<label>
<input
type="checkbox"
prop:checked=move || use_streaming.get()
on:change=move |ev| {
let target = event_target::<web_sys::HtmlInputElement>(&ev);
use_streaming.set(target.checked());
}
/>
" Use streaming"
</label>
</div>
</div>
</div>
<div class="chat-messages">
<For
each=move || messages.get().into_iter().enumerate()
key=|(i, _)| *i
children=move |(_, message)| {
let role_class = if message.role == "user" { "user-message" } else { "assistant-message" };
view! {
<div class=format!("message {}", role_class)>
<div class="message-role">{message.role.clone()}</div>
<div class="message-content">{message.content.clone()}</div>
</div>
}
}
/>
{move || {
if is_streaming.get() {
let content = streaming_content.get();
if !content.is_empty() {
view! {
<div class="message assistant-message streaming">
<div class="message-role">"assistant"</div>
<div class="message-content">{content}<span class="cursor">""</span></div>
</div>
}.into_any()
} else {
view! {
<div class="message assistant-message loading">
<div class="message-role">"assistant"</div>
<div class="message-content">"Thinking..."</div>
</div>
}.into_any()
}
} else if is_loading.get() && !use_streaming.get() {
view! {
<div class="message assistant-message loading">
<div class="message-role">"assistant"</div>
<div class="message-content">"Thinking..."</div>
</div>
}.into_any()
} else {
view! {}.into_any()
}
}}
</div>
{move || {
if let Some(error) = error_message.get() {
view! {
<div class="error-message">
"Error: " {error}
</div>
}.into_any()
} else {
view! {}.into_any()
}
}}
<div class="chat-input">
<textarea
placeholder="Type your message here... (Press Enter to send, Shift+Enter for new line)"
prop:value=move || input_text.get()
on:input=move |ev| input_text.set(event_target_value(&ev))
on:keydown=on_key_down
class:disabled=move || is_loading.get()
/>
<button
on:click=on_button_click
class:disabled=move || is_loading.get() || input_text.get().trim().is_empty()
>
"Send"
</button>
</div>
</div>
}
}

View File

@@ -0,0 +1,9 @@
pub mod app;
#[cfg(feature = "hydrate")]
#[wasm_bindgen::prelude::wasm_bindgen]
pub fn hydrate() {
use crate::app::*;
console_error_panic_hook::set_once();
leptos::mount::hydrate_body(App);
}

View File

@@ -0,0 +1,26 @@
#[cfg(feature = "ssr")]
#[tokio::main]
async fn main() {
use axum::Router;
use chat_ui::app::*;
use leptos::logging::log;
use leptos::prelude::*;
use leptos_axum::{generate_route_list, LeptosRoutes};
let conf = get_configuration(None).expect("failed to read config");
let addr = conf.leptos_options.site_addr;
// Build the app router with your extracted function
let app: Router = create_router(conf.leptos_options);
log!("listening on http://{}", &addr);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app.into_make_service())
.await
.unwrap();
}
#[cfg(not(feature = "ssr"))]
pub fn main() {
// no client-side main function
}

View File

@@ -0,0 +1,265 @@
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
background-color: #f5f5f5;
height: 100vh;
overflow: hidden;
}
.chat-container {
display: flex;
flex-direction: column;
height: 100vh;
max-width: 800px;
margin: 0 auto;
background-color: white;
box-shadow: 0 0 20px rgba(0, 0, 0, 0.1);
}
.chat-header {
background-color: #000000;
color: white;
padding: 1rem;
text-align: center;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
display: flex;
flex-direction: column;
gap: 1rem;
h1 {
margin: 0;
font-size: 1.5rem;
font-weight: 600;
}
.model-selector {
display: flex;
align-items: center;
justify-content: center;
gap: 0.5rem;
flex-wrap: wrap;
label {
font-weight: 500;
font-size: 0.9rem;
}
select {
background-color: white;
color: #374151;
border: 1px solid #d1d5db;
border-radius: 6px;
padding: 0.5rem 0.75rem;
font-size: 0.9rem;
font-family: inherit;
cursor: pointer;
min-width: 200px;
&:focus {
outline: none;
border-color: #663c99;
box-shadow: 0 0 0 2px rgba(29, 78, 216, 0.2);
}
option {
padding: 0.5rem;
}
}
.streaming-toggle {
display: flex;
align-items: center;
margin-left: 1rem;
label {
display: flex;
align-items: center;
gap: 0.5rem;
cursor: pointer;
font-size: 0.9rem;
input[type="checkbox"] {
cursor: pointer;
}
}
}
}
}
.chat-messages {
flex: 1;
overflow-y: auto;
padding: 1rem;
display: flex;
flex-direction: column;
gap: 1rem;
background-color: #fafafa;
}
.message {
display: flex;
flex-direction: column;
gap: 0.5rem;
padding: 1rem;
border-radius: 12px;
max-width: 80%;
word-wrap: break-word;
&.user-message {
align-self: flex-end;
background-color: #2563eb;
color: white;
.message-role {
font-weight: 600;
font-size: 0.8rem;
opacity: 0.8;
text-transform: uppercase;
}
.message-content {
line-height: 1.5;
}
}
&.assistant-message {
align-self: flex-start;
background-color: #646873;
border: 1px solid #e5e7eb;
color: #f3f3f3;
.message-role {
font-weight: 600;
font-size: 0.8rem;
color: #c4c5cd;
text-transform: uppercase;
}
.message-content {
line-height: 1.5;
}
&.loading {
background-color: #f3f4f6;
border-color: #d1d5db;
.message-content {
font-style: italic;
color: #6b7280;
}
}
&.streaming {
.message-content {
.cursor {
display: inline-block;
animation: blink 1s infinite;
color: #9ca3af;
}
}
}
}
}
.error-message {
background-color: #fef2f2;
border: 1px solid #fca5a5;
color: #dc2626;
padding: 1rem;
margin: 0 1rem;
border-radius: 8px;
text-align: center;
font-weight: 500;
}
.chat-input {
display: flex;
gap: 0.5rem;
padding: 1rem;
background-color: white;
border-top: 1px solid #e5e7eb;
textarea {
flex: 1;
padding: 0.75rem;
border: 1px solid #d1d5db;
border-radius: 8px;
resize: none;
min-height: 60px;
max-height: 120px;
font-family: inherit;
font-size: 1rem;
line-height: 1.5;
&:focus {
outline: none;
border-color: #663c99;
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
}
&.disabled {
background-color: #f9fafb;
color: #6b7280;
cursor: not-allowed;
}
}
button {
padding: 0.75rem 1.5rem;
background-color: #663c99;
color: white;
border: none;
border-radius: 8px;
font-weight: 600;
cursor: pointer;
transition: background-color 0.2s ease;
align-self: flex-end;
&:hover:not(.disabled) {
background-color: #663c99;
}
&.disabled {
background-color: #9ca3af;
cursor: not-allowed;
}
&:focus {
outline: none;
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.3);
}
}
}
/* Scrollbar styling for webkit browsers */
.chat-messages::-webkit-scrollbar {
width: 6px;
}
.chat-messages::-webkit-scrollbar-track {
background: #f1f1f1;
}
.chat-messages::-webkit-scrollbar-thumb {
background: #c1c1c1;
border-radius: 3px;
}
.chat-messages::-webkit-scrollbar-thumb:hover {
background: #a8a8a8;
}
/* Cursor blink animation */
@keyframes blink {
0%, 50% {
opacity: 1;
}
51%, 100% {
opacity: 0;
}
}

View File

@@ -1,6 +1,6 @@
[package]
name = "embeddings-engine"
version = "0.1.0"
version.workspace = true
edition = "2024"
[lib]
@@ -25,15 +25,9 @@ rand = "0.8.5"
async-openai = "0.28.3"
once_cell = "1.19.0"
[package.metadata.compose]
image = "ghcr.io/geoffsee/embeddings-service:latest"
port = 8080
# generates kubernetes manifests
[package.metadata.kube]
image = "ghcr.io/geoffsee/embeddings-service:latest"
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
cmd = ["./bin/embeddings-engine"]
replicas = 1
port = 8080

View File

@@ -1,42 +0,0 @@
# ---- Build stage ----
FROM rust:1-slim-bullseye AS builder
WORKDIR /usr/src/app
# Install build dependencies
RUN apt-get update && \
apt-get install -y --no-install-recommends \
pkg-config \
libssl-dev \
build-essential \
&& rm -rf /var/lib/apt/lists/*
# Cache deps first
COPY . ./
RUN rm -rf src
RUN mkdir src && echo "fn main() {}" > src/main.rs && echo "// lib" > src/lib.rs && cargo build --release
RUN rm -rf src
# Copy real sources and build
COPY . .
RUN cargo build --release
# ---- Runtime stage ----
FROM debian:bullseye-slim
# Install only what the compiled binary needs
RUN apt-get update && \
apt-get install -y --no-install-recommends \
libssl1.1 \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Copy binary from builder
COPY --from=builder /usr/src/app/target/release/embeddings-engine /usr/local/bin/
# Run as non-root user for safety
RUN useradd -m appuser
USER appuser
EXPOSE 8080
CMD ["embeddings-engine"]

View File

@@ -1,4 +1,100 @@
# Embeddings Engine
A high-performance text embeddings service that generates vector representations of text using state-of-the-art models.
This crate wraps the fastembed crate to provide embeddings and partially adapts the openai specification.
A high-performance text embeddings service that generates vector representations of text using state-of-the-art models. This crate wraps the FastEmbed library to provide embeddings with OpenAI-compatible API endpoints.
## Overview
The embeddings-engine provides a standalone service for generating text embeddings that can be used for semantic search, similarity comparisons, and other NLP tasks. It's designed to be compatible with OpenAI's embeddings API format.
## Features
- **OpenAI-Compatible API**: `/v1/embeddings` endpoint matching OpenAI's specification
- **FastEmbed Integration**: Powered by the FastEmbed library for high-quality embeddings
- **Multiple Model Support**: Support for various embedding models
- **High Performance**: Optimized for fast embedding generation
- **Standalone Service**: Can run independently or as part of the predict-otron-9000 platform
## Building and Running
### Prerequisites
- Rust toolchain
- Internet connection for initial model downloads
### Standalone Server
```bash
cargo run --bin embeddings-engine --release
```
The service will start on port 8080 by default.
## API Usage
### Generate Embeddings
**Endpoint**: `POST /v1/embeddings`
**Request Body**:
```json
{
"input": "Your text to embed",
"model": "nomic-embed-text-v1.5"
}
```
**Response**:
```json
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [0.1, 0.2, 0.3, ...]
}
],
"model": "nomic-embed-text-v1.5",
"usage": {
"prompt_tokens": 0,
"total_tokens": 0
}
}
```
### Example Usage
**Using cURL**:
```bash
curl -s http://localhost:8080/v1/embeddings \
-H "Content-Type: application/json" \
-d '{
"input": "The quick brown fox jumps over the lazy dog",
"model": "nomic-embed-text-v1.5"
}' | jq
```
**Using Python OpenAI Client**:
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8080/v1",
api_key="dummy" # Not validated but required by client
)
response = client.embeddings.create(
input="Your text here",
model="nomic-embed-text-v1.5"
)
print(response.data[0].embedding)
```
## Configuration
The service can be configured through environment variables:
- `SERVER_PORT`: Port to run on (default: 8080)
- `RUST_LOG`: Logging level (default: info)
## Integration
This service is designed to work seamlessly with the predict-otron-9000 main server, but can also be deployed independently for dedicated embeddings workloads.

View File

@@ -1,47 +1,231 @@
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
use axum::{
response::Json as ResponseJson, routing::{post},
Json,
Router,
};
use axum::{Json, Router, http::StatusCode, response::Json as ResponseJson, routing::post};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use once_cell::sync::Lazy;
use serde::Serialize;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tower_http::trace::TraceLayer;
use tracing;
// Persistent model instance (singleton pattern)
static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
tracing::info!("Initializing persistent embedding model (singleton)");
// Cache for multiple embedding models
static MODEL_CACHE: Lazy<RwLock<HashMap<EmbeddingModel, Arc<TextEmbedding>>>> =
Lazy::new(|| RwLock::new(HashMap::new()));
#[derive(Serialize)]
pub struct ModelInfo {
pub id: String,
pub object: String,
pub owned_by: String,
pub description: String,
pub dimensions: usize,
}
#[derive(Serialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelInfo>,
}
// Function to convert model name strings to EmbeddingModel enum variants
fn parse_embedding_model(model_name: &str) -> Result<EmbeddingModel, String> {
match model_name {
// Sentence Transformers models
"sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => {
Ok(EmbeddingModel::AllMiniLML6V2)
}
"sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => {
Ok(EmbeddingModel::AllMiniLML6V2Q)
}
"sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => {
Ok(EmbeddingModel::AllMiniLML12V2)
}
"sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => {
Ok(EmbeddingModel::AllMiniLML12V2Q)
}
// BGE models
"BAAI/bge-base-en-v1.5" | "bge-base-en-v1.5" => Ok(EmbeddingModel::BGEBaseENV15),
"BAAI/bge-base-en-v1.5-q" | "bge-base-en-v1.5-q" => Ok(EmbeddingModel::BGEBaseENV15Q),
"BAAI/bge-large-en-v1.5" | "bge-large-en-v1.5" => Ok(EmbeddingModel::BGELargeENV15),
"BAAI/bge-large-en-v1.5-q" | "bge-large-en-v1.5-q" => Ok(EmbeddingModel::BGELargeENV15Q),
"BAAI/bge-small-en-v1.5" | "bge-small-en-v1.5" => Ok(EmbeddingModel::BGESmallENV15),
"BAAI/bge-small-en-v1.5-q" | "bge-small-en-v1.5-q" => Ok(EmbeddingModel::BGESmallENV15Q),
"BAAI/bge-small-zh-v1.5" | "bge-small-zh-v1.5" => Ok(EmbeddingModel::BGESmallZHV15),
"BAAI/bge-large-zh-v1.5" | "bge-large-zh-v1.5" => Ok(EmbeddingModel::BGELargeZHV15),
// Nomic models
"nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => {
Ok(EmbeddingModel::NomicEmbedTextV1)
}
"nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => {
Ok(EmbeddingModel::NomicEmbedTextV15)
}
"nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => {
Ok(EmbeddingModel::NomicEmbedTextV15Q)
}
// Paraphrase models
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
| "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2),
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q"
| "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q),
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
| "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2),
// ModernBert
"lightonai/modernbert-embed-large" | "modernbert-embed-large" => {
Ok(EmbeddingModel::ModernBertEmbedLarge)
}
// Multilingual E5 models
"intfloat/multilingual-e5-small" | "multilingual-e5-small" => {
Ok(EmbeddingModel::MultilingualE5Small)
}
"intfloat/multilingual-e5-base" | "multilingual-e5-base" => {
Ok(EmbeddingModel::MultilingualE5Base)
}
"intfloat/multilingual-e5-large" | "multilingual-e5-large" => {
Ok(EmbeddingModel::MultilingualE5Large)
}
// Mixedbread models
"mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => {
Ok(EmbeddingModel::MxbaiEmbedLargeV1)
}
"mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => {
Ok(EmbeddingModel::MxbaiEmbedLargeV1Q)
}
// GTE models
"Alibaba-NLP/gte-base-en-v1.5" | "gte-base-en-v1.5" => Ok(EmbeddingModel::GTEBaseENV15),
"Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => {
Ok(EmbeddingModel::GTEBaseENV15Q)
}
"Alibaba-NLP/gte-large-en-v1.5" | "gte-large-en-v1.5" => Ok(EmbeddingModel::GTELargeENV15),
"Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => {
Ok(EmbeddingModel::GTELargeENV15Q)
}
// CLIP model
"Qdrant/clip-ViT-B-32-text" | "clip-vit-b-32" => Ok(EmbeddingModel::ClipVitB32),
// Jina model
"jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => {
Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode)
}
_ => Err(format!("Unsupported embedding model: {}", model_name)),
}
}
// Function to get model dimensions
fn get_model_dimensions(model: &EmbeddingModel) -> usize {
match model {
EmbeddingModel::AllMiniLML6V2 | EmbeddingModel::AllMiniLML6V2Q => 384,
EmbeddingModel::AllMiniLML12V2 | EmbeddingModel::AllMiniLML12V2Q => 384,
EmbeddingModel::BGEBaseENV15 | EmbeddingModel::BGEBaseENV15Q => 768,
EmbeddingModel::BGELargeENV15 | EmbeddingModel::BGELargeENV15Q => 1024,
EmbeddingModel::BGESmallENV15 | EmbeddingModel::BGESmallENV15Q => 384,
EmbeddingModel::BGESmallZHV15 => 512,
EmbeddingModel::BGELargeZHV15 => 1024,
EmbeddingModel::NomicEmbedTextV1
| EmbeddingModel::NomicEmbedTextV15
| EmbeddingModel::NomicEmbedTextV15Q => 768,
EmbeddingModel::ParaphraseMLMiniLML12V2 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => 384,
EmbeddingModel::ParaphraseMLMpnetBaseV2 => 768,
EmbeddingModel::ModernBertEmbedLarge => 1024,
EmbeddingModel::MultilingualE5Small => 384,
EmbeddingModel::MultilingualE5Base => 768,
EmbeddingModel::MultilingualE5Large => 1024,
EmbeddingModel::MxbaiEmbedLargeV1 | EmbeddingModel::MxbaiEmbedLargeV1Q => 1024,
EmbeddingModel::GTEBaseENV15 | EmbeddingModel::GTEBaseENV15Q => 768,
EmbeddingModel::GTELargeENV15 | EmbeddingModel::GTELargeENV15Q => 1024,
EmbeddingModel::ClipVitB32 => 512,
EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
}
}
// Function to get or create a model from cache
fn get_or_create_model(embedding_model: EmbeddingModel) -> Result<Arc<TextEmbedding>, String> {
// First try to get from cache (read lock)
{
let cache = MODEL_CACHE
.read()
.map_err(|e| format!("Failed to acquire read lock: {}", e))?;
if let Some(model) = cache.get(&embedding_model) {
tracing::debug!("Using cached model: {:?}", embedding_model);
return Ok(Arc::clone(model));
}
}
// Model not in cache, create it (write lock)
let mut cache = MODEL_CACHE
.write()
.map_err(|e| format!("Failed to acquire write lock: {}", e))?;
// Double-check after acquiring write lock
if let Some(model) = cache.get(&embedding_model) {
tracing::debug!("Using cached model (double-check): {:?}", embedding_model);
return Ok(Arc::clone(model));
}
tracing::info!("Initializing new embedding model: {:?}", embedding_model);
let model_start_time = std::time::Instant::now();
let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true)
InitOptions::new(embedding_model.clone()).with_show_download_progress(true),
)
.expect("Failed to initialize persistent embedding model");
.map_err(|e| format!("Failed to initialize model {:?}: {}", embedding_model, e))?;
let model_init_time = model_start_time.elapsed();
tracing::info!("Persistent embedding model initialized in {:.2?}", model_init_time);
model
});
tracing::info!(
"Embedding model {:?} initialized in {:.2?}",
embedding_model,
model_init_time
);
let model_arc = Arc::new(model);
cache.insert(embedding_model.clone(), Arc::clone(&model_arc));
Ok(model_arc)
}
pub async fn embeddings_create(
Json(payload): Json<CreateEmbeddingRequest>,
) -> ResponseJson<serde_json::Value> {
) -> Result<ResponseJson<serde_json::Value>, (StatusCode, String)> {
// Start timing the entire process
let start_time = std::time::Instant::now();
// Phase 1: Access persistent model instance
// Phase 1: Parse and get the embedding model
let model_start_time = std::time::Instant::now();
// Access the lazy-initialized persistent model instance
// This will only initialize the model on the first request
let embedding_model = match parse_embedding_model(&payload.model) {
Ok(model) => model,
Err(e) => {
tracing::error!("Invalid model requested: {}", e);
return Err((StatusCode::BAD_REQUEST, format!("Invalid model: {}", e)));
}
};
let model = match get_or_create_model(embedding_model.clone()) {
Ok(model) => model,
Err(e) => {
tracing::error!("Failed to get/create model: {}", e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Model initialization failed: {}", e),
));
}
};
let model_access_time = model_start_time.elapsed();
tracing::debug!("Persistent model access completed in {:.2?}", model_access_time);
tracing::debug!(
"Model access/creation completed in {:.2?}",
model_access_time
);
// Phase 2: Process input
let input_start_time = std::time::Instant::now();
let embedding_input = payload.input;
let texts_from_embedding_input = match embedding_input {
EmbeddingInput::String(text) => vec![text],
@@ -53,41 +237,62 @@ pub async fn embeddings_create(
panic!("Array of integer arrays not supported for text embeddings");
}
};
let input_processing_time = input_start_time.elapsed();
tracing::debug!("Input processing completed in {:.2?}", input_processing_time);
tracing::debug!(
"Input processing completed in {:.2?}",
input_processing_time
);
// Phase 3: Generate embeddings
let embedding_start_time = std::time::Instant::now();
let embeddings = EMBEDDING_MODEL
.embed(texts_from_embedding_input, None)
.expect("failed to embed document");
let embeddings = model.embed(texts_from_embedding_input, None).map_err(|e| {
tracing::error!("Failed to generate embeddings: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Embedding generation failed: {}", e),
)
})?;
let embedding_generation_time = embedding_start_time.elapsed();
tracing::info!("Embedding generation completed in {:.2?}", embedding_generation_time);
tracing::info!(
"Embedding generation completed in {:.2?}",
embedding_generation_time
);
// Memory usage estimation (approximate)
let embedding_size_bytes = embeddings.iter()
let embedding_size_bytes = embeddings
.iter()
.map(|e| e.len() * std::mem::size_of::<f32>())
.sum::<usize>();
tracing::debug!("Embedding size: {:.2} MB", embedding_size_bytes as f64 / 1024.0 / 1024.0);
tracing::debug!(
"Embedding size: {:.2} MB",
embedding_size_bytes as f64 / 1024.0 / 1024.0
);
// Only log detailed embedding information at trace level to reduce log volume
tracing::trace!("Embeddings length: {}", embeddings.len());
tracing::info!("Embedding dimension: {}", embeddings[0].len());
// Log the first 10 values of the original embedding at trace level
tracing::trace!("Original embedding preview: {:?}", &embeddings[0][..10.min(embeddings[0].len())]);
tracing::trace!(
"Original embedding preview: {:?}",
&embeddings[0][..10.min(embeddings[0].len())]
);
// Check if there are any NaN or zero values in the original embedding
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count);
tracing::trace!(
"Original embedding stats: NaN count={}, zero count={}",
nan_count,
zero_count
);
// Phase 4: Post-process embeddings
let postprocessing_start_time = std::time::Instant::now();
// Create the final embedding
let final_embedding = {
// Check if the embedding is all zeros
@@ -98,8 +303,9 @@ pub async fn embeddings_create(
// Generate a random non-zero embedding
use rand::Rng;
let mut rng = rand::thread_rng();
let mut random_embedding = Vec::with_capacity(768);
for _ in 0..768 {
let expected_dimensions = get_model_dimensions(&embedding_model);
let mut random_embedding = Vec::with_capacity(expected_dimensions);
for _ in 0..expected_dimensions {
// Generate random values between -1.0 and 1.0, excluding 0
let mut val = 0.0;
while val == 0.0 {
@@ -110,6 +316,8 @@ pub async fn embeddings_create(
// Normalize the random embedding
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
#[allow(clippy::needless_range_loop)]
for i in 0..random_embedding.len() {
random_embedding[i] /= norm;
}
@@ -117,31 +325,42 @@ pub async fn embeddings_create(
random_embedding
} else {
// Check if dimensions parameter is provided and pad the embeddings if necessary
let mut padded_embedding = embeddings[0].clone();
let padded_embedding = embeddings[0].clone();
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
let target_dimension = 768;
if padded_embedding.len() < target_dimension {
let padding_needed = target_dimension - padded_embedding.len();
tracing::trace!("Padding embedding with {} zeros to reach {} dimensions", padding_needed, target_dimension);
padded_embedding.extend(vec![0.0; padding_needed]);
// Use the actual model dimensions instead of hardcoded 768
let actual_dimensions = padded_embedding.len();
let expected_dimensions = get_model_dimensions(&embedding_model);
if actual_dimensions != expected_dimensions {
tracing::warn!(
"Model {:?} produced {} dimensions but expected {}",
embedding_model,
actual_dimensions,
expected_dimensions
);
}
padded_embedding
}
};
let postprocessing_time = postprocessing_start_time.elapsed();
tracing::debug!("Embedding post-processing completed in {:.2?}", postprocessing_time);
tracing::debug!(
"Embedding post-processing completed in {:.2?}",
postprocessing_time
);
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
// Log the first 10 values of the final embedding at trace level
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]);
tracing::trace!(
"Final embedding preview: {:?}",
&final_embedding[..10.min(final_embedding.len())]
);
// Phase 5: Prepare response
let response_start_time = std::time::Instant::now();
// Return a response that matches the OpenAI API format
let response = serde_json::json!({
"object": "list",
@@ -158,10 +377,10 @@ pub async fn embeddings_create(
"total_tokens": 0
}
});
let response_time = response_start_time.elapsed();
tracing::debug!("Response preparation completed in {:.2?}", response_time);
// Log total time and breakdown
let total_time = start_time.elapsed();
tracing::info!(
@@ -171,12 +390,235 @@ pub async fn embeddings_create(
embedding_generation_time,
postprocessing_time
);
ResponseJson(response)
Ok(ResponseJson(response))
}
pub async fn models_list() -> ResponseJson<ModelsResponse> {
let models = vec![
ModelInfo {
id: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence Transformer model, MiniLM-L6-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/all-MiniLM-L6-v2-q".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Quantized Sentence Transformer model, MiniLM-L6-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/all-MiniLM-L12-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence Transformer model, MiniLM-L12-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/all-MiniLM-L12-v2-q".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Quantized Sentence Transformer model, MiniLM-L12-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "BAAI/bge-base-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the base English model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "BAAI/bge-base-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "Quantized v1.5 release of the base English model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "BAAI/bge-large-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the large English model".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "BAAI/bge-large-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "Quantized v1.5 release of the large English model".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "BAAI/bge-small-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the fast and default English model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "BAAI/bge-small-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "Quantized v1.5 release of the fast and default English model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "BAAI/bge-small-zh-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the small Chinese model".to_string(),
dimensions: 512,
},
ModelInfo {
id: "BAAI/bge-large-zh-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the large Chinese model".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "nomic-ai/nomic-embed-text-v1".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "8192 context length english model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "nomic-ai/nomic-embed-text-v1.5".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "v1.5 release of the 8192 context length english model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "nomic-ai/nomic-embed-text-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "Quantized v1.5 release of the 8192 context length english model"
.to_string(),
dimensions: 768,
},
ModelInfo {
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Multi-lingual model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Quantized Multi-lingual model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence-transformers model for tasks like clustering or semantic search"
.to_string(),
dimensions: 768,
},
ModelInfo {
id: "lightonai/modernbert-embed-large".to_string(),
object: "model".to_string(),
owned_by: "lightonai".to_string(),
description: "Large model of ModernBert Text Embeddings".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "intfloat/multilingual-e5-small".to_string(),
object: "model".to_string(),
owned_by: "intfloat".to_string(),
description: "Small model of multilingual E5 Text Embeddings".to_string(),
dimensions: 384,
},
ModelInfo {
id: "intfloat/multilingual-e5-base".to_string(),
object: "model".to_string(),
owned_by: "intfloat".to_string(),
description: "Base model of multilingual E5 Text Embeddings".to_string(),
dimensions: 768,
},
ModelInfo {
id: "intfloat/multilingual-e5-large".to_string(),
object: "model".to_string(),
owned_by: "intfloat".to_string(),
description: "Large model of multilingual E5 Text Embeddings".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "mixedbread-ai/mxbai-embed-large-v1".to_string(),
object: "model".to_string(),
owned_by: "mixedbread-ai".to_string(),
description: "Large English embedding model from MixedBreed.ai".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "mixedbread-ai/mxbai-embed-large-v1-q".to_string(),
object: "model".to_string(),
owned_by: "mixedbread-ai".to_string(),
description: "Quantized Large English embedding model from MixedBreed.ai".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "Alibaba-NLP/gte-base-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Base multilingual embedding model from Alibaba".to_string(),
dimensions: 768,
},
ModelInfo {
id: "Alibaba-NLP/gte-base-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Quantized Base multilingual embedding model from Alibaba".to_string(),
dimensions: 768,
},
ModelInfo {
id: "Alibaba-NLP/gte-large-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Large multilingual embedding model from Alibaba".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "Alibaba-NLP/gte-large-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Quantized Large multilingual embedding model from Alibaba".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "Qdrant/clip-ViT-B-32-text".to_string(),
object: "model".to_string(),
owned_by: "Qdrant".to_string(),
description: "CLIP text encoder based on ViT-B/32".to_string(),
dimensions: 512,
},
ModelInfo {
id: "jinaai/jina-embeddings-v2-base-code".to_string(),
object: "model".to_string(),
owned_by: "jinaai".to_string(),
description: "Jina embeddings v2 base code".to_string(),
dimensions: 768,
},
];
ResponseJson(ModelsResponse {
object: "list".to_string(),
data: models,
})
}
pub fn create_embeddings_router() -> Router {
Router::new()
.route("/v1/embeddings", post(embeddings_create))
// .route("/v1/models", get(models_list))
.layer(TraceLayer::new_for_http())
}
}

View File

@@ -1,11 +1,9 @@
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
use axum::{
response::Json as ResponseJson, routing::{get, post},
Json,
Router,
Json, Router,
response::Json as ResponseJson,
routing::{get, post},
};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use serde::{Deserialize, Serialize};
use std::env;
use tower_http::trace::TraceLayer;
use tracing;
@@ -13,115 +11,28 @@ use tracing;
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
const DEFAULT_SERVER_PORT: &str = "8080";
use embeddings_engine;
async fn embeddings_create(
Json(payload): Json<CreateEmbeddingRequest>,
) -> ResponseJson<serde_json::Value> {
let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true)
)
.expect("Failed to initialize model");
) -> Result<ResponseJson<serde_json::Value>, axum::response::Response> {
match embeddings_engine::embeddings_create(Json(payload)).await {
Ok(response) => Ok(response),
Err((status_code, message)) => Err(axum::response::Response::builder()
.status(status_code)
.body(axum::body::Body::from(message))
.unwrap()),
}
}
let embedding_input = payload.input;
let texts_from_embedding_input = match embedding_input {
EmbeddingInput::String(text) => vec![text],
EmbeddingInput::StringArray(texts) => texts,
EmbeddingInput::IntegerArray(_) => {
panic!("Integer array input not supported for text embeddings");
}
EmbeddingInput::ArrayOfIntegerArray(_) => {
panic!("Array of integer arrays not supported for text embeddings");
}
};
let embeddings = model
.embed(texts_from_embedding_input, None)
.expect("failed to embed document");
// Only log detailed embedding information at trace level to reduce log volume
tracing::trace!("Embeddings length: {}", embeddings.len());
tracing::info!("Embedding dimension: {}", embeddings[0].len());
// Log the first 10 values of the original embedding at trace level
tracing::trace!("Original embedding preview: {:?}", &embeddings[0][..10.min(embeddings[0].len())]);
// Check if there are any NaN or zero values in the original embedding
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count);
// Create the final embedding
let final_embedding = {
// Check if the embedding is all zeros
let all_zeros = embeddings[0].iter().all(|&x| x == 0.0);
if all_zeros {
tracing::warn!("Embedding is all zeros. Generating random non-zero embedding.");
// Generate a random non-zero embedding
use rand::Rng;
let mut rng = rand::thread_rng();
let mut random_embedding = Vec::with_capacity(768);
for _ in 0..768 {
// Generate random values between -1.0 and 1.0, excluding 0
let mut val = 0.0;
while val == 0.0 {
val = rng.gen_range(-1.0..1.0);
}
random_embedding.push(val);
}
// Normalize the random embedding
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
for i in 0..random_embedding.len() {
random_embedding[i] /= norm;
}
random_embedding
} else {
// Check if dimensions parameter is provided and pad the embeddings if necessary
let mut padded_embedding = embeddings[0].clone();
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
let target_dimension = 768;
if padded_embedding.len() < target_dimension {
let padding_needed = target_dimension - padded_embedding.len();
tracing::trace!("Padding embedding with {} zeros to reach {} dimensions", padding_needed, target_dimension);
padded_embedding.extend(vec![0.0; padding_needed]);
}
padded_embedding
}
};
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
// Log the first 10 values of the final embedding at trace level
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]);
// Return a response that matches the OpenAI API format
let response = serde_json::json!({
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": final_embedding
}
],
"model": payload.model,
"usage": {
"prompt_tokens": 0,
"total_tokens": 0
}
});
ResponseJson(response)
async fn models_list() -> ResponseJson<embeddings_engine::ModelsResponse> {
embeddings_engine::models_list().await
}
fn create_app() -> Router {
Router::new()
Router::new()
.route("/v1/embeddings", post(embeddings_create))
.route("/v1/models", get(models_list))
.layer(TraceLayer::new_for_http())
}
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
@@ -143,21 +54,21 @@ async fn main() {
.init();
let app = create_app();
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string());
let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string());
let server_address = format!("{}:{}", server_host, server_port);
let listener = tokio::net::TcpListener::bind(server_address).await.unwrap();
tracing::info!("Listening on {}", listener.local_addr().unwrap());
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string());
let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string());
let server_address = format!("{}:{}", server_host, server_port);
let listener = tokio::net::TcpListener::bind(server_address).await.unwrap();
tracing::info!("Listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::to_bytes;
use axum::body::Body;
use axum::http::StatusCode;
use tower::ServiceExt;
use super::*;
use axum::body::Body;
use axum::body::to_bytes;
use axum::http::StatusCode;
use tower::ServiceExt;
#[tokio::test]
async fn test_embeddings_create() {
@@ -168,11 +79,13 @@ mod tests {
let body = CreateEmbeddingRequest {
model: "nomic-text-embed".to_string(),
input: EmbeddingInput::from(vec!["The food was delicious and the waiter...".to_string()]),
encoding_format: None,
user: None,
dimensions: Some(768),
};
input: EmbeddingInput::from(vec![
"The food was delicious and the waiter...".to_string(),
]),
encoding_format: None,
user: None,
dimensions: Some(768),
};
let response = app
.oneshot(

View File

@@ -1,34 +1,15 @@
[package]
name = "inference-engine"
version = "0.1.0"
edition = "2021"
[[bin]]
name="cli"
path = "src/cli_main.rs"
version.workspace = true
edition = "2024"
[dependencies]
accelerate-src = { version = "0.3.2", optional = true }
candle-datasets = { version = "=0.9.1", optional = true }
candle-nn = { version = "=0.9.1" }
candle-transformers = { version = "=0.9.1" }
candle-core = { git = "https://github.com/huggingface/candle.git" }
candle-nn = { git = "https://github.com/huggingface/candle.git" }
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
candle-flash-attn = { version = "=0.9.1", optional = true }
candle-onnx = { version = "=0.9.1", optional = true }
csv = "1.3.0"
cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false, optional = true }
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"], optional = true }
hf-hub = { version = "0.4.1", features = ["tokio"] }
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"], optional = true }
num-traits = { version = "0.2.15" }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true}
pyo3 = { version = "0.22.0", features = ["auto-initialize", "abi3-py311"], optional = true }
rayon = "1.7.0"
rubato = { version = "0.15.0", optional = true }
safetensors = "0.4.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_json = "1.0.99"
symphonia = { version = "0.5.3", features = ["all"], optional = true }
@@ -50,20 +31,22 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
uuid = { version = "1.7.0", features = ["v4"] }
reborrow = "0.5.5"
futures-util = "0.3.31"
gemma-runner = { path = "../../integration/gemma-runner" }
llama-runner = { path = "../../integration/llama-runner" }
embeddings-engine = { path = "../embeddings-engine" }
[target.'cfg(target_os = "linux")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", default-features = false }
candle-nn = { git = "https://github.com/huggingface/candle.git", default-features = false }
candle-transformers = { git = "https://github.com/huggingface/candle.git", default-features = false }
# --- Add this section for conditional compilation ---
[target.'cfg(target_os = "macos")'.dependencies]
# Use CPU backend for macOS to avoid Metal rotary-emb implementation issues
candle-core = { version = "=0.9.1", features = ["metal"], optional = false }
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
gemma-runner = { path = "../../integration/gemma-runner", features = ["metal"] }
llama-runner = { path = "../../integration/llama-runner", features = ["metal"] }
[target.'cfg(not(target_os = "macos"))'.dependencies]
# For Linux or other non-macOS systems, you likely want the CPU backend or CUDA
# If you're building on Linux with a CUDA-enabled GPU:
candle-core = { version = "=0.9.1", features = ["cuda"], default-features = false } # Or just "cuda" if not using default features
# If you're building on Linux with only CPU:
# candle-core = { version = "=0.9.1", default-features = false } # CPU is often the default, but good to be explicit
# --- End of conditional compilation section ---
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
@@ -83,15 +66,16 @@ tokio = "1.43.0"
anyhow = { version = "1", features = ["backtrace"] }
bindgen_cuda = { version = "0.1.1", optional = true }
[features]
bin = []
[package.metadata.compose]
image = "ghcr.io/geoffsee/inference-engine:latest"
port = 8080
[[bin]]
name = "inference-engine"
path = "src/main.rs"
# generates kubernetes manifests
[package.metadata.kube]
image = "ghcr.io/geoffsee/inference-service:latest"
replicas = 1
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
cmd = ["./bin/inference-engine"]
port = 8080
replicas = 1

View File

@@ -1,86 +0,0 @@
# ---- Build stage ----
FROM rust:1-slim-bullseye AS builder
WORKDIR /usr/src/app
# Install build dependencies including CUDA toolkit for GPU support
RUN apt-get update && \
apt-get install -y --no-install-recommends \
pkg-config \
libssl-dev \
build-essential \
wget \
gnupg2 \
curl \
&& rm -rf /var/lib/apt/lists/*
# Install CUDA toolkit (optional, for GPU support)
# This is a minimal CUDA installation for building
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb && \
dpkg -i cuda-keyring_1.0-1_all.deb && \
apt-get update && \
apt-get install -y --no-install-recommends \
cuda-minimal-build-11-8 \
libcublas-dev-11-8 \
libcurand-dev-11-8 \
&& rm -rf /var/lib/apt/lists/* \
&& rm cuda-keyring_1.0-1_all.deb
# Set CUDA environment variables
ENV CUDA_HOME=/usr/local/cuda
ENV PATH=${CUDA_HOME}/bin:${PATH}
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
# Copy the entire workspace to get access to all crates
COPY . ./
# Cache dependencies first - create dummy source files
RUN rm -rf crates/inference-engine/src
RUN mkdir -p crates/inference-engine/src && \
echo "fn main() {}" > crates/inference-engine/src/main.rs && \
echo "fn main() {}" > crates/inference-engine/src/cli_main.rs && \
echo "// lib" > crates/inference-engine/src/lib.rs && \
cargo build --release --bin cli --package inference-engine
# Remove dummy source and copy real sources
RUN rm -rf crates/inference-engine/src
COPY . .
# Build the actual CLI binary
RUN cargo build --release --bin cli --package inference-engine
# ---- Runtime stage ----
FROM debian:bullseye-slim
# Install runtime dependencies
RUN apt-get update && \
apt-get install -y --no-install-recommends \
libssl1.1 \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Install CUDA runtime libraries (optional, for GPU support at runtime)
RUN apt-get update && \
apt-get install -y --no-install-recommends \
wget \
gnupg2 \
&& wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb \
&& dpkg -i cuda-keyring_1.0-1_all.deb \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
cuda-cudart-11-8 \
libcublas11 \
libcurand10 \
&& rm -rf /var/lib/apt/lists/* \
&& rm cuda-keyring_1.0-1_all.deb \
&& apt-get purge -y wget gnupg2
# Copy binary from builder
COPY --from=builder /usr/src/app/target/release/cli /usr/local/bin/inference-cli
# Run as non-root user for safety
RUN useradd -m appuser
USER appuser
EXPOSE 8080
CMD ["inference-cli"]

View File

@@ -1,72 +0,0 @@
use clap::Parser;
use crate::model::Which;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
pub cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
pub tracing: bool,
/// Run in server mode with OpenAI compatible API
#[arg(long)]
pub server: bool,
/// Port to use for the server
#[arg(long, default_value_t = 3777)]
pub port: u16,
/// Prompt for text generation (not used in server mode)
#[arg(long)]
pub prompt: Option<String>,
/// The temperature used to generate samples.
#[arg(long)]
pub temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
pub top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
pub seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
pub sample_len: usize,
#[arg(long)]
pub model_id: Option<String>,
#[arg(long, default_value = "main")]
pub revision: String,
#[arg(long)]
pub tokenizer_file: Option<String>,
#[arg(long)]
pub config_file: Option<String>,
#[arg(long)]
pub weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
pub repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
pub repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "3-1b-it")]
pub which: Which,
#[arg(long)]
pub use_flash_attn: bool,
}

View File

@@ -1,912 +0,0 @@
mod token_output_stream;
mod utilities_lib;
#[cfg(feature = "intel-mkl-src")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate-src")]
extern crate accelerate_src;
#[cfg(feature = "metal")]
extern crate metal_src;
use anyhow::{Error as E, Result};
use axum::{
extract::State,
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use clap::Parser;
use either::Either;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use tokio::sync::Mutex;
use tower_http::cors::{Any, CorsLayer};
use utoipa::ToSchema;
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
// OpenAI API compatible structs
/// Inner content structure for messages that can be either a string or key-value pairs
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageInnerContent(
#[serde(with = "either::serde_untagged")] pub Either<String, HashMap<String, String>>,
);
impl ToSchema<'_> for MessageInnerContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
(
"MessageInnerContent",
utoipa::openapi::RefOr::T(message_inner_content_schema()),
)
}
}
/// Function for MessageInnerContent Schema generation to handle `Either`
fn message_inner_content_schema() -> utoipa::openapi::Schema {
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
Schema::OneOf(
OneOfBuilder::new()
// Either::Left - simple string
.item(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))
// Either::Right - object with string values
.item(Schema::Object(
ObjectBuilder::new()
.schema_type(SchemaType::Object)
.additional_properties(Some(RefOr::T(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))))
.build(),
))
.build(),
)
}
/// Message content that can be either simple text or complex structured content
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageContent(
#[serde(with = "either::serde_untagged")]
Either<String, Vec<HashMap<String, MessageInnerContent>>>,
);
impl ToSchema<'_> for MessageContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
}
}
/// Function for MessageContent Schema generation to handle `Either`
fn message_content_schema() -> utoipa::openapi::Schema {
use utoipa::openapi::{ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, SchemaType};
Schema::OneOf(
OneOfBuilder::new()
.item(Schema::Object(
ObjectBuilder::new().schema_type(SchemaType::String).build(),
))
.item(Schema::Array(
ArrayBuilder::new()
.items(RefOr::T(Schema::Object(
ObjectBuilder::new()
.schema_type(SchemaType::Object)
.additional_properties(Some(RefOr::Ref(
utoipa::openapi::Ref::from_schema_name("MessageInnerContent"),
)))
.build(),
)))
.build(),
))
.build(),
)
}
/// Represents a single message in a conversation
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct Message {
/// The message content
pub content: Option<MessageContent>,
/// The role of the message sender ("user", "assistant", "system", "tool", etc.)
pub role: String,
pub name: Option<String>,
}
/// Stop token configuration for generation
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
#[serde(untagged)]
pub enum StopTokens {
/// Multiple possible stop sequences
Multi(Vec<String>),
/// Single stop sequence
Single(String),
}
/// Default value helper
fn default_false() -> bool {
false
}
/// Default value helper
fn default_1usize() -> usize {
1
}
/// Default value helper
fn default_model() -> String {
"default".to_string()
}
/// Chat completion request following OpenAI's specification
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct ChatCompletionRequest {
#[schema(example = json!([{"role": "user", "content": "Why did the crab cross the road?"}]))]
pub messages: Vec<Message>,
#[schema(example = "gemma-3-1b-it")]
#[serde(default = "default_model")]
pub model: String,
#[serde(default = "default_false")]
#[schema(example = false)]
pub logprobs: bool,
#[schema(example = 256)]
pub max_tokens: Option<usize>,
#[serde(rename = "n")]
#[serde(default = "default_1usize")]
#[schema(example = 1)]
pub n_choices: usize,
#[schema(example = 0.7)]
pub temperature: Option<f64>,
#[schema(example = 0.9)]
pub top_p: Option<f64>,
#[schema(example = false)]
pub stream: Option<bool>,
}
/// Chat completion choice
#[derive(Debug, Serialize, ToSchema)]
pub struct ChatCompletionChoice {
pub index: usize,
pub message: Message,
pub finish_reason: String,
}
/// Chat completion response
#[derive(Debug, Serialize, ToSchema)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatCompletionChoice>,
pub usage: Usage,
}
/// Token usage information
#[derive(Debug, Serialize, ToSchema)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
// Application state shared between handlers
#[derive(Clone)]
struct AppState {
text_generation: Arc<Mutex<TextGeneration>>,
model_id: String,
}
// Chat completions endpoint handler
async fn chat_completions(
State(state): State<AppState>,
Json(request): Json<ChatCompletionRequest>,
) -> Result<Json<ChatCompletionResponse>, (StatusCode, Json<serde_json::Value>)> {
let mut prompt = String::new();
// Convert messages to a prompt string
for message in &request.messages {
let role = &message.role;
let content = match &message.content {
Some(content) => match &content.0 {
Either::Left(text) => text.clone(),
Either::Right(_) => "".to_string(), // Handle complex content if needed
},
None => "".to_string(),
};
// Format based on role
match role.as_str() {
"system" => prompt.push_str(&format!("System: {}\n", content)),
"user" => prompt.push_str(&format!("User: {}\n", content)),
"assistant" => prompt.push_str(&format!("Assistant: {}\n", content)),
_ => prompt.push_str(&format!("{}: {}\n", role, content)),
}
}
// Add the assistant prefix for the response
prompt.push_str("Assistant: ");
// Capture the output
let mut output = Vec::new();
{
let mut text_gen = state.text_generation.lock().await;
// Buffer to capture the output
let mut buffer = Vec::new();
// Run text generation
let max_tokens = request.max_tokens.unwrap_or(1000);
let result = text_gen.run_with_output(&prompt, max_tokens, &mut buffer);
if let Err(e) = result {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": {
"message": "The OpenAI API is currently not supported due to compatibility issues with the tensor operations. Please use the CLI mode instead with: cargo run --bin inference-engine -- --prompt \"Your prompt here\"",
"type": "unsupported_api"
}
})),
));
}
// Convert buffer to string
if let Ok(text) = String::from_utf8(buffer) {
output.push(text);
}
}
// Create response
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4().to_string().replace("-", "")),
object: "chat.completion".to_string(),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
model: request.model,
choices: vec![ChatCompletionChoice {
index: 0,
message: Message {
role: "assistant".to_string(),
content: Some(MessageContent(Either::Left(output.join("")))),
name: None,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens: prompt.len() / 4, // Rough estimate
completion_tokens: output.join("").len() / 4, // Rough estimate
total_tokens: (prompt.len() + output.join("").len()) / 4, // Rough estimate
},
};
// Return the response as JSON
Ok(Json(response))
}
use candle_core::{DType, Device, MetalDevice, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{Repo, RepoType, api::sync::Api};
use serde_json::json;
use tokenizers::Tokenizer;
use crate::token_output_stream::TokenOutputStream;
use crate::utilities_lib::device;
// Create the router with the chat completions endpoint
fn create_router(app_state: AppState) -> Router {
// CORS layer to allow requests from any origin
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
Router::new()
// OpenAI compatible endpoints
.route("/v1/chat/completions", post(chat_completions))
// Add more endpoints as needed
.layer(cors)
.with_state(app_state)
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "2b")]
Base2B,
#[value(name = "7b")]
Base7B,
#[value(name = "2b-it")]
Instruct2B,
#[value(name = "7b-it")]
Instruct7B,
#[value(name = "1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
#[value(name = "code-2b")]
CodeBase2B,
#[value(name = "code-7b")]
CodeBase7B,
#[value(name = "code-2b-it")]
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
#[value(name = "2-2b")]
BaseV2_2B,
#[value(name = "2-2b-it")]
InstructV2_2B,
#[value(name = "2-9b")]
BaseV2_9B,
#[value(name = "2-9b-it")]
InstructV2_9B,
#[value(name = "3-1b")]
BaseV3_1B,
#[value(name = "3-1b-it")]
InstructV3_1B,
}
enum Model {
V1(Model1),
V2(Model2),
V3(Model3),
}
impl Model {
fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
Self::V3(m) => m.forward(input_ids, pos),
}
}
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
// Run text generation and print to stdout
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
println!(
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
);
eos_token
}
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
// Run text generation and write to a buffer
fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
// Write prompt tokens to output
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
write!(output, "{}", t)?;
}
}
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
eos_token
}
};
// Determine if we're using a Model3 (gemma-3) variant
let is_model3 = match &self.model {
Model::V3(_) => true,
_ => false,
};
// For Model3, we need to use a different approach
if is_model3 {
// For gemma-3 models, we'll generate one token at a time with the full context
let start_gen = std::time::Instant::now();
// Initial generation with the full prompt
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
let mut logits = self.model.forward(&input, 0)?;
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
for _ in 0..sample_len {
// Apply repeat penalty if needed
let current_logits = if self.repeat_penalty == 1. {
logits.clone()
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&current_logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
write!(output, "{}", t)?;
}
// For the next iteration, just use the new token
let new_input = Tensor::new(&[next_token], &self.device)?.unsqueeze(0)?;
logits = self.model.forward(&new_input, tokens.len() - 1)?;
logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
}
return Ok(());
}
// Standard approach for other models
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
// Manual implementation of repeat penalty to avoid type conflicts
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in &tokens[start_at..] {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
logits_vec[token_id] = sign * score / self.repeat_penalty;
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
new_logits.reshape(shape)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
write!(output, "{}", t)?;
}
}
// Write any remaining tokens
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
write!(output, "{}", rest)?;
}
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Run in server mode with OpenAI compatible API
#[arg(long)]
server: bool,
/// Port to use for the server
#[arg(long, default_value_t = 3777)]
port: u16,
/// Prompt for text generation (not used in server mode)
#[arg(long)]
prompt: Option<String>,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "3-1b-it")]
which: Which,
#[arg(long)]
use_flash_attn: bool,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle_core::utils::with_avx(),
candle_core::utils::with_neon(),
candle_core::utils::with_simd128(),
candle_core::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
Which::Base2B => "google/gemma-2b".to_string(),
Which::Base7B => "google/gemma-7b".to_string(),
Which::Instruct2B => "google/gemma-2b-it".to_string(),
Which::Instruct7B => "google/gemma-7b-it".to_string(),
Which::CodeBase2B => "google/codegemma-2b".to_string(),
Which::CodeBase7B => "google/codegemma-7b".to_string(),
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
Which::BaseV2_2B => "google/gemma-2-2b".to_string(),
Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id.clone(),
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.which {
Which::BaseV3_1B | Which::InstructV3_1B => vec![repo.get("model.safetensors")?],
_ => utilities_lib::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let initial_device = utilities_lib::device(args.cpu)?;
// Check if we're using a V3 model (Gemma 3) and if we're on Metal (macOS)
let is_v3_model = matches!(args.which, Which::BaseV3_1B | Which::InstructV3_1B);
let is_metal = !initial_device.is_cpu() && candle_core::utils::metal_is_available() && !args.cpu;
// Use CPU for V3 models on Metal due to missing implementations
let device = if is_v3_model && is_metal {
println!("Note: Using CPU for Gemma 3 model due to missing Metal implementations for required operations (e.g., rotary-emb).");
Device::Cpu
} else {
initial_device
};
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
// Use the selected device and dtype
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = match args.which {
Which::Base2B
| Which::Base7B
| Which::Instruct2B
| Which::Instruct7B
| Which::InstructV1_1_2B
| Which::InstructV1_1_7B
| Which::CodeBase2B
| Which::CodeBase7B
| Which::CodeInstruct2B
| Which::CodeInstruct7B => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(args.use_flash_attn, &config, vb)?;
Model::V1(model)
}
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(args.use_flash_attn, &config, vb)?;
Model::V2(model)
}
Which::BaseV3_1B | Which::InstructV3_1B => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model3::new(args.use_flash_attn, &config, vb)?;
Model::V3(model)
}
};
println!("loaded the model in {:?}", start.elapsed());
let pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
if args.server {
// Start the server
println!("Starting server on port {}", args.port);
// Create app state
let app_state = AppState {
text_generation: Arc::new(Mutex::new(pipeline)),
model_id,
};
// Create router
let app = create_router(app_state);
// Run the server
let addr = SocketAddr::from(([0, 0, 0, 0], args.port));
// Use tokio to run the server
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?
.block_on(async {
axum::serve(tokio::net::TcpListener::bind(&addr).await?, app)
.await
.map_err(|e| anyhow::anyhow!("Server error: {}", e))
})?;
Ok(())
} else {
// Run in CLI mode
if let Some(prompt_text) = &args.prompt {
let prompt = match args.which {
Which::Base2B
| Which::Base7B
| Which::Instruct2B
| Which::Instruct7B
| Which::InstructV1_1_2B
| Which::InstructV1_1_7B
| Which::CodeBase2B
| Which::CodeBase7B
| Which::CodeInstruct2B
| Which::CodeInstruct7B
| Which::BaseV2_2B
| Which::InstructV2_2B
| Which::BaseV2_9B
| Which::InstructV2_9B
| Which::BaseV3_1B => prompt_text.clone(),
Which::InstructV3_1B => {
format!(
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
prompt_text
)
}
};
let mut pipeline = pipeline;
pipeline.run(&prompt, args.sample_len)?;
Ok(())
} else {
anyhow::bail!("Prompt is required in CLI mode. Use --prompt to specify a prompt or --server to run in server mode.")
}
}
}

View File

@@ -0,0 +1,33 @@
use anyhow::Result;
use candle_core::Tensor;
/// ModelInference trait defines the common interface for model inference operations
///
/// This trait serves as an abstraction for different model implementations (Gemma and Llama)
/// to provide a unified interface for the inference engine.
pub trait ModelInference {
/// Perform model inference for the given input tensor starting at the specified position
///
/// # Arguments
///
/// * `input_ids` - The input tensor containing token IDs
/// * `pos` - The position to start generation from
///
/// # Returns
///
/// A tensor containing the logits for the next token prediction
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> Result<Tensor>;
/// Reset the model's internal state, if applicable
///
/// This method can be used to clear any cached state between inference requests
fn reset_state(&mut self) -> Result<()>;
/// Get the model type name
///
/// Returns a string identifier for the model type (e.g., "Gemma", "Llama")
fn model_type(&self) -> &'static str;
}
/// Factory function type for creating model inference implementations
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;

View File

@@ -1,16 +1,13 @@
// Expose modules for testing and library usage
pub mod token_output_stream;
pub mod model;
pub mod text_generation;
pub mod utilities_lib;
pub mod openai_types;
pub mod cli;
// pub mod cli;
pub mod inference;
pub mod server;
// Re-export key components for easier access
pub use inference::ModelInference;
pub use model::{Model, Which};
pub use text_generation::TextGeneration;
pub use token_output_stream::TokenOutputStream;
pub use server::{AppState, create_router};
use std::env;

View File

@@ -0,0 +1,26 @@
use inference_engine::{AppState, create_router, get_server_config, init_tracing};
use tokio::net::TcpListener;
use tracing::info;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
init_tracing();
let app_state = AppState::default();
let app = create_router(app_state);
let (server_host, server_port, server_address) = get_server_config();
let listener = TcpListener::bind(&server_address).await?;
info!(
"Inference Engine server starting on http://{}",
server_address
);
info!("Available endpoints:");
info!(" POST /v1/chat/completions - OpenAI-compatible chat completions");
info!(" GET /v1/models - List available models");
axum::serve(listener, app).await?;
Ok(())
}

View File

@@ -1,10 +1,57 @@
// use candle_core::Tensor;
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
#[derive(Clone, Debug)]
pub enum Model {
V1(Model1),
V2(Model2),
V3(Model3),
Llama(LlamaModel),
}
impl Model {
pub fn forward(
&mut self,
input_ids: &candle_core::Tensor,
pos: usize,
) -> candle_core::Result<candle_core::Tensor> {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
Self::V3(m) => m.forward(input_ids, pos),
Self::Llama(m) => m.forward(input_ids, pos),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Family {
GemmaV1,
GemmaV2,
GemmaV3,
Llama,
}
#[derive(Clone, Copy, Debug)]
pub struct ModelMeta {
pub id: &'static str,
pub family: Family,
pub instruct: bool,
}
const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta {
ModelMeta {
id,
family,
instruct,
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum Which {
// Gemma 1.x
#[value(name = "2b")]
Base2B,
#[value(name = "7b")]
@@ -17,6 +64,8 @@ pub enum Which {
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
// CodeGemma
#[value(name = "code-2b")]
CodeBase2B,
#[value(name = "code-7b")]
@@ -25,6 +74,8 @@ pub enum Which {
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
// Gemma 2
#[value(name = "2-2b")]
BaseV2_2B,
#[value(name = "2-2b-it")]
@@ -33,58 +84,73 @@ pub enum Which {
BaseV2_9B,
#[value(name = "2-9b-it")]
InstructV2_9B,
// Gemma 3
#[value(name = "3-1b")]
BaseV3_1B,
#[value(name = "3-1b-it")]
InstructV3_1B,
}
pub enum Model {
V1(Model1),
V2(Model2),
V3(Model3),
}
impl Model {
pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
Self::V3(m) => m.forward(input_ids, pos),
}
}
// Llama 3.2 (use aliases instead of duplicate variants)
#[value(name = "llama-3.2-1b")]
Llama32_1B,
#[value(name = "llama-3.2-1b-it", alias = "llama-3.2-1b-instruct")]
Llama32_1BInstruct,
#[value(name = "llama-3.2-3b")]
Llama32_3B,
#[value(name = "llama-3.2-3b-it", alias = "llama-3.2-3b-instruct")]
Llama32_3BInstruct,
}
impl Which {
pub fn to_model_id(&self) -> String {
pub const fn meta(&self) -> ModelMeta {
use Family::*;
match self {
Self::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
Self::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
Self::Base2B => "google/gemma-2b".to_string(),
Self::Base7B => "google/gemma-7b".to_string(),
Self::Instruct2B => "google/gemma-2b-it".to_string(),
Self::Instruct7B => "google/gemma-7b-it".to_string(),
Self::CodeBase2B => "google/codegemma-2b".to_string(),
Self::CodeBase7B => "google/codegemma-7b".to_string(),
Self::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
Self::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
Self::BaseV2_2B => "google/gemma-2-2b".to_string(),
Self::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
Self::BaseV2_9B => "google/gemma-2-9b".to_string(),
Self::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
Self::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
Self::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
// Gemma 1.x
Self::Base2B => m("google/gemma-2b", GemmaV1, false),
Self::Base7B => m("google/gemma-7b", GemmaV1, false),
Self::Instruct2B => m("google/gemma-2b-it", GemmaV1, true),
Self::Instruct7B => m("google/gemma-7b-it", GemmaV1, true),
Self::InstructV1_1_2B => m("google/gemma-1.1-2b-it", GemmaV1, true),
Self::InstructV1_1_7B => m("google/gemma-1.1-7b-it", GemmaV1, true),
// CodeGemma
Self::CodeBase2B => m("google/codegemma-2b", GemmaV1, false),
Self::CodeBase7B => m("google/codegemma-7b", GemmaV1, false),
Self::CodeInstruct2B => m("google/codegemma-2b-it", GemmaV1, true),
Self::CodeInstruct7B => m("google/codegemma-7b-it", GemmaV1, true),
// Gemma 2
Self::BaseV2_2B => m("google/gemma-2-2b", GemmaV2, false),
Self::InstructV2_2B => m("google/gemma-2-2b-it", GemmaV2, true),
Self::BaseV2_9B => m("google/gemma-2-9b", GemmaV2, false),
Self::InstructV2_9B => m("google/gemma-2-9b-it", GemmaV2, true),
// Gemma 3
Self::BaseV3_1B => m("google/gemma-3-1b-pt", GemmaV3, false),
Self::InstructV3_1B => m("google/gemma-3-1b-it", GemmaV3, true),
// Llama 3.2
Self::Llama32_1B => m("meta-llama/Llama-3.2-1B", Llama, false),
Self::Llama32_1BInstruct => m("meta-llama/Llama-3.2-1B-Instruct", Llama, true),
Self::Llama32_3B => m("meta-llama/Llama-3.2-3B", Llama, false),
Self::Llama32_3BInstruct => m("meta-llama/Llama-3.2-3B-Instruct", Llama, true),
}
}
pub fn to_model_id(&self) -> String {
self.meta().id.to_string()
}
pub fn is_instruct_model(&self) -> bool {
match self {
Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false,
_ => true,
}
self.meta().instruct
}
pub fn is_v3_model(&self) -> bool {
matches!(self, Self::BaseV3_1B | Self::InstructV3_1B)
matches!(self.meta().family, Family::GemmaV3)
}
}
pub fn is_llama_model(&self) -> bool {
matches!(self.meta().family, Family::Llama)
}
}

View File

@@ -1,5 +1,6 @@
use either::Either;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use utoipa::ToSchema;
@@ -10,7 +11,10 @@ pub struct MessageInnerContent(
);
impl ToSchema<'_> for MessageInnerContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
fn schema() -> (
&'static str,
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
) {
(
"MessageInnerContent",
utoipa::openapi::RefOr::T(message_inner_content_schema()),
@@ -45,12 +49,18 @@ fn message_inner_content_schema() -> utoipa::openapi::Schema {
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageContent(
#[serde(with = "either::serde_untagged")]
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
);
impl ToSchema<'_> for MessageContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
fn schema() -> (
&'static str,
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
) {
(
"MessageContent",
utoipa::openapi::RefOr::T(message_content_schema()),
)
}
}
@@ -213,4 +223,4 @@ pub struct ModelListResponse {
pub object: String,
/// Array of available models
pub data: Vec<Model>,
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -9,7 +9,10 @@ mod tests {
// Test a few representative model variants
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
assert_eq!(Which::InstructV1_1_2B.to_model_id(), "google/gemma-1.1-2b-it");
assert_eq!(
Which::InstructV1_1_2B.to_model_id(),
"google/gemma-1.1-2b-it"
);
assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b");
assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b");
assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it");
@@ -64,4 +67,4 @@ mod tests {
// Note: Testing the Model enum's forward method would require creating actual model instances,
// which is complex and would require loading model weights. This is better suited for
// integration tests or mocking the models.
}
}

View File

@@ -1,549 +0,0 @@
use anyhow::Result;
use candle_core::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use inference_engine::model::Which;
use inference_engine::text_generation::TextGeneration;
use inference_engine::token_output_stream::TokenOutputStream;
use std::collections::HashMap;
use tokenizers::Tokenizer;
#[cfg(test)]
mod tests {
use super::*;
// Helper function to create a simple tokenizer for testing
fn create_test_tokenizer() -> Result<Tokenizer> {
// Create a simple tokenizer from the pretrained model
// This uses the tokenizer from the Hugging Face hub
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
Ok(tokenizer)
}
// Test the Which enum's to_model_id method
#[test]
fn test_which_model_id() {
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
}
// Test the Which enum's is_instruct_model method
#[test]
fn test_which_is_instruct() {
assert!(!Which::Base2B.is_instruct_model());
assert!(Which::Instruct7B.is_instruct_model());
}
// Test the Which enum's is_v3_model method
#[test]
fn test_which_is_v3() {
assert!(!Which::Base2B.is_v3_model());
assert!(Which::BaseV3_1B.is_v3_model());
}
// Test the TokenOutputStream functionality
#[test]
fn test_token_output_stream() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Test encoding and decoding
let text = "Hello, world!";
let encoded = token_stream.tokenizer().encode(text, true).unwrap();
let token_ids = encoded.get_ids();
// Add tokens one by one
for &token_id in token_ids {
token_stream.next_token(token_id)?;
}
// Decode all and check
let decoded = token_stream.decode_all()?;
assert_eq!(decoded.trim(), text);
Ok(())
}
// Test the LogitsProcessor
#[test]
fn test_logits_processor() -> Result<()> {
// Create a LogitsProcessor with default settings
let seed = 42;
let temp = Some(0.8);
let top_p = Some(0.9);
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
// Create a simple logits tensor
// In a real test, we would create a tensor with known values and verify
// that sampling produces expected results
// For now, we'll just verify that the LogitsProcessor can be created
assert!(true);
Ok(())
}
// Test the TextGeneration constructor
#[test]
fn test_text_generation_constructor() -> Result<()> {
// We can't easily create a Model instance for testing,
// but we can test that the constructor compiles and the types are correct
// In a real test with a mock Model, we would:
// 1. Create a mock model
// 2. Create a tokenizer
// 3. Call TextGeneration::new
// 4. Verify the properties of the created instance
// For now, we'll just verify that the code compiles
assert!(true);
Ok(())
}
// Test apply_cached_repeat_penalty method with no penalty
#[test]
fn test_apply_cached_repeat_penalty_no_penalty() -> Result<()> {
// Create a simple test setup
let device = Device::Cpu;
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 3u32];
// Create a mock TextGeneration instance
// Since we can't easily create a full TextGeneration instance without a model,
// we'll test the logic by creating a simple struct with the necessary fields
struct MockTextGeneration {
repeat_penalty: f32,
repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>,
}
impl MockTextGeneration {
fn apply_cached_repeat_penalty(
&mut self,
logits: Tensor,
tokens: &[u32],
) -> Result<(Tensor, std::time::Duration)> {
let repeat_start = std::time::Instant::now();
// If no penalty, return the original logits
if self.repeat_penalty == 1.0 {
return Ok((logits, repeat_start.elapsed()));
}
// Get the tokens to penalize (the last n tokens)
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
let penalty_tokens = &tokens[start_at..];
// Extract logits to a vector for modification
let mut logits_vec = logits.to_vec1::<f32>()?;
let cache_hits = std::cell::Cell::new(0);
// Apply penalties with caching
for &token_id in penalty_tokens {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
// Check if we've already calculated this token's penalty
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
// Use cached value
logits_vec[token_id] = *penalized_score;
cache_hits.set(cache_hits.get() + 1);
} else {
// Calculate and cache new value
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
let penalized_score = sign * score / self.repeat_penalty;
logits_vec[token_id] = penalized_score;
self.penalty_cache.insert(token_id, penalized_score);
}
}
}
// Create a new tensor with the modified logits
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
let result = new_logits.reshape(shape)?;
let elapsed = repeat_start.elapsed();
Ok((result, elapsed))
}
}
let mut mock_gen = MockTextGeneration {
repeat_penalty: 1.0, // No penalty
repeat_last_n: 3,
penalty_cache: HashMap::new(),
};
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?;
// With no penalty, logits should be unchanged
assert_eq!(result_data, logits_data);
Ok(())
}
// Test apply_cached_repeat_penalty method with penalty
#[test]
fn test_apply_cached_repeat_penalty_with_penalty() -> Result<()> {
let device = Device::Cpu;
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 3u32];
struct MockTextGeneration {
repeat_penalty: f32,
repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>,
}
impl MockTextGeneration {
fn apply_cached_repeat_penalty(
&mut self,
logits: Tensor,
tokens: &[u32],
) -> Result<(Tensor, std::time::Duration)> {
let repeat_start = std::time::Instant::now();
if self.repeat_penalty == 1.0 {
return Ok((logits, repeat_start.elapsed()));
}
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
let penalty_tokens = &tokens[start_at..];
let mut logits_vec = logits.to_vec1::<f32>()?;
let cache_hits = std::cell::Cell::new(0);
for &token_id in penalty_tokens {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
logits_vec[token_id] = *penalized_score;
cache_hits.set(cache_hits.get() + 1);
} else {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
let penalized_score = sign * score / self.repeat_penalty;
logits_vec[token_id] = penalized_score;
self.penalty_cache.insert(token_id, penalized_score);
}
}
}
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
let result = new_logits.reshape(shape)?;
let elapsed = repeat_start.elapsed();
Ok((result, elapsed))
}
}
let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0, // Apply penalty
repeat_last_n: 3,
penalty_cache: HashMap::new(),
};
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?;
// Tokens 1, 2, 3 should be penalized (divided by 2.0)
let expected = vec![1.0f32, 1.0, 1.5, 2.0, 5.0]; // [1.0, 2.0/2.0, 3.0/2.0, 4.0/2.0, 5.0]
assert_eq!(result_data, expected);
Ok(())
}
// Test apply_cached_repeat_penalty caching behavior
#[test]
fn test_apply_cached_repeat_penalty_caching() -> Result<()> {
let device = Device::Cpu;
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 1u32, 1u32]; // Repeated token should use cache
struct MockTextGeneration {
repeat_penalty: f32,
repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>,
}
impl MockTextGeneration {
fn apply_cached_repeat_penalty(
&mut self,
logits: Tensor,
tokens: &[u32],
) -> Result<(Tensor, std::time::Duration)> {
let repeat_start = std::time::Instant::now();
if self.repeat_penalty == 1.0 {
return Ok((logits, repeat_start.elapsed()));
}
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
let penalty_tokens = &tokens[start_at..];
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in penalty_tokens {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
logits_vec[token_id] = *penalized_score;
} else {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
let penalized_score = sign * score / self.repeat_penalty;
logits_vec[token_id] = penalized_score;
self.penalty_cache.insert(token_id, penalized_score);
}
}
}
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
let result = new_logits.reshape(shape)?;
let elapsed = repeat_start.elapsed();
Ok((result, elapsed))
}
}
let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0,
repeat_last_n: 3,
penalty_cache: HashMap::new(),
};
// First call should cache the penalty for token 1
let (_result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
// Cache should contain the penalized value for token 1
assert!(mock_gen.penalty_cache.contains_key(&1));
assert_eq!(mock_gen.penalty_cache.get(&1), Some(&1.0)); // 2.0 / 2.0 = 1.0
Ok(())
}
// Test edge case: empty tokens array
#[test]
fn test_apply_cached_repeat_penalty_empty_tokens() -> Result<()> {
let device = Device::Cpu;
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens: Vec<u32> = vec![]; // Empty tokens
struct MockTextGeneration {
repeat_penalty: f32,
repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>,
}
impl MockTextGeneration {
fn apply_cached_repeat_penalty(
&mut self,
logits: Tensor,
tokens: &[u32],
) -> Result<(Tensor, std::time::Duration)> {
let repeat_start = std::time::Instant::now();
if self.repeat_penalty == 1.0 {
return Ok((logits, repeat_start.elapsed()));
}
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
let penalty_tokens = &tokens[start_at..];
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in penalty_tokens {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
logits_vec[token_id] = *penalized_score;
} else {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
let penalized_score = sign * score / self.repeat_penalty;
logits_vec[token_id] = penalized_score;
self.penalty_cache.insert(token_id, penalized_score);
}
}
}
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
let result = new_logits.reshape(shape)?;
let elapsed = repeat_start.elapsed();
Ok((result, elapsed))
}
}
let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0,
repeat_last_n: 3,
penalty_cache: HashMap::new(),
};
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?;
// With empty tokens, logits should be unchanged
assert_eq!(result_data, logits_data);
Ok(())
}
// Test edge case: out-of-bounds token IDs
#[test]
fn test_apply_cached_repeat_penalty_out_of_bounds() -> Result<()> {
let device = Device::Cpu;
let logits_data = vec![1.0f32, 2.0, 3.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 5u32, 10u32]; // Token 5 and 10 are out of bounds
struct MockTextGeneration {
repeat_penalty: f32,
repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>,
}
impl MockTextGeneration {
fn apply_cached_repeat_penalty(
&mut self,
logits: Tensor,
tokens: &[u32],
) -> Result<(Tensor, std::time::Duration)> {
let repeat_start = std::time::Instant::now();
if self.repeat_penalty == 1.0 {
return Ok((logits, repeat_start.elapsed()));
}
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
let penalty_tokens = &tokens[start_at..];
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in penalty_tokens {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
logits_vec[token_id] = *penalized_score;
} else {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
let penalized_score = sign * score / self.repeat_penalty;
logits_vec[token_id] = penalized_score;
self.penalty_cache.insert(token_id, penalized_score);
}
}
}
let device = logits.device().clone();
let shape = logits.shape().clone();
let new_logits = Tensor::new(&logits_vec[..], &device)?;
let result = new_logits.reshape(shape)?;
let elapsed = repeat_start.elapsed();
Ok((result, elapsed))
}
}
let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0,
repeat_last_n: 3,
penalty_cache: HashMap::new(),
};
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?;
// Only token 1 should be penalized, out-of-bounds tokens should be ignored
let expected = vec![1.0f32, 1.0, 3.0]; // [1.0, 2.0/2.0, 3.0]
assert_eq!(result_data, expected);
Ok(())
}
// Test the actual apply_cached_repeat_penalty method from TextGeneration
// This test creates a TextGeneration instance with minimal dependencies to test the real method
#[test]
fn test_actual_apply_cached_repeat_penalty_implementation() -> Result<()> {
// Since creating a real TextGeneration instance requires a Model which needs model weights,
// we'll create a test that demonstrates the method is now public and can be accessed.
// The comprehensive functionality testing is already covered by the mock tests above.
// Test data setup
let device = Device::Cpu;
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 3u32];
// Test that we can create the necessary components
let tokenizer = create_test_tokenizer()?;
// The method is now public as confirmed by making it pub fn apply_cached_repeat_penalty
// This test verifies the method signature and that it's accessible from external code
// We could create a TextGeneration instance if we had a way to mock the Model,
// but for now we confirm that the existing mock tests cover the functionality
// and the method is properly exposed as public
println!("apply_cached_repeat_penalty method is now public and accessible for testing");
assert!(true);
Ok(())
}
// Integration test that demonstrates the method usage pattern
#[test]
fn test_apply_cached_repeat_penalty_usage_pattern() -> Result<()> {
// This test demonstrates how the apply_cached_repeat_penalty method would be used
// in practice, even though we can't create a full TextGeneration instance in unit tests
let device = Device::Cpu;
let logits_data = vec![1.5f32, 2.5, 3.5, 4.5, 5.5];
let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 1u32, 3u32]; // Repeated token 1 to test caching
// Test parameters that would be used with TextGeneration
let repeat_penalty = 1.2f32;
let repeat_last_n = 3usize;
let mut penalty_cache: HashMap<usize, f32> = HashMap::new();
// Simulate the method's logic to verify it works as expected
let start_time = std::time::Instant::now();
if repeat_penalty != 1.0 {
let start_at = tokens.len().saturating_sub(repeat_last_n);
let penalty_tokens = &tokens[start_at..];
let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in penalty_tokens {
let token_id = token_id as usize;
if token_id < logits_vec.len() {
if let Some(_cached_score) = penalty_cache.get(&token_id) {
// Cache hit simulation
} else {
let score = logits_vec[token_id];
let sign = if score < 0.0 { -1.0 } else { 1.0 };
let penalized_score = sign * score / repeat_penalty;
penalty_cache.insert(token_id, penalized_score);
}
}
}
}
let _duration = start_time.elapsed();
// Verify that tokens were processed correctly
assert!(penalty_cache.contains_key(&1)); // Token 1 should be cached
assert!(penalty_cache.contains_key(&2)); // Token 2 should be cached
assert!(penalty_cache.contains_key(&3)); // Token 3 should be cached
println!("Successfully demonstrated apply_cached_repeat_penalty usage pattern");
Ok(())
}
// Note: Testing the actual text generation functionality would require
// integration tests with real models, which is beyond the scope of these unit tests.
// The tests above focus on the components that can be tested in isolation.
}

View File

@@ -1,129 +0,0 @@
use inference_engine::token_output_stream::TokenOutputStream;
use tokenizers::Tokenizer;
use std::path::PathBuf;
use anyhow::Result;
#[cfg(test)]
mod tests {
use super::*;
// Helper function to create a simple tokenizer for testing
fn create_test_tokenizer() -> Result<Tokenizer> {
// Create a simple tokenizer from the pretrained model
// This uses the tokenizer from the Hugging Face hub
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
Ok(tokenizer)
}
#[test]
fn test_new_token_output_stream() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer);
// Check that the token stream was created successfully
assert!(token_stream.tokenizer().get_vocab(true).len() > 0);
Ok(())
}
#[test]
fn test_clear() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Add a token
let token_id = token_stream.get_token("<eos>").unwrap();
token_stream.next_token(token_id)?;
// Clear the stream
token_stream.clear();
// Check that the stream is empty by trying to decode all
let decoded = token_stream.decode_all()?;
assert_eq!(decoded, "");
Ok(())
}
#[test]
fn test_get_token() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer);
// Get a token that should exist
let eos_token = token_stream.get_token("<eos>");
assert!(eos_token.is_some());
// Get a token that shouldn't exist
let nonexistent_token = token_stream.get_token("<this_token_does_not_exist>");
assert!(nonexistent_token.is_none());
Ok(())
}
#[test]
fn test_next_token_and_decode() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Get some tokens
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
let token_ids = hello_tokens.get_ids();
// Add tokens one by one
let mut output = String::new();
for &token_id in token_ids {
if let Some(text) = token_stream.next_token(token_id)? {
output.push_str(&text);
}
}
// Get any remaining text
if let Some(rest) = token_stream.decode_rest()? {
output.push_str(&rest);
}
// Check the output
assert!(!output.is_empty());
assert_eq!(output.trim(), "Hello world");
Ok(())
}
#[test]
fn test_decode_all() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer);
// Get some tokens
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
let token_ids = hello_tokens.get_ids();
// Add tokens one by one
for &token_id in token_ids {
token_stream.next_token(token_id)?;
}
// Decode all
let decoded = token_stream.decode_all()?;
// Check the output
assert_eq!(decoded.trim(), "Hello world");
Ok(())
}
#[test]
fn test_into_inner() -> Result<()> {
let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer);
// Get the inner tokenizer
let inner_tokenizer = token_stream.into_inner();
// Check that the inner tokenizer works
let encoded = inner_tokenizer.encode("Test", true).unwrap();
assert!(encoded.get_ids().len() > 0);
Ok(())
}
}

View File

@@ -1,3 +0,0 @@
# Ensure getrandom works on wasm32-unknown-unknown without needing manual RUSTFLAGS
[target.wasm32-unknown-unknown]
rustflags = ["--cfg", "getrandom_backend=\"wasm_js\""]

View File

@@ -1,21 +0,0 @@
# Build stage
FROM rust:1-alpine AS builder
# Install build dependencies
RUN apk add --no-cache npm nodejs musl-dev pkgconfig openssl-dev git curl bash
RUN curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash
WORKDIR /app
# Copy manifest first (cache deps)
COPY . .
# Install cargo-leptos
RUN cargo binstall cargo-leptos
# Build release artifacts
RUN cargo leptos build --release
EXPOSE 8788
CMD ["cargo", "leptos", "serve", "--release"]

View File

@@ -1,494 +0,0 @@
use leptos::prelude::*;
use leptos_meta::{provide_meta_context, MetaTags, Stylesheet, Title};
use leptos_router::{
components::{Route, Router, Routes},
StaticSegment,
};
#[cfg(feature = "hydrate")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "hydrate")]
use std::collections::VecDeque;
#[cfg(feature = "hydrate")]
use uuid::Uuid;
#[cfg(feature = "hydrate")]
use js_sys::Date;
#[cfg(feature = "hydrate")]
use web_sys::{HtmlInputElement, KeyboardEvent, SubmitEvent};
#[cfg(feature = "hydrate")]
use futures_util::StreamExt;
#[cfg(feature = "hydrate")]
use async_openai_wasm::{
types::{
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, Model as OpenAIModel,
},
Client,
};
#[cfg(feature = "hydrate")]
use async_openai_wasm::config::OpenAIConfig;
#[cfg(feature = "hydrate")]
use async_openai_wasm::types::{Role, FinishReason};
#[cfg(feature = "hydrate")]
use leptos::task::spawn_local;
#[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub id: String,
pub role: String,
pub content: String,
pub timestamp: f64,
}
#[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageContent(pub either::Either<String, Vec<std::collections::HashMap<String, MessageInnerContent>>>);
#[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageInnerContent(pub either::Either<String, std::collections::HashMap<String, String>>);
#[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: Option<MessageContent>,
pub name: Option<String>,
}
#[cfg(feature = "hydrate")]
const DEFAULT_MODEL: &str = "default";
#[cfg(feature = "hydrate")]
async fn fetch_available_models() -> Result<Vec<OpenAIModel>, String> {
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Starting model fetch from http://localhost:8080/v1");
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
let client = Client::with_config(config);
match client.models().list().await {
Ok(response) => {
let model_count = response.data.len();
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Successfully fetched {} models", model_count);
if model_count > 0 {
let model_names: Vec<String> = response.data.iter().map(|m| m.id.clone()).collect();
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Available models: {:?}", model_names);
} else {
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: No models returned by server");
}
Ok(response.data)
},
Err(e) => {
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Failed to fetch models: {:?}", e);
Err(format!("Failed to fetch models: {}", e))
}
}
}
pub fn shell(options: LeptosOptions) -> impl IntoView {
view! {
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<AutoReload options=options.clone() />
<HydrationScripts options/>
<MetaTags/>
</head>
<body>
<App/>
</body>
</html>
}
}
#[component]
pub fn App() -> impl IntoView {
// Provides context that manages stylesheets, titles, meta tags, etc.
provide_meta_context();
view! {
// injects a stylesheet into the document <head>
// id=leptos means cargo-leptos will hot-reload this stylesheet
<Stylesheet id="leptos" href="/pkg/leptos-app.css"/>
// sets the document title
<Title text="Chat Interface"/>
// content for this chat interface
<Router>
<main>
<Routes fallback=|| "Page not found.".into_view()>
<Route path=StaticSegment("") view=ChatInterface/>
</Routes>
</main>
</Router>
}
}
/// Renders the home page of your application.
#[component]
fn HomePage() -> impl IntoView {
// Creates a reactive value to update the button
let count = RwSignal::new(0);
let on_click = move |_| *count.write() += 1;
view! {
<h1>"Welcome to Leptos!"</h1>
<button on:click=on_click>"Click Me: " {count}</button>
}
}
/// Renders the chat interface
#[component]
fn ChatInterface() -> impl IntoView {
#[cfg(feature = "hydrate")]
{
ChatInterfaceImpl()
}
#[cfg(not(feature = "hydrate"))]
{
view! {
<div class="chat-container">
<h1>"Chat Interface"</h1>
<p>"Loading chat interface..."</p>
</div>
}
}
}
#[cfg(feature = "hydrate")]
#[component]
fn ChatInterfaceImpl() -> impl IntoView {
let (messages, set_messages) = RwSignal::new(VecDeque::<Message>::new()).split();
let (input_value, set_input_value) = RwSignal::new(String::new()).split();
let (is_loading, set_is_loading) = RwSignal::new(false).split();
let (available_models, set_available_models) = RwSignal::new(Vec::<OpenAIModel>::new()).split();
let (selected_model, set_selected_model) = RwSignal::new(DEFAULT_MODEL.to_string()).split();
let (models_loading, set_models_loading) = RwSignal::new(false).split();
// Fetch models on component initialization
Effect::new(move |_| {
spawn_local(async move {
set_models_loading.set(true);
match fetch_available_models().await {
Ok(models) => {
set_available_models.set(models);
set_models_loading.set(false);
}
Err(e) => {
leptos::logging::log!("Failed to fetch models: {}", e);
set_available_models.set(vec![]);
set_models_loading.set(false);
}
}
});
});
let send_message = Action::new_unsync(move |content: &String| {
let content = content.clone();
async move {
if content.trim().is_empty() {
leptos::logging::log!("[DEBUG_LOG] send_message: Empty content, skipping");
return;
}
leptos::logging::log!("[DEBUG_LOG] send_message: Starting message send process");
set_is_loading.set(true);
// Add user message to chat
let user_message = Message {
id: Uuid::new_v4().to_string(),
role: "user".to_string(),
content: content.clone(),
timestamp: Date::now(),
};
set_messages.update(|msgs| msgs.push_back(user_message.clone()));
set_input_value.set(String::new());
let mut chat_messages = Vec::new();
// Add system message
let system_message = ChatCompletionRequestSystemMessageArgs::default()
.content("You are a helpful assistant.")
.build()
.expect("failed to build system message");
chat_messages.push(system_message.into());
// Add history messages
let history_count = messages.get_untracked().len();
for msg in messages.get_untracked().iter() {
match msg.role.as_str() {
"user" => {
let message = ChatCompletionRequestUserMessageArgs::default()
.content(msg.content.clone())
.build()
.expect("failed to build user message");
chat_messages.push(message.into());
}
"assistant" => {
let message = ChatCompletionRequestAssistantMessageArgs::default()
.content(msg.content.clone())
.build()
.expect("failed to build assistant message");
chat_messages.push(message.into());
}
_ => {}
}
}
// Add current user message
let message = ChatCompletionRequestUserMessageArgs::default()
.content(user_message.content.clone())
.build()
.expect("failed to build user message");
chat_messages.push(message.into());
let current_model = selected_model.get_untracked();
let total_messages = chat_messages.len();
leptos::logging::log!("[DEBUG_LOG] send_message: Preparing request - model: '{}', history_count: {}, total_messages: {}",
current_model, history_count, total_messages);
let request = CreateChatCompletionRequestArgs::default()
.model(current_model.as_str())
.max_tokens(512u32)
.messages(chat_messages)
.stream(true)
.build()
.expect("failed to build request");
// Send request
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
let client = Client::with_config(config);
leptos::logging::log!("[DEBUG_LOG] send_message: Sending request to http://localhost:8080/v1 with model: '{}'", current_model);
match client.chat().create_stream(request).await {
Ok(mut stream) => {
leptos::logging::log!("[DEBUG_LOG] send_message: Successfully created stream");
let mut assistant_created = false;
let mut content_appended = false;
let mut chunks_received = 0;
while let Some(next) = stream.next().await {
match next {
Ok(chunk) => {
chunks_received += 1;
if let Some(choice) = chunk.choices.get(0) {
if !assistant_created {
if let Some(role) = &choice.delta.role {
if role == &Role::Assistant {
assistant_created = true;
let assistant_id = Uuid::new_v4().to_string();
set_messages.update(|msgs| {
msgs.push_back(Message {
id: assistant_id,
role: "assistant".to_string(),
content: String::new(),
timestamp: Date::now(),
});
});
}
}
}
if let Some(content) = &choice.delta.content {
if !content.is_empty() {
if !assistant_created {
assistant_created = true;
let assistant_id = Uuid::new_v4().to_string();
set_messages.update(|msgs| {
msgs.push_back(Message {
id: assistant_id,
role: "assistant".to_string(),
content: String::new(),
timestamp: Date::now(),
});
});
}
content_appended = true;
set_messages.update(|msgs| {
if let Some(last) = msgs.back_mut() {
if last.role == "assistant" {
last.content.push_str(content);
last.timestamp = Date::now();
}
}
});
}
}
if let Some(reason) = &choice.finish_reason {
if reason == &FinishReason::Stop {
leptos::logging::log!("[DEBUG_LOG] send_message: Received finish_reason=stop after {} chunks", chunks_received);
break;
}
}
}
}
Err(e) => {
leptos::logging::log!("[DEBUG_LOG] send_message: Stream error after {} chunks: {:?}", chunks_received, e);
set_messages.update(|msgs| {
msgs.push_back(Message {
id: Uuid::new_v4().to_string(),
role: "system".to_string(),
content: format!("Stream error: {}", e),
timestamp: Date::now(),
});
});
break;
}
}
}
if assistant_created && !content_appended {
set_messages.update(|msgs| {
let should_pop = msgs
.back()
.map(|m| m.role == "assistant" && m.content.is_empty())
.unwrap_or(false);
if should_pop {
msgs.pop_back();
}
});
}
leptos::logging::log!("[DEBUG_LOG] send_message: Stream completed successfully, received {} chunks", chunks_received);
}
Err(e) => {
leptos::logging::log!("[DEBUG_LOG] send_message: Request failed with error: {:?}", e);
let error_message = Message {
id: Uuid::new_v4().to_string(),
role: "system".to_string(),
content: format!("Error: Request failed - {}", e),
timestamp: Date::now(),
};
set_messages.update(|msgs| msgs.push_back(error_message));
}
}
set_is_loading.set(false);
}
});
let on_input = move |ev| {
let input = event_target::<HtmlInputElement>(&ev);
set_input_value.set(input.value());
};
let on_submit = move |ev: SubmitEvent| {
ev.prevent_default();
let content = input_value.get();
send_message.dispatch(content);
};
let on_keypress = move |ev: KeyboardEvent| {
if ev.key() == "Enter" && !ev.shift_key() {
ev.prevent_default();
let content = input_value.get();
send_message.dispatch(content);
}
};
let on_model_change = move |ev| {
let select = event_target::<web_sys::HtmlSelectElement>(&ev);
set_selected_model.set(select.value());
};
let messages_list = move || {
messages.get()
.into_iter()
.map(|message| {
let role_class = match message.role.as_str() {
"user" => "user-message",
"assistant" => "assistant-message",
_ => "system-message",
};
view! {
<div class=format!("message {}", role_class)>
<div class="message-role">{message.role}</div>
<div class="message-content">{message.content}</div>
</div>
}
})
.collect::<Vec<_>>()
};
let loading_indicator = move || {
is_loading.get().then(|| {
view! {
<div class="message assistant-message">
<div class="message-role">"assistant"</div>
<div class="message-content">"Thinking..."</div>
</div>
}
})
};
view! {
<div class="chat-container">
<h1>"Chat Interface"</h1>
<div class="model-selector">
<label for="model-select">"Model: "</label>
<select
id="model-select"
on:change=on_model_change
prop:value=selected_model
prop:disabled=models_loading
>
{move || {
if models_loading.get() {
vec![view! {
<option value={String::from("")} selected=false>{String::from("Loading models...")}</option>
}]
} else {
let models = available_models.get();
if models.is_empty() {
vec![view! {
<option value={String::from("default")} selected=true>{String::from("default")}</option>
}]
} else {
models.into_iter().map(|model| {
view! {
<option value=model.id.clone() selected={model.id == DEFAULT_MODEL}>{model.id.clone()}</option>
}
}).collect::<Vec<_>>()
}
}
}}
</select>
</div>
<div class="messages-container">
{messages_list}
{loading_indicator}
</div>
<form class="input-form" on:submit=on_submit>
<input
type="text"
class="message-input"
placeholder="Type your message here..."
prop:value=input_value
on:input=on_input
on:keypress=on_keypress
prop:disabled=is_loading
/>
<button
type="submit"
class="send-button"
prop:disabled=move || is_loading.get() || input_value.get().trim().is_empty()
>
"Send"
</button>
</form>
</div>
}
}

View File

@@ -1,30 +0,0 @@
pub mod app;
#[cfg(feature = "hydrate")]
#[wasm_bindgen::prelude::wasm_bindgen]
pub fn hydrate() {
use crate::app::*;
console_error_panic_hook::set_once();
leptos::mount::hydrate_body(App);
}
#[cfg(feature = "ssr")]
pub fn create_leptos_router() -> axum::Router {
use axum::Router;
use leptos::prelude::*;
use leptos_axum::{generate_route_list, LeptosRoutes};
use crate::app::*;
let conf = get_configuration(None).unwrap();
let leptos_options = conf.leptos_options;
// Generate the list of routes in your Leptos App
let routes = generate_route_list(App);
Router::new()
.leptos_routes(&leptos_options, routes, {
let leptos_options = leptos_options.clone();
move || shell(leptos_options.clone())
})
.fallback(leptos_axum::file_and_error_handler(shell))
.with_state(leptos_options)
}

View File

@@ -1,39 +0,0 @@
#[cfg(feature = "ssr")]
#[tokio::main]
async fn main() {
use axum::Router;
use leptos::logging::log;
use leptos::prelude::*;
use leptos_axum::{generate_route_list, LeptosRoutes};
use leptos_app::app::*;
let conf = get_configuration(None).unwrap();
let addr = conf.leptos_options.site_addr;
let leptos_options = conf.leptos_options;
// Generate the list of routes in your Leptos App
let routes = generate_route_list(App);
let app = Router::new()
.leptos_routes(&leptos_options, routes, {
let leptos_options = leptos_options.clone();
move || shell(leptos_options.clone())
})
.fallback(leptos_axum::file_and_error_handler(shell))
.with_state(leptos_options);
// run our app with hyper
// `axum::Server` is a re-export of `hyper::Server`
log!("listening on http://{}", &addr);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app.into_make_service())
.await
.unwrap();
}
#[cfg(not(feature = "ssr"))]
pub fn main() {
// no client-side main function
// unless we want this to work with e.g., Trunk for pure client-side testing
// see lib.rs for hydration function instead
}

View File

@@ -1,4 +0,0 @@
body {
font-family: sans-serif;
text-align: center;
}

View File

@@ -1,6 +1,6 @@
[package]
name = "predict-otron-9000"
version = "0.1.0"
version.workspace = true
edition = "2024"
[[bin]]
@@ -19,7 +19,7 @@ tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.7.0", features = ["v4"] }
reqwest = { version = "0.12", features = ["json"] }
rust-embed = { version = "8.7.2", features = ["include-exclude"] }
rust-embed = { version = "8.7.2", features = ["include-exclude", "axum"] }
# Dependencies for embeddings functionality
embeddings-engine = { path = "../embeddings-engine" }
@@ -28,20 +28,24 @@ embeddings-engine = { path = "../embeddings-engine" }
inference-engine = { path = "../inference-engine" }
# Dependencies for leptos web app
leptos-app = { path = "../leptos-app", features = ["ssr"] }
#leptos-app = { path = "../leptos-app", features = ["ssr"] }
chat-ui = { path = "../chat-ui", features = ["ssr", "hydrate"], optional = true }
mime_guess = "2.0.5"
[package.metadata.compose]
name = "predict-otron-9000"
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
port = 8080
log = "0.4.27"
# generates kubernetes manifests
[package.metadata.kube]
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
replicas = 1
port = 8080
env = { SERVER_CONFIG = "" }
cmd = ["./bin/predict-otron-9000"]
# SERVER_CONFIG Example: {\"serverMode\":\"HighAvailability\",\"services\":{\"inference_url\":\"http://custom-inference:9000\",\"embeddings_url\":\"http://custom-embeddings:9001\"}}
# you can generate this via node to avoid toil
# const server_config = {serverMode: "HighAvailability", services: {inference_url: "http://custom-inference:9000", embeddings_url: "http://custom-embeddings:9001"} };
# console.log(JSON.stringify(server_config).replace(/"/g, '\\"'));
env = { SERVER_CONFIG = "<your-json-value-here>" }
[features]
default = ["ui"]
ui = ["dep:chat-ui"]

View File

@@ -1,89 +0,0 @@
# ---- Build stage ----
FROM rust:1-slim-bullseye AS builder
WORKDIR /usr/src/app
# Install build dependencies including CUDA toolkit for GPU support (needed for inference-engine dependency)
RUN apt-get update && \
apt-get install -y --no-install-recommends \
pkg-config \
libssl-dev \
build-essential \
wget \
gnupg2 \
curl \
&& rm -rf /var/lib/apt/lists/*
# Install CUDA toolkit (required for inference-engine dependency)
# This is a minimal CUDA installation for building
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb && \
dpkg -i cuda-keyring_1.0-1_all.deb && \
apt-get update && \
apt-get install -y --no-install-recommends \
cuda-minimal-build-11-8 \
libcublas-dev-11-8 \
libcurand-dev-11-8 \
&& rm -rf /var/lib/apt/lists/* \
&& rm cuda-keyring_1.0-1_all.deb
# Set CUDA environment variables
ENV CUDA_HOME=/usr/local/cuda
ENV PATH=${CUDA_HOME}/bin:${PATH}
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
# Copy the entire workspace to get access to all crates (needed for local dependencies)
COPY . ./
# Cache dependencies first - create dummy source files for all crates
RUN rm -rf crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src
RUN mkdir -p crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src && \
echo "fn main() {}" > crates/predict-otron-9000/src/main.rs && \
echo "fn main() {}" > crates/inference-engine/src/main.rs && \
echo "fn main() {}" > crates/inference-engine/src/cli_main.rs && \
echo "// lib" > crates/inference-engine/src/lib.rs && \
echo "fn main() {}" > crates/embeddings-engine/src/main.rs && \
echo "// lib" > crates/embeddings-engine/src/lib.rs && \
cargo build --release --bin predict-otron-9000 --package predict-otron-9000
# Remove dummy sources and copy real sources
RUN rm -rf crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src
COPY . .
# Build the actual binary
RUN cargo build --release --bin predict-otron-9000 --package predict-otron-9000
# ---- Runtime stage ----
FROM debian:bullseye-slim
# Install runtime dependencies
RUN apt-get update && \
apt-get install -y --no-install-recommends \
libssl1.1 \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Install CUDA runtime libraries (required for inference-engine dependency)
RUN apt-get update && \
apt-get install -y --no-install-recommends \
wget \
gnupg2 \
&& wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb \
&& dpkg -i cuda-keyring_1.0-1_all.deb \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
cuda-cudart-11-8 \
libcublas11 \
libcurand10 \
&& rm -rf /var/lib/apt/lists/* \
&& rm cuda-keyring_1.0-1_all.deb \
&& apt-get purge -y wget gnupg2
# Copy binary from builder
COPY --from=builder /usr/src/app/target/release/predict-otron-9000 /usr/local/bin/
# Run as non-root user for safety
RUN useradd -m appuser
USER appuser
EXPOSE 8080
CMD ["predict-otron-9000"]

View File

@@ -1,7 +1,12 @@
use serde::{Deserialize, Serialize};
use std::env;
#[derive(Debug, Clone, Deserialize, Serialize)]
use tracing::info;
use tracing::log::error;
/// # Generating `SERVER_CONFIG` with Node
// # const server_config = {serverMode: "HighAvailability", services: {inference_url: "http://custom-inference:9000", embeddings_url: "http://custom-embeddings:9001"} };
// # console.log(JSON.stringify(server_config).replace(/"/g, '\\"'));
///
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub struct ServerConfig {
#[serde(default = "default_server_host")]
@@ -10,14 +15,16 @@ pub struct ServerConfig {
pub server_port: u16,
pub server_mode: ServerMode,
#[serde(default)]
pub services: Services,
pub services: Option<Services>,
}
fn default_server_host() -> String {
"127.0.0.1".to_string()
}
fn default_server_port() -> u16 { 8080 }
fn default_server_port() -> u16 {
8080
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "PascalCase")]
@@ -32,29 +39,10 @@ impl Default for ServerMode {
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct Services {
#[serde(default = "inference_service_url")]
pub inference_url: String,
#[serde(default = "embeddings_service_url")]
pub embeddings_url: String,
}
impl Default for Services {
fn default() -> Self {
Self {
inference_url: inference_service_url(),
embeddings_url: embeddings_service_url(),
}
}
}
fn inference_service_url() -> String {
"http://inference-service:8080".to_string()
}
fn embeddings_service_url() -> String {
"http://embeddings-service:8080".to_string()
pub inference_url: Option<String>,
pub embeddings_url: Option<String>,
}
impl Default for ServerConfig {
@@ -63,7 +51,7 @@ impl Default for ServerConfig {
server_host: "127.0.0.1".to_string(),
server_port: 8080,
server_mode: ServerMode::Standalone,
services: Services::default(),
services: Some(Services::default()),
}
}
}
@@ -73,21 +61,19 @@ impl ServerConfig {
/// Falls back to default (Local mode) if not set or invalid
pub fn from_env() -> Self {
match env::var("SERVER_CONFIG") {
Ok(config_str) => {
match serde_json::from_str::<ServerConfig>(&config_str) {
Ok(config) => {
tracing::info!("Loaded server configuration: {:?}", config);
config
}
Err(e) => {
tracing::warn!(
"Failed to parse SERVER_CONFIG environment variable: {}. Using default configuration.",
e
);
ServerConfig::default()
}
Ok(config_str) => match serde_json::from_str::<ServerConfig>(&config_str) {
Ok(config) => {
tracing::info!("Loaded server configuration: {:?}", config);
config
}
}
Err(e) => {
tracing::warn!(
"Failed to parse SERVER_CONFIG environment variable: {}. Using default configuration.",
e
);
ServerConfig::default()
}
},
Err(_) => {
tracing::info!("SERVER_CONFIG not set, Standalone mode active");
ServerConfig::default()
@@ -96,18 +82,51 @@ impl ServerConfig {
}
/// Check if the server should run in high availability mode
pub fn is_high_availability(&self) -> bool {
self.server_mode == ServerMode::HighAvailability
pub fn is_high_availability(&self) -> Result<bool, std::io::Error> {
if self.server_mode == ServerMode::HighAvailability {
let services_well_defined: bool = self.clone().services.is_some();
let inference_url_well_defined: bool =
services_well_defined && self.clone().services.unwrap().inference_url.is_some();
let embeddings_well_defined: bool =
services_well_defined && self.clone().services.unwrap().embeddings_url.is_some();
let is_well_defined_for_ha =
services_well_defined && inference_url_well_defined && embeddings_well_defined;
if !is_well_defined_for_ha {
let config_string = serde_json::to_string_pretty(&self).unwrap();
error!(
"HighAvailability mode configured but services not well defined! \n## Config Used:\n {}",
config_string
);
let err = std::io::Error::other(
"HighAvailability mode configured but services not well defined!",
);
return Err(err);
}
}
Ok(self.server_mode == ServerMode::HighAvailability)
}
/// Get the inference service URL for proxying
pub fn inference_url(&self) -> &str {
&self.services.inference_url
pub fn inference_url(&self) -> Option<String> {
if self.services.is_some() {
self.services.clone()?.inference_url
} else {
None
}
}
/// Get the embeddings service URL for proxying
pub fn embeddings_url(&self) -> &str {
&self.services.embeddings_url
pub fn embeddings_url(&self) -> Option<String> {
if self.services.is_some() {
self.services.clone()?.embeddings_url
} else {
None
}
}
}
@@ -119,7 +138,7 @@ mod tests {
fn test_default_config() {
let config = ServerConfig::default();
assert_eq!(config.server_mode, ServerMode::Standalone);
assert!(!config.is_high_availability());
assert!(!config.is_high_availability().unwrap());
}
#[test]
@@ -134,23 +153,26 @@ mod tests {
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
assert_eq!(config.server_mode, ServerMode::HighAvailability);
assert!(config.is_high_availability());
assert_eq!(config.inference_url(), "http://inference-service:8080");
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
assert!(config.is_high_availability().unwrap());
assert_eq!(
config.inference_url().unwrap(),
"http://inference-service:8080"
);
assert_eq!(
config.embeddings_url().unwrap(),
"http://embeddings-service:8080"
);
}
#[test]
fn test_local_mode_config() {
let config_json = r#"{
"serverMode": "Local"
"serverMode": "Standalone"
}"#;
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
assert_eq!(config.server_mode, ServerMode::Standalone);
assert!(!config.is_high_availability());
// Should use default URLs
assert_eq!(config.inference_url(), "http://inference-service:8080");
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
assert!(!config.is_high_availability().unwrap());
}
#[test]
@@ -164,17 +186,26 @@ mod tests {
}"#;
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
assert_eq!(config.inference_url(), "http://custom-inference:9000");
assert_eq!(config.embeddings_url(), "http://custom-embeddings:9001");
assert_eq!(
config.inference_url().unwrap(),
"http://custom-inference:9000"
);
assert_eq!(
config.embeddings_url().unwrap(),
"http://custom-embeddings:9001"
);
}
#[test]
fn test_minimal_high_availability_config() {
fn test_minimal_high_availability_config_error() {
let config_json = r#"{"serverMode": "HighAvailability"}"#;
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
assert!(config.is_high_availability());
// Should use default URLs
assert_eq!(config.inference_url(), "http://inference-service:8080");
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
let is_high_availability = config.is_high_availability();
assert!(is_high_availability.is_err());
// // Should use default URLs
// assert_eq!(config.inference_url().unwrap(), "http://inference-service:8080");
// assert_eq!(config.embeddings_url().unwrap(), "http://embeddings-service:8080");
}
}
}

View File

@@ -1,10 +1,10 @@
use axum::{
Router,
body::Body,
extract::{Request, State},
http::{HeaderMap, Method, StatusCode, Uri},
response::{IntoResponse, Response},
routing::{get, post},
Router,
};
use reqwest::Client;
use serde_json::Value;
@@ -12,7 +12,121 @@ use std::time::Duration;
use crate::config::ServerConfig;
/// HTTP client configured for proxying requests
/// # Generating `SERVER_CONFIG` for TOML using Node.js
///
/// You can still use the Node.js REPL to build the JSON, but when pasting into
/// a `.toml` file you must follow TOML's string rules. Below are the safest patterns.
///
/// ## 1) Generate the JSON in Node
/// ```bash
/// node
/// ```
/// ```javascript
/// const myobject = {
/// serverMode: "HighAvailability",
/// services: {
/// inference_url: "http://custom-inference:9000",
/// embeddings_url: "http://custom-embeddings:9001"
/// }
/// };
/// const json = JSON.stringify(myobject);
/// json
/// // -> '{"serverMode":"HighAvailability","services":{"inference_url":"http://custom-inference:9000","embeddings_url":"http://custom-embeddings:9001"}}'
/// ```
///
/// ## 2) Put it into `.toml`
///
/// ### Option A (recommended): single-quoted TOML *literal* string
/// Single quotes in TOML mean "no escaping", so your inner double quotes are safe.
/// ```toml
/// SERVER_CONFIG = '{"serverMode":"HighAvailability","services":{"inference_url":"http://custom-inference:9000","embeddings_url":"http://custom-embeddings:9001"}}'
/// ```
///
/// ### Option B: double-quoted TOML string (must escape inner quotes)
/// If you *must* use double quotes in TOML, escape all `"` inside the JSON.
/// You can have Node do this for you:
/// ```javascript
/// // In Node:
/// const jsonForToml = JSON.stringify(myobject).replace(/"/g, '\\"');
/// jsonForToml
/// // -> \"{\\\"serverMode\\\":\\\"HighAvailability\\\",...}\"
/// ```
/// Then paste into TOML:
/// ```toml
/// SERVER_CONFIG = "{\"serverMode\":\"HighAvailability\",\"services\":{\"inference_url\":\"http://custom-inference:9000\",\"embeddings_url\":\"http://custom-embeddings:9001\"}}"
/// ```
///
/// ### Option C: multi-line literal (for pretty JSON)
/// If you want pretty-printed JSON in the file, use TOML's triple single quotes:
/// ```javascript
/// // In Node (pretty with 2 spaces):
/// const pretty = JSON.stringify(myobject, null, 2);
/// ```
/// ```toml
/// SERVER_CONFIG = '''{
/// "serverMode": "HighAvailability",
/// "services": {
/// "inference_url": "http://custom-inference:9000",
/// "embeddings_url": "http://custom-embeddings:9001"
/// }
/// }'''
/// ```
///
/// ## 3) Reading it in Rust
///
/// If `SERVER_CONFIG` is stored as a **string** in TOML (Options A/B/C):
/// ```rust
/// use serde_json::Value;
///
/// // Suppose you've already loaded your .toml into a struct or a toml::Value:
/// // e.g., struct FileCfg { pub SERVER_CONFIG: String }
/// fn parse_server_config(raw: &str) -> anyhow::Result<Value> {
/// let v: Value = serde_json::from_str(raw)?;
/// Ok(v)
/// }
/// ```
///
/// ### Alternative: store it as TOML tables and serialize to JSON at runtime
/// Instead of a JSON string, you can make the TOML first-class tables:
/// ```toml
/// [SERVER_CONFIG]
/// serverMode = "HighAvailability"
///
/// [SERVER_CONFIG.services]
/// inference_url = "http://custom-inference:9000"
/// embeddings_url = "http://custom-embeddings:9001"
/// ```
/// ```rust
/// use serde::{Deserialize, Serialize};
/// use serde_json::Value;
///
/// #[derive(Debug, Serialize, Deserialize)]
/// struct Services {
/// inference_url: String,
/// embeddings_url: String,
/// }
///
/// #[derive(Debug, Serialize, Deserialize)]
/// struct ServerConfig {
/// serverMode: String,
/// services: Services,
/// }
///
/// // After loading the .toml (e.g., via `toml::from_str`):
/// // let cfg: ServerConfig = toml::from_str(toml_str)?;
/// // Convert to JSON if needed:
/// fn to_json(cfg: &ServerConfig) -> serde_json::Result<Value> {
/// Ok(serde_json::to_value(cfg)?)
/// }
/// ```
///
/// ## Gotchas
/// - Prefer **single-quoted** TOML strings for raw JSON to avoid escaping.
/// - If you use **double-quoted** TOML strings, escape every inner `"` in the JSON.
/// - Pretty JSON is fine in TOML using `''' ... '''`, but remember the newlines are part of the string.
/// - If you control the consumer, TOML tables (the alternative above) are more ergonomic than embedding JSON.
/// HTTP client configured for proxying requests
#[derive(Clone)]
pub struct ProxyClient {
client: Client,
@@ -31,7 +145,7 @@ impl ProxyClient {
}
/// Create a router that proxies requests to external services in HighAvailability mode
pub fn create_proxy_router(config: ServerConfig) -> Router {
pub fn create_ha_router(config: ServerConfig) -> Router {
let proxy_client = ProxyClient::new(config.clone());
Router::new()
@@ -47,10 +161,16 @@ async fn proxy_chat_completions(
headers: HeaderMap,
body: Body,
) -> Result<Response, StatusCode> {
let target_url = format!("{}/v1/chat/completions", proxy_client.config.inference_url());
let target_url = format!(
"{}/v1/chat/completions",
proxy_client
.config
.inference_url()
.expect("Invalid Configuration")
);
tracing::info!("Proxying chat completions request to: {}", target_url);
// Extract body as bytes
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes,
@@ -63,7 +183,9 @@ async fn proxy_chat_completions(
// Check if this is a streaming request
let is_streaming = if let Ok(body_str) = String::from_utf8(body_bytes.to_vec()) {
if let Ok(json) = serde_json::from_str::<Value>(&body_str) {
json.get("stream").and_then(|v| v.as_bool()).unwrap_or(false)
json.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false)
} else {
false
}
@@ -72,7 +194,8 @@ async fn proxy_chat_completions(
};
// Forward the request
let mut req_builder = proxy_client.client
let mut req_builder = proxy_client
.client
.post(&target_url)
.body(body_bytes.to_vec());
@@ -85,8 +208,7 @@ async fn proxy_chat_completions(
match req_builder.send().await {
Ok(response) => {
let mut resp_builder = Response::builder()
.status(response.status());
let mut resp_builder = Response::builder().status(response.status());
// Forward response headers
for (name, value) in response.headers().iter() {
@@ -99,14 +221,12 @@ async fn proxy_chat_completions(
if is_streaming {
// For streaming, we need to forward the response as-is
match response.bytes().await {
Ok(body) => {
resp_builder
.header("content-type", "text/plain; charset=utf-8")
.header("cache-control", "no-cache")
.header("connection", "keep-alive")
.body(Body::from(body))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
Ok(body) => resp_builder
.header("content-type", "text/plain; charset=utf-8")
.header("cache-control", "no-cache")
.header("connection", "keep-alive")
.body(Body::from(body))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
Err(e) => {
tracing::error!("Failed to read streaming response body: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR)
@@ -115,11 +235,9 @@ async fn proxy_chat_completions(
} else {
// For non-streaming, forward the JSON response
match response.bytes().await {
Ok(body) => {
resp_builder
.body(Body::from(body))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
Ok(body) => resp_builder
.body(Body::from(body))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
Err(e) => {
tracing::error!("Failed to read response body: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR)
@@ -139,10 +257,16 @@ async fn proxy_models(
State(proxy_client): State<ProxyClient>,
headers: HeaderMap,
) -> Result<Response, StatusCode> {
let target_url = format!("{}/v1/models", proxy_client.config.inference_url());
let target_url = format!(
"{}/v1/models",
proxy_client
.config
.inference_url()
.expect("Invalid Configuration Detected")
);
tracing::info!("Proxying models request to: {}", target_url);
let mut req_builder = proxy_client.client.get(&target_url);
// Forward relevant headers
@@ -154,8 +278,7 @@ async fn proxy_models(
match req_builder.send().await {
Ok(response) => {
let mut resp_builder = Response::builder()
.status(response.status());
let mut resp_builder = Response::builder().status(response.status());
// Forward response headers
for (name, value) in response.headers().iter() {
@@ -165,11 +288,9 @@ async fn proxy_models(
}
match response.bytes().await {
Ok(body) => {
resp_builder
.body(Body::from(body))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
Ok(body) => resp_builder
.body(Body::from(body))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
Err(e) => {
tracing::error!("Failed to read models response body: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR)
@@ -189,10 +310,16 @@ async fn proxy_embeddings(
headers: HeaderMap,
body: Body,
) -> Result<Response, StatusCode> {
let target_url = format!("{}/v1/embeddings", proxy_client.config.embeddings_url());
let target_url = format!(
"{}/v1/embeddings",
proxy_client
.config
.embeddings_url()
.expect("Invalid Configuration Detected")
);
tracing::info!("Proxying embeddings request to: {}", target_url);
// Extract body as bytes
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes,
@@ -203,7 +330,8 @@ async fn proxy_embeddings(
};
// Forward the request
let mut req_builder = proxy_client.client
let mut req_builder = proxy_client
.client
.post(&target_url)
.body(body_bytes.to_vec());
@@ -216,8 +344,7 @@ async fn proxy_embeddings(
match req_builder.send().await {
Ok(response) => {
let mut resp_builder = Response::builder()
.status(response.status());
let mut resp_builder = Response::builder().status(response.status());
// Forward response headers
for (name, value) in response.headers().iter() {
@@ -227,11 +354,9 @@ async fn proxy_embeddings(
}
match response.bytes().await {
Ok(body) => {
resp_builder
.body(Body::from(body))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
Ok(body) => resp_builder
.body(Body::from(body))
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
Err(e) => {
tracing::error!("Failed to read embeddings response body: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR)
@@ -250,7 +375,7 @@ fn should_forward_header(header_name: &str) -> bool {
match header_name.to_lowercase().as_str() {
"content-type" | "content-length" | "authorization" | "user-agent" | "accept" => true,
"host" | "connection" | "upgrade" => false, // Don't forward connection-specific headers
_ => true, // Forward other headers by default
_ => true, // Forward other headers by default
}
}
@@ -259,7 +384,7 @@ fn should_forward_response_header(header_name: &str) -> bool {
match header_name.to_lowercase().as_str() {
"content-type" | "content-length" | "cache-control" | "connection" => true,
"server" | "date" => false, // Don't forward server-specific headers
_ => true, // Forward other headers by default
_ => true, // Forward other headers by default
}
}
@@ -290,14 +415,20 @@ mod tests {
server_host: "127.0.0.1".to_string(),
server_port: 8080,
server_mode: ServerMode::HighAvailability,
services: Services {
inference_url: "http://test-inference:8080".to_string(),
embeddings_url: "http://test-embeddings:8080".to_string(),
},
services: Some(Services {
inference_url: Some("http://test-inference:8080".to_string()),
embeddings_url: Some("http://test-embeddings:8080".to_string()),
}),
};
let proxy_client = ProxyClient::new(config);
assert_eq!(proxy_client.config.inference_url(), "http://test-inference:8080");
assert_eq!(proxy_client.config.embeddings_url(), "http://test-embeddings:8080");
assert_eq!(
proxy_client.config.inference_url().unwrap().as_str(),
"http://test-inference:8080"
);
assert_eq!(
proxy_client.config.embeddings_url().unwrap().as_str(),
"http://test-embeddings:8080"
);
}
}
}

View File

@@ -1,22 +1,63 @@
mod config;
mod ha_mode;
mod middleware;
mod proxy;
mod standalone_mode;
use axum::response::IntoResponse;
use crate::standalone_mode::create_standalone_router;
use axum::routing::get;
use axum::{Router, http::Uri, response::Html, serve};
use axum::{Router, serve};
use config::ServerConfig;
use inference_engine::AppState;
use ha_mode::create_ha_router;
use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
use proxy::create_proxy_router;
use rust_embed::Embed;
use std::env;
#[cfg(feature = "ui")]
use axum::http::StatusCode as AxumStatusCode;
#[cfg(feature = "ui")]
use axum::http::Uri;
#[cfg(feature = "ui")]
use axum::http::header;
#[cfg(feature = "ui")]
use axum::response::IntoResponse;
#[cfg(feature = "ui")]
use mime_guess::from_path;
#[cfg(feature = "ui")]
use rust_embed::Embed;
use tokio::net::TcpListener;
use tower_http::classify::ServerErrorsFailureClass::StatusCode;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[cfg(feature = "ui")]
#[derive(Embed)]
#[folder = "../../target/site"]
#[include = "*.js"]
#[include = "*.wasm"]
#[include = "*.css"]
#[include = "*.ico"]
struct Asset;
#[cfg(feature = "ui")]
async fn static_handler(uri: Uri) -> axum::response::Response {
// Strip the leading `/`
let path = uri.path().trim_start_matches('/');
tracing::info!("Static file: {}", &path);
// If root is requested, serve index.html
let path = if path.is_empty() { "index.html" } else { path };
match Asset::get(path) {
Some(content) => {
let body = content.data.into_owned();
let mime = from_path(path).first_or_octet_stream();
([(header::CONTENT_TYPE, mime.as_ref())], body).into_response()
}
None => (AxumStatusCode::NOT_FOUND, "404 Not Found").into_response(),
}
}
#[tokio::main]
async fn main() {
// Initialize tracing
@@ -49,44 +90,19 @@ async fn main() {
let default_host = server_config.server_host.clone();
let default_port = server_config.server_port;
// Create router based on server mode
let service_router = if server_config.clone().is_high_availability() {
tracing::info!("Running in HighAvailability mode - proxying to external services");
tracing::info!(" Inference service URL: {}", server_config.inference_url());
tracing::info!(
" Embeddings service URL: {}",
server_config.embeddings_url()
);
// Use proxy router that forwards requests to external services
create_proxy_router(server_config.clone())
} else {
tracing::info!("Running in Standalone mode - using embedded services");
// Create unified router by merging embeddings and inference routers (existing behavior)
let embeddings_router = embeddings_engine::create_embeddings_router();
// Create AppState with correct model configuration
use inference_engine::Which;
use inference_engine::server::{PipelineArgs, build_pipeline};
let mut pipeline_args = PipelineArgs::default();
pipeline_args.model_id = "google/gemma-3-1b-it".to_string();
pipeline_args.which = Which::InstructV3_1B;
let text_generation = build_pipeline(pipeline_args.clone());
let app_state = AppState {
text_generation: std::sync::Arc::new(tokio::sync::Mutex::new(text_generation)),
model_id: "google/gemma-3-1b-it".to_string(),
build_args: pipeline_args,
};
// Get the inference router directly from the inference engine
let inference_router = inference_engine::create_router(app_state);
// Merge the local routers
Router::new()
.merge(embeddings_router)
.merge(inference_router)
let service_router = match server_config.clone().is_high_availability() {
Ok(is_ha) => {
if is_ha {
log_config(server_config.clone());
create_ha_router(server_config.clone())
} else {
log_config(server_config.clone());
create_standalone_router(server_config)
}
}
Err(error) => {
panic!("{}", error);
}
};
// Create CORS layer
@@ -99,20 +115,28 @@ async fn main() {
// Create metrics layer
let metrics_layer = MetricsLayer::new(metrics_store);
// Create the leptos router for the web frontend
let leptos_router = leptos_app::create_leptos_router();
// Merge the service router with base routes and add middleware layers
let app = Router::new()
let mut app = Router::new()
.route("/health", get(|| async { "ok" }))
.merge(service_router)
.merge(leptos_router) // Add leptos web frontend routes
.merge(service_router);
// Add UI routes if the UI feature is enabled
#[cfg(feature = "ui")]
{
let leptos_config = chat_ui::app::AppConfig::default();
let leptos_router = chat_ui::app::create_router(leptos_config.config.leptos_options);
app = app
.route("/pkg/{*path}", get(static_handler))
.merge(leptos_router);
}
let app = app
.layer(metrics_layer) // Add metrics tracking
.layer(cors)
.layer(TraceLayer::new_for_http());
// Server configuration
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| String::from(default_host));
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| default_host.to_string());
let server_port = env::var("SERVER_PORT")
.map(|v| v.parse::<u16>().unwrap_or(default_port))
@@ -127,12 +151,34 @@ async fn main() {
);
tracing::info!("Performance metrics tracking enabled - summary logs every 60 seconds");
tracing::info!("Available endpoints:");
#[cfg(feature = "ui")]
tracing::info!(" GET / - Leptos chat web application");
tracing::info!(" GET /health - Health check");
tracing::info!(" POST /v1/models - List Models");
tracing::info!(" POST /v1/embeddings - Text embeddings API");
tracing::info!(" POST /v1/chat/completions - Chat completions API");
serve(listener, app).await.unwrap();
serve(listener, app.into_make_service()).await.unwrap();
}
fn log_config(config: ServerConfig) {
match config.is_high_availability() {
Ok(is_high) => {
if is_high {
tracing::info!("Running in HighAvailability mode - proxying to external services");
tracing::info!("Inference service URL: {}", config.inference_url().unwrap());
tracing::info!(
"Embeddings service URL: {}",
config.embeddings_url().unwrap()
);
} else {
tracing::info!("Running in Standalone mode");
}
}
Err(error) => {
panic!("{}", error);
}
}
}
// Chat completions handler that properly uses the inference server crate's error handling

View File

@@ -2,6 +2,8 @@ use axum::{
extract::MatchedPath,
http::{Request, Response},
};
use std::fmt;
use std::task::ready;
use std::{
future::Future,
pin::Pin,
@@ -12,8 +14,6 @@ use std::{
use tokio::sync::Mutex;
use tower::{Layer, Service};
use tracing::{debug, info};
use std::task::ready;
use std::fmt;
/// Performance metrics for a specific endpoint
#[derive(Debug, Clone, Default)]
@@ -33,16 +33,16 @@ impl EndpointMetrics {
pub fn add_response_time(&mut self, time_ms: u64) {
self.count += 1;
self.total_time_ms += time_ms;
if self.min_time_ms == 0 || time_ms < self.min_time_ms {
self.min_time_ms = time_ms;
}
if time_ms > self.max_time_ms {
self.max_time_ms = time_ms;
}
}
/// Get the average response time in milliseconds
pub fn avg_time_ms(&self) -> f64 {
if self.count == 0 {
@@ -51,12 +51,15 @@ impl EndpointMetrics {
self.total_time_ms as f64 / self.count as f64
}
}
/// Get a human-readable summary of the metrics
pub fn summary(&self) -> String {
format!(
"requests: {}, avg: {:.2}ms, min: {}ms, max: {}ms",
self.count, self.avg_time_ms(), self.min_time_ms, self.max_time_ms
self.count,
self.avg_time_ms(),
self.min_time_ms,
self.max_time_ms
)
}
}
@@ -75,14 +78,16 @@ impl MetricsStore {
endpoints: Arc::new(Mutex::new(std::collections::HashMap::new())),
}
}
/// Record a request's timing information
pub async fn record(&self, path: String, time_ms: u64) {
let mut endpoints = self.endpoints.lock().await;
let metrics = endpoints.entry(path).or_insert_with(EndpointMetrics::default);
let metrics = endpoints
.entry(path)
.or_insert_with(EndpointMetrics::default);
metrics.add_response_time(time_ms);
}
/// Get metrics for all endpoints
pub async fn get_all(&self) -> Vec<(String, EndpointMetrics)> {
let endpoints = self.endpoints.lock().await;
@@ -91,12 +96,12 @@ impl MetricsStore {
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
/// Log a summary of all metrics
pub async fn log_summary(&self) {
let metrics = self.get_all().await;
info!("Performance metrics summary:");
for (path, metric) in metrics {
info!(" {}: {}", path, metric.summary());
}
@@ -163,26 +168,28 @@ where
} else {
req.uri().path().to_string()
};
let method = req.method().clone();
let start = Instant::now();
let metrics_store = self.metrics_store.clone();
let future = self.inner.call(req);
Box::pin(async move {
let response = future.await?;
let time = start.elapsed();
let status = response.status();
let time_ms = time.as_millis() as u64;
// Record the timing in our metrics store
metrics_store.record(format!("{} {}", method, path), time_ms).await;
metrics_store
.record(format!("{} {}", method, path), time_ms)
.await;
// Log the request timing
debug!("{} {} {} - {} ms", method, path, status, time_ms);
Ok(response)
})
}
@@ -214,7 +221,7 @@ impl Future for MetricsLoggerFuture {
metrics_store.log_summary().await;
});
}
Poll::Pending
}
}
}

View File

@@ -1,7 +1,3 @@
pub mod metrics;
pub use metrics::{
MetricsStore,
MetricsLoggerFuture,
MetricsLayer,
};
pub use metrics::{MetricsLayer, MetricsLoggerFuture, MetricsStore};

View File

@@ -0,0 +1,20 @@
use crate::config::ServerConfig;
use axum::Router;
use inference_engine::AppState;
pub fn create_standalone_router(_server_config: ServerConfig) -> Router {
// Create unified router by merging embeddings and inference routers (existing behavior)
let embeddings_router = embeddings_engine::create_embeddings_router();
// Create AppState - no default model, must be configured explicitly
// This removes the hardcoded gemma-3-1b-it default behavior
let app_state = AppState::default();
// Get the inference router directly from the inference engine
let inference_router = inference_engine::create_router(app_state);
// Merge the local routers
Router::new()
.merge(embeddings_router)
.merge(inference_router)
}

View File

@@ -22,7 +22,7 @@ The Predict-Otron-9000 is a comprehensive multi-service AI platform built around
graph TB
subgraph "Core Components"
A[Main Server<br/>predict-otron-9000]
B[Inference Engine<br/>Gemma via Candle]
B[Inference Engine<br/>Gemma/Llama via Candle]
C[Embeddings Engine<br/>FastEmbed]
D[Web Frontend<br/>Leptos WASM]
end
@@ -52,7 +52,7 @@ graph TB
## Workspace Structure
The project uses a 4-crate Rust workspace with TypeScript tooling, designed for maximum flexibility in deployment configurations.
The project uses a 9-crate Rust workspace with TypeScript tooling, designed for maximum flexibility in deployment configurations.
```mermaid
graph TD
@@ -61,25 +61,33 @@ graph TD
A[predict-otron-9000<br/>Edition: 2024<br/>Port: 8080]
end
subgraph "AI Services"
B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Candle ML]
subgraph "AI Services (crates/)"
B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Multi-model orchestrator]
C[embeddings-engine<br/>Edition: 2024<br/>Port: 8080<br/>FastEmbed]
end
subgraph "Frontend"
D[leptos-app<br/>Edition: 2021<br/>Port: 3000/8788<br/>WASM/SSR]
subgraph "Frontend (crates/)"
D[chat-ui<br/>Edition: 2021<br/>Port: 8788<br/>WASM UI]
end
subgraph "Integration Tools (integration/)"
L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment]
E[cli<br/>Edition: 2024<br/>TypeScript/Bun CLI]
M[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
N[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
O[utils<br/>Edition: 2021<br/>Shared utilities]
end
end
subgraph "External Tooling"
E[cli.ts<br/>TypeScript/Bun<br/>OpenAI SDK]
end
subgraph "Dependencies"
A --> B
A --> C
A --> D
B -.-> F[Candle 0.9.1]
B --> M
B --> N
M -.-> F[Candle 0.9.1]
N -.-> F
C -.-> G[FastEmbed 4.x]
D -.-> H[Leptos 0.8.0]
E -.-> I[OpenAI SDK 5.16+]
@@ -90,6 +98,10 @@ graph TD
style C fill:#e8f5e8
style D fill:#fff3e0
style E fill:#fce4ec
style L fill:#fff9c4
style M fill:#f3e5f5
style N fill:#f3e5f5
style O fill:#fff9c4
```
## Deployment Configurations
@@ -181,7 +193,7 @@ graph TB
end
subgraph "Frontend"
D[leptos-app Pod<br/>:8788<br/>ClusterIP Service]
D[chat-ui Pod<br/>:8788<br/>ClusterIP Service]
end
subgraph "Ingress"

View File

@@ -0,0 +1,11 @@
[package]
name = "cli"
version.workspace = true
edition = "2021"
build = "build.rs"
[[bin]]
name = "cli"
path = "src/main.rs"
[dependencies]

24
integration/cli/README.md Normal file
View File

@@ -0,0 +1,24 @@
# cli
A Rust/Typescript Hybrid
```console
bun run cli.ts [options] [prompt]
Simple CLI tool for testing the local OpenAI-compatible API server.
Options:
--model <model> Model to use (default: gemma-3-1b-it)
--prompt <prompt> The prompt to send (can also be provided as positional argument)
--list-models List all available models from the server
--help Show this help message
Examples:
cd integration/cli/package
bun run cli.ts "What is the capital of France?"
bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
bun run cli.ts --prompt "Who was the 16th president of the United States?"
bun run cli.ts --list-models
The server must be running at http://localhost:8080
```

202
integration/cli/build.rs Normal file
View File

@@ -0,0 +1,202 @@
use std::env;
use std::fs;
use std::io::{self, BufRead, Write};
use std::path::{Path, PathBuf};
use std::process::{ChildStderr, ChildStdout, Command, Stdio};
use std::thread;
use std::time::{Duration, SystemTime};
mod bun_target;
use bun_target::BunTarget;
fn main() {
println!("cargo:rerun-if-changed=");
if let Err(e) = run_build() {
println!("cargo:warning=build.rs failed: {e}");
std::process::exit(1);
}
}
fn run_build() -> io::Result<()> {
let manifest_dir =
PathBuf::from(env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"));
let package_dir = manifest_dir.join("package");
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set by Cargo"));
let output_path = out_dir.join("client-cli");
let bun_tgt = BunTarget::from_cargo_env().map_err(|e| io::Error::other(e.to_string()))?;
// Optional: warn if using a Bun target thats marked unsupported in your chart
if matches!(bun_tgt, BunTarget::WindowsArm64) {
println!(
"cargo:warning=bun-windows-arm64 is marked unsupported in the compatibility chart"
);
}
warn(&format!("Building CLI into: {}", output_path.display()));
// --- bun install (in ./package), keep temps inside OUT_DIR ---
let mut install = Command::new("bun")
.current_dir(&package_dir)
.env("TMPDIR", &out_dir)
.arg("install")
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| io::Error::new(e.kind(), format!("Failed to spawn `bun install`: {e}")))?;
let install_join = stream_child("bun install", install.stdout.take(), install.stderr.take());
let install_status = install.wait()?;
// ensure streams finish
join_streams(install_join);
if !install_status.success() {
let code = install_status.code().unwrap_or(1);
return Err(io::Error::other(format!(
"bun install failed with status {code}"
)));
}
let _target = env::var("TARGET").unwrap();
// --- bun build (in ./package), emit to OUT_DIR, keep temps inside OUT_DIR ---
let mut build = Command::new("bun")
.current_dir(&package_dir)
.env("TMPDIR", &out_dir)
.arg("build")
.arg("./cli.ts")
.arg(format!("--target={}", bun_tgt.as_bun_flag()))
.arg("--compile")
.arg("--outfile")
.arg(&output_path)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| io::Error::new(e.kind(), format!("Failed to spawn `bun build`: {e}")))?;
let build_join = stream_child("bun build", build.stdout.take(), build.stderr.take());
let status = build.wait()?;
// ensure streams finish
join_streams(build_join);
if status.success() {
info("bun build succeeded");
} else {
let code = status.code().unwrap_or(1);
warn(&format!("bun build failed with status: {code}"));
return Err(io::Error::other("bun build failed"));
}
// Ensure the output is executable (after it exists)
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(&output_path)?.permissions();
perms.set_mode(0o755);
fs::set_permissions(&output_path, perms)?;
}
println!("cargo:warning=Built CLI at {}", output_path.display());
println!("cargo:rustc-env=CLIENT_CLI_BIN={}", output_path.display());
// --- Cleanup stray .bun-build temp files (conservative: older than 5 minutes) ---
for dir in [&manifest_dir, &package_dir, &out_dir] {
if let Err(e) = remove_bun_temp_files(dir, Some(Duration::from_secs(5 * 60))) {
println!("cargo:warning=cleanup in {} failed: {e}", dir.display());
}
}
Ok(())
}
// Spawn readers for child's stdout/stderr so we don't deadlock on pipe buffers
fn stream_child(
tag: &str,
stdout: Option<ChildStdout>,
stderr: Option<ChildStderr>,
) -> (
Option<thread::JoinHandle<()>>,
Option<thread::JoinHandle<()>>,
) {
let t1 = stdout.map(|out| {
let tag = tag.to_string();
thread::spawn(move || {
let reader = io::BufReader::new(out);
for line in reader.lines() {
info(&format!("[{tag} stdout] {}", line.unwrap_or_default()));
}
})
});
let t2 = stderr.map(|err| {
let tag = tag.to_string();
thread::spawn(move || {
let reader = io::BufReader::new(err);
for line in reader.lines() {
warn(&format!("[{tag} stderr] {}", line.unwrap_or_default()));
}
})
});
(t1, t2)
}
fn join_streams(
joins: (
Option<thread::JoinHandle<()>>,
Option<thread::JoinHandle<()>>,
),
) {
if let Some(j) = joins.0 {
let _ = j.join();
}
if let Some(j) = joins.1 {
let _ = j.join();
}
}
fn remove_bun_temp_files(dir: &Path, older_than: Option<Duration>) -> io::Result<()> {
let now = SystemTime::now();
for entry in fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if !path.is_file() {
continue;
}
// Files like ".1860e7df40ff1bef-00000000.bun-build"
let name = entry.file_name();
let name = name.to_string_lossy();
let looks_like_bun_temp = name.starts_with('.') && name.ends_with(".bun-build");
if !looks_like_bun_temp {
continue;
}
if let Some(age) = older_than {
if let Ok(meta) = entry.metadata() {
if let Ok(modified) = meta.modified() {
if now.duration_since(modified).unwrap_or_default() < age {
// too new; skip to avoid racing an in-flight builder
continue;
}
}
}
}
match fs::remove_file(&path) {
Ok(_) => println!("cargo:warning=removed stray bun temp {}", path.display()),
Err(e) => println!("cargo:warning=failed to remove {}: {e}", path.display()),
}
}
Ok(())
}
fn warn(msg: &str) {
let _ = writeln!(io::stderr(), "[build.rs] {msg}");
println!("cargo:warning={msg}");
}
fn info(msg: &str) {
let _ = writeln!(io::stderr(), "[build.rs] {msg}");
println!("cargo:warning=INFO|{msg}");
}

View File

@@ -0,0 +1,131 @@
use std::env;
use std::fmt;
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub enum BunTarget {
LinuxX64Glibc,
LinuxArm64Glibc,
LinuxX64Musl,
LinuxArm64Musl,
WindowsX64,
WindowsArm64,
MacX64,
MacArm64,
}
impl BunTarget {
pub const fn as_bun_flag(self) -> &'static str {
match self {
BunTarget::LinuxX64Glibc => "bun-linux-x64",
BunTarget::LinuxArm64Glibc => "bun-linux-arm64",
BunTarget::LinuxX64Musl => "bun-linux-x64-musl",
BunTarget::LinuxArm64Musl => "bun-linux-arm64-musl",
BunTarget::WindowsX64 => "bun-windows-x64",
BunTarget::WindowsArm64 => "bun-windows-arm64",
BunTarget::MacX64 => "bun-darwin-x64",
BunTarget::MacArm64 => "bun-darwin-arm64",
}
}
pub const fn rust_triples(self) -> &'static [&'static str] {
match self {
BunTarget::LinuxX64Glibc => {
&["x86_64-unknown-linux-gnu", "x86_64-unknown-linux-gnu.2.17"]
}
BunTarget::LinuxArm64Glibc => &["aarch64-unknown-linux-gnu"],
BunTarget::LinuxX64Musl => &["x86_64-unknown-linux-musl"],
BunTarget::LinuxArm64Musl => &["aarch64-unknown-linux-musl"],
BunTarget::WindowsX64 => &["x86_64-pc-windows-msvc"],
BunTarget::WindowsArm64 => &["aarch64-pc-windows-msvc"], // chart says unsupported; still map
BunTarget::MacX64 => &["x86_64-apple-darwin"],
BunTarget::MacArm64 => &["aarch64-apple-darwin"],
}
}
pub fn from_rust_target(triple: &str) -> Option<Self> {
let norm = triple.trim();
if norm.starts_with("x86_64-") && norm.contains("-linux-") && norm.ends_with("gnu") {
return Some(BunTarget::LinuxX64Glibc);
}
if norm.starts_with("aarch64-") && norm.contains("-linux-") && norm.ends_with("gnu") {
return Some(BunTarget::LinuxArm64Glibc);
}
if norm.starts_with("x86_64-") && norm.contains("-linux-") && norm.ends_with("musl") {
return Some(BunTarget::LinuxX64Musl);
}
if norm.starts_with("aarch64-") && norm.contains("-linux-") && norm.ends_with("musl") {
return Some(BunTarget::LinuxArm64Musl);
}
if norm == "x86_64-pc-windows-msvc" {
return Some(BunTarget::WindowsX64);
}
if norm == "aarch64-pc-windows-msvc" {
return Some(BunTarget::WindowsArm64);
}
if norm == "x86_64-apple-darwin" {
return Some(BunTarget::MacX64);
}
if norm == "aarch64-apple-darwin" {
return Some(BunTarget::MacArm64);
}
for bt in [
BunTarget::LinuxX64Glibc,
BunTarget::LinuxArm64Glibc,
BunTarget::LinuxX64Musl,
BunTarget::LinuxArm64Musl,
BunTarget::WindowsX64,
BunTarget::WindowsArm64,
BunTarget::MacX64,
BunTarget::MacArm64,
] {
for &t in bt.rust_triples() {
if t == norm {
return Some(bt);
}
}
}
None
}
pub fn from_cargo_env() -> Result<Self, BunTargetError> {
if let Ok(triple) = env::var("TARGET") {
if let Some(bt) = Self::from_rust_target(&triple) {
return Ok(bt);
}
return Err(BunTargetError::UnknownTriple(triple));
}
let os = env::var("CARGO_CFG_TARGET_OS").unwrap_or_default();
let arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap_or_default();
let envv = env::var("CARGO_CFG_TARGET_ENV").unwrap_or_default();
let vendor = env::var("CARGO_CFG_TARGET_VENDOR").unwrap_or_else(|_| "unknown".into());
let triple = format!(
"{}-{}-{}-{}",
arch,
vendor,
os,
if envv.is_empty() { "gnu" } else { &envv }
);
if let Some(bt) = Self::from_rust_target(&triple) {
Ok(bt)
} else {
Err(BunTargetError::UnknownTriple(triple))
}
}
}
#[derive(Debug)]
pub enum BunTargetError {
UnknownTriple(String),
}
impl fmt::Display for BunTargetError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BunTargetError::UnknownTriple(t) => write!(f, "unrecognized Rust target triple: {t}"),
}
}
}
impl std::error::Error for BunTargetError {}

View File

@@ -0,0 +1,17 @@
{
"lockfileVersion": 1,
"workspaces": {
"": {
"name": "cli",
"dependencies": {
"install": "^0.13.0",
"openai": "^5.16.0",
},
},
},
"packages": {
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
"openai": ["openai@5.19.1", "", { "peerDependencies": { "ws": "^8.18.0", "zod": "^3.23.8" }, "optionalPeers": ["ws", "zod"], "bin": { "openai": "bin/cli" } }, "sha512-zSqnUF7oR9ksmpusKkpUgkNrj8Sl57U+OyzO8jzc7LUjTMg4DRfR3uCm+EIMA6iw06sRPNp4t7ojp3sCpEUZRQ=="],
}
}

View File

@@ -30,24 +30,23 @@ type ChunkStat = {
function printHelp() {
console.log(`
Usage: bun client_cli.ts [options] [prompt]
./cli [options] [prompt]
Simple CLI tool for testing the local OpenAI-compatible API server.
Options:
--model <model> Model to use (default: ${DEFAULT_MODEL})
--model <model> Model to use (default: gemma-3-1b-it)
--prompt <prompt> The prompt to send (can also be provided as positional argument)
--list-models List all available models from the server
--help Show this help message
Examples:
./cli.ts "What is the capital of France?"
./cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
./cli.ts --prompt "Who was the 16th president of the United States?"
./cli.ts --list-models
./cli "What is the capital of France?"
./cli --model gemma-3-1b-it --prompt "Hello, world!"
./cli --prompt "Who was the 16th president of the United States?"
./cli --list-models
The server should be running at http://localhost:8080
Start it with: ./run_server.sh
The server must be running at http://localhost:8080
`);
}

View File

@@ -0,0 +1,11 @@
{
"name": "cli",
"main": "cli.ts",
"scripts": {
"build": "bun build cli.ts --compile --outfile cli"
},
"dependencies": {
"install": "^0.13.0",
"openai": "^5.16.0"
}
}

View File

@@ -0,0 +1,32 @@
use std::{env, fs, io, path::PathBuf, process::Command};
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
fn main() -> io::Result<()> {
// Absolute path provided by build.rs at compile time.
// `include_bytes!` accepts string literals; `env!` expands to a literal at compile time.
const CLIENT_CLI: &[u8] = include_bytes!(env!("CLIENT_CLI_BIN"));
// Write to a temp file
let mut tmp = env::temp_dir();
tmp.push("client-cli-embedded");
fs::write(&tmp, CLIENT_CLI)?;
// Ensure it's executable on Unix
#[cfg(unix)]
{
let mut perms = fs::metadata(&tmp)?.permissions();
perms.set_mode(0o755);
fs::set_permissions(&tmp, perms)?;
}
// Run it
let status = Command::new(&tmp).arg("--version").status()?;
if !status.success() {
return Err(io::Error::other("client-cli failed"));
}
Ok(())
}

View File

@@ -0,0 +1,32 @@
[package]
name = "gemma-runner"
version.workspace = true
edition = "2021"
[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git" }
candle-nn = { git = "https://github.com/huggingface/candle.git" }
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
hf-hub = "0.4"
tokenizers = "0.22.0"
anyhow = "1.0"
clap = { version = "4.0", features = ["derive", "string"] }
serde_json = "1.0"
tracing = "0.1"
tracing-chrome = "0.7"
tracing-subscriber = "0.3"
utils = {path = "../utils" }
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
[features]
default = []
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]

View File

@@ -0,0 +1,137 @@
# Gemma Runner
Fast Gemma inference with Candle framework in Rust.
## Features
- Support for multiple Gemma model versions (v1, v2, v3)
- GPU acceleration with CUDA and Metal
- Configurable sampling parameters
- Multiple model variants including instruct and code models
## Supported Models
### Gemma v1
- `gemma-2b` - Base 2B model
- `gemma-7b` - Base 7B model
- `gemma-2b-it` - Instruct 2B model
- `gemma-7b-it` - Instruct 7B model
- `gemma-1.1-2b-it` - Instruct 2B v1.1 model
- `gemma-1.1-7b-it` - Instruct 7B v1.1 model
### CodeGemma
- `codegemma-2b` - Code base 2B model
- `codegemma-7b` - Code base 7B model
- `codegemma-2b-it` - Code instruct 2B model
- `codegemma-7b-it` - Code instruct 7B model
### Gemma v2
- `gemma-2-2b` - Base 2B v2 model (default)
- `gemma-2-2b-it` - Instruct 2B v2 model
- `gemma-2-9b` - Base 9B v2 model
- `gemma-2-9b-it` - Instruct 9B v2 model
### Gemma v3
- `gemma-3-1b` - Base 1B v3 model
- `gemma-3-1b-it` - Instruct 1B v3 model
## Installation
```bash
cd gemma-runner
cargo build --release
```
For GPU support:
```bash
# CUDA
cargo build --release --features cuda
# Metal (macOS)
cargo build --release --features metal
```
## Usage
### Basic Usage
```bash
# Run with default model (gemma-2-2b)
cargo run -- --prompt "The capital of France is"
# Specify a different model
cargo run -- --model gemma-2b-it --prompt "Explain quantum computing"
# Generate more tokens
cargo run -- --model codegemma-2b-it --prompt "Write a Python function to sort a list" --max-tokens 200
```
### Advanced Options
```bash
# Use CPU instead of GPU
cargo run -- --cpu --prompt "Hello world"
# Adjust sampling parameters
cargo run -- --temperature 0.8 --top-p 0.9 --prompt "Write a story about"
# Use custom model from HuggingFace Hub
cargo run -- --model-id "google/gemma-2-2b-it" --prompt "What is AI?"
# Enable tracing for performance analysis
cargo run -- --tracing --prompt "Explain machine learning"
```
### Command Line Arguments
- `--prompt, -p` - The prompt to generate text from (default: "The capital of France is")
- `--model, -m` - The model to use (default: "gemma-2-2b")
- `--cpu` - Run on CPU rather than GPU
- `--temperature, -t` - Sampling temperature (optional)
- `--top-p` - Nucleus sampling probability cutoff (optional)
- `--seed` - Random seed (default: 299792458)
- `--max-tokens, -n` - Maximum tokens to generate (default: 100)
- `--model-id` - Custom model ID from HuggingFace Hub
- `--revision` - Model revision (default: "main")
- `--use-flash-attn` - Use flash attention
- `--repeat-penalty` - Repetition penalty (default: 1.1)
- `--repeat-last-n` - Context size for repeat penalty (default: 64)
- `--dtype` - Data type (f16, bf16, f32)
- `--tracing` - Enable performance tracing
## Examples
### Text Generation
```bash
cargo run -- --model gemma-2b-it --prompt "Explain the theory of relativity" --max-tokens 150
```
### Code Generation
```bash
cargo run -- --model codegemma-7b-it --prompt "Write a Rust function to calculate factorial" --max-tokens 100
```
### Creative Writing
```bash
cargo run -- --model gemma-7b-it --temperature 0.9 --prompt "Once upon a time in a magical forest" --max-tokens 200
```
### Chat with Gemma 3 (Instruct format)
```bash
cargo run -- --model gemma-3-1b-it --prompt "How do I learn Rust programming?"
```
## Performance Notes
- GPU acceleration is automatically detected and used when available
- BF16 precision is used on CUDA for better performance
- F32 precision is used on CPU
- Flash attention can be enabled with `--use-flash-attn` for supported models
- Model files are cached locally after first download
## Requirements
- Rust 1.70+
- CUDA toolkit (for CUDA support)
- Metal (automatically available on macOS)
- Internet connection for first-time model download

View File

@@ -0,0 +1,447 @@
use anyhow::{Error as E, Result};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
// Removed gemma_cli import as it's not needed for the API
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
use std::fmt;
use std::str::FromStr;
use std::sync::mpsc::{self, Receiver, Sender};
use std::thread;
use tokenizers::Tokenizer;
use utils::hub_load_safetensors;
use utils::token_output_stream::TokenOutputStream;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum WhichModel {
#[value(name = "gemma-2b")]
Base2B,
#[value(name = "gemma-7b")]
Base7B,
#[value(name = "gemma-2b-it")]
Instruct2B,
#[value(name = "gemma-7b-it")]
Instruct7B,
#[value(name = "gemma-1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "gemma-1.1-7b-it")]
InstructV1_1_7B,
#[value(name = "codegemma-2b")]
CodeBase2B,
#[value(name = "codegemma-7b")]
CodeBase7B,
#[value(name = "codegemma-2b-it")]
CodeInstruct2B,
#[value(name = "codegemma-7b-it")]
CodeInstruct7B,
#[value(name = "gemma-2-2b")]
BaseV2_2B,
#[value(name = "gemma-2-2b-it")]
InstructV2_2B,
#[value(name = "gemma-2-9b")]
BaseV2_9B,
#[value(name = "gemma-2-9b-it")]
InstructV2_9B,
#[value(name = "gemma-3-1b")]
BaseV3_1B,
#[value(name = "gemma-3-1b-it")]
InstructV3_1B,
}
impl FromStr for WhichModel {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gemma-2b" => Ok(Self::Base2B),
"gemma-7b" => Ok(Self::Base7B),
"gemma-2b-it" => Ok(Self::Instruct2B),
"gemma-7b-it" => Ok(Self::Instruct7B),
"gemma-1.1-2b-it" => Ok(Self::InstructV1_1_2B),
"gemma-1.1-7b-it" => Ok(Self::InstructV1_1_7B),
"codegemma-2b" => Ok(Self::CodeBase2B),
"codegemma-7b" => Ok(Self::CodeBase7B),
"codegemma-2b-it" => Ok(Self::CodeInstruct2B),
"codegemma-7b-it" => Ok(Self::CodeInstruct7B),
"gemma-2-2b" => Ok(Self::BaseV2_2B),
"gemma-2-2b-it" => Ok(Self::InstructV2_2B),
"gemma-2-9b" => Ok(Self::BaseV2_9B),
"gemma-2-9b-it" => Ok(Self::InstructV2_9B),
"gemma-3-1b" => Ok(Self::BaseV3_1B),
"gemma-3-1b-it" => Ok(Self::InstructV3_1B),
_ => Err(format!("Unknown model: {}", s)),
}
}
}
impl fmt::Display for WhichModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
Self::Base2B => "gemma-2b",
Self::Base7B => "gemma-7b",
Self::Instruct2B => "gemma-2b-it",
Self::Instruct7B => "gemma-7b-it",
Self::InstructV1_1_2B => "gemma-1.1-2b-it",
Self::InstructV1_1_7B => "gemma-1.1-7b-it",
Self::CodeBase2B => "codegemma-2b",
Self::CodeBase7B => "codegemma-7b",
Self::CodeInstruct2B => "codegemma-2b-it",
Self::CodeInstruct7B => "codegemma-7b-it",
Self::BaseV2_2B => "gemma-2-2b",
Self::InstructV2_2B => "gemma-2-2b-it",
Self::BaseV2_9B => "gemma-2-9b",
Self::InstructV2_9B => "gemma-2-9b-it",
Self::BaseV3_1B => "gemma-3-1b",
Self::InstructV3_1B => "gemma-3-1b-it",
};
write!(f, "{}", name)
}
}
enum Model {
V1(Model1),
V2(Model2),
V3(Model3),
}
impl Model {
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> candle_core::Result<Tensor> {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
Self::V3(m) => m.forward(input_ids, pos),
}
}
}
pub struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if candle_core::utils::cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if candle_core::utils::metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
Ok(Device::Cpu)
}
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: tokenizers::Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
/// Stream-only generation: sends freshly generated token strings over `tx`.
/// (Does not send the prompt tokens; only newly generated model tokens.)
fn run_stream(
&mut self,
prompt: &str,
sample_len: usize,
tx: Sender<Result<String>>,
) -> Result<()> {
self.tokenizer.clear();
// Encode prompt (context only; do not emit prompt tokens to the stream).
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
// Warm the tokenizer's internal state with prompt tokens (so merges are correct),
// but do not send them to the receiver.
for &t in tokens.iter() {
let _ = self.tokenizer.next_token(t)?;
}
// Make sure stdout isn't holding anything (if caller also prints).
std::io::stdout().flush()?;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
Some(token) => token,
None => {
eprintln!("Warning: <end_of_turn> token not found, using <eos> as backup");
eos_token
}
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
if next_token == eos_token || next_token == eot_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
// Best-effort send; ignore if receiver dropped.
let _ = tx.send(Ok(t));
}
}
let _dt = start_gen.elapsed();
// Flush any remaining buffered bytes as one final chunk.
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
let _ = tx.send(Ok(rest));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct GemmaInferenceConfig {
pub tracing: bool,
pub prompt: String,
pub model: Option<WhichModel>,
pub cpu: bool,
pub dtype: Option<String>,
pub model_id: Option<String>,
pub revision: String,
pub use_flash_attn: bool,
pub seed: u64,
pub temperature: f64,
pub top_p: Option<f64>,
pub repeat_penalty: f32,
pub repeat_last_n: usize,
pub max_tokens: usize,
}
impl Default for GemmaInferenceConfig {
fn default() -> Self {
Self {
tracing: false,
prompt: "Hello".to_string(),
model: Some(WhichModel::InstructV2_2B),
cpu: false,
dtype: None,
model_id: None,
revision: "main".to_string(),
use_flash_attn: false,
seed: 299792458,
temperature: 0.8,
top_p: None,
repeat_penalty: 1.1,
repeat_last_n: 128,
max_tokens: 100,
}
}
}
// Removed From<Args> implementation as Args is not available and not needed for API usage
/// Builds the model and returns a channel that streams generated token strings.
/// If model setup fails, the `Result` is returned immediately.
pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String>>> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let _guard = if cfg.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle_core::utils::with_avx(),
candle_core::utils::with_neon(),
candle_core::utils::with_simd128(),
candle_core::utils::with_f16c()
);
let device = device(cfg.cpu)?;
println!("Device: {:?}", device);
let dtype = match cfg.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => anyhow::bail!("Unsupported dtype {dtype}"),
None => {
if device.is_cuda() {
DType::BF16
} else {
DType::F16
}
}
};
println!("Using dtype: {:?}", dtype);
println!("Raw model string: {:?}", cfg.model_id);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = cfg.model_id.unwrap_or_else(|| {
match cfg.model {
Some(WhichModel::Base2B) => "google/gemma-2b",
Some(WhichModel::Base7B) => "google/gemma-7b",
Some(WhichModel::Instruct2B) => "google/gemma-2b-it",
Some(WhichModel::Instruct7B) => "google/gemma-7b-it",
Some(WhichModel::InstructV1_1_2B) => "google/gemma-1.1-2b-it",
Some(WhichModel::InstructV1_1_7B) => "google/gemma-1.1-7b-it",
Some(WhichModel::CodeBase2B) => "google/codegemma-2b",
Some(WhichModel::CodeBase7B) => "google/codegemma-7b",
Some(WhichModel::CodeInstruct2B) => "google/codegemma-2b-it",
Some(WhichModel::CodeInstruct7B) => "google/codegemma-7b-it",
Some(WhichModel::BaseV2_2B) => "google/gemma-2-2b",
Some(WhichModel::InstructV2_2B) => "google/gemma-2-2b-it",
Some(WhichModel::BaseV2_9B) => "google/gemma-2-9b",
Some(WhichModel::InstructV2_9B) => "google/gemma-2-9b-it",
Some(WhichModel::BaseV3_1B) => "google/gemma-3-1b-pt",
Some(WhichModel::InstructV3_1B) => "google/gemma-3-1b-it",
None => "google/gemma-2-2b-it", // default fallback
}
.to_string()
});
println!("Loading model: {}", &model_id);
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, cfg.revision));
let tokenizer_filename = repo.get("tokenizer.json")?;
let config_filename = repo.get("config.json")?;
let filenames = match cfg.model {
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
vec![repo.get("model.safetensors")?]
}
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("Retrieved files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model: Model = match cfg.model {
Some(WhichModel::Base2B)
| Some(WhichModel::Base7B)
| Some(WhichModel::Instruct2B)
| Some(WhichModel::Instruct7B)
| Some(WhichModel::InstructV1_1_2B)
| Some(WhichModel::InstructV1_1_7B)
| Some(WhichModel::CodeBase2B)
| Some(WhichModel::CodeBase7B)
| Some(WhichModel::CodeInstruct2B)
| Some(WhichModel::CodeInstruct7B) => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
Model::V1(model)
}
Some(WhichModel::BaseV2_2B)
| Some(WhichModel::InstructV2_2B)
| Some(WhichModel::BaseV2_9B)
| Some(WhichModel::InstructV2_9B)
| None => {
// default to V2 model
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
Model::V2(model)
}
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
Model::V3(model)
}
};
println!("Loaded model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
cfg.seed,
cfg.temperature.into(),
cfg.top_p,
cfg.repeat_penalty,
cfg.repeat_last_n,
&device,
);
let prompt = match cfg.model {
Some(WhichModel::InstructV3_1B) => {
format!(
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
cfg.prompt
)
}
_ => cfg.prompt,
};
println!("Starting inference...");
// Create the channel after successful setup.
let (tx, rx) = mpsc::channel::<Result<String>>();
// Spawn generation thread; send tokens to the channel.
thread::spawn(move || {
// If generation fails, forward the error once.
if let Err(e) = pipeline.run_stream(&prompt, cfg.max_tokens, tx.clone()) {
let _ = tx.send(Err(e));
}
// Channel closes when tx is dropped.
});
Ok(rx)
}

View File

@@ -0,0 +1,97 @@
use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
use clap::Parser;
use std::io::Write;
#[derive(Parser, Debug)]
#[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)]
pub struct Args {
/// The prompt to generate text from
#[arg(short, long, default_value = "The capital of France is")]
pub(crate) prompt: String,
/// The model to use
#[arg(short, long, default_value = "gemma-2-2b")]
pub(crate) model: WhichModel,
/// Run on CPU rather than GPU
#[arg(long)]
pub(crate) cpu: bool,
/// The temperature used to generate samples
#[arg(short, long)]
pub(crate) temperature: Option<f64>,
/// Nucleus sampling probability cutoff
#[arg(long)]
pub(crate) top_p: Option<f64>,
/// The seed to use when generating random samples
#[arg(long, default_value_t = 299792458)]
pub(crate) seed: u64,
/// The length of the sample to generate (in tokens)
#[arg(short = 'n', long, default_value_t = 100)]
pub(crate) max_tokens: usize,
/// Use different dtype than default
#[arg(long)]
pub(crate) dtype: Option<String>,
/// Custom model ID from HuggingFace Hub
#[arg(long)]
pub(crate) model_id: Option<String>,
/// Model revision
#[arg(long, default_value = "main")]
pub(crate) revision: String,
/// Use flash attention
#[arg(long)]
pub(crate) use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty
#[arg(long, default_value_t = 1.1)]
pub(crate) repeat_penalty: f32,
/// The context size to consider for the repeat penalty
#[arg(long, default_value_t = 64)]
pub(crate) repeat_last_n: usize,
/// Enable tracing
#[arg(long)]
pub(crate) tracing: bool,
}
pub fn run_cli() -> anyhow::Result<()> {
let args = Args::parse();
let cfg = GemmaInferenceConfig {
tracing: args.tracing,
prompt: args.prompt,
model: Some(args.model),
cpu: args.cpu,
dtype: args.dtype,
model_id: args.model_id,
revision: args.revision,
use_flash_attn: args.use_flash_attn,
seed: args.seed,
temperature: args.temperature.unwrap_or(0.8),
top_p: args.top_p,
repeat_penalty: args.repeat_penalty,
repeat_last_n: args.repeat_last_n,
max_tokens: args.max_tokens,
};
let rx = run_gemma_api(cfg)?;
for msg in rx {
match msg {
Ok(tok) => {
print!("{tok}");
let _ = std::io::stdout().flush(); // <- force it out now
}
Err(e) => {
eprintln!("generation error: {e}");
break;
}
}
}
Ok(())
}

View File

@@ -0,0 +1,3 @@
pub mod gemma_api;
pub use gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};

View File

@@ -0,0 +1,15 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod gemma_api;
mod gemma_cli;
use anyhow::Error;
use crate::gemma_cli::run_cli;
/// just a placeholder, not used for anything
fn main() -> std::result::Result<(), Error> {
run_cli()
}

View File

@@ -1,10 +1,8 @@
[package]
name = "helm-chart-tool"
version = "0.1.0"
version.workspace = true
edition = "2021"
[workspace]
[[bin]]
name = "helm-chart-tool"
path = "src/main.rs"

View File

@@ -64,14 +64,9 @@ version = "0.1.0"
# Required: Kubernetes metadata
[package.metadata.kube]
image = "ghcr.io/myorg/my-service:latest"
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
replicas = 1
port = 8080
# Optional: Docker Compose metadata (currently not used but parsed)
[package.metadata.compose]
image = "ghcr.io/myorg/my-service:latest"
port = 8080
```
### Required Fields
@@ -137,7 +132,7 @@ Parsing workspace at: ..
Output directory: ../generated-helm-chart
Chart name: predict-otron-9000
Found 4 services:
- leptos-app: ghcr.io/geoffsee/leptos-app:latest (port 8788)
- chat-ui: ghcr.io/geoffsee/chat-ui:latest (port 8788)
- inference-engine: ghcr.io/geoffsee/inference-service:latest (port 8080)
- embeddings-engine: ghcr.io/geoffsee/embeddings-service:latest (port 8080)
- predict-otron-9000: ghcr.io/geoffsee/predict-otron-9000:latest (port 8080)

View File

@@ -1,9 +1,8 @@
use anyhow::{Context, Result};
use clap::{Arg, Command};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use serde::Deserialize;
use std::fs;
use std::path::{Path, PathBuf};
use std::path::Path;
use walkdir::WalkDir;
#[derive(Debug, Deserialize)]
@@ -20,7 +19,6 @@ struct Package {
#[derive(Debug, Deserialize)]
struct Metadata {
kube: Option<KubeMetadata>,
compose: Option<ComposeMetadata>,
}
#[derive(Debug, Deserialize)]
@@ -30,12 +28,6 @@ struct KubeMetadata {
port: u16,
}
#[derive(Debug, Deserialize)]
struct ComposeMetadata {
image: Option<String>,
port: Option<u16>,
}
#[derive(Debug, Clone)]
struct ServiceInfo {
name: String,
@@ -84,7 +76,10 @@ fn main() -> Result<()> {
let services = discover_services(workspace_path)?;
println!("Found {} services:", services.len());
for service in &services {
println!(" - {}: {} (port {})", service.name, service.image, service.port);
println!(
" - {}: {} (port {})",
service.name, service.image, service.port
);
}
generate_helm_chart(output_path, chart_name, &services)?;
@@ -102,7 +97,9 @@ fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
.into_iter()
.filter_map(|e| e.ok())
{
if entry.file_name() == "Cargo.toml" && entry.path() != workspace_root.join("Cargo.toml") {
if entry.file_name() == "Cargo.toml"
&& entry.path() != workspace_root.join("../../../Cargo.toml")
{
if let Ok(service_info) = parse_cargo_toml(entry.path()) {
services.push(service_info);
}
@@ -115,17 +112,20 @@ fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> {
let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read Cargo.toml at {:?}", path))?;
let cargo_toml: CargoToml = toml::from_str(&content)
.with_context(|| format!("Failed to parse Cargo.toml at {:?}", path))?;
let package = cargo_toml.package
let package = cargo_toml
.package
.ok_or_else(|| anyhow::anyhow!("No package section found in {:?}", path))?;
let metadata = package.metadata
let metadata = package
.metadata
.ok_or_else(|| anyhow::anyhow!("No metadata section found in {:?}", path))?;
let kube_metadata = metadata.kube
let kube_metadata = metadata
.kube
.ok_or_else(|| anyhow::anyhow!("No kube metadata found in {:?}", path))?;
Ok(ServiceInfo {
@@ -136,7 +136,11 @@ fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> {
})
}
fn generate_helm_chart(output_path: &str, chart_name: &str, services: &[ServiceInfo]) -> Result<()> {
fn generate_helm_chart(
output_path: &str,
chart_name: &str,
services: &[ServiceInfo],
) -> Result<()> {
let chart_dir = Path::new(output_path);
let templates_dir = chart_dir.join("templates");
@@ -365,7 +369,7 @@ spec:
Ok(())
}
fn generate_ingress_template(templates_dir: &Path, services: &[ServiceInfo]) -> Result<()> {
fn generate_ingress_template(templates_dir: &Path, _services: &[ServiceInfo]) -> Result<()> {
let ingress_template = r#"{{- if .Values.ingress.enabled -}}
apiVersion: networking.k8s.io/v1
kind: Ingress
@@ -512,4 +516,4 @@ fn generate_helmignore(chart_dir: &Path) -> Result<()> {
fs::write(chart_dir.join(".helmignore"), helmignore_content)?;
Ok(())
}
}

View File

@@ -0,0 +1,24 @@
[package]
name = "llama-runner"
version.workspace = true
edition = "2021"
[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git" }
candle-nn = { git = "https://github.com/huggingface/candle.git" }
candle-transformers = { git = "https://github.com/huggingface/candle.git"}
hf-hub = "0.3"
tokenizers = "0.20"
anyhow = "1.0"
clap = { version = "4.0", features = ["derive", "string"] }
serde_json = "1.0"
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
[features]
default = []
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]

View File

@@ -0,0 +1,188 @@
# Llama Runner
A fast Rust implementation for running Llama and other language models using the Candle deep learning framework. Built on the official Candle examples with optimizations for speed and usability.
## Features
- 🚀 **High Performance**: Metal GPU acceleration on macOS, CUDA support on Linux/Windows
- 🤖 **Multiple Models**: Supports Llama 3.2, SmolLM2, TinyLlama, and more
-**Fast Inference**: Optimized with F16 precision and KV caching
- 🎯 **Advanced Sampling**: Top-k, top-p, temperature, and repeat penalty controls
- 📊 **Performance Metrics**: Real-time tokens/second reporting
- 🔧 **Easy CLI**: Simple command-line interface with sensible defaults
## Supported Models
| Model | Size | Command | Description |
|-------|------|---------|-------------|
| SmolLM2-135M | 135M | `smollm2-135m` | Tiny, fast model for testing |
| SmolLM2-360M | 360M | `smollm2-360m` | Small, efficient model |
| SmolLM2-1.7B | 1.7B | `smollm2-1.7b` | Balanced performance/speed |
| Llama-3.2-1B | 1B | `llama-3.2-1b` | Meta's compact model |
| Llama-3.2-3B | 3B | `llama-3.2-3b` | Larger Llama model |
| TinyLlama-1.1B | 1.1B | `tinyllama-1.1b-chat` | Chat-optimized small model |
Add `-instruct` suffix for instruction-tuned variants (e.g., `smollm2-135m-instruct`).
## Installation
```bash
# Clone the repository
git clone <repository-url>
cd llama-runner
# Build with GPU acceleration (recommended)
cargo build --release --features metal # macOS
cargo build --release --features cuda # Linux/Windows with NVIDIA GPU
# CPU-only build
cargo build --release
```
## Quick Start
```bash
# Fast inference with GPU acceleration
cargo run --features metal -- --prompt "What is quantum computing?"
# Specify a model and parameters
cargo run --features metal -- \
--prompt "Write a short story about space exploration" \
--model smollm2-360m \
--max-tokens 100 \
--temperature 0.8
# Use CPU (slower but works everywhere)
cargo run -- --prompt "Hello, world!" --model smollm2-135m --cpu
```
## Usage Examples
### Basic Text Generation
```bash
# Simple completion
cargo run --features metal -- --prompt "The capital of France is"
# Creative writing with higher temperature
cargo run --features metal -- \
--prompt "Once upon a time" \
--temperature 1.0 \
--max-tokens 200
```
### Advanced Sampling
```bash
# Top-k and top-p sampling
cargo run --features metal -- \
--prompt "Explain artificial intelligence" \
--top-k 40 \
--top-p 0.9 \
--temperature 0.7
# Reduce repetition
cargo run --features metal -- \
--prompt "List the benefits of renewable energy" \
--repeat-penalty 1.2 \
--repeat-last-n 64
```
### Different Models
```bash
# Ultra-fast with tiny model
cargo run --features metal -- \
--prompt "Quick test" \
--model smollm2-135m
# Better quality with larger model
cargo run --features metal -- \
--prompt "Explain quantum physics" \
--model llama-3.2-1b \
--max-tokens 150
```
## Command-Line Options
| Option | Short | Default | Description |
|--------|-------|---------|-------------|
| `--prompt` | `-p` | "The capital of France is" | Input prompt |
| `--model` | `-m` | `smollm2-135m` | Model to use |
| `--max-tokens` | `-n` | 100 | Maximum tokens to generate |
| `--temperature` | `-t` | 0.8 | Sampling temperature (0.0 = deterministic) |
| `--top-k` | | None | Top-k sampling |
| `--top-p` | | None | Top-p (nucleus) sampling |
| `--seed` | | 299792458 | Random seed for reproducibility |
| `--repeat-penalty` | | 1.1 | Repetition penalty (1.0 = no penalty) |
| `--repeat-last-n` | | 128 | Context window for repeat penalty |
| `--cpu` | | false | Force CPU usage |
| `--dtype` | | f16 | Data type: f16, bf16, f32 |
| `--no-kv-cache` | | false | Disable key-value caching |
## Performance
Typical performance on Apple M2 with Metal acceleration:
| Model | Size | Speed | Memory |
|-------|------|-------|--------|
| SmolLM2-135M | 135M | ~100 tok/s | ~500MB |
| SmolLM2-360M | 360M | ~80 tok/s | ~1GB |
| SmolLM2-1.7B | 1.7B | ~50 tok/s | ~3GB |
| Llama-3.2-1B | 1B | ~40 tok/s | ~2GB |
## Requirements
- **Rust**: 1.70+ (latest stable recommended)
- **Memory**: 2-8GB RAM depending on model size
- **Storage**: 1-10GB for model weights
- **Network**: Internet connection for first-time model download
- **GPU** (optional): Metal on macOS, CUDA on Linux/Windows
## GPU Support
### macOS (Metal)
```bash
cargo run --features metal -- [options]
```
### Linux/Windows (CUDA)
```bash
cargo run --features cuda -- [options]
```
### CPU Only
```bash
cargo run -- --cpu [options]
```
## Model Downloads
Models are automatically downloaded from HuggingFace Hub on first use and cached locally. Download times:
- SmolLM2-135M: ~1 minute
- SmolLM2-360M: ~2 minutes
- Llama-3.2-1B: ~5 minutes
- Larger models: 10+ minutes
## Troubleshooting
### Slow Performance
- Use `--features metal` on macOS or `--features cuda` on Linux/Windows
- Try smaller models like `smollm2-135m` for faster inference
- Ensure sufficient RAM for your chosen model
### Out of Memory
- Use `--cpu` to use system RAM instead of GPU memory
- Try smaller models or reduce `--max-tokens`
- Use `--dtype f32` if f16 causes issues
### Model Download Issues
- Check internet connection
- Some models may require HuggingFace Hub authentication
- Verify sufficient disk space in `~/.cache/huggingface/`
## Contributing
Contributions welcome! This project is based on the [Candle](https://github.com/huggingface/candle) framework by HuggingFace.
## License
MIT License - see LICENSE file for details.

View File

@@ -0,0 +1,6 @@
pub mod llama_api;
pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
// Re-export constants and types that might be needed
pub const EOS_TOKEN: &str = "</s>";

View File

@@ -0,0 +1,354 @@
use crate::EOS_TOKEN;
use anyhow::{bail, Error as E};
use candle_core::{utils, DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama as model;
use candle_transformers::models::llama::{Llama, LlamaConfig};
use clap::ValueEnum;
use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType};
use std::sync::mpsc::{self, Receiver};
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
pub enum WhichModel {
#[value(name = "llama-3.2-1b")]
#[default]
Llama32_1B,
#[value(name = "llama-3.2-1b-instruct")]
Llama32_1BInstruct,
#[value(name = "llama-3.2-3b")]
Llama32_3B,
#[value(name = "llama-3.2-3b-instruct")]
Llama32_3BInstruct,
#[value(name = "smollm2-135m")]
SmolLM2_135M,
#[value(name = "smollm2-135m-instruct")]
SmolLM2_135MInstruct,
#[value(name = "smollm2-360m")]
SmolLM2_360M,
#[value(name = "smollm2-360m-instruct")]
SmolLM2_360MInstruct,
#[value(name = "smollm2-1.7b")]
SmolLM2_1_7B,
#[value(name = "smollm2-1.7b-instruct")]
SmolLM2_1_7BInstruct,
#[value(name = "tinyllama-1.1b-chat")]
TinyLlama1_1BChat,
}
#[derive(Debug, Clone)]
pub struct LlamaInferenceConfig {
pub prompt: String,
pub model: WhichModel,
pub cpu: bool,
pub temperature: f64,
pub top_p: Option<f64>,
pub top_k: Option<usize>,
pub seed: u64,
pub max_tokens: usize,
pub no_kv_cache: bool,
pub dtype: Option<String>,
pub model_id: Option<String>,
pub revision: Option<String>,
pub use_flash_attn: bool,
pub repeat_penalty: f32,
pub repeat_last_n: usize,
}
impl LlamaInferenceConfig {
pub fn new(model: WhichModel) -> Self {
Self {
prompt: String::new(),
model,
cpu: false,
temperature: 1.0,
top_p: None,
top_k: None,
seed: 42,
max_tokens: 512,
no_kv_cache: false,
dtype: None,
model_id: None,
revision: None,
use_flash_attn: true,
repeat_penalty: 1.1,
repeat_last_n: 64,
}
}
}
impl Default for LlamaInferenceConfig {
fn default() -> Self {
Self {
// Leave prompt empty by default; let call sites set it.
prompt: String::new(),
// Keep your existing model choice; swap at call-site if needed.
model: WhichModel::Llama32_1BInstruct,
// Prefer GPU if available.
cpu: false,
// Sampling: balanced + stable
temperature: 0.7,
top_p: Some(0.95),
top_k: Some(50),
// Reproducible by default; override for variability.
seed: 42,
// Dont run unbounded generations.
max_tokens: 512,
// Performance flags
no_kv_cache: false, // keep cache ON for speed
use_flash_attn: false, // great speed boost if supported
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
dtype: Some("bf16".to_string()),
// Optional model source pinning (None = app defaults)
model_id: None,
revision: None,
// Anti-repeat heuristics
repeat_penalty: 1.15,
repeat_last_n: 128,
}
}
}
fn device(cpu: bool) -> anyhow::Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if utils::cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if utils::metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
Ok(Device::Cpu)
}
}
fn hub_load_safetensors(
api: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> anyhow::Result<Vec<std::path::PathBuf>> {
let json_file = api.get(json_file)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value = serde_json::from_reader(&json_file)?;
let weight_map = match json.get("weight_map") {
None => bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| api.get(v))
.collect::<anyhow::Result<Vec<_>, _>>()?;
Ok(safetensors_files)
}
pub fn run_llama_inference(
cfg: LlamaInferenceConfig,
) -> anyhow::Result<Receiver<anyhow::Result<String>>, anyhow::Error> {
// ---- Device & dtype -----------------------------------------------------
let device = device(cfg.cpu)?;
println!("Device: {:?}", device);
let dtype = match cfg.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
println!("Using dtype: {:?}", dtype);
// ---- Load model & tokenizer --------------------------------------------
let (llama, tokenizer, mut cache) = {
let api = Api::new()?;
let model_id = cfg.model_id.clone().unwrap_or_else(|| {
match cfg.model {
WhichModel::Llama32_1B => "meta-llama/Llama-3.2-1B",
WhichModel::Llama32_1BInstruct => "meta-llama/Llama-3.2-1B-Instruct",
WhichModel::Llama32_3B => "meta-llama/Llama-3.2-3B",
WhichModel::Llama32_3BInstruct => "meta-llama/Llama-3.2-3B-Instruct",
WhichModel::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
WhichModel::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
WhichModel::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
WhichModel::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
WhichModel::SmolLM2_1_7B => "HuggingFaceTB/SmolLM2-1.7B",
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
.to_string()
});
println!("Loading model: {}", model_id);
let revision = cfg.revision.clone().unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = api.get("tokenizer.json")?;
let config_filename = api.get("config.json")?;
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(cfg.use_flash_attn);
let filenames = match cfg.model {
WhichModel::Llama32_3B | WhichModel::Llama32_3BInstruct => {
hub_load_safetensors(&api, "model.safetensors.index.json")?
}
_ => vec![api.get("model.safetensors")?],
};
let cache = model::Cache::new(!cfg.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let llama = Llama::load(vb, &config)?;
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
(llama, tokenizer, cache)
};
// ---- Prepare prompt & sampler ------------------------------------------
let eos_token_id = tokenizer
.token_to_id(EOS_TOKEN)
.map(model::LlamaEosToks::Single);
let mut tokens = tokenizer
.encode(cfg.prompt.as_str(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
println!("Starting inference...");
let mut logits_processor = {
let temperature = cfg.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (cfg.top_k, cfg.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(cfg.seed, sampling)
};
// Channel for streaming decoded fragments to the caller.
let (tx, rx) = mpsc::channel::<anyhow::Result<String>>();
// ---- Spawn generation thread -------------------------------------------
std::thread::spawn(move || {
let start_gen = std::time::Instant::now();
let mut index_pos = 0usize;
let mut token_generated = 0usize;
for index in 0..cfg.max_tokens {
// Use KV-cache for single-token step after the first pass.
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = match Tensor::new(ctxt, &device).and_then(|t| t.unsqueeze(0)) {
Ok(t) => t,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
let logits = match llama.forward(&input, context_index, &mut cache) {
Ok(l) => l,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
let logits = match logits.squeeze(0) {
Ok(l) => l,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
let logits = if cfg.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(cfg.repeat_last_n);
match candle_transformers::utils::apply_repeat_penalty(
&logits,
cfg.repeat_penalty,
&tokens[start_at..],
) {
Ok(l) => l,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
}
};
index_pos += ctxt.len();
let next_token = match logits_processor.sample(&logits) {
Ok(t) => t,
Err(e) => {
let _ = tx.send(Err(e.into()));
break;
}
};
token_generated += 1;
tokens.push(next_token);
// Early stop on EOS.
let stop = match eos_token_id {
Some(model::LlamaEosToks::Single(eos_tok_id)) => next_token == eos_tok_id,
Some(model::LlamaEosToks::Multiple(ref eos_ids)) => eos_ids.contains(&next_token),
None => false,
};
if stop {
break;
}
// Decode this token's text and stream it out.
match tokenizer.decode(&[next_token], false) {
Ok(text) => {
if !text.is_empty() {
// Best-effort send; if receiver is gone, just stop.
if tx.send(Ok(text)).is_err() {
break;
}
}
}
Err(e) => {
let _ = tx.send(Err(anyhow::anyhow!("{}", e)));
break;
}
}
}
// Optional: final stats as a debug line (not sent through the stream).
let dt = start_gen.elapsed();
eprintln!(
"[llama-runner] {} tokens generated ({:.2} tokens/s)",
token_generated,
token_generated as f64 / dt.as_secs_f64(),
);
// Dropping tx closes the stream.
});
Ok(rx)
}

View File

@@ -0,0 +1,108 @@
use crate::llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
use clap::Parser;
use std::io::Write;
#[derive(Parser, Debug, Default)]
#[command(author, version, about = "Fast Llama inference with Candle", long_about = None)]
struct Args {
/// The prompt to generate text from
#[arg(short, long, default_value = "The capital of France is")]
prompt: String,
/// The model to use
#[arg(short, long, default_value = "llama-3.2-1b-instruct")]
model: WhichModel,
/// Run on CPU rather than GPU
#[arg(long)]
cpu: bool,
/// The temperature used to generate samples
#[arg(short, long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens)
#[arg(short = 'n', long, default_value_t = 100)]
max_tokens: usize,
/// Disable the key-value cache
#[arg(long)]
no_kv_cache: bool,
/// Use different dtype than f16
#[arg(long)]
dtype: Option<String>,
/// Custom model ID from HuggingFace Hub
#[arg(long)]
model_id: Option<String>,
/// Model revision
#[arg(long)]
revision: Option<String>,
/// Use flash attention
#[arg(long)]
use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty
#[arg(long, default_value_t = 128)]
repeat_last_n: usize,
}
impl Into<LlamaInferenceConfig> for Args {
fn into(self) -> LlamaInferenceConfig {
LlamaInferenceConfig {
prompt: self.prompt,
model: self.model,
cpu: self.cpu,
temperature: self.temperature,
top_p: self.top_p,
top_k: self.top_k,
seed: self.seed,
max_tokens: self.max_tokens,
no_kv_cache: self.no_kv_cache,
dtype: self.dtype,
model_id: self.model_id,
revision: self.revision,
use_flash_attn: self.use_flash_attn,
repeat_penalty: self.repeat_penalty,
repeat_last_n: self.repeat_last_n,
}
}
}
pub fn run_cli() -> anyhow::Result<()> {
let args = Args::parse();
let cfg = args.into();
let rx = run_llama_inference(cfg)?;
for msg in rx {
match msg {
Ok(tok) => {
print!("{tok}");
let _ = std::io::stdout().flush(); // <- force it out now
}
Err(e) => {
eprintln!("generation error: {e}");
break;
}
}
}
Ok(())
}

View File

@@ -0,0 +1,16 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod llama_api;
mod llama_cli;
use anyhow::Result;
use crate::llama_cli::run_cli;
const EOS_TOKEN: &str = "</s>";
fn main() -> Result<()> {
run_cli()
}

View File

@@ -0,0 +1,96 @@
[package]
name = "utils"
edition = "2021"
[lib]
path = "src/lib.rs"
[dependencies]
accelerate-src = {version = "0.3.2", optional = true }
candle-flash-attn = {version = "0.9.1", optional = true }
candle-onnx = {version = "0.9.1", optional = true }
csv = "1.3.0"
anyhow = "1.0.99"
cudarc = {version = "0.17.3", optional = true }
half = {version = "2.6.0", optional = true }
hf-hub = {version = "0.4.3", features = ["tokio"] }
image = {version = "0.25.6" }
intel-mkl-src = {version = "0.8.1", optional = true }
num-traits = {version = "0.2.19" }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true }
pyo3 = { version = "0.22.0", features = [
"auto-initialize",
"abi3-py311",
], optional = true }
rayon = {version = "1.11.0" }
rubato = { version = "0.15.0", optional = true }
safetensors = {version = "0.6.2" }
serde = {version = "1.0.219" }
serde_json = {version = "1.0.143" }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = {version = "0.22.0", features = ["onig"] }
cpal = { version = "0.15.2", optional = true }
pdf2image = { version = "0.1.2", optional = true }
tekken-rs = { version = "0.1.1", optional = true }
[dev-dependencies]
anyhow = {version = "1.0.99" }
byteorder = {version = "1.5.0" }
clap = {version = "4.5.46" }
imageproc = {version = "0.25.0" }
memmap2 = {version = "0.9.8" }
rand = {version = "0.9.2" }
ab_glyph = {version = "0.2.31" }
tracing = {version = "0.1.41" }
tracing-chrome = {version = "0.7.2" }
tracing-subscriber = {version = "0.3.20" }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.43.0"
[build-dependencies]
anyhow = {version = "1.0.99" }
bindgen_cuda = { version = "0.1.1", optional = true }
#
[features]
default = []
accelerate = [
"dep:accelerate-src",
"candle-core/accelerate",
"candle-nn/accelerate",
"candle-transformers/accelerate",
]
cuda = [
"candle-core/cuda",
"candle-nn/cuda",
"candle-transformers/cuda",
"dep:bindgen_cuda",
]
cudnn = ["candle-core/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = [
"dep:intel-mkl-src",
"candle-core/mkl",
"candle-nn/mkl",
"candle-transformers/mkl",
]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle-core/metal", "candle-nn/metal"]
microphone = ["cpal", "rubato"]
encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
snac = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
tekken = ["tekken-rs"]
# Platform-specific candle dependencies
[target.'cfg(target_os = "linux")'.dependencies]
candle-nn = {version = "0.9.1", default-features = false }
candle-transformers = {version = "0.9.1", default-features = false }
candle-core = {version = "0.9.1", default-features = false }
[target.'cfg(not(target_os = "linux"))'.dependencies]
candle-nn = {version = "0.9.1" }
candle-transformers = {version = "0.9.1" }
candle-core = {version = "0.9.1" }

View File

@@ -0,0 +1,138 @@
use candle_core::{Result, Tensor};
// https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/data/audio_utils.py#L57
pub fn normalize_loudness(
wav: &Tensor,
sample_rate: u32,
loudness_compressor: bool,
) -> Result<Tensor> {
let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::<f32>()?;
if energy < 2e-3 {
return Ok(wav.clone());
}
let wav_array = wav.to_vec1::<f32>()?;
let mut meter = crate::bs1770::ChannelLoudnessMeter::new(sample_rate);
meter.push(wav_array.into_iter());
let power = meter.as_100ms_windows();
let loudness = match crate::bs1770::gated_mean(power) {
None => return Ok(wav.clone()),
Some(gp) => gp.loudness_lkfs() as f64,
};
let delta_loudness = -14. - loudness;
let gain = 10f64.powf(delta_loudness / 20.);
let wav = (wav * gain)?;
if loudness_compressor {
wav.tanh()
} else {
Ok(wav)
}
}
#[cfg(feature = "symphonia")]
pub fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
use symphonia::core::conv::FromSample;
fn conv<T>(
samples: &mut Vec<f32>,
data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>,
) where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
// Open the media source.
let src = std::fs::File::open(path).map_err(candle::Error::wrap)?;
// Create the media source stream.
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
// Create a probe hint using the file's extension. [Optional]
let hint = symphonia::core::probe::Hint::new();
// Use the default options for metadata and format readers.
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
// Probe the media source.
let probed = symphonia::default::get_probe()
.format(&hint, mss, &fmt_opts, &meta_opts)
.map_err(candle::Error::wrap)?;
// Get the instantiated format reader.
let mut format = probed.format;
// Find the first audio track with a known (decodeable) codec.
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
.ok_or_else(|| candle::Error::Msg("no supported audio tracks".to_string()))?;
// Use the default options for the decoder.
let dec_opts: DecoderOptions = Default::default();
// Create a decoder for the track.
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &dec_opts)
.map_err(|_| candle::Error::Msg("unsupported codec".to_string()))?;
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
// The decode loop.
while let Ok(packet) = format.next_packet() {
// Consume any new metadata that has been read since the last packet.
while !format.metadata().is_latest() {
format.metadata().pop();
}
// If the packet does not belong to the selected track, skip over it.
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet).map_err(candle::Error::wrap)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
#[cfg(feature = "rubato")]
pub fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
use rubato::Resampler;
let mut pcm_out =
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)
.map_err(candle::Error::wrap)?;
let mut output_buffer = resampler.output_buffer_allocate(true);
let mut pos_in = 0;
while pos_in + resampler.input_frames_next() < pcm_in.len() {
let (in_len, out_len) = resampler
.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)
.map_err(candle::Error::wrap)?;
pos_in += in_len;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
if pos_in < pcm_in.len() {
let (_in_len, out_len) = resampler
.process_partial_into_buffer(Some(&[&pcm_in[pos_in..]]), &mut output_buffer, None)
.map_err(candle::Error::wrap)?;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
Ok(pcm_out)
}

View File

@@ -0,0 +1,506 @@
// Copied from https://github.com/ruuda/bs1770/blob/master/src/lib.rs
// BS1770 -- Loudness analysis library conforming to ITU-R BS.1770
// Copyright 2020 Ruud van Asseldonk
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// A copy of the License has been included in the root of the repository.
//! Loudness analysis conforming to [ITU-R BS.1770-4][bs17704].
//!
//! This library offers the building blocks to perform BS.1770 loudness
//! measurements, but you need to put the pieces together yourself.
//!
//! [bs17704]: https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
//!
//! # Stereo integrated loudness example
//!
//! ```ignore
//! # fn load_stereo_audio() -> [Vec<i16>; 2] {
//! # [vec![0; 48_000], vec![0; 48_000]]
//! # }
//! #
//! let sample_rate_hz = 44_100;
//! let bits_per_sample = 16;
//! let channel_samples: [Vec<i16>; 2] = load_stereo_audio();
//!
//! // When converting integer samples to float, note that the maximum amplitude
//! // is `1 << (bits_per_sample - 1)`, one bit is the sign bit.
//! let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
//!
//! let channel_power: Vec<_> = channel_samples.iter().map(|samples| {
//! let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate_hz);
//! meter.push(samples.iter().map(|&s| s as f32 * normalizer));
//! meter.into_100ms_windows()
//! }).collect();
//!
//! let stereo_power = bs1770::reduce_stereo(
//! channel_power[0].as_ref(),
//! channel_power[1].as_ref(),
//! );
//!
//! let gated_power = bs1770::gated_mean(
//! stereo_power.as_ref()
//! ).unwrap_or(bs1770::Power(0.0));
//! println!("Integrated loudness: {:.1} LUFS", gated_power.loudness_lkfs());
//! ```
use std::f32;
/// Coefficients for a 2nd-degree infinite impulse response filter.
///
/// Coefficient a0 is implicitly 1.0.
#[derive(Clone)]
struct Filter {
a1: f32,
a2: f32,
b0: f32,
b1: f32,
b2: f32,
// The past two input and output samples.
x1: f32,
x2: f32,
y1: f32,
y2: f32,
}
impl Filter {
/// Stage 1 of th BS.1770-4 pre-filter.
pub fn high_shelf(sample_rate_hz: f32) -> Filter {
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
let gain_db = 3.999_843_8;
let q = 0.707_175_25;
let center_hz = 1_681.974_5;
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L134-L143.
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
let vh = 10.0_f32.powf(gain_db / 20.0);
let vb = vh.powf(0.499_666_78);
let a0 = 1.0 + k / q + k * k;
Filter {
b0: (vh + vb * k / q + k * k) / a0,
b1: 2.0 * (k * k - vh) / a0,
b2: (vh - vb * k / q + k * k) / a0,
a1: 2.0 * (k * k - 1.0) / a0,
a2: (1.0 - k / q + k * k) / a0,
x1: 0.0,
x2: 0.0,
y1: 0.0,
y2: 0.0,
}
}
/// Stage 2 of th BS.1770-4 pre-filter.
pub fn high_pass(sample_rate_hz: f32) -> Filter {
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
let q = 0.500_327_05;
let center_hz = 38.135_47;
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L145-L151
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
Filter {
a1: 2.0 * (k * k - 1.0) / (1.0 + k / q + k * k),
a2: (1.0 - k / q + k * k) / (1.0 + k / q + k * k),
b0: 1.0,
b1: -2.0,
b2: 1.0,
x1: 0.0,
x2: 0.0,
y1: 0.0,
y2: 0.0,
}
}
/// Feed the next input sample, get the next output sample.
#[inline(always)]
pub fn apply(&mut self, x0: f32) -> f32 {
let y0 = 0.0 + self.b0 * x0 + self.b1 * self.x1 + self.b2 * self.x2
- self.a1 * self.y1
- self.a2 * self.y2;
self.x2 = self.x1;
self.x1 = x0;
self.y2 = self.y1;
self.y1 = y0;
y0
}
}
/// Compensated sum, for summing many values of different orders of magnitude
/// accurately.
#[derive(Copy, Clone, PartialEq)]
struct Sum {
sum: f32,
residue: f32,
}
impl Sum {
#[inline(always)]
fn zero() -> Sum {
Sum {
sum: 0.0,
residue: 0.0,
}
}
#[inline(always)]
fn add(&mut self, x: f32) {
let sum = self.sum + (self.residue + x);
self.residue = (self.residue + x) - (sum - self.sum);
self.sum = sum;
}
}
/// The mean of the squares of the K-weighted samples in a window of time.
///
/// K-weighted power is equivalent to K-weighted loudness, the only difference
/// is one of scale: power is quadratic in sample amplitudes, whereas loudness
/// units are logarithmic. `loudness_lkfs` and `from_lkfs` convert between power,
/// and K-weighted Loudness Units relative to nominal Full Scale (LKFS).
///
/// The term “LKFS” (Loudness Units, K-Weighted, relative to nominal Full Scale)
/// is used in BS.1770-4 to emphasize K-weighting, but the term is otherwise
/// interchangeable with the more widespread term “LUFS” (Loudness Units,
/// relative to Full Scale). Loudness units are related to decibels in the
/// following sense: boosting a signal that has a loudness of
/// -<var>L<sub>K</sub></var> LUFS by <var>L<sub>K</sub></var> dB (by
/// multiplying the amplitude by 10<sup><var>L<sub>K</sub></var>/20</sup>) will
/// bring the loudness to 0 LUFS.
///
/// K-weighting refers to a high-shelf and high-pass filter that model the
/// effect that humans perceive a certain amount of power in low frequencies to
/// be less loud than the same amount of power in higher frequencies. In this
/// library the `Power` type is used exclusively to refer to power after applying K-weighting.
///
/// The nominal “full scale” is the range [-1.0, 1.0]. Because the power is the
/// mean square of the samples, if no input samples exceeded the full scale, the
/// power will be in the range [0.0, 1.0]. However, the power delivered by
/// multiple channels, which is a weighted sum over individual channel powers,
/// can exceed this range, because the weighted sum is not normalized.
#[derive(Copy, Clone, PartialEq, PartialOrd)]
pub struct Power(pub f32);
impl Power {
/// Convert Loudness Units relative to Full Scale into a squared sample amplitude.
///
/// This is the inverse of `loudness_lkfs`.
pub fn from_lkfs(lkfs: f32) -> Power {
// The inverse of the formula below.
Power(10.0_f32.powf((lkfs + 0.691) * 0.1))
}
/// Return the loudness of this window in Loudness Units, K-weighted, relative to Full Scale.
///
/// This is the inverse of `from_lkfs`.
pub fn loudness_lkfs(&self) -> f32 {
// Equation 2 (p.5) of BS.1770-4.
-0.691 + 10.0 * self.0.log10()
}
}
/// A `T` value for non-overlapping windows of audio, 100ms in length.
///
/// The `ChannelLoudnessMeter` applies K-weighting and then produces the power
/// for non-overlapping windows of 100ms duration.
///
/// These non-overlapping 100ms windows can later be combined into overlapping
/// windows of 400ms, spaced 100ms apart, to compute instantaneous loudness or
/// to perform a gated measurement, or they can be combined into even larger
/// windows for a momentary loudness measurement.
#[derive(Copy, Clone, Debug)]
pub struct Windows100ms<T> {
pub inner: T,
}
impl<T> Windows100ms<T> {
/// Wrap a new empty vector.
pub fn new() -> Windows100ms<Vec<T>> {
Windows100ms { inner: Vec::new() }
}
/// Apply `as_ref` to the inner value.
pub fn as_ref(&self) -> Windows100ms<&[Power]>
where
T: AsRef<[Power]>,
{
Windows100ms {
inner: self.inner.as_ref(),
}
}
/// Apply `as_mut` to the inner value.
pub fn as_mut(&mut self) -> Windows100ms<&mut [Power]>
where
T: AsMut<[Power]>,
{
Windows100ms {
inner: self.inner.as_mut(),
}
}
#[allow(clippy::len_without_is_empty)]
/// Apply `len` to the inner value.
pub fn len(&self) -> usize
where
T: AsRef<[Power]>,
{
self.inner.as_ref().len()
}
}
/// Measures K-weighted power of non-overlapping 100ms windows of a single channel of audio.
///
/// # Output
///
/// The output of the meter is an intermediate result in the form of power for
/// 100ms non-overlapping windows. The windows need to be processed further to
/// get one of the instantaneous, momentary, and integrated loudness
/// measurements defined in BS.1770.
///
/// The windows can also be inspected directly; the data is meaningful
/// on its own (the K-weighted power delivered in that window of time), but it
/// is not something that BS.1770 defines a term for.
///
/// # Multichannel audio
///
/// To perform a loudness measurement of multichannel audio, construct a
/// `ChannelLoudnessMeter` per channel, and later combine the measured power
/// with e.g. `reduce_stereo`.
///
/// # Instantaneous loudness
///
/// The instantaneous loudness is the power over a 400ms window, so you can
/// average four 100ms windows. No special functionality is implemented to help
/// with that at this time. ([Pull requests would be accepted.][contribute])
///
/// # Momentary loudness
///
/// The momentary loudness is the power over a 3-second window, so you can
/// average thirty 100ms windows. No special functionality is implemented to
/// help with that at this time. ([Pull requests would be accepted.][contribute])
///
/// # Integrated loudness
///
/// Use `gated_mean` to perform an integrated loudness measurement:
///
/// ```ignore
/// # use std::iter;
/// # use bs1770::{ChannelLoudnessMeter, gated_mean};
/// # let sample_rate_hz = 44_100;
/// # let samples_per_100ms = sample_rate_hz / 10;
/// # let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
/// # meter.push((0..44_100).map(|i| (i as f32 * 0.01).sin()));
/// let integrated_loudness_lkfs = gated_mean(meter.as_100ms_windows())
/// .unwrap_or(bs1770::Power(0.0))
/// .loudness_lkfs();
/// ```
///
/// [contribute]: https://github.com/ruuda/bs1770/blob/master/CONTRIBUTING.md
#[derive(Clone)]
pub struct ChannelLoudnessMeter {
/// The number of samples that fit in 100ms of audio.
samples_per_100ms: u32,
/// Stage 1 filter (head effects, high shelf).
filter_stage1: Filter,
/// Stage 2 filter (high-pass).
filter_stage2: Filter,
/// Sum of the squares over non-overlapping windows of 100ms.
windows: Windows100ms<Vec<Power>>,
/// The number of samples in the current unfinished window.
count: u32,
/// The sum of the squares of the samples in the current unfinished window.
square_sum: Sum,
}
impl ChannelLoudnessMeter {
/// Construct a new loudness meter for the given sample rate.
pub fn new(sample_rate_hz: u32) -> ChannelLoudnessMeter {
ChannelLoudnessMeter {
samples_per_100ms: sample_rate_hz / 10,
filter_stage1: Filter::high_shelf(sample_rate_hz as f32),
filter_stage2: Filter::high_pass(sample_rate_hz as f32),
windows: Windows100ms::new(),
count: 0,
square_sum: Sum::zero(),
}
}
/// Feed input samples for loudness analysis.
///
/// # Full scale
///
/// Full scale for the input samples is the interval [-1.0, 1.0]. If your
/// input consists of signed integer samples, you can convert as follows:
///
/// ```ignore
/// # let mut meter = bs1770::ChannelLoudnessMeter::new(44_100);
/// # let bits_per_sample = 16_usize;
/// # let samples = &[0_i16];
/// // Note that the maximum amplitude is `1 << (bits_per_sample - 1)`,
/// // one bit is the sign bit.
/// let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
/// meter.push(samples.iter().map(|&s| s as f32 * normalizer));
/// ```
///
/// # Repeated calls
///
/// You can call `push` multiple times to feed multiple batches of samples.
/// This is equivalent to feeding a single chained iterator. The leftover of
/// samples that did not fill a full 100ms window is not discarded:
///
/// ```ignore
/// # use std::iter;
/// # use bs1770::ChannelLoudnessMeter;
/// let sample_rate_hz = 44_100;
/// let samples_per_100ms = sample_rate_hz / 10;
/// let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
///
/// meter.push(iter::repeat(0.0).take(samples_per_100ms as usize - 1));
/// assert_eq!(meter.as_100ms_windows().len(), 0);
///
/// meter.push(iter::once(0.0));
/// assert_eq!(meter.as_100ms_windows().len(), 1);
/// ```
pub fn push<I: Iterator<Item = f32>>(&mut self, samples: I) {
let normalizer = 1.0 / self.samples_per_100ms as f32;
// LLVM, if you could go ahead and inline those apply calls, and then
// unroll and vectorize the loop, that'd be terrific.
for x in samples {
let y = self.filter_stage1.apply(x);
let z = self.filter_stage2.apply(y);
self.square_sum.add(z * z);
self.count += 1;
// TODO: Should this branch be marked cold?
if self.count == self.samples_per_100ms {
let mean_squares = Power(self.square_sum.sum * normalizer);
self.windows.inner.push(mean_squares);
// We intentionally do not reset the residue. That way, leftover
// energy from this window is not lost, so for the file overall,
// the sum remains more accurate.
self.square_sum.sum = 0.0;
self.count = 0;
}
}
}
/// Return a reference to the 100ms windows analyzed so far.
pub fn as_100ms_windows(&self) -> Windows100ms<&[Power]> {
self.windows.as_ref()
}
/// Return all 100ms windows analyzed so far.
pub fn into_100ms_windows(self) -> Windows100ms<Vec<Power>> {
self.windows
}
}
/// Combine power for multiple channels by taking a weighted sum.
///
/// Note that BS.1770-4 defines power for a multi-channel signal as a weighted
/// sum over channels which is not normalized. This means that a stereo signal
/// is inherently louder than a mono signal. For a mono signal played back on
/// stereo speakers, you should therefore still apply `reduce_stereo`, passing
/// in the same signal for both channels.
pub fn reduce_stereo(
left: Windows100ms<&[Power]>,
right: Windows100ms<&[Power]>,
) -> Windows100ms<Vec<Power>> {
assert_eq!(
left.len(),
right.len(),
"Channels must have the same length."
);
let mut result = Vec::with_capacity(left.len());
for (l, r) in left.inner.iter().zip(right.inner) {
result.push(Power(l.0 + r.0));
}
Windows100ms { inner: result }
}
/// In-place version of `reduce_stereo` that stores the result in the former left channel.
pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows100ms<&[Power]>) {
assert_eq!(
left.len(),
right.len(),
"Channels must have the same length."
);
for (l, r) in left.inner.iter_mut().zip(right.inner) {
l.0 += r.0;
}
}
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
///
/// The integrated loudness measurement is not just the average power over the
/// entire signal. BS.1770-4 defines two stages of gating that exclude
/// parts of the signal, to ensure that silent parts do not contribute to the
/// loudness measurement. This function performs that gating, and returns the
/// average power over the windows that were not excluded.
///
/// The result of this function is the integrated loudness measurement.
///
/// When no signal remains after applying the gate, this function returns
/// `None`. In particular, this happens when all of the signal is softer than
/// -70 LKFS, including a signal that consists of pure silence.
pub fn gated_mean(windows_100ms: Windows100ms<&[Power]>) -> Option<Power> {
let mut gating_blocks = Vec::with_capacity(windows_100ms.len());
// Stage 1: an absolute threshold of -70 LKFS. (Equation 6, p.6.)
let absolute_threshold = Power::from_lkfs(-70.0);
// Iterate over all 400ms windows.
for window in windows_100ms.inner.windows(4) {
// Note that the sum over channels has already been performed at this point.
let gating_block_power = Power(0.25 * window.iter().map(|mean| mean.0).sum::<f32>());
if gating_block_power > absolute_threshold {
gating_blocks.push(gating_block_power);
}
}
if gating_blocks.is_empty() {
return None;
}
// Compute the loudness after applying the absolute gate, in order to
// determine the threshold for the relative gate.
let mut sum_power = Sum::zero();
for &gating_block_power in &gating_blocks {
sum_power.add(gating_block_power.0);
}
let absolute_gated_power = Power(sum_power.sum / (gating_blocks.len() as f32));
// Stage 2: Apply the relative gate.
let relative_threshold = Power::from_lkfs(absolute_gated_power.loudness_lkfs() - 10.0);
let mut sum_power = Sum::zero();
let mut n_blocks = 0_usize;
for &gating_block_power in &gating_blocks {
if gating_block_power > relative_threshold {
sum_power.add(gating_block_power.0);
n_blocks += 1;
}
}
if n_blocks == 0 {
return None;
}
let relative_gated_power = Power(sum_power.sum / n_blocks as f32);
Some(relative_gated_power)
}

View File

@@ -0,0 +1,82 @@
pub const NAMES: [&str; 80] = [
"person",
"bicycle",
"car",
"motorbike",
"aeroplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"sofa",
"pottedplant",
"bed",
"diningtable",
"toilet",
"tvmonitor",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
];

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,19 @@
use candle_core::utils::{cuda_is_available, metal_is_available};
use candle_core::{Device, Result, Tensor};
extern crate candle_core;
extern crate candle_transformers;
extern crate tokenizers;
pub fn device(cpu: bool) -> Result<Device> {
pub mod audio;
pub mod bs1770;
pub mod coco_classes;
pub mod imagenet;
pub mod token_output_stream;
pub mod wav;
use candle_core::{
utils::{cuda_is_available, metal_is_available},
Device, Tensor,
};
pub fn device(cpu: bool) -> Result<Device, anyhow::Error> {
if cpu {
Ok(Device::Cpu)
} else if cuda_is_available() {
@@ -26,7 +38,7 @@ pub fn device(cpu: bool) -> Result<Device> {
pub fn load_image<P: AsRef<std::path::Path>>(
p: P,
resize_longest: Option<usize>,
) -> Result<(Tensor, usize, usize)> {
) -> Result<(Tensor, usize, usize), anyhow::Error> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle_core::Error::wrap)?;
@@ -57,7 +69,7 @@ pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
p: P,
width: usize,
height: usize,
) -> Result<Tensor> {
) -> candle_core::Result<Tensor> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle_core::Error::wrap)?
@@ -73,60 +85,36 @@ pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
/// Saves an image to disk using the image crate, this expects an input with shape
/// (c, height, width).
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<(), anyhow::Error> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
candle_core::bail!("save_image expects an input of shape (3, height, width)")
anyhow::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle_core::bail!("error saving image {p:?}"),
None => anyhow::bail!("error saving image {p:?}"),
};
image.save(p).map_err(candle_core::Error::wrap)?;
Ok(())
}
pub fn save_image_resize<P: AsRef<std::path::Path>>(
img: &Tensor,
p: P,
h: usize,
w: usize,
) -> Result<()> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
candle_core::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle_core::bail!("error saving image {p:?}"),
};
let image = image::DynamicImage::from(image);
let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
image.save(p).map_err(candle_core::Error::wrap)?;
Ok(())
}
/// Loads the safetensors files for a model from the hub based on a json index file.
pub fn hub_load_safetensors(
repo: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>> {
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle_core::bail!("no weight map in {json_file:?}"),
None => anyhow::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"),
Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
@@ -136,22 +124,23 @@ pub fn hub_load_safetensors(
}
let safetensors_files = safetensors_files
.iter()
.map(|v| repo.get(v).map_err(candle_core::Error::wrap))
.collect::<Result<Vec<_>>>()?;
.map(|v| repo.get(v).map_err(std::io::Error::other))
.collect::<Result<Vec<_>, std::io::Error>>()?;
Ok(safetensors_files)
}
pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
path: P,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>> {
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
let path = path.as_ref();
let jsfile = std::fs::File::open(path.join(json_file))?;
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let json: serde_json::Value =
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle_core::bail!("no weight map in {json_file:?}"),
None => anyhow::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"),
Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
@@ -164,4 +153,4 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
.map(|v| path.join(v))
.collect();
Ok(safetensors_files)
}
}

View File

@@ -0,0 +1,3 @@
fn main() {
println!("Hello, world!");
}

View File

@@ -1,7 +1,6 @@
use candle_core::Result;
use tokenizers::Tokenizer;
/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
/// streaming way rather than having to wait for the full decoding.
pub struct TokenOutputStream {
tokenizer: tokenizers::Tokenizer,
tokens: Vec<u32>,
@@ -40,8 +39,7 @@ impl TokenOutputStream {
};
self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() {
// Modified to include all tokens, not just alphanumeric ones
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
let text = text.split_at(prev_text.len());
self.prev_index = self.current_index;
self.current_index = self.tokens.len();
@@ -84,4 +82,4 @@ impl TokenOutputStream {
self.prev_index = 0;
self.current_index = 0;
}
}
}

View File

@@ -0,0 +1,56 @@
use std::io::prelude::*;
pub trait Sample {
fn to_i16(&self) -> i16;
}
impl Sample for f32 {
fn to_i16(&self) -> i16 {
(self.clamp(-1.0, 1.0) * 32767.0) as i16
}
}
impl Sample for f64 {
fn to_i16(&self) -> i16 {
(self.clamp(-1.0, 1.0) * 32767.0) as i16
}
}
impl Sample for i16 {
fn to_i16(&self) -> i16 {
*self
}
}
pub fn write_pcm_as_wav<W: Write, S: Sample>(
w: &mut W,
samples: &[S],
sample_rate: u32,
) -> std::io::Result<()> {
let len = 12u32; // header
let len = len + 24u32; // fmt
let len = len + samples.len() as u32 * 2 + 8; // data
let n_channels = 1u16;
let bytes_per_second = sample_rate * 2 * n_channels as u32;
w.write_all(b"RIFF")?;
w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
w.write_all(b"WAVE")?;
// Format block
w.write_all(b"fmt ")?;
w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes
w.write_all(&1u16.to_le_bytes())?; // PCM
w.write_all(&n_channels.to_le_bytes())?; // one channel
w.write_all(&sample_rate.to_le_bytes())?;
w.write_all(&bytes_per_second.to_le_bytes())?;
w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample
w.write_all(&16u16.to_le_bytes())?; // bits per sample
// Data block
w.write_all(b"data")?;
w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
for sample in samples.iter() {
w.write_all(&sample.to_i16().to_le_bytes())?
}
Ok(())
}

View File

@@ -1,8 +1,8 @@
{
"dependencies": {
"openai": "^5.16.0"
},
"name": "predict-otron-9000",
"workspaces": ["integration/cli/package"],
"scripts": {
"cli": "./scripts/cli.ts"
"# WORKSPACE ALIASES": "#",
"cli": "bun --filter integration/cli/package"
}
}

BIN
predict-otron-9000.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 248 KiB

14
scripts/build_ui.sh Executable file
View File

@@ -0,0 +1,14 @@
#!/usr/bin/env sh
# Resolve the project root (script_dir/..)
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
# Move into the chat-ui crate
cd "$PROJECT_ROOT/crates/chat-ui" || exit 1
# Build with cargo leptos
cargo leptos build --release
# Move the wasm file, keeping paths relative to the project root
mv "$PROJECT_ROOT/target/site/pkg/chat-ui.wasm" \
"$PROJECT_ROOT/target/site/pkg/chat-ui_bg.wasm"

View File

@@ -15,7 +15,7 @@ CONNECT_TIMEOUT=${CONNECT_TIMEOUT:-10}
MAX_TIME=${MAX_TIME:-30}
cat <<EOF
[info] POST $SERVER_URL/v1/chat/completions/stream (SSE)
[info] POST $SERVER_URL/v1/chat/completions (SSE)
[info] model=$MODEL_ID, max_tokens=$MAX_TOKENS
[info] prompt=$PROMPT
[info] timeouts: connect=${CONNECT_TIMEOUT}s, max=${MAX_TIME}s
@@ -35,7 +35,7 @@ curl -N -sS -X POST \
--connect-timeout "$CONNECT_TIMEOUT" \
--max-time "$MAX_TIME" \
-H "Content-Type: application/json" \
"$SERVER_URL/v1/chat/completions/stream" \
"$SERVER_URL/v1/chat/completions" \
-d @- <<JSON
{
"model": "${MODEL_ID}",

17
scripts/run.sh Executable file
View File

@@ -0,0 +1,17 @@
#!/bin/bash
set -e
# Resolve the project root (script_dir/..)
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
# todo, conditionally run this only when those files change
"$PROJECT_ROOT/scripts/build_ui.sh"
# build the frontend first
# Start the unified predict-otron-9000 server on port 8080
export SERVER_PORT=${SERVER_PORT:-8080}
export RUST_LOG=${RUST_LOG:-info}
cd "$PROJECT_ROOT" || exit 1
cargo run --bin predict-otron-9000 --release

View File

@@ -1,7 +0,0 @@
#!/bin/bash
# Start the unified predict-otron-9000 server on port 8080
export SERVER_PORT=${SERVER_PORT:-8080}
export RUST_LOG=${RUST_LOG:-info}
cargo run --bin predict-otron-9000 --release