mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
Compare commits
22 Commits
inference-
...
v0.1.5
Author | SHA1 | Date | |
---|---|---|---|
![]() |
400c70f17d | ||
![]() |
bcbc6c4693 | ||
![]() |
21f20470de | ||
![]() |
2deecb5e51 | ||
![]() |
545e0c9831 | ||
![]() |
eca61c51ad | ||
![]() |
d1a7d5b28e | ||
![]() |
8d2b85b0b9 | ||
![]() |
4570780666 | ||
![]() |
44e4f9e5e1 | ||
![]() |
64daa77c6b | ||
![]() |
2b4a8a9df8 | ||
![]() |
38d51722f2 | ||
![]() |
7bc9479a11 | ||
![]() |
0580dc8c5e | ||
![]() |
9e9aa69769 | ||
![]() |
3eb1a5329b | ||
![]() |
eb1591aa5d | ||
![]() |
e6c417bd83 | ||
![]() |
f5d2a85f2e | ||
![]() |
419e1c2ea7 | ||
![]() |
06fdfcf898 |
49
.github/dependabot.yml
vendored
Normal file
49
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
# Monitor Rust dependencies in the main crate
|
||||||
|
- package-ecosystem: "cargo"
|
||||||
|
directory: "/crates/predict-otron-9000"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
day: "monday"
|
||||||
|
time: "09:00"
|
||||||
|
timezone: "UTC"
|
||||||
|
# Focus on security updates with higher priority
|
||||||
|
open-pull-requests-limit: 10
|
||||||
|
reviewers:
|
||||||
|
- "security-team"
|
||||||
|
assignees:
|
||||||
|
- "maintainer"
|
||||||
|
labels:
|
||||||
|
- "dependencies"
|
||||||
|
- "security"
|
||||||
|
# Security updates get higher priority
|
||||||
|
allow:
|
||||||
|
- dependency-type: "all"
|
||||||
|
# Group minor and patch updates to reduce noise
|
||||||
|
# Separate major updates for careful review
|
||||||
|
ignore:
|
||||||
|
- dependency-name: "*"
|
||||||
|
update-types: ["version-update:semver-major"]
|
||||||
|
commit-message:
|
||||||
|
prefix: "deps"
|
||||||
|
include: "scope"
|
||||||
|
|
||||||
|
# Monitor security updates more frequently
|
||||||
|
- package-ecosystem: "cargo"
|
||||||
|
directory: "/crates/predict-otron-9000"
|
||||||
|
schedule:
|
||||||
|
interval: "daily"
|
||||||
|
# Only security updates in daily checks
|
||||||
|
allow:
|
||||||
|
- dependency-type: "direct"
|
||||||
|
update-types: ["security"]
|
||||||
|
- dependency-type: "indirect"
|
||||||
|
update-types: ["security"]
|
||||||
|
open-pull-requests-limit: 5
|
||||||
|
labels:
|
||||||
|
- "security-update"
|
||||||
|
- "high-priority"
|
||||||
|
commit-message:
|
||||||
|
prefix: "security"
|
||||||
|
include: "scope"
|
56
.github/workflows/ci.yml
vendored
Normal file
56
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: build-and-test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo/bin/
|
||||||
|
~/.cargo/registry/index/
|
||||||
|
~/.cargo/registry/cache/
|
||||||
|
~/.cargo/git/db/
|
||||||
|
target/
|
||||||
|
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
|
||||||
|
- name: Setup Rust
|
||||||
|
run: rustup update stable && rustup default stable && rustup target add wasm32-unknown-unknown
|
||||||
|
|
||||||
|
- name: Setup Bun
|
||||||
|
uses: oven-sh/setup-bun@v2
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
run: |
|
||||||
|
cargo install --locked cargo-leptos
|
||||||
|
cd crates/chat-ui && cargo leptos build --release
|
||||||
|
cargo build --release -p predict-otron-9000 -p cli
|
||||||
|
|
||||||
|
- name: Install clippy and rustfmt
|
||||||
|
run: rustup component add clippy rustfmt
|
||||||
|
|
||||||
|
- name: Cargo fmt (check)
|
||||||
|
run: cargo fmt --all -- --check
|
||||||
|
|
||||||
|
- name: Clippy
|
||||||
|
shell: bash
|
||||||
|
run: cargo clippy --all-targets
|
||||||
|
|
||||||
|
- name: Tests
|
||||||
|
shell: bash
|
||||||
|
run: cargo test --all
|
||||||
|
|
||||||
|
- name: Build Docs
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cargo doc -p predict-otron-9000 --no-deps
|
240
.github/workflows/release.yml
vendored
Normal file
240
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
name: Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
|
||||||
|
env:
|
||||||
|
CARGO_TERM_COLOR: always
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
name: Test before release
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
working-directory: crates/predict-otron-9000
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo/bin/
|
||||||
|
~/.cargo/registry/index/
|
||||||
|
~/.cargo/registry/cache/
|
||||||
|
~/.cargo/git/db/
|
||||||
|
target/
|
||||||
|
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
|
||||||
|
- name: Setup Rust
|
||||||
|
run: rustup update stable && rustup default stable && rustup target add wasm32-unknown-unknown
|
||||||
|
|
||||||
|
- name: Setup Bun
|
||||||
|
uses: oven-sh/setup-bun@v2
|
||||||
|
|
||||||
|
- name: Install clippy and rustfmt
|
||||||
|
run: rustup component add clippy rustfmt
|
||||||
|
|
||||||
|
- name: Cargo fmt (check)
|
||||||
|
run: cargo fmt --all -- --check
|
||||||
|
|
||||||
|
- name: Clippy
|
||||||
|
shell: bash
|
||||||
|
run: cargo clippy --all-targets
|
||||||
|
|
||||||
|
- name: Tests
|
||||||
|
shell: bash
|
||||||
|
run: cargo test --all
|
||||||
|
|
||||||
|
# publish:
|
||||||
|
# name: Publish to crates.io
|
||||||
|
# runs-on: ubuntu-latest
|
||||||
|
# permissions:
|
||||||
|
# id-token: write # Required for OIDC token exchange https://crates.io/docs/trusted-publishing
|
||||||
|
# needs: test
|
||||||
|
# defaults:
|
||||||
|
# run:
|
||||||
|
# working-directory: crates/predict-otron-9000
|
||||||
|
# steps:
|
||||||
|
# - name: Checkout
|
||||||
|
# uses: actions/checkout@v4
|
||||||
|
#
|
||||||
|
# - uses: actions/cache@v4
|
||||||
|
# with:
|
||||||
|
# path: |
|
||||||
|
# ~/.cargo/bin/
|
||||||
|
# ~/.cargo/registry/index/
|
||||||
|
# ~/.cargo/registry/cache/
|
||||||
|
# ~/.cargo/git/db/
|
||||||
|
# target/
|
||||||
|
# key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
#
|
||||||
|
# - name: Setup Rust
|
||||||
|
# run: rustup update stable && rustup default stable
|
||||||
|
#
|
||||||
|
# - name: Verify tag matches version
|
||||||
|
# run: |
|
||||||
|
# TAG_VERSION=${GITHUB_REF#refs/tags/v}
|
||||||
|
# CARGO_VERSION=$(cargo metadata --no-deps --format-version 1 | jq -r '.packages[0].version')
|
||||||
|
# if [ "$TAG_VERSION" != "$CARGO_VERSION" ]; then
|
||||||
|
# echo "Tag version ($TAG_VERSION) does not match Cargo.toml version ($CARGO_VERSION)"
|
||||||
|
# exit 1
|
||||||
|
# fi
|
||||||
|
#
|
||||||
|
# # See Trusted publishing: https://crates.io/docs/trusted-publishing
|
||||||
|
# - uses: rust-lang/crates-io-auth-action@v1
|
||||||
|
# id: auth
|
||||||
|
#
|
||||||
|
# - run: cargo publish
|
||||||
|
# env:
|
||||||
|
# CARGO_REGISTRY_TOKEN: ${{ steps.auth.outputs.token }}
|
||||||
|
|
||||||
|
build-binaries:
|
||||||
|
name: Build binaries
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
needs: test
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- target: x86_64-unknown-linux-gnu
|
||||||
|
os: ubuntu-latest
|
||||||
|
name: predict-otron-9000-x86_64-unknown-linux-gnu
|
||||||
|
- target: x86_64-apple-darwin
|
||||||
|
os: macos-latest
|
||||||
|
name: predict-otron-9000-x86_64-apple-darwin
|
||||||
|
- target: aarch64-apple-darwin
|
||||||
|
os: macos-latest
|
||||||
|
name: predict-otron-9000-aarch64-apple-darwin
|
||||||
|
- target: x86_64-pc-windows-msvc
|
||||||
|
os: windows-latest
|
||||||
|
name: predict-otron-9000-x86_64-pc-windows-msvc.exe
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo/bin/
|
||||||
|
~/.cargo/registry/index/
|
||||||
|
~/.cargo/registry/cache/
|
||||||
|
~/.cargo/git/db/
|
||||||
|
target/
|
||||||
|
key: ${{ runner.os }}-${{ matrix.target }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
|
||||||
|
- name: Setup Rust
|
||||||
|
run: rustup update stable && rustup default stable && rustup target add wasm32-unknown-unknown
|
||||||
|
|
||||||
|
- name: Add target
|
||||||
|
run: rustup target add ${{ matrix.target }}
|
||||||
|
|
||||||
|
- name: Build UI
|
||||||
|
run: cargo install --locked cargo-leptos && cd crates/chat-ui && cargo leptos build --release
|
||||||
|
env:
|
||||||
|
CARGO_TERM_COLOR: always
|
||||||
|
|
||||||
|
- name: Build Binary
|
||||||
|
run: cargo build --release --target ${{ matrix.target }} -p predict-otron-9000 -p cli
|
||||||
|
env:
|
||||||
|
CARGO_TERM_COLOR: always
|
||||||
|
|
||||||
|
- name: Package binary (Unix)
|
||||||
|
if: matrix.os != 'windows-latest'
|
||||||
|
run: |
|
||||||
|
cd target/${{ matrix.target }}/release
|
||||||
|
tar czf ../../../${{ matrix.name }}.tar.gz predict-otron-9000 cli
|
||||||
|
cd ../../../
|
||||||
|
|
||||||
|
- name: Package binary (Windows)
|
||||||
|
if: matrix.os == 'windows-latest'
|
||||||
|
run: |
|
||||||
|
cd target/${{ matrix.target }}/release
|
||||||
|
7z a ../../../${{ matrix.name }}.zip predict-otron-9000.exe cli.exe
|
||||||
|
cd ../../../
|
||||||
|
|
||||||
|
- name: Upload binary artifacts (Unix)
|
||||||
|
if: matrix.os != 'windows-latest'
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: ${{ matrix.name }}
|
||||||
|
path: ${{ matrix.name }}.tar.gz
|
||||||
|
|
||||||
|
- name: Upload binary artifacts (Windows)
|
||||||
|
if: matrix.os == 'windows-latest'
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: ${{ matrix.name }}
|
||||||
|
path: ${{ matrix.name }}.zip
|
||||||
|
|
||||||
|
release:
|
||||||
|
name: Create GitHub Release
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [test, build-binaries]
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Extract tag name
|
||||||
|
id: tag
|
||||||
|
run: echo "tag=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Generate changelog
|
||||||
|
id: changelog
|
||||||
|
run: |
|
||||||
|
# Get the previous tag
|
||||||
|
PREV_TAG=$(git describe --tags --abbrev=0 HEAD^ 2>/dev/null || echo "")
|
||||||
|
|
||||||
|
# Generate changelog
|
||||||
|
if [ -n "$PREV_TAG" ]; then
|
||||||
|
echo "## What's Changed" > changelog.md
|
||||||
|
echo "" >> changelog.md
|
||||||
|
git log --pretty=format:"* %s (%h)" ${PREV_TAG}..HEAD >> changelog.md
|
||||||
|
echo "" >> changelog.md
|
||||||
|
echo "" >> changelog.md
|
||||||
|
echo "**Full Changelog**: https://github.com/${{ github.repository }}/compare/${PREV_TAG}...${{ steps.tag.outputs.tag }}" >> changelog.md
|
||||||
|
else
|
||||||
|
echo "## What's Changed" > changelog.md
|
||||||
|
echo "" >> changelog.md
|
||||||
|
echo "Initial release of predict-otron-9000" >> changelog.md
|
||||||
|
echo "" >> changelog.md
|
||||||
|
echo "OpenAI Compatible Inference Server" >> changelog.md
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set the changelog as output (handle multiline)
|
||||||
|
echo "changelog<<EOF" >> $GITHUB_OUTPUT
|
||||||
|
cat changelog.md >> $GITHUB_OUTPUT
|
||||||
|
echo "EOF" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Download all artifacts
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
path: artifacts
|
||||||
|
|
||||||
|
- name: Create Release
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
run: |
|
||||||
|
if [[ "${{ steps.tag.outputs.tag }}" == *"-"* ]]; then
|
||||||
|
PRERELEASE_FLAG="--prerelease"
|
||||||
|
else
|
||||||
|
PRERELEASE_FLAG=""
|
||||||
|
fi
|
||||||
|
|
||||||
|
gh release create "${{ steps.tag.outputs.tag }}" \
|
||||||
|
--title "Release ${{ steps.tag.outputs.tag }}" \
|
||||||
|
--notes-file changelog.md \
|
||||||
|
$PRERELEASE_FLAG \
|
||||||
|
artifacts/predict-otron-9000-x86_64-unknown-linux-gnu/predict-otron-9000-x86_64-unknown-linux-gnu.tar.gz \
|
||||||
|
artifacts/predict-otron-9000-x86_64-apple-darwin/predict-otron-9000-x86_64-apple-darwin.tar.gz \
|
||||||
|
artifacts/predict-otron-9000-aarch64-apple-darwin/predict-otron-9000-aarch64-apple-darwin.tar.gz \
|
||||||
|
artifacts/predict-otron-9000-x86_64-pc-windows-msvc.exe/predict-otron-9000-x86_64-pc-windows-msvc.exe.zip
|
5
.gitignore
vendored
5
.gitignore
vendored
@@ -23,7 +23,6 @@ package-lock.json
|
|||||||
|
|
||||||
# Web frontend build outputs
|
# Web frontend build outputs
|
||||||
dist/
|
dist/
|
||||||
.trunk/
|
|
||||||
|
|
||||||
# ML model and embedding caches
|
# ML model and embedding caches
|
||||||
.fastembed_cache/
|
.fastembed_cache/
|
||||||
@@ -75,4 +74,6 @@ venv/
|
|||||||
# Backup files
|
# Backup files
|
||||||
*.bak
|
*.bak
|
||||||
*.backup
|
*.backup
|
||||||
*~
|
!/scripts/cli.ts
|
||||||
|
/**/.*.bun-build
|
||||||
|
/AGENTS.md
|
||||||
|
857
Cargo.lock
generated
857
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
41
Cargo.toml
41
Cargo.toml
@@ -3,15 +3,42 @@ members = [
|
|||||||
"crates/predict-otron-9000",
|
"crates/predict-otron-9000",
|
||||||
"crates/inference-engine",
|
"crates/inference-engine",
|
||||||
"crates/embeddings-engine",
|
"crates/embeddings-engine",
|
||||||
"crates/leptos-app",
|
|
||||||
"crates/helm-chart-tool",
|
"crates/helm-chart-tool",
|
||||||
"crates/llama-runner",
|
"crates/llama-runner",
|
||||||
"crates/gemma-runner"
|
"crates/gemma-runner",
|
||||||
]
|
"crates/cli",
|
||||||
|
"crates/chat-ui"
|
||||||
|
, "crates/utils"]
|
||||||
default-members = ["crates/predict-otron-9000"]
|
default-members = ["crates/predict-otron-9000"]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[[workspace.metadata.leptos]]
|
[workspace.package]
|
||||||
# project name
|
version = "0.1.4"
|
||||||
bin-package = "leptos-app"
|
|
||||||
lib-package = "leptos-app"
|
# Compiler optimization profiles for the workspace
|
||||||
|
[profile.release]
|
||||||
|
opt-level = 3
|
||||||
|
debug = false
|
||||||
|
strip = true
|
||||||
|
lto = "thin"
|
||||||
|
codegen-units = 1
|
||||||
|
panic = "abort"
|
||||||
|
|
||||||
|
[profile.dev]
|
||||||
|
opt-level = 0
|
||||||
|
debug = true
|
||||||
|
strip = false
|
||||||
|
overflow-checks = true
|
||||||
|
|
||||||
|
# Profile for fast development builds with some optimization
|
||||||
|
[profile.dev-opt]
|
||||||
|
inherits = "dev"
|
||||||
|
opt-level = 1
|
||||||
|
debug = true
|
||||||
|
overflow-checks = true
|
||||||
|
|
||||||
|
# Profile for benchmarking and profiling
|
||||||
|
[profile.bench]
|
||||||
|
opt-level = 3
|
||||||
|
debug = true
|
||||||
|
lto = "thin"
|
||||||
|
75
README.md
75
README.md
@@ -1,10 +1,26 @@
|
|||||||
# predict-otron-9000
|
<h1 align="center">
|
||||||
|
predict-otron-9000
|
||||||
|
</h1>
|
||||||
|
<p align="center">
|
||||||
|
AI inference Server with OpenAI-compatible API (Limited Features)
|
||||||
|
</p>
|
||||||
|
<p align="center">
|
||||||
|
<img src="https://github.com/geoffsee/predict-otron-9001/blob/master/predict-otron-9000.png?raw=true" width="90%" />
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<br/>
|
||||||
|
> This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks.
|
||||||
|
|
||||||
|
> By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems.
|
||||||
|
Stability is currently best effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
|
||||||
|
|
||||||
A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces.
|
A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces.
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
Powerful local AI inference with OpenAI-compatible APIs
|
~~~shell
|
||||||
</p>
|
./scripts/run.sh
|
||||||
|
~~~
|
||||||
|
|
||||||
|
|
||||||
## Project Overview
|
## Project Overview
|
||||||
|
|
||||||
@@ -24,14 +40,14 @@ The system supports both CPU and GPU acceleration (CUDA/Metal), with intelligent
|
|||||||
- **Text Embeddings**: Generate high-quality text embeddings using FastEmbed
|
- **Text Embeddings**: Generate high-quality text embeddings using FastEmbed
|
||||||
- **Text Generation**: Chat completions with OpenAI-compatible API using Gemma and Llama models (various sizes including instruction-tuned variants)
|
- **Text Generation**: Chat completions with OpenAI-compatible API using Gemma and Llama models (various sizes including instruction-tuned variants)
|
||||||
- **Performance Optimized**: Efficient caching and platform-specific optimizations for improved throughput
|
- **Performance Optimized**: Efficient caching and platform-specific optimizations for improved throughput
|
||||||
- **Web Chat Interface**: Leptos-based WebAssembly (WASM) chat interface for browser-based interaction
|
- **Web Chat Interface**: Leptos chat interface
|
||||||
- **Flexible Deployment**: Run as monolithic service or microservices architecture
|
- **Flexible Deployment**: Run as monolithic service or microservices architecture
|
||||||
|
|
||||||
## Architecture Overview
|
## Architecture Overview
|
||||||
|
|
||||||
### Workspace Structure
|
### Workspace Structure
|
||||||
|
|
||||||
The project uses a 7-crate Rust workspace plus TypeScript components:
|
The project uses a 9-crate Rust workspace plus TypeScript components:
|
||||||
|
|
||||||
```
|
```
|
||||||
crates/
|
crates/
|
||||||
@@ -40,17 +56,18 @@ crates/
|
|||||||
├── gemma-runner/ # Gemma model inference via Candle (Rust 2021)
|
├── gemma-runner/ # Gemma model inference via Candle (Rust 2021)
|
||||||
├── llama-runner/ # Llama model inference via Candle (Rust 2021)
|
├── llama-runner/ # Llama model inference via Candle (Rust 2021)
|
||||||
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
|
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
|
||||||
├── leptos-app/ # WASM web frontend (Rust 2021)
|
├── chat-ui/ # WASM web frontend (Rust 2021)
|
||||||
├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024)
|
├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024)
|
||||||
└── scripts/
|
└── cli/ # CLI client crate (Rust 2024)
|
||||||
└── cli.ts # TypeScript/Bun CLI client
|
└── package/
|
||||||
|
└── cli.ts # TypeScript/Bun CLI client
|
||||||
```
|
```
|
||||||
|
|
||||||
### Service Architecture
|
### Service Architecture
|
||||||
|
|
||||||
- **Main Server** (port 8080): Orchestrates inference and embeddings services
|
- **Main Server** (port 8080): Orchestrates inference and embeddings services
|
||||||
- **Embeddings Service** (port 8080): Standalone FastEmbed service with OpenAI API compatibility
|
- **Embeddings Service** (port 8080): Standalone FastEmbed service with OpenAI API compatibility
|
||||||
- **Web Frontend** (port 8788): Leptos WASM chat interface served by Trunk
|
- **Web Frontend** (port 8788): chat-ui WASM app
|
||||||
- **CLI Client**: TypeScript/Bun client for testing and automation
|
- **CLI Client**: TypeScript/Bun client for testing and automation
|
||||||
|
|
||||||
### Deployment Modes
|
### Deployment Modes
|
||||||
@@ -76,11 +93,6 @@ The architecture supports multiple deployment patterns:
|
|||||||
- **Bun**: Required for TypeScript CLI client: `curl -fsSL https://bun.sh/install | bash`
|
- **Bun**: Required for TypeScript CLI client: `curl -fsSL https://bun.sh/install | bash`
|
||||||
- **Node.js**: Alternative to Bun, supports OpenAI SDK v5.16.0+
|
- **Node.js**: Alternative to Bun, supports OpenAI SDK v5.16.0+
|
||||||
|
|
||||||
#### WASM Frontend Toolchain
|
|
||||||
- **Trunk**: Required for Leptos frontend builds: `cargo install trunk`
|
|
||||||
- **wasm-pack**: `cargo install wasm-pack`
|
|
||||||
- **WASM target**: `rustup target add wasm32-unknown-unknown`
|
|
||||||
|
|
||||||
#### ML Framework Dependencies
|
#### ML Framework Dependencies
|
||||||
- **Candle**: Version 0.9.1 with conditional compilation:
|
- **Candle**: Version 0.9.1 with conditional compilation:
|
||||||
- macOS: Metal support with CPU fallback for stability
|
- macOS: Metal support with CPU fallback for stability
|
||||||
@@ -125,11 +137,6 @@ cargo build --bin cli --package inference-engine --release
|
|||||||
cargo build --bin embeddings-engine --release
|
cargo build --bin embeddings-engine --release
|
||||||
```
|
```
|
||||||
|
|
||||||
**Web Frontend:**
|
|
||||||
```bash
|
|
||||||
cd crates/leptos-app
|
|
||||||
trunk build --release
|
|
||||||
```
|
|
||||||
|
|
||||||
### Running Services
|
### Running Services
|
||||||
|
|
||||||
@@ -143,26 +150,26 @@ trunk build --release
|
|||||||
|
|
||||||
#### Web Frontend (Port 8788)
|
#### Web Frontend (Port 8788)
|
||||||
```bash
|
```bash
|
||||||
cd crates/leptos-app
|
cd crates/chat-ui
|
||||||
./run.sh
|
./run.sh
|
||||||
```
|
```
|
||||||
- Serves Leptos WASM frontend on port 8788
|
- Serves chat-ui WASM frontend on port 8788
|
||||||
- Sets required RUSTFLAGS for WebAssembly getrandom support
|
- Sets required RUSTFLAGS for WebAssembly getrandom support
|
||||||
- Auto-reloads during development
|
- Auto-reloads during development
|
||||||
|
|
||||||
#### TypeScript CLI Client
|
#### TypeScript CLI Client
|
||||||
```bash
|
```bash
|
||||||
# List available models
|
# List available models
|
||||||
bun run scripts/cli.ts --list-models
|
cd crates/cli/package && bun run cli.ts --list-models
|
||||||
|
|
||||||
# Chat completion
|
# Chat completion
|
||||||
bun run scripts/cli.ts "What is the capital of France?"
|
cd crates/cli/package && bun run cli.ts "What is the capital of France?"
|
||||||
|
|
||||||
# With specific model
|
# With specific model
|
||||||
bun run scripts/cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
cd crates/cli/package && bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||||
|
|
||||||
# Show help
|
# Show help
|
||||||
bun run scripts/cli.ts --help
|
cd crates/cli/package && bun run cli.ts --help
|
||||||
```
|
```
|
||||||
|
|
||||||
## API Usage
|
## API Usage
|
||||||
@@ -278,7 +285,7 @@ cargo test --workspace
|
|||||||
|
|
||||||
**End-to-end test script:**
|
**End-to-end test script:**
|
||||||
```bash
|
```bash
|
||||||
./test.sh
|
./scripts/smoke_test.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
This script:
|
This script:
|
||||||
@@ -367,7 +374,7 @@ All services include Docker metadata in `Cargo.toml`:
|
|||||||
- Port: 8080
|
- Port: 8080
|
||||||
|
|
||||||
**Web Frontend:**
|
**Web Frontend:**
|
||||||
- Image: `ghcr.io/geoffsee/leptos-app:latest`
|
- Image: `ghcr.io/geoffsee/chat-ui:latest`
|
||||||
- Port: 8788
|
- Port: 8788
|
||||||
|
|
||||||
**Docker Compose:**
|
**Docker Compose:**
|
||||||
@@ -426,8 +433,7 @@ For Kubernetes deployment details, see the [ARCHITECTURE.md](docs/ARCHITECTURE.m
|
|||||||
**Symptom:** WASM compilation failures
|
**Symptom:** WASM compilation failures
|
||||||
**Solution:**
|
**Solution:**
|
||||||
1. Install required targets: `rustup target add wasm32-unknown-unknown`
|
1. Install required targets: `rustup target add wasm32-unknown-unknown`
|
||||||
2. Install trunk: `cargo install trunk`
|
2. Check RUSTFLAGS in chat-ui/run.sh
|
||||||
3. Check RUSTFLAGS in leptos-app/run.sh
|
|
||||||
|
|
||||||
### Network/Timeout Issues
|
### Network/Timeout Issues
|
||||||
**Symptom:** First-time model downloads timing out
|
**Symptom:** First-time model downloads timing out
|
||||||
@@ -458,24 +464,23 @@ curl -s http://localhost:8080/v1/models | jq
|
|||||||
|
|
||||||
**CLI client test:**
|
**CLI client test:**
|
||||||
```bash
|
```bash
|
||||||
bun run scripts/cli.ts "What is 2+2?"
|
cd crates/cli/package && bun run cli.ts "What is 2+2?"
|
||||||
```
|
```
|
||||||
|
|
||||||
**Web frontend:**
|
**Web frontend:**
|
||||||
```bash
|
```bash
|
||||||
cd crates/leptos-app && ./run.sh &
|
cd crates/chat-ui && ./run.sh &
|
||||||
# Navigate to http://localhost:8788
|
# Navigate to http://localhost:8788
|
||||||
```
|
```
|
||||||
|
|
||||||
**Integration test:**
|
**Integration test:**
|
||||||
```bash
|
```bash
|
||||||
./test.sh
|
./scripts/smoke_test.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
**Cleanup:**
|
**Cleanup:**
|
||||||
```bash
|
```bash
|
||||||
pkill -f "predict-otron-9000"
|
pkill -f "predict-otron-9000"
|
||||||
pkill -f "trunk"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
For networked tests and full functionality, ensure Hugging Face authentication is configured as described above.
|
For networked tests and full functionality, ensure Hugging Face authentication is configured as described above.
|
||||||
@@ -497,4 +502,4 @@ For networked tests and full functionality, ensure Hugging Face authentication i
|
|||||||
4. Ensure all tests pass: `cargo test`
|
4. Ensure all tests pass: `cargo test`
|
||||||
5. Submit a pull request
|
5. Submit a pull request
|
||||||
|
|
||||||
_Warning: Do NOT use this in production unless you are cool like that._
|
_Warning: Do NOT use this in production unless you are cool like that._
|
||||||
|
22
bun.lock
Normal file
22
bun.lock
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"lockfileVersion": 1,
|
||||||
|
"workspaces": {
|
||||||
|
"": {
|
||||||
|
"name": "predict-otron-9000",
|
||||||
|
},
|
||||||
|
"crates/cli/package": {
|
||||||
|
"name": "cli",
|
||||||
|
"dependencies": {
|
||||||
|
"install": "^0.13.0",
|
||||||
|
"openai": "^5.16.0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"packages": {
|
||||||
|
"cli": ["cli@workspace:crates/cli/package"],
|
||||||
|
|
||||||
|
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
|
||||||
|
|
||||||
|
"openai": ["openai@5.16.0", "", { "peerDependencies": { "ws": "^8.18.0", "zod": "^3.23.8" }, "optionalPeers": ["ws", "zod"], "bin": { "openai": "bin/cli" } }, "sha512-hoEH8ZNvg1HXjU9mp88L/ZH8O082Z8r6FHCXGiWAzVRrEv443aI57qhch4snu07yQydj+AUAWLenAiBXhu89Tw=="],
|
||||||
|
}
|
||||||
|
}
|
@@ -1,5 +1,5 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "leptos-app"
|
name = "chat-ui"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
@@ -15,45 +15,33 @@ leptos_axum = { version = "0.8.0", optional = true }
|
|||||||
leptos_meta = { version = "0.8.0" }
|
leptos_meta = { version = "0.8.0" }
|
||||||
tokio = { version = "1", features = ["rt-multi-thread"], optional = true }
|
tokio = { version = "1", features = ["rt-multi-thread"], optional = true }
|
||||||
wasm-bindgen = { version = "=0.2.100", optional = true }
|
wasm-bindgen = { version = "=0.2.100", optional = true }
|
||||||
|
wasm-bindgen-futures = "0.4"
|
||||||
# Chat interface dependencies
|
js-sys = "0.3"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
async-openai-wasm = { version = "0.29", default-features = false }
|
reqwest = { version = "0.12", features = ["json"] }
|
||||||
futures-util = "0.3"
|
web-sys = { version = "0.3", features = [
|
||||||
js-sys = { version = "0.3", optional = true }
|
"console",
|
||||||
either = { version = "1.9", features = ["serde"] }
|
"EventSource",
|
||||||
|
"MessageEvent",
|
||||||
web-sys = { version = "0.3", optional = true, features = [
|
"Window",
|
||||||
"console",
|
"Request",
|
||||||
"Window",
|
"RequestInit",
|
||||||
"Document",
|
"Response",
|
||||||
"Element",
|
"Headers",
|
||||||
"HtmlElement",
|
"ReadableStream",
|
||||||
"HtmlInputElement",
|
"ReadableStreamDefaultReader",
|
||||||
"HtmlSelectElement",
|
"TextDecoder",
|
||||||
"HtmlTextAreaElement",
|
"TextDecoderOptions",
|
||||||
"Event",
|
"HtmlInputElement"
|
||||||
"EventTarget",
|
|
||||||
"KeyboardEvent",
|
|
||||||
] }
|
] }
|
||||||
|
gloo-net = { version = "0.6", features = ["http"] }
|
||||||
[dependencies.uuid]
|
|
||||||
version = "1.0"
|
|
||||||
features = [
|
|
||||||
"v4",
|
|
||||||
"fast-rng",
|
|
||||||
"macro-diagnostics",
|
|
||||||
"js",
|
|
||||||
]
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
hydrate = [
|
hydrate = [
|
||||||
"leptos/hydrate",
|
"leptos/hydrate",
|
||||||
"dep:console_error_panic_hook",
|
"dep:console_error_panic_hook",
|
||||||
"dep:wasm-bindgen",
|
"dep:wasm-bindgen",
|
||||||
"dep:js-sys",
|
|
||||||
"dep:web-sys",
|
|
||||||
]
|
]
|
||||||
ssr = [
|
ssr = [
|
||||||
"dep:axum",
|
"dep:axum",
|
||||||
@@ -73,8 +61,9 @@ codegen-units = 1
|
|||||||
panic = "abort"
|
panic = "abort"
|
||||||
|
|
||||||
[package.metadata.leptos]
|
[package.metadata.leptos]
|
||||||
|
name = "chat-ui"
|
||||||
# The name used by wasm-bindgen/cargo-leptos for the JS/WASM bundle. Defaults to the crate name
|
# The name used by wasm-bindgen/cargo-leptos for the JS/WASM bundle. Defaults to the crate name
|
||||||
output-name = "leptos-app"
|
output-name = "chat-ui"
|
||||||
|
|
||||||
# The site root folder is where cargo-leptos generate all output. WARNING: all content of this folder will be erased on a rebuild. Use it in your server setup.
|
# The site root folder is where cargo-leptos generate all output. WARNING: all content of this folder will be erased on a rebuild. Use it in your server setup.
|
||||||
site-root = "target/site"
|
site-root = "target/site"
|
||||||
@@ -84,7 +73,7 @@ site-root = "target/site"
|
|||||||
site-pkg-dir = "pkg"
|
site-pkg-dir = "pkg"
|
||||||
|
|
||||||
# [Optional] The source CSS file. If it ends with .sass or .scss then it will be compiled by dart-sass into CSS. The CSS is optimized by Lightning CSS before being written to <site-root>/<site-pkg>/app.css
|
# [Optional] The source CSS file. If it ends with .sass or .scss then it will be compiled by dart-sass into CSS. The CSS is optimized by Lightning CSS before being written to <site-root>/<site-pkg>/app.css
|
||||||
style-file = "style/main.scss"
|
style-file = "./style/main.scss"
|
||||||
# Assets source dir. All files found here will be copied and synchronized to site-root.
|
# Assets source dir. All files found here will be copied and synchronized to site-root.
|
||||||
# The assets-dir cannot have a sub directory with the same name/path as site-pkg-dir.
|
# The assets-dir cannot have a sub directory with the same name/path as site-pkg-dir.
|
||||||
#
|
#
|
||||||
@@ -132,4 +121,4 @@ lib-default-features = false
|
|||||||
# The profile to use for the lib target when compiling for release
|
# The profile to use for the lib target when compiling for release
|
||||||
#
|
#
|
||||||
# Optional. Defaults to "release".
|
# Optional. Defaults to "release".
|
||||||
lib-profile-release = "wasm-release"
|
lib-profile-release = "release"
|
41
crates/chat-ui/README.md
Normal file
41
crates/chat-ui/README.md
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# chat-ui
|
||||||
|
|
||||||
|
A WASM-based web chat interface for the predict-otron-9000 AI platform.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The chat-ui provides a real-time web interface for interacting with language models through the predict-otron-9000 server. Built with Leptos and compiled to WebAssembly, it offers a modern chat experience with streaming response support.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Real-time chat interface with the inference server
|
||||||
|
- Streaming response support
|
||||||
|
- Conversation history
|
||||||
|
- Responsive web design
|
||||||
|
- WebAssembly-powered for optimal performance
|
||||||
|
|
||||||
|
## Building and Running
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
- Rust toolchain with WASM target: `rustup target add wasm32-unknown-unknown`
|
||||||
|
- The predict-otron-9000 server must be running on port 8080
|
||||||
|
|
||||||
|
### Development Server
|
||||||
|
```bash
|
||||||
|
cd crates/chat-ui
|
||||||
|
./run.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
This starts the development server on port 8788 with auto-reload capabilities.
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
1. Start the predict-otron-9000 server: `./scripts/run.sh`
|
||||||
|
2. Start the chat-ui: `cd crates/chat-ui && ./run.sh`
|
||||||
|
3. Navigate to `http://localhost:8788`
|
||||||
|
4. Start chatting with your AI models!
|
||||||
|
|
||||||
|
## Technical Details
|
||||||
|
- Built with Leptos framework
|
||||||
|
- Compiled to WebAssembly for browser execution
|
||||||
|
- Communicates with predict-otron-9000 API via HTTP
|
||||||
|
- Sets required RUSTFLAGS for WebAssembly getrandom support
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
613
crates/chat-ui/src/app.rs
Normal file
613
crates/chat-ui/src/app.rs
Normal file
@@ -0,0 +1,613 @@
|
|||||||
|
#[cfg(feature = "ssr")]
|
||||||
|
use axum::Router;
|
||||||
|
#[cfg(feature = "ssr")]
|
||||||
|
use leptos::prelude::LeptosOptions;
|
||||||
|
#[cfg(feature = "ssr")]
|
||||||
|
use leptos_axum::{generate_route_list, LeptosRoutes};
|
||||||
|
|
||||||
|
pub struct AppConfig {
|
||||||
|
pub config: ConfFile,
|
||||||
|
pub address: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AppConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
let conf = get_configuration(Some(concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml")))
|
||||||
|
.expect("failed to read config");
|
||||||
|
|
||||||
|
let addr = conf.leptos_options.site_addr;
|
||||||
|
|
||||||
|
AppConfig {
|
||||||
|
config: conf, // or whichever field/string representation you need
|
||||||
|
address: addr.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build the Axum router for this app, including routes, fallback, and state.
|
||||||
|
/// Call this from another crate (or your bin) when running the server.
|
||||||
|
#[cfg(feature = "ssr")]
|
||||||
|
pub fn create_router(leptos_options: LeptosOptions) -> Router {
|
||||||
|
// Generate the list of routes in your Leptos App
|
||||||
|
let routes = generate_route_list(App);
|
||||||
|
|
||||||
|
Router::new()
|
||||||
|
.leptos_routes(&leptos_options, routes, {
|
||||||
|
let leptos_options = leptos_options.clone();
|
||||||
|
move || shell(leptos_options.clone())
|
||||||
|
})
|
||||||
|
.fallback(leptos_axum::file_and_error_handler(shell))
|
||||||
|
.with_state(leptos_options)
|
||||||
|
}
|
||||||
|
|
||||||
|
use gloo_net::http::Request;
|
||||||
|
use leptos::prelude::*;
|
||||||
|
use leptos_meta::{provide_meta_context, MetaTags, Stylesheet, Title};
|
||||||
|
use leptos_router::{
|
||||||
|
components::{Route, Router, Routes},
|
||||||
|
StaticSegment,
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use web_sys::console;
|
||||||
|
|
||||||
|
// Data structures for OpenAI-compatible API
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatMessage {
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct ChatRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub messages: Vec<ChatMessage>,
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
pub stream: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ChatChoice {
|
||||||
|
pub message: ChatMessage,
|
||||||
|
pub index: u32,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Streaming response structures
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct StreamDelta {
|
||||||
|
pub role: Option<String>,
|
||||||
|
pub content: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct StreamChoice {
|
||||||
|
pub index: u32,
|
||||||
|
pub delta: StreamDelta,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct StreamChatResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<StreamChoice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ChatResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<ChatChoice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Data structures for models API
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ModelInfo {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub owned_by: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ModelsResponse {
|
||||||
|
pub object: String,
|
||||||
|
pub data: Vec<ModelInfo>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// API client function to fetch available models
|
||||||
|
pub async fn fetch_models() -> Result<Vec<ModelInfo>, String> {
|
||||||
|
let response = Request::get("/v1/models")
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to fetch models: {:?}", e))?;
|
||||||
|
|
||||||
|
if response.ok() {
|
||||||
|
let models_response: ModelsResponse = response
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to parse models response: {:?}", e))?;
|
||||||
|
Ok(models_response.data)
|
||||||
|
} else {
|
||||||
|
let status = response.status();
|
||||||
|
let error_text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
|
Err(format!("Failed to fetch models {}: {}", status, error_text))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// API client function to send chat completion requests
|
||||||
|
pub async fn send_chat_completion(
|
||||||
|
messages: Vec<ChatMessage>,
|
||||||
|
model: String,
|
||||||
|
) -> Result<String, String> {
|
||||||
|
let request = ChatRequest {
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
max_tokens: Some(1024),
|
||||||
|
stream: Some(false),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = Request::post("/v1/chat/completions")
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.json(&request)
|
||||||
|
.map_err(|e| format!("Failed to create request: {:?}", e))?
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to send request: {:?}", e))?;
|
||||||
|
|
||||||
|
if response.ok() {
|
||||||
|
let chat_response: ChatResponse = response
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to parse response: {:?}", e))?;
|
||||||
|
|
||||||
|
if let Some(choice) = chat_response.choices.first() {
|
||||||
|
Ok(choice.message.content.clone())
|
||||||
|
} else {
|
||||||
|
Err("No response choices available".to_string())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let status = response.status();
|
||||||
|
let error_text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
|
Err(format!("Server error {}: {}", status, error_text))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Streaming chat completion using EventSource
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
pub fn send_chat_completion_stream(
|
||||||
|
messages: Vec<ChatMessage>,
|
||||||
|
model: String,
|
||||||
|
on_chunk: impl Fn(String) + 'static,
|
||||||
|
on_complete: impl Fn() + 'static,
|
||||||
|
on_error: impl Fn(String) + 'static,
|
||||||
|
) {
|
||||||
|
use wasm_bindgen::prelude::*;
|
||||||
|
use wasm_bindgen::JsCast;
|
||||||
|
|
||||||
|
let request = ChatRequest {
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
max_tokens: Some(1024),
|
||||||
|
stream: Some(true),
|
||||||
|
};
|
||||||
|
|
||||||
|
// We need to send a POST request but EventSource only supports GET
|
||||||
|
// So we'll use fetch with a readable stream instead
|
||||||
|
let window = web_sys::window().unwrap();
|
||||||
|
let request_json = serde_json::to_string(&request).unwrap();
|
||||||
|
|
||||||
|
let opts = web_sys::RequestInit::new();
|
||||||
|
opts.set_method("POST");
|
||||||
|
opts.set_body(&JsValue::from_str(&request_json));
|
||||||
|
|
||||||
|
let headers = web_sys::Headers::new().unwrap();
|
||||||
|
headers.set("Content-Type", "application/json").unwrap();
|
||||||
|
headers.set("Accept", "text/event-stream").unwrap();
|
||||||
|
opts.set_headers(&headers);
|
||||||
|
|
||||||
|
let request = web_sys::Request::new_with_str_and_init("/v1/chat/completions", &opts).unwrap();
|
||||||
|
|
||||||
|
let promise = window.fetch_with_request(&request);
|
||||||
|
|
||||||
|
wasm_bindgen_futures::spawn_local(async move {
|
||||||
|
match wasm_bindgen_futures::JsFuture::from(promise).await {
|
||||||
|
Ok(resp_value) => {
|
||||||
|
let resp: web_sys::Response = resp_value.dyn_into().unwrap();
|
||||||
|
|
||||||
|
if !resp.ok() {
|
||||||
|
on_error(format!("Server error: {}", resp.status()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let body = resp.body();
|
||||||
|
if body.is_none() {
|
||||||
|
on_error("No response body".to_string());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let reader = body
|
||||||
|
.unwrap()
|
||||||
|
.get_reader()
|
||||||
|
.dyn_into::<web_sys::ReadableStreamDefaultReader>()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let decoder = web_sys::TextDecoder::new().unwrap();
|
||||||
|
let mut buffer = String::new();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match wasm_bindgen_futures::JsFuture::from(reader.read()).await {
|
||||||
|
Ok(result) => {
|
||||||
|
let done = js_sys::Reflect::get(&result, &JsValue::from_str("done"))
|
||||||
|
.unwrap()
|
||||||
|
.as_bool()
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
if done {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let value = js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap();
|
||||||
|
let array = js_sys::Uint8Array::new(&value);
|
||||||
|
let mut bytes = vec![0; array.length() as usize];
|
||||||
|
array.copy_to(&mut bytes);
|
||||||
|
let text = decoder.decode_with_u8_array(&bytes).unwrap();
|
||||||
|
|
||||||
|
buffer.push_str(&text);
|
||||||
|
|
||||||
|
// Process complete SSE events from buffer
|
||||||
|
while let Some(event_end) = buffer.find("\n\n") {
|
||||||
|
let event = buffer[..event_end].to_string();
|
||||||
|
buffer = buffer[event_end + 2..].to_string();
|
||||||
|
|
||||||
|
// Parse SSE event
|
||||||
|
for line in event.lines() {
|
||||||
|
if let Some(data) = line.strip_prefix("data: ") {
|
||||||
|
if data == "[DONE]" {
|
||||||
|
on_complete();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse JSON chunk
|
||||||
|
if let Ok(chunk) = serde_json::from_str::<StreamChatResponse>(data) {
|
||||||
|
if let Some(choice) = chunk.choices.first() {
|
||||||
|
if let Some(content) = &choice.delta.content {
|
||||||
|
on_chunk(content.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
on_error(format!("Read error: {:?}", e));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
on_complete();
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
on_error(format!("Fetch error: {:?}", e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn shell(options: LeptosOptions) -> impl IntoView {
|
||||||
|
view! {
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8"/>
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||||
|
<AutoReload options=options.clone() />
|
||||||
|
<HydrationScripts options/>
|
||||||
|
<MetaTags/>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<App/>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[component]
|
||||||
|
pub fn App() -> impl IntoView {
|
||||||
|
// Provides context that manages stylesheets, titles, meta tags, etc.
|
||||||
|
provide_meta_context();
|
||||||
|
|
||||||
|
view! {
|
||||||
|
// injects a stylesheet into the document <head>
|
||||||
|
// id=leptos means cargo-leptos will hot-reload this stylesheet
|
||||||
|
<Stylesheet id="leptos" href="/pkg/chat-ui.css"/>
|
||||||
|
|
||||||
|
// sets the document title
|
||||||
|
<Title text="Predict-Otron-9000 Chat"/>
|
||||||
|
|
||||||
|
// content for this welcome page
|
||||||
|
<Router>
|
||||||
|
<main>
|
||||||
|
<Routes fallback=|| "Page not found.".into_view()>
|
||||||
|
<Route path=StaticSegment("") view=ChatPage/>
|
||||||
|
</Routes>
|
||||||
|
</main>
|
||||||
|
</Router>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Renders the chat interface page
|
||||||
|
#[component]
|
||||||
|
fn ChatPage() -> impl IntoView {
|
||||||
|
// State for conversation messages
|
||||||
|
let messages = RwSignal::new(Vec::<ChatMessage>::new());
|
||||||
|
|
||||||
|
// State for current user input
|
||||||
|
let input_text = RwSignal::new(String::new());
|
||||||
|
|
||||||
|
// State for loading indicator
|
||||||
|
let is_loading = RwSignal::new(false);
|
||||||
|
|
||||||
|
// State for error messages
|
||||||
|
let error_message = RwSignal::new(Option::<String>::None);
|
||||||
|
|
||||||
|
// State for available models and selected model
|
||||||
|
let available_models = RwSignal::new(Vec::<ModelInfo>::new());
|
||||||
|
let selected_model = RwSignal::new(String::from("gemma-3-1b-it")); // Default model
|
||||||
|
|
||||||
|
// State for streaming response
|
||||||
|
let streaming_content = RwSignal::new(String::new());
|
||||||
|
let is_streaming = RwSignal::new(false);
|
||||||
|
|
||||||
|
// State for streaming mode toggle
|
||||||
|
let use_streaming = RwSignal::new(true); // Default to streaming
|
||||||
|
|
||||||
|
// Client-side only: Fetch models on component mount
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
{
|
||||||
|
use leptos::task::spawn_local;
|
||||||
|
spawn_local(async move {
|
||||||
|
match fetch_models().await {
|
||||||
|
Ok(models) => {
|
||||||
|
available_models.set(models);
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
console::log_1(&format!("Failed to fetch models: {}", error).into());
|
||||||
|
error_message.set(Some(format!("Failed to load models: {}", error)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shared logic for sending a message
|
||||||
|
let send_message_logic = move || {
|
||||||
|
let user_input = input_text.get();
|
||||||
|
if user_input.trim().is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add user message to conversation
|
||||||
|
let user_message = ChatMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: user_input.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
messages.update(|msgs| msgs.push(user_message.clone()));
|
||||||
|
input_text.set(String::new());
|
||||||
|
is_loading.set(true);
|
||||||
|
error_message.set(None);
|
||||||
|
|
||||||
|
// Client-side only: Send chat completion request
|
||||||
|
#[cfg(target_arch = "wasm32")]
|
||||||
|
{
|
||||||
|
use leptos::task::spawn_local;
|
||||||
|
|
||||||
|
// Prepare messages for API call
|
||||||
|
let current_messages = messages.get();
|
||||||
|
let current_model = selected_model.get();
|
||||||
|
let should_stream = use_streaming.get();
|
||||||
|
|
||||||
|
if should_stream {
|
||||||
|
// Clear streaming content and set streaming flag
|
||||||
|
streaming_content.set(String::new());
|
||||||
|
is_streaming.set(true);
|
||||||
|
|
||||||
|
// Use streaming API
|
||||||
|
send_chat_completion_stream(
|
||||||
|
current_messages,
|
||||||
|
current_model,
|
||||||
|
move |chunk| {
|
||||||
|
// Append chunk to streaming content
|
||||||
|
streaming_content.update(|content| content.push_str(&chunk));
|
||||||
|
},
|
||||||
|
move || {
|
||||||
|
// On complete, move streaming content to messages
|
||||||
|
let final_content = streaming_content.get();
|
||||||
|
if !final_content.is_empty() {
|
||||||
|
let assistant_message = ChatMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: final_content,
|
||||||
|
};
|
||||||
|
messages.update(|msgs| msgs.push(assistant_message));
|
||||||
|
}
|
||||||
|
streaming_content.set(String::new());
|
||||||
|
is_streaming.set(false);
|
||||||
|
is_loading.set(false);
|
||||||
|
},
|
||||||
|
move |error| {
|
||||||
|
console::log_1(&format!("Streaming Error: {}", error).into());
|
||||||
|
error_message.set(Some(error));
|
||||||
|
is_streaming.set(false);
|
||||||
|
is_loading.set(false);
|
||||||
|
streaming_content.set(String::new());
|
||||||
|
},
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// Use non-streaming API
|
||||||
|
spawn_local(async move {
|
||||||
|
match send_chat_completion(current_messages, current_model).await {
|
||||||
|
Ok(response_content) => {
|
||||||
|
let assistant_message = ChatMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: response_content,
|
||||||
|
};
|
||||||
|
messages.update(|msgs| msgs.push(assistant_message));
|
||||||
|
is_loading.set(false);
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
console::log_1(&format!("API Error: {}", error).into());
|
||||||
|
error_message.set(Some(error));
|
||||||
|
is_loading.set(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Button click handler
|
||||||
|
let on_button_click = {
|
||||||
|
let send_logic = send_message_logic.clone();
|
||||||
|
move |_: web_sys::MouseEvent| {
|
||||||
|
send_logic();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle enter key press in input field
|
||||||
|
let on_key_down = move |ev: web_sys::KeyboardEvent| {
|
||||||
|
if ev.key() == "Enter" && !ev.shift_key() {
|
||||||
|
ev.prevent_default();
|
||||||
|
send_message_logic();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
view! {
|
||||||
|
<div class="chat-container">
|
||||||
|
<div class="chat-header">
|
||||||
|
<h1>"Predict-Otron-9000 Chat"</h1>
|
||||||
|
<div class="model-selector">
|
||||||
|
<label for="model-select">"Model:"</label>
|
||||||
|
<select
|
||||||
|
id="model-select"
|
||||||
|
prop:value=move || selected_model.get()
|
||||||
|
on:change=move |ev| {
|
||||||
|
let new_model = event_target_value(&ev);
|
||||||
|
selected_model.set(new_model);
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<For
|
||||||
|
each=move || available_models.get().into_iter()
|
||||||
|
key=|model| model.id.clone()
|
||||||
|
children=move |model| {
|
||||||
|
view! {
|
||||||
|
<option value=model.id.clone()>
|
||||||
|
{format!("{} ({})", model.id, model.owned_by)}
|
||||||
|
</option>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</select>
|
||||||
|
<div class="streaming-toggle">
|
||||||
|
<label>
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
prop:checked=move || use_streaming.get()
|
||||||
|
on:change=move |ev| {
|
||||||
|
let target = event_target::<web_sys::HtmlInputElement>(&ev);
|
||||||
|
use_streaming.set(target.checked());
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
" Use streaming"
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="chat-messages">
|
||||||
|
<For
|
||||||
|
each=move || messages.get().into_iter().enumerate()
|
||||||
|
key=|(i, _)| *i
|
||||||
|
children=move |(_, message)| {
|
||||||
|
let role_class = if message.role == "user" { "user-message" } else { "assistant-message" };
|
||||||
|
view! {
|
||||||
|
<div class=format!("message {}", role_class)>
|
||||||
|
<div class="message-role">{message.role.clone()}</div>
|
||||||
|
<div class="message-content">{message.content.clone()}</div>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{move || {
|
||||||
|
if is_streaming.get() {
|
||||||
|
let content = streaming_content.get();
|
||||||
|
if !content.is_empty() {
|
||||||
|
view! {
|
||||||
|
<div class="message assistant-message streaming">
|
||||||
|
<div class="message-role">"assistant"</div>
|
||||||
|
<div class="message-content">{content}<span class="cursor">"▊"</span></div>
|
||||||
|
</div>
|
||||||
|
}.into_any()
|
||||||
|
} else {
|
||||||
|
view! {
|
||||||
|
<div class="message assistant-message loading">
|
||||||
|
<div class="message-role">"assistant"</div>
|
||||||
|
<div class="message-content">"Thinking..."</div>
|
||||||
|
</div>
|
||||||
|
}.into_any()
|
||||||
|
}
|
||||||
|
} else if is_loading.get() && !use_streaming.get() {
|
||||||
|
view! {
|
||||||
|
<div class="message assistant-message loading">
|
||||||
|
<div class="message-role">"assistant"</div>
|
||||||
|
<div class="message-content">"Thinking..."</div>
|
||||||
|
</div>
|
||||||
|
}.into_any()
|
||||||
|
} else {
|
||||||
|
view! {}.into_any()
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{move || {
|
||||||
|
if let Some(error) = error_message.get() {
|
||||||
|
view! {
|
||||||
|
<div class="error-message">
|
||||||
|
"Error: " {error}
|
||||||
|
</div>
|
||||||
|
}.into_any()
|
||||||
|
} else {
|
||||||
|
view! {}.into_any()
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
|
||||||
|
<div class="chat-input">
|
||||||
|
<textarea
|
||||||
|
placeholder="Type your message here... (Press Enter to send, Shift+Enter for new line)"
|
||||||
|
prop:value=move || input_text.get()
|
||||||
|
on:input=move |ev| input_text.set(event_target_value(&ev))
|
||||||
|
on:keydown=on_key_down
|
||||||
|
class:disabled=move || is_loading.get()
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
on:click=on_button_click
|
||||||
|
class:disabled=move || is_loading.get() || input_text.get().trim().is_empty()
|
||||||
|
>
|
||||||
|
"Send"
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
}
|
9
crates/chat-ui/src/lib.rs
Normal file
9
crates/chat-ui/src/lib.rs
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
pub mod app;
|
||||||
|
|
||||||
|
#[cfg(feature = "hydrate")]
|
||||||
|
#[wasm_bindgen::prelude::wasm_bindgen]
|
||||||
|
pub fn hydrate() {
|
||||||
|
use crate::app::*;
|
||||||
|
console_error_panic_hook::set_once();
|
||||||
|
leptos::mount::hydrate_body(App);
|
||||||
|
}
|
26
crates/chat-ui/src/main.rs
Normal file
26
crates/chat-ui/src/main.rs
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
#[cfg(feature = "ssr")]
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
use axum::Router;
|
||||||
|
use chat_ui::app::*;
|
||||||
|
use leptos::logging::log;
|
||||||
|
use leptos::prelude::*;
|
||||||
|
use leptos_axum::{generate_route_list, LeptosRoutes};
|
||||||
|
|
||||||
|
let conf = get_configuration(None).expect("failed to read config");
|
||||||
|
let addr = conf.leptos_options.site_addr;
|
||||||
|
|
||||||
|
// Build the app router with your extracted function
|
||||||
|
let app: Router = create_router(conf.leptos_options);
|
||||||
|
|
||||||
|
log!("listening on http://{}", &addr);
|
||||||
|
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||||
|
axum::serve(listener, app.into_make_service())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "ssr"))]
|
||||||
|
pub fn main() {
|
||||||
|
// no client-side main function
|
||||||
|
}
|
265
crates/chat-ui/style/main.scss
Normal file
265
crates/chat-ui/style/main.scss
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
* {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||||
|
background-color: #f5f5f5;
|
||||||
|
height: 100vh;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-container {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
height: 100vh;
|
||||||
|
max-width: 800px;
|
||||||
|
margin: 0 auto;
|
||||||
|
background-color: white;
|
||||||
|
box-shadow: 0 0 20px rgba(0, 0, 0, 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-header {
|
||||||
|
background-color: #000000;
|
||||||
|
color: white;
|
||||||
|
padding: 1rem;
|
||||||
|
text-align: center;
|
||||||
|
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 1rem;
|
||||||
|
|
||||||
|
h1 {
|
||||||
|
margin: 0;
|
||||||
|
font-size: 1.5rem;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
|
||||||
|
.model-selector {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
gap: 0.5rem;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
|
||||||
|
label {
|
||||||
|
font-weight: 500;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
background-color: white;
|
||||||
|
color: #374151;
|
||||||
|
border: 1px solid #d1d5db;
|
||||||
|
border-radius: 6px;
|
||||||
|
padding: 0.5rem 0.75rem;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
font-family: inherit;
|
||||||
|
cursor: pointer;
|
||||||
|
min-width: 200px;
|
||||||
|
|
||||||
|
&:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #663c99;
|
||||||
|
box-shadow: 0 0 0 2px rgba(29, 78, 216, 0.2);
|
||||||
|
}
|
||||||
|
|
||||||
|
option {
|
||||||
|
padding: 0.5rem;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.streaming-toggle {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
margin-left: 1rem;
|
||||||
|
|
||||||
|
label {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.5rem;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
|
||||||
|
input[type="checkbox"] {
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-messages {
|
||||||
|
flex: 1;
|
||||||
|
overflow-y: auto;
|
||||||
|
padding: 1rem;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 1rem;
|
||||||
|
background-color: #fafafa;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.5rem;
|
||||||
|
padding: 1rem;
|
||||||
|
border-radius: 12px;
|
||||||
|
max-width: 80%;
|
||||||
|
word-wrap: break-word;
|
||||||
|
|
||||||
|
&.user-message {
|
||||||
|
align-self: flex-end;
|
||||||
|
background-color: #2563eb;
|
||||||
|
color: white;
|
||||||
|
|
||||||
|
.message-role {
|
||||||
|
font-weight: 600;
|
||||||
|
font-size: 0.8rem;
|
||||||
|
opacity: 0.8;
|
||||||
|
text-transform: uppercase;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-content {
|
||||||
|
line-height: 1.5;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
&.assistant-message {
|
||||||
|
align-self: flex-start;
|
||||||
|
background-color: #646873;
|
||||||
|
border: 1px solid #e5e7eb;
|
||||||
|
color: #f3f3f3;
|
||||||
|
|
||||||
|
.message-role {
|
||||||
|
font-weight: 600;
|
||||||
|
font-size: 0.8rem;
|
||||||
|
color: #c4c5cd;
|
||||||
|
text-transform: uppercase;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-content {
|
||||||
|
line-height: 1.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.loading {
|
||||||
|
background-color: #f3f4f6;
|
||||||
|
border-color: #d1d5db;
|
||||||
|
|
||||||
|
.message-content {
|
||||||
|
font-style: italic;
|
||||||
|
color: #6b7280;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
&.streaming {
|
||||||
|
.message-content {
|
||||||
|
.cursor {
|
||||||
|
display: inline-block;
|
||||||
|
animation: blink 1s infinite;
|
||||||
|
color: #9ca3af;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.error-message {
|
||||||
|
background-color: #fef2f2;
|
||||||
|
border: 1px solid #fca5a5;
|
||||||
|
color: #dc2626;
|
||||||
|
padding: 1rem;
|
||||||
|
margin: 0 1rem;
|
||||||
|
border-radius: 8px;
|
||||||
|
text-align: center;
|
||||||
|
font-weight: 500;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-input {
|
||||||
|
display: flex;
|
||||||
|
gap: 0.5rem;
|
||||||
|
padding: 1rem;
|
||||||
|
background-color: white;
|
||||||
|
border-top: 1px solid #e5e7eb;
|
||||||
|
|
||||||
|
textarea {
|
||||||
|
flex: 1;
|
||||||
|
padding: 0.75rem;
|
||||||
|
border: 1px solid #d1d5db;
|
||||||
|
border-radius: 8px;
|
||||||
|
resize: none;
|
||||||
|
min-height: 60px;
|
||||||
|
max-height: 120px;
|
||||||
|
font-family: inherit;
|
||||||
|
font-size: 1rem;
|
||||||
|
line-height: 1.5;
|
||||||
|
|
||||||
|
&:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #663c99;
|
||||||
|
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
&.disabled {
|
||||||
|
background-color: #f9fafb;
|
||||||
|
color: #6b7280;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
button {
|
||||||
|
padding: 0.75rem 1.5rem;
|
||||||
|
background-color: #663c99;
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-weight: 600;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: background-color 0.2s ease;
|
||||||
|
align-self: flex-end;
|
||||||
|
|
||||||
|
&:hover:not(.disabled) {
|
||||||
|
background-color: #663c99;
|
||||||
|
}
|
||||||
|
|
||||||
|
&.disabled {
|
||||||
|
background-color: #9ca3af;
|
||||||
|
cursor: not-allowed;
|
||||||
|
}
|
||||||
|
|
||||||
|
&:focus {
|
||||||
|
outline: none;
|
||||||
|
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Scrollbar styling for webkit browsers */
|
||||||
|
.chat-messages::-webkit-scrollbar {
|
||||||
|
width: 6px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-messages::-webkit-scrollbar-track {
|
||||||
|
background: #f1f1f1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-messages::-webkit-scrollbar-thumb {
|
||||||
|
background: #c1c1c1;
|
||||||
|
border-radius: 3px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-messages::-webkit-scrollbar-thumb:hover {
|
||||||
|
background: #a8a8a8;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Cursor blink animation */
|
||||||
|
@keyframes blink {
|
||||||
|
0%, 50% {
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
51%, 100% {
|
||||||
|
opacity: 0;
|
||||||
|
}
|
||||||
|
}
|
11
crates/cli/Cargo.toml
Normal file
11
crates/cli/Cargo.toml
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
[package]
|
||||||
|
name = "cli"
|
||||||
|
version.workspace = true
|
||||||
|
edition = "2021"
|
||||||
|
build = "build.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "cli"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
24
crates/cli/README.md
Normal file
24
crates/cli/README.md
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# cli
|
||||||
|
|
||||||
|
A Rust/Typescript Hybrid
|
||||||
|
|
||||||
|
```console
|
||||||
|
bun run cli.ts [options] [prompt]
|
||||||
|
|
||||||
|
Simple CLI tool for testing the local OpenAI-compatible API server.
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--model <model> Model to use (default: gemma-3-1b-it)
|
||||||
|
--prompt <prompt> The prompt to send (can also be provided as positional argument)
|
||||||
|
--list-models List all available models from the server
|
||||||
|
--help Show this help message
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
cd crates/cli/package
|
||||||
|
bun run cli.ts "What is the capital of France?"
|
||||||
|
bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||||
|
bun run cli.ts --prompt "Who was the 16th president of the United States?"
|
||||||
|
bun run cli.ts --list-models
|
||||||
|
|
||||||
|
The server must be running at http://localhost:8080
|
||||||
|
```
|
204
crates/cli/build.rs
Normal file
204
crates/cli/build.rs
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
use std::env;
|
||||||
|
use std::fs;
|
||||||
|
use std::io::{self, BufRead, Write};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::process::{ChildStderr, ChildStdout, Command, Stdio};
|
||||||
|
use std::thread;
|
||||||
|
use std::time::{Duration, SystemTime};
|
||||||
|
mod bun_target;
|
||||||
|
use bun_target::BunTarget;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
println!("cargo:rerun-if-changed=");
|
||||||
|
|
||||||
|
if let Err(e) = run_build() {
|
||||||
|
println!("cargo:warning=build.rs failed: {e}");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_build() -> io::Result<()> {
|
||||||
|
let manifest_dir =
|
||||||
|
PathBuf::from(env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"));
|
||||||
|
let package_dir = manifest_dir.join("package");
|
||||||
|
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set by Cargo"));
|
||||||
|
let output_path = out_dir.join("client-cli");
|
||||||
|
|
||||||
|
let bun_tgt = BunTarget::from_cargo_env()
|
||||||
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
|
||||||
|
|
||||||
|
// Optional: warn if using a Bun target that’s marked unsupported in your chart
|
||||||
|
if matches!(bun_tgt, BunTarget::WindowsArm64) {
|
||||||
|
println!(
|
||||||
|
"cargo:warning=bun-windows-arm64 is marked unsupported in the compatibility chart"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
warn(&format!("Building CLI into: {}", output_path.display()));
|
||||||
|
|
||||||
|
// --- bun install (in ./package), keep temps inside OUT_DIR ---
|
||||||
|
let mut install = Command::new("bun")
|
||||||
|
.current_dir(&package_dir)
|
||||||
|
.env("TMPDIR", &out_dir)
|
||||||
|
.arg("install")
|
||||||
|
.stdin(Stdio::null())
|
||||||
|
.stdout(Stdio::piped())
|
||||||
|
.stderr(Stdio::piped())
|
||||||
|
.spawn()
|
||||||
|
.map_err(|e| io::Error::new(e.kind(), format!("Failed to spawn `bun install`: {e}")))?;
|
||||||
|
|
||||||
|
let install_join = stream_child("bun install", install.stdout.take(), install.stderr.take());
|
||||||
|
let install_status = install.wait()?;
|
||||||
|
// ensure streams finish
|
||||||
|
join_streams(install_join);
|
||||||
|
|
||||||
|
if !install_status.success() {
|
||||||
|
let code = install_status.code().unwrap_or(1);
|
||||||
|
return Err(io::Error::new(
|
||||||
|
io::ErrorKind::Other,
|
||||||
|
format!("bun install failed with status {code}"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let target = env::var("TARGET").unwrap();
|
||||||
|
|
||||||
|
// --- bun build (in ./package), emit to OUT_DIR, keep temps inside OUT_DIR ---
|
||||||
|
let mut build = Command::new("bun")
|
||||||
|
.current_dir(&package_dir)
|
||||||
|
.env("TMPDIR", &out_dir)
|
||||||
|
.arg("build")
|
||||||
|
.arg("./cli.ts")
|
||||||
|
.arg(format!("--target={}", bun_tgt.as_bun_flag()))
|
||||||
|
.arg("--compile")
|
||||||
|
.arg("--outfile")
|
||||||
|
.arg(&output_path)
|
||||||
|
.stdout(Stdio::piped())
|
||||||
|
.stderr(Stdio::piped())
|
||||||
|
.spawn()
|
||||||
|
.map_err(|e| io::Error::new(e.kind(), format!("Failed to spawn `bun build`: {e}")))?;
|
||||||
|
|
||||||
|
let build_join = stream_child("bun build", build.stdout.take(), build.stderr.take());
|
||||||
|
let status = build.wait()?;
|
||||||
|
// ensure streams finish
|
||||||
|
join_streams(build_join);
|
||||||
|
|
||||||
|
if status.success() {
|
||||||
|
info("bun build succeeded");
|
||||||
|
} else {
|
||||||
|
let code = status.code().unwrap_or(1);
|
||||||
|
warn(&format!("bun build failed with status: {code}"));
|
||||||
|
return Err(io::Error::new(io::ErrorKind::Other, "bun build failed"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the output is executable (after it exists)
|
||||||
|
#[cfg(unix)]
|
||||||
|
{
|
||||||
|
use std::os::unix::fs::PermissionsExt;
|
||||||
|
let mut perms = fs::metadata(&output_path)?.permissions();
|
||||||
|
perms.set_mode(0o755);
|
||||||
|
fs::set_permissions(&output_path, perms)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("cargo:warning=Built CLI at {}", output_path.display());
|
||||||
|
println!("cargo:rustc-env=CLIENT_CLI_BIN={}", output_path.display());
|
||||||
|
|
||||||
|
// --- Cleanup stray .bun-build temp files (conservative: older than 5 minutes) ---
|
||||||
|
for dir in [&manifest_dir, &package_dir, &out_dir] {
|
||||||
|
if let Err(e) = remove_bun_temp_files(dir, Some(Duration::from_secs(5 * 60))) {
|
||||||
|
println!("cargo:warning=cleanup in {} failed: {e}", dir.display());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Spawn readers for child's stdout/stderr so we don't deadlock on pipe buffers
|
||||||
|
fn stream_child(
|
||||||
|
tag: &str,
|
||||||
|
stdout: Option<ChildStdout>,
|
||||||
|
stderr: Option<ChildStderr>,
|
||||||
|
) -> (
|
||||||
|
Option<thread::JoinHandle<()>>,
|
||||||
|
Option<thread::JoinHandle<()>>,
|
||||||
|
) {
|
||||||
|
let t1 = stdout.map(|out| {
|
||||||
|
let tag = tag.to_string();
|
||||||
|
thread::spawn(move || {
|
||||||
|
let reader = io::BufReader::new(out);
|
||||||
|
for line in reader.lines() {
|
||||||
|
info(&format!("[{tag} stdout] {}", line.unwrap_or_default()));
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
let t2 = stderr.map(|err| {
|
||||||
|
let tag = tag.to_string();
|
||||||
|
thread::spawn(move || {
|
||||||
|
let reader = io::BufReader::new(err);
|
||||||
|
for line in reader.lines() {
|
||||||
|
warn(&format!("[{tag} stderr] {}", line.unwrap_or_default()));
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
(t1, t2)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn join_streams(
|
||||||
|
joins: (
|
||||||
|
Option<thread::JoinHandle<()>>,
|
||||||
|
Option<thread::JoinHandle<()>>,
|
||||||
|
),
|
||||||
|
) {
|
||||||
|
if let Some(j) = joins.0 {
|
||||||
|
let _ = j.join();
|
||||||
|
}
|
||||||
|
if let Some(j) = joins.1 {
|
||||||
|
let _ = j.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove_bun_temp_files(dir: &Path, older_than: Option<Duration>) -> io::Result<()> {
|
||||||
|
let now = SystemTime::now();
|
||||||
|
for entry in fs::read_dir(dir)? {
|
||||||
|
let entry = entry?;
|
||||||
|
let path = entry.path();
|
||||||
|
if !path.is_file() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Files like ".1860e7df40ff1bef-00000000.bun-build"
|
||||||
|
let name = entry.file_name();
|
||||||
|
let name = name.to_string_lossy();
|
||||||
|
let looks_like_bun_temp = name.starts_with('.') && name.ends_with(".bun-build");
|
||||||
|
|
||||||
|
if !looks_like_bun_temp {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(age) = older_than {
|
||||||
|
if let Ok(meta) = entry.metadata() {
|
||||||
|
if let Ok(modified) = meta.modified() {
|
||||||
|
if now.duration_since(modified).unwrap_or_default() < age {
|
||||||
|
// too new; skip to avoid racing an in-flight builder
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match fs::remove_file(&path) {
|
||||||
|
Ok(_) => println!("cargo:warning=removed stray bun temp {}", path.display()),
|
||||||
|
Err(e) => println!("cargo:warning=failed to remove {}: {e}", path.display()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn warn(msg: &str) {
|
||||||
|
let _ = writeln!(io::stderr(), "[build.rs] {msg}");
|
||||||
|
println!("cargo:warning={msg}");
|
||||||
|
}
|
||||||
|
|
||||||
|
fn info(msg: &str) {
|
||||||
|
let _ = writeln!(io::stderr(), "[build.rs] {msg}");
|
||||||
|
println!("cargo:warning=INFO|{msg}");
|
||||||
|
}
|
131
crates/cli/bun_target.rs
Normal file
131
crates/cli/bun_target.rs
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
use std::env;
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
|
||||||
|
pub enum BunTarget {
|
||||||
|
LinuxX64Glibc,
|
||||||
|
LinuxArm64Glibc,
|
||||||
|
LinuxX64Musl,
|
||||||
|
LinuxArm64Musl,
|
||||||
|
WindowsX64,
|
||||||
|
WindowsArm64,
|
||||||
|
MacX64,
|
||||||
|
MacArm64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BunTarget {
|
||||||
|
pub const fn as_bun_flag(self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
BunTarget::LinuxX64Glibc => "bun-linux-x64",
|
||||||
|
BunTarget::LinuxArm64Glibc => "bun-linux-arm64",
|
||||||
|
BunTarget::LinuxX64Musl => "bun-linux-x64-musl",
|
||||||
|
BunTarget::LinuxArm64Musl => "bun-linux-arm64-musl",
|
||||||
|
BunTarget::WindowsX64 => "bun-windows-x64",
|
||||||
|
BunTarget::WindowsArm64 => "bun-windows-arm64",
|
||||||
|
BunTarget::MacX64 => "bun-darwin-x64",
|
||||||
|
BunTarget::MacArm64 => "bun-darwin-arm64",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const fn rust_triples(self) -> &'static [&'static str] {
|
||||||
|
match self {
|
||||||
|
BunTarget::LinuxX64Glibc => {
|
||||||
|
&["x86_64-unknown-linux-gnu", "x86_64-unknown-linux-gnu.2.17"]
|
||||||
|
}
|
||||||
|
BunTarget::LinuxArm64Glibc => &["aarch64-unknown-linux-gnu"],
|
||||||
|
BunTarget::LinuxX64Musl => &["x86_64-unknown-linux-musl"],
|
||||||
|
BunTarget::LinuxArm64Musl => &["aarch64-unknown-linux-musl"],
|
||||||
|
BunTarget::WindowsX64 => &["x86_64-pc-windows-msvc"],
|
||||||
|
BunTarget::WindowsArm64 => &["aarch64-pc-windows-msvc"], // chart says unsupported; still map
|
||||||
|
BunTarget::MacX64 => &["x86_64-apple-darwin"],
|
||||||
|
BunTarget::MacArm64 => &["aarch64-apple-darwin"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_rust_target(triple: &str) -> Option<Self> {
|
||||||
|
let norm = triple.trim();
|
||||||
|
if norm.starts_with("x86_64-") && norm.contains("-linux-") && norm.ends_with("gnu") {
|
||||||
|
return Some(BunTarget::LinuxX64Glibc);
|
||||||
|
}
|
||||||
|
if norm.starts_with("aarch64-") && norm.contains("-linux-") && norm.ends_with("gnu") {
|
||||||
|
return Some(BunTarget::LinuxArm64Glibc);
|
||||||
|
}
|
||||||
|
if norm.starts_with("x86_64-") && norm.contains("-linux-") && norm.ends_with("musl") {
|
||||||
|
return Some(BunTarget::LinuxX64Musl);
|
||||||
|
}
|
||||||
|
if norm.starts_with("aarch64-") && norm.contains("-linux-") && norm.ends_with("musl") {
|
||||||
|
return Some(BunTarget::LinuxArm64Musl);
|
||||||
|
}
|
||||||
|
if norm == "x86_64-pc-windows-msvc" {
|
||||||
|
return Some(BunTarget::WindowsX64);
|
||||||
|
}
|
||||||
|
if norm == "aarch64-pc-windows-msvc" {
|
||||||
|
return Some(BunTarget::WindowsArm64);
|
||||||
|
}
|
||||||
|
if norm == "x86_64-apple-darwin" {
|
||||||
|
return Some(BunTarget::MacX64);
|
||||||
|
}
|
||||||
|
if norm == "aarch64-apple-darwin" {
|
||||||
|
return Some(BunTarget::MacArm64);
|
||||||
|
}
|
||||||
|
for bt in [
|
||||||
|
BunTarget::LinuxX64Glibc,
|
||||||
|
BunTarget::LinuxArm64Glibc,
|
||||||
|
BunTarget::LinuxX64Musl,
|
||||||
|
BunTarget::LinuxArm64Musl,
|
||||||
|
BunTarget::WindowsX64,
|
||||||
|
BunTarget::WindowsArm64,
|
||||||
|
BunTarget::MacX64,
|
||||||
|
BunTarget::MacArm64,
|
||||||
|
] {
|
||||||
|
for &t in bt.rust_triples() {
|
||||||
|
if t == norm {
|
||||||
|
return Some(bt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_cargo_env() -> Result<Self, BunTargetError> {
|
||||||
|
if let Ok(triple) = env::var("TARGET") {
|
||||||
|
if let Some(bt) = Self::from_rust_target(&triple) {
|
||||||
|
return Ok(bt);
|
||||||
|
}
|
||||||
|
return Err(BunTargetError::UnknownTriple(triple));
|
||||||
|
}
|
||||||
|
|
||||||
|
let os = env::var("CARGO_CFG_TARGET_OS").unwrap_or_default();
|
||||||
|
let arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap_or_default();
|
||||||
|
let envv = env::var("CARGO_CFG_TARGET_ENV").unwrap_or_default();
|
||||||
|
let vendor = env::var("CARGO_CFG_TARGET_VENDOR").unwrap_or_else(|_| "unknown".into());
|
||||||
|
|
||||||
|
let triple = format!(
|
||||||
|
"{}-{}-{}-{}",
|
||||||
|
arch,
|
||||||
|
vendor,
|
||||||
|
os,
|
||||||
|
if envv.is_empty() { "gnu" } else { &envv }
|
||||||
|
);
|
||||||
|
if let Some(bt) = Self::from_rust_target(&triple) {
|
||||||
|
Ok(bt)
|
||||||
|
} else {
|
||||||
|
Err(BunTargetError::UnknownTriple(triple))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum BunTargetError {
|
||||||
|
UnknownTriple(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for BunTargetError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
BunTargetError::UnknownTriple(t) => write!(f, "unrecognized Rust target triple: {t}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for BunTargetError {}
|
@@ -30,24 +30,23 @@ type ChunkStat = {
|
|||||||
|
|
||||||
function printHelp() {
|
function printHelp() {
|
||||||
console.log(`
|
console.log(`
|
||||||
Usage: bun client_cli.ts [options] [prompt]
|
./cli [options] [prompt]
|
||||||
|
|
||||||
Simple CLI tool for testing the local OpenAI-compatible API server.
|
Simple CLI tool for testing the local OpenAI-compatible API server.
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
--model <model> Model to use (default: ${DEFAULT_MODEL})
|
--model <model> Model to use (default: gemma-3-1b-it)
|
||||||
--prompt <prompt> The prompt to send (can also be provided as positional argument)
|
--prompt <prompt> The prompt to send (can also be provided as positional argument)
|
||||||
--list-models List all available models from the server
|
--list-models List all available models from the server
|
||||||
--help Show this help message
|
--help Show this help message
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
./cli.ts "What is the capital of France?"
|
./cli "What is the capital of France?"
|
||||||
./cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
./cli --model gemma-3-1b-it --prompt "Hello, world!"
|
||||||
./cli.ts --prompt "Who was the 16th president of the United States?"
|
./cli --prompt "Who was the 16th president of the United States?"
|
||||||
./cli.ts --list-models
|
./cli --list-models
|
||||||
|
|
||||||
The server should be running at http://localhost:8080
|
The server must be running at http://localhost:8080
|
||||||
Start it with: ./run_server.sh
|
|
||||||
`);
|
`);
|
||||||
}
|
}
|
||||||
|
|
11
crates/cli/package/package.json
Normal file
11
crates/cli/package/package.json
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"name": "cli",
|
||||||
|
"main": "cli.ts",
|
||||||
|
"scripts": {
|
||||||
|
"build": "bun build cli.ts --compile --outfile cli"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"install": "^0.13.0",
|
||||||
|
"openai": "^5.16.0"
|
||||||
|
}
|
||||||
|
}
|
32
crates/cli/src/main.rs
Normal file
32
crates/cli/src/main.rs
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
use std::{env, fs, io, path::PathBuf, process::Command};
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
use std::os::unix::fs::PermissionsExt;
|
||||||
|
|
||||||
|
fn main() -> io::Result<()> {
|
||||||
|
// Absolute path provided by build.rs at compile time.
|
||||||
|
// `include_bytes!` accepts string literals; `env!` expands to a literal at compile time.
|
||||||
|
const CLIENT_CLI: &[u8] = include_bytes!(env!("CLIENT_CLI_BIN"));
|
||||||
|
|
||||||
|
// Write to a temp file
|
||||||
|
let mut tmp = env::temp_dir();
|
||||||
|
tmp.push("client-cli-embedded");
|
||||||
|
|
||||||
|
fs::write(&tmp, CLIENT_CLI)?;
|
||||||
|
|
||||||
|
// Ensure it's executable on Unix
|
||||||
|
#[cfg(unix)]
|
||||||
|
{
|
||||||
|
let mut perms = fs::metadata(&tmp)?.permissions();
|
||||||
|
perms.set_mode(0o755);
|
||||||
|
fs::set_permissions(&tmp, perms)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run it
|
||||||
|
let status = Command::new(&tmp).arg("--version").status()?;
|
||||||
|
if !status.success() {
|
||||||
|
return Err(io::Error::new(io::ErrorKind::Other, "client-cli failed"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "embeddings-engine"
|
name = "embeddings-engine"
|
||||||
version = "0.1.0"
|
version.workspace = true
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
|
@@ -1,4 +1,100 @@
|
|||||||
# Embeddings Engine
|
# Embeddings Engine
|
||||||
|
|
||||||
A high-performance text embeddings service that generates vector representations of text using state-of-the-art models.
|
A high-performance text embeddings service that generates vector representations of text using state-of-the-art models. This crate wraps the FastEmbed library to provide embeddings with OpenAI-compatible API endpoints.
|
||||||
This crate wraps the fastembed crate to provide embeddings and partially adapts the openai specification.
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The embeddings-engine provides a standalone service for generating text embeddings that can be used for semantic search, similarity comparisons, and other NLP tasks. It's designed to be compatible with OpenAI's embeddings API format.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **OpenAI-Compatible API**: `/v1/embeddings` endpoint matching OpenAI's specification
|
||||||
|
- **FastEmbed Integration**: Powered by the FastEmbed library for high-quality embeddings
|
||||||
|
- **Multiple Model Support**: Support for various embedding models
|
||||||
|
- **High Performance**: Optimized for fast embedding generation
|
||||||
|
- **Standalone Service**: Can run independently or as part of the predict-otron-9000 platform
|
||||||
|
|
||||||
|
## Building and Running
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
- Rust toolchain
|
||||||
|
- Internet connection for initial model downloads
|
||||||
|
|
||||||
|
### Standalone Server
|
||||||
|
```bash
|
||||||
|
cargo run --bin embeddings-engine --release
|
||||||
|
```
|
||||||
|
|
||||||
|
The service will start on port 8080 by default.
|
||||||
|
|
||||||
|
## API Usage
|
||||||
|
|
||||||
|
### Generate Embeddings
|
||||||
|
|
||||||
|
**Endpoint**: `POST /v1/embeddings`
|
||||||
|
|
||||||
|
**Request Body**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": "Your text to embed",
|
||||||
|
"model": "nomic-embed-text-v1.5"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": 0,
|
||||||
|
"embedding": [0.1, 0.2, 0.3, ...]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "nomic-embed-text-v1.5",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example Usage
|
||||||
|
|
||||||
|
**Using cURL**:
|
||||||
|
```bash
|
||||||
|
curl -s http://localhost:8080/v1/embeddings \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"input": "The quick brown fox jumps over the lazy dog",
|
||||||
|
"model": "nomic-embed-text-v1.5"
|
||||||
|
}' | jq
|
||||||
|
```
|
||||||
|
|
||||||
|
**Using Python OpenAI Client**:
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
base_url="http://localhost:8080/v1",
|
||||||
|
api_key="dummy" # Not validated but required by client
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.embeddings.create(
|
||||||
|
input="Your text here",
|
||||||
|
model="nomic-embed-text-v1.5"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response.data[0].embedding)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
The service can be configured through environment variables:
|
||||||
|
- `SERVER_PORT`: Port to run on (default: 8080)
|
||||||
|
- `RUST_LOG`: Logging level (default: info)
|
||||||
|
|
||||||
|
## Integration
|
||||||
|
|
||||||
|
This service is designed to work seamlessly with the predict-otron-9000 main server, but can also be deployed independently for dedicated embeddings workloads.
|
@@ -1,9 +1,5 @@
|
|||||||
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
||||||
use axum::{
|
use axum::{Json, Router, response::Json as ResponseJson, routing::post};
|
||||||
response::Json as ResponseJson, routing::{post},
|
|
||||||
Json,
|
|
||||||
Router,
|
|
||||||
};
|
|
||||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
@@ -13,15 +9,18 @@ use tracing;
|
|||||||
static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
|
static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
|
||||||
tracing::info!("Initializing persistent embedding model (singleton)");
|
tracing::info!("Initializing persistent embedding model (singleton)");
|
||||||
let model_start_time = std::time::Instant::now();
|
let model_start_time = std::time::Instant::now();
|
||||||
|
|
||||||
let model = TextEmbedding::try_new(
|
let model = TextEmbedding::try_new(
|
||||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true)
|
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
|
||||||
)
|
)
|
||||||
.expect("Failed to initialize persistent embedding model");
|
.expect("Failed to initialize persistent embedding model");
|
||||||
|
|
||||||
let model_init_time = model_start_time.elapsed();
|
let model_init_time = model_start_time.elapsed();
|
||||||
tracing::info!("Persistent embedding model initialized in {:.2?}", model_init_time);
|
tracing::info!(
|
||||||
|
"Persistent embedding model initialized in {:.2?}",
|
||||||
|
model_init_time
|
||||||
|
);
|
||||||
|
|
||||||
model
|
model
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -30,18 +29,21 @@ pub async fn embeddings_create(
|
|||||||
) -> ResponseJson<serde_json::Value> {
|
) -> ResponseJson<serde_json::Value> {
|
||||||
// Start timing the entire process
|
// Start timing the entire process
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
|
|
||||||
// Phase 1: Access persistent model instance
|
// Phase 1: Access persistent model instance
|
||||||
let model_start_time = std::time::Instant::now();
|
let model_start_time = std::time::Instant::now();
|
||||||
|
|
||||||
// Access the lazy-initialized persistent model instance
|
// Access the lazy-initialized persistent model instance
|
||||||
// This will only initialize the model on the first request
|
// This will only initialize the model on the first request
|
||||||
let model_access_time = model_start_time.elapsed();
|
let model_access_time = model_start_time.elapsed();
|
||||||
tracing::debug!("Persistent model access completed in {:.2?}", model_access_time);
|
tracing::debug!(
|
||||||
|
"Persistent model access completed in {:.2?}",
|
||||||
|
model_access_time
|
||||||
|
);
|
||||||
|
|
||||||
// Phase 2: Process input
|
// Phase 2: Process input
|
||||||
let input_start_time = std::time::Instant::now();
|
let input_start_time = std::time::Instant::now();
|
||||||
|
|
||||||
let embedding_input = payload.input;
|
let embedding_input = payload.input;
|
||||||
let texts_from_embedding_input = match embedding_input {
|
let texts_from_embedding_input = match embedding_input {
|
||||||
EmbeddingInput::String(text) => vec![text],
|
EmbeddingInput::String(text) => vec![text],
|
||||||
@@ -53,41 +55,58 @@ pub async fn embeddings_create(
|
|||||||
panic!("Array of integer arrays not supported for text embeddings");
|
panic!("Array of integer arrays not supported for text embeddings");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let input_processing_time = input_start_time.elapsed();
|
let input_processing_time = input_start_time.elapsed();
|
||||||
tracing::debug!("Input processing completed in {:.2?}", input_processing_time);
|
tracing::debug!(
|
||||||
|
"Input processing completed in {:.2?}",
|
||||||
|
input_processing_time
|
||||||
|
);
|
||||||
|
|
||||||
// Phase 3: Generate embeddings
|
// Phase 3: Generate embeddings
|
||||||
let embedding_start_time = std::time::Instant::now();
|
let embedding_start_time = std::time::Instant::now();
|
||||||
|
|
||||||
let embeddings = EMBEDDING_MODEL
|
let embeddings = EMBEDDING_MODEL
|
||||||
.embed(texts_from_embedding_input, None)
|
.embed(texts_from_embedding_input, None)
|
||||||
.expect("failed to embed document");
|
.expect("failed to embed document");
|
||||||
|
|
||||||
let embedding_generation_time = embedding_start_time.elapsed();
|
let embedding_generation_time = embedding_start_time.elapsed();
|
||||||
tracing::info!("Embedding generation completed in {:.2?}", embedding_generation_time);
|
tracing::info!(
|
||||||
|
"Embedding generation completed in {:.2?}",
|
||||||
|
embedding_generation_time
|
||||||
|
);
|
||||||
|
|
||||||
// Memory usage estimation (approximate)
|
// Memory usage estimation (approximate)
|
||||||
let embedding_size_bytes = embeddings.iter()
|
let embedding_size_bytes = embeddings
|
||||||
|
.iter()
|
||||||
.map(|e| e.len() * std::mem::size_of::<f32>())
|
.map(|e| e.len() * std::mem::size_of::<f32>())
|
||||||
.sum::<usize>();
|
.sum::<usize>();
|
||||||
tracing::debug!("Embedding size: {:.2} MB", embedding_size_bytes as f64 / 1024.0 / 1024.0);
|
tracing::debug!(
|
||||||
|
"Embedding size: {:.2} MB",
|
||||||
|
embedding_size_bytes as f64 / 1024.0 / 1024.0
|
||||||
|
);
|
||||||
|
|
||||||
// Only log detailed embedding information at trace level to reduce log volume
|
// Only log detailed embedding information at trace level to reduce log volume
|
||||||
tracing::trace!("Embeddings length: {}", embeddings.len());
|
tracing::trace!("Embeddings length: {}", embeddings.len());
|
||||||
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
||||||
|
|
||||||
// Log the first 10 values of the original embedding at trace level
|
// Log the first 10 values of the original embedding at trace level
|
||||||
tracing::trace!("Original embedding preview: {:?}", &embeddings[0][..10.min(embeddings[0].len())]);
|
tracing::trace!(
|
||||||
|
"Original embedding preview: {:?}",
|
||||||
|
&embeddings[0][..10.min(embeddings[0].len())]
|
||||||
|
);
|
||||||
|
|
||||||
// Check if there are any NaN or zero values in the original embedding
|
// Check if there are any NaN or zero values in the original embedding
|
||||||
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
||||||
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
||||||
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count);
|
tracing::trace!(
|
||||||
|
"Original embedding stats: NaN count={}, zero count={}",
|
||||||
|
nan_count,
|
||||||
|
zero_count
|
||||||
|
);
|
||||||
|
|
||||||
// Phase 4: Post-process embeddings
|
// Phase 4: Post-process embeddings
|
||||||
let postprocessing_start_time = std::time::Instant::now();
|
let postprocessing_start_time = std::time::Instant::now();
|
||||||
|
|
||||||
// Create the final embedding
|
// Create the final embedding
|
||||||
let final_embedding = {
|
let final_embedding = {
|
||||||
// Check if the embedding is all zeros
|
// Check if the embedding is all zeros
|
||||||
@@ -110,6 +129,8 @@ pub async fn embeddings_create(
|
|||||||
|
|
||||||
// Normalize the random embedding
|
// Normalize the random embedding
|
||||||
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||||
|
|
||||||
|
#[allow(clippy::needless_range_loop)]
|
||||||
for i in 0..random_embedding.len() {
|
for i in 0..random_embedding.len() {
|
||||||
random_embedding[i] /= norm;
|
random_embedding[i] /= norm;
|
||||||
}
|
}
|
||||||
@@ -123,25 +144,35 @@ pub async fn embeddings_create(
|
|||||||
let target_dimension = 768;
|
let target_dimension = 768;
|
||||||
if padded_embedding.len() < target_dimension {
|
if padded_embedding.len() < target_dimension {
|
||||||
let padding_needed = target_dimension - padded_embedding.len();
|
let padding_needed = target_dimension - padded_embedding.len();
|
||||||
tracing::trace!("Padding embedding with {} zeros to reach {} dimensions", padding_needed, target_dimension);
|
tracing::trace!(
|
||||||
|
"Padding embedding with {} zeros to reach {} dimensions",
|
||||||
|
padding_needed,
|
||||||
|
target_dimension
|
||||||
|
);
|
||||||
padded_embedding.extend(vec![0.0; padding_needed]);
|
padded_embedding.extend(vec![0.0; padding_needed]);
|
||||||
}
|
}
|
||||||
|
|
||||||
padded_embedding
|
padded_embedding
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let postprocessing_time = postprocessing_start_time.elapsed();
|
let postprocessing_time = postprocessing_start_time.elapsed();
|
||||||
tracing::debug!("Embedding post-processing completed in {:.2?}", postprocessing_time);
|
tracing::debug!(
|
||||||
|
"Embedding post-processing completed in {:.2?}",
|
||||||
|
postprocessing_time
|
||||||
|
);
|
||||||
|
|
||||||
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
||||||
|
|
||||||
// Log the first 10 values of the final embedding at trace level
|
// Log the first 10 values of the final embedding at trace level
|
||||||
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]);
|
tracing::trace!(
|
||||||
|
"Final embedding preview: {:?}",
|
||||||
|
&final_embedding[..10.min(final_embedding.len())]
|
||||||
|
);
|
||||||
|
|
||||||
// Phase 5: Prepare response
|
// Phase 5: Prepare response
|
||||||
let response_start_time = std::time::Instant::now();
|
let response_start_time = std::time::Instant::now();
|
||||||
|
|
||||||
// Return a response that matches the OpenAI API format
|
// Return a response that matches the OpenAI API format
|
||||||
let response = serde_json::json!({
|
let response = serde_json::json!({
|
||||||
"object": "list",
|
"object": "list",
|
||||||
@@ -158,10 +189,10 @@ pub async fn embeddings_create(
|
|||||||
"total_tokens": 0
|
"total_tokens": 0
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let response_time = response_start_time.elapsed();
|
let response_time = response_start_time.elapsed();
|
||||||
tracing::debug!("Response preparation completed in {:.2?}", response_time);
|
tracing::debug!("Response preparation completed in {:.2?}", response_time);
|
||||||
|
|
||||||
// Log total time and breakdown
|
// Log total time and breakdown
|
||||||
let total_time = start_time.elapsed();
|
let total_time = start_time.elapsed();
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
@@ -171,7 +202,7 @@ pub async fn embeddings_create(
|
|||||||
embedding_generation_time,
|
embedding_generation_time,
|
||||||
postprocessing_time
|
postprocessing_time
|
||||||
);
|
);
|
||||||
|
|
||||||
ResponseJson(response)
|
ResponseJson(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -179,4 +210,4 @@ pub fn create_embeddings_router() -> Router {
|
|||||||
Router::new()
|
Router::new()
|
||||||
.route("/v1/embeddings", post(embeddings_create))
|
.route("/v1/embeddings", post(embeddings_create))
|
||||||
.layer(TraceLayer::new_for_http())
|
.layer(TraceLayer::new_for_http())
|
||||||
}
|
}
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
||||||
use axum::{
|
use axum::{
|
||||||
response::Json as ResponseJson, routing::{get, post},
|
Json, Router,
|
||||||
Json,
|
response::Json as ResponseJson,
|
||||||
Router,
|
routing::{get, post},
|
||||||
};
|
};
|
||||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@@ -13,19 +13,17 @@ use tracing;
|
|||||||
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
|
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
|
||||||
const DEFAULT_SERVER_PORT: &str = "8080";
|
const DEFAULT_SERVER_PORT: &str = "8080";
|
||||||
|
|
||||||
|
|
||||||
async fn embeddings_create(
|
async fn embeddings_create(
|
||||||
Json(payload): Json<CreateEmbeddingRequest>,
|
Json(payload): Json<CreateEmbeddingRequest>,
|
||||||
) -> ResponseJson<serde_json::Value> {
|
) -> ResponseJson<serde_json::Value> {
|
||||||
let model = TextEmbedding::try_new(
|
let model = TextEmbedding::try_new(
|
||||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true)
|
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
|
||||||
)
|
)
|
||||||
.expect("Failed to initialize model");
|
.expect("Failed to initialize model");
|
||||||
|
|
||||||
|
let embedding_input = payload.input;
|
||||||
|
|
||||||
let embedding_input = payload.input;
|
let texts_from_embedding_input = match embedding_input {
|
||||||
|
|
||||||
let texts_from_embedding_input = match embedding_input {
|
|
||||||
EmbeddingInput::String(text) => vec![text],
|
EmbeddingInput::String(text) => vec![text],
|
||||||
EmbeddingInput::StringArray(texts) => texts,
|
EmbeddingInput::StringArray(texts) => texts,
|
||||||
EmbeddingInput::IntegerArray(_) => {
|
EmbeddingInput::IntegerArray(_) => {
|
||||||
@@ -45,12 +43,19 @@ async fn embeddings_create(
|
|||||||
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
||||||
|
|
||||||
// Log the first 10 values of the original embedding at trace level
|
// Log the first 10 values of the original embedding at trace level
|
||||||
tracing::trace!("Original embedding preview: {:?}", &embeddings[0][..10.min(embeddings[0].len())]);
|
tracing::trace!(
|
||||||
|
"Original embedding preview: {:?}",
|
||||||
|
&embeddings[0][..10.min(embeddings[0].len())]
|
||||||
|
);
|
||||||
|
|
||||||
// Check if there are any NaN or zero values in the original embedding
|
// Check if there are any NaN or zero values in the original embedding
|
||||||
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
||||||
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
||||||
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count);
|
tracing::trace!(
|
||||||
|
"Original embedding stats: NaN count={}, zero count={}",
|
||||||
|
nan_count,
|
||||||
|
zero_count
|
||||||
|
);
|
||||||
|
|
||||||
// Create the final embedding
|
// Create the final embedding
|
||||||
let final_embedding = {
|
let final_embedding = {
|
||||||
@@ -87,7 +92,11 @@ async fn embeddings_create(
|
|||||||
let target_dimension = 768;
|
let target_dimension = 768;
|
||||||
if padded_embedding.len() < target_dimension {
|
if padded_embedding.len() < target_dimension {
|
||||||
let padding_needed = target_dimension - padded_embedding.len();
|
let padding_needed = target_dimension - padded_embedding.len();
|
||||||
tracing::trace!("Padding embedding with {} zeros to reach {} dimensions", padding_needed, target_dimension);
|
tracing::trace!(
|
||||||
|
"Padding embedding with {} zeros to reach {} dimensions",
|
||||||
|
padding_needed,
|
||||||
|
target_dimension
|
||||||
|
);
|
||||||
padded_embedding.extend(vec![0.0; padding_needed]);
|
padded_embedding.extend(vec![0.0; padding_needed]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,7 +107,10 @@ async fn embeddings_create(
|
|||||||
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
||||||
|
|
||||||
// Log the first 10 values of the final embedding at trace level
|
// Log the first 10 values of the final embedding at trace level
|
||||||
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]);
|
tracing::trace!(
|
||||||
|
"Final embedding preview: {:?}",
|
||||||
|
&final_embedding[..10.min(final_embedding.len())]
|
||||||
|
);
|
||||||
|
|
||||||
// Return a response that matches the OpenAI API format
|
// Return a response that matches the OpenAI API format
|
||||||
let response = serde_json::json!({
|
let response = serde_json::json!({
|
||||||
@@ -120,7 +132,7 @@ async fn embeddings_create(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn create_app() -> Router {
|
fn create_app() -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/v1/embeddings", post(embeddings_create))
|
.route("/v1/embeddings", post(embeddings_create))
|
||||||
.layer(TraceLayer::new_for_http())
|
.layer(TraceLayer::new_for_http())
|
||||||
}
|
}
|
||||||
@@ -143,21 +155,21 @@ async fn main() {
|
|||||||
.init();
|
.init();
|
||||||
let app = create_app();
|
let app = create_app();
|
||||||
|
|
||||||
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string());
|
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string());
|
||||||
let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string());
|
let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string());
|
||||||
let server_address = format!("{}:{}", server_host, server_port);
|
let server_address = format!("{}:{}", server_host, server_port);
|
||||||
let listener = tokio::net::TcpListener::bind(server_address).await.unwrap();
|
let listener = tokio::net::TcpListener::bind(server_address).await.unwrap();
|
||||||
tracing::info!("Listening on {}", listener.local_addr().unwrap());
|
tracing::info!("Listening on {}", listener.local_addr().unwrap());
|
||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use axum::body::to_bytes;
|
use axum::body::Body;
|
||||||
use axum::body::Body;
|
use axum::body::to_bytes;
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
use tower::ServiceExt;
|
use tower::ServiceExt;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_embeddings_create() {
|
async fn test_embeddings_create() {
|
||||||
@@ -168,11 +180,13 @@ mod tests {
|
|||||||
|
|
||||||
let body = CreateEmbeddingRequest {
|
let body = CreateEmbeddingRequest {
|
||||||
model: "nomic-text-embed".to_string(),
|
model: "nomic-text-embed".to_string(),
|
||||||
input: EmbeddingInput::from(vec!["The food was delicious and the waiter...".to_string()]),
|
input: EmbeddingInput::from(vec![
|
||||||
encoding_format: None,
|
"The food was delicious and the waiter...".to_string(),
|
||||||
user: None,
|
]),
|
||||||
dimensions: Some(768),
|
encoding_format: None,
|
||||||
};
|
user: None,
|
||||||
|
dimensions: Some(768),
|
||||||
|
};
|
||||||
|
|
||||||
let response = app
|
let response = app
|
||||||
.oneshot(
|
.oneshot(
|
||||||
|
@@ -1,26 +1,30 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "gemma-runner"
|
name = "gemma-runner"
|
||||||
version = "0.1.0"
|
version.workspace = true
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||||
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||||
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
||||||
candle-examples = { git = "https://github.com/huggingface/candle.git" }
|
|
||||||
|
|
||||||
[target.'cfg(target_os = "macos")'.dependencies]
|
|
||||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
|
||||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
|
||||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
|
||||||
hf-hub = "0.4"
|
hf-hub = "0.4"
|
||||||
tokenizers = "0.21"
|
tokenizers = "0.22.0"
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
clap = { version = "4.0", features = ["derive", "string"] }
|
clap = { version = "4.0", features = ["derive", "string"] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-chrome = "0.7"
|
tracing-chrome = "0.7"
|
||||||
tracing-subscriber = "0.3"
|
tracing-subscriber = "0.3"
|
||||||
|
utils = {path = "../utils"}
|
||||||
|
|
||||||
|
[target.'cfg(target_os = "macos")'.dependencies]
|
||||||
|
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
|
@@ -4,22 +4,23 @@ extern crate accelerate_src;
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::ValueEnum;
|
|
||||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||||
|
use clap::ValueEnum;
|
||||||
|
|
||||||
// Removed gemma_cli import as it's not needed for the API
|
// Removed gemma_cli import as it's not needed for the API
|
||||||
use candle_core::{utils, DType, Device, Tensor};
|
use candle_core::{DType, Device, Tensor};
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
use std::sync::mpsc::{self, Receiver, Sender};
|
use std::sync::mpsc::{self, Receiver, Sender};
|
||||||
use std::thread;
|
use std::thread;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use utils::hub_load_safetensors;
|
||||||
|
use utils::token_output_stream::TokenOutputStream;
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
pub enum WhichModel {
|
pub enum WhichModel {
|
||||||
@@ -85,9 +86,9 @@ pub struct TextGeneration {
|
|||||||
fn device(cpu: bool) -> Result<Device> {
|
fn device(cpu: bool) -> Result<Device> {
|
||||||
if cpu {
|
if cpu {
|
||||||
Ok(Device::Cpu)
|
Ok(Device::Cpu)
|
||||||
} else if utils::cuda_is_available() {
|
} else if candle_core::utils::cuda_is_available() {
|
||||||
Ok(Device::new_cuda(0)?)
|
Ok(Device::new_cuda(0)?)
|
||||||
} else if utils::metal_is_available() {
|
} else if candle_core::utils::metal_is_available() {
|
||||||
Ok(Device::new_metal(0)?)
|
Ok(Device::new_metal(0)?)
|
||||||
} else {
|
} else {
|
||||||
Ok(Device::Cpu)
|
Ok(Device::Cpu)
|
||||||
@@ -98,7 +99,7 @@ impl TextGeneration {
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn new(
|
fn new(
|
||||||
model: Model,
|
model: Model,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: tokenizers::Tokenizer,
|
||||||
seed: u64,
|
seed: u64,
|
||||||
temp: Option<f64>,
|
temp: Option<f64>,
|
||||||
top_p: Option<f64>,
|
top_p: Option<f64>,
|
||||||
@@ -119,7 +120,12 @@ impl TextGeneration {
|
|||||||
|
|
||||||
/// Stream-only generation: sends freshly generated token strings over `tx`.
|
/// Stream-only generation: sends freshly generated token strings over `tx`.
|
||||||
/// (Does not send the prompt tokens; only newly generated model tokens.)
|
/// (Does not send the prompt tokens; only newly generated model tokens.)
|
||||||
fn run_stream(&mut self, prompt: &str, sample_len: usize, tx: Sender<Result<String>>) -> Result<()> {
|
fn run_stream(
|
||||||
|
&mut self,
|
||||||
|
prompt: &str,
|
||||||
|
sample_len: usize,
|
||||||
|
tx: Sender<Result<String>>,
|
||||||
|
) -> Result<()> {
|
||||||
self.tokenizer.clear();
|
self.tokenizer.clear();
|
||||||
|
|
||||||
// Encode prompt (context only; do not emit prompt tokens to the stream).
|
// Encode prompt (context only; do not emit prompt tokens to the stream).
|
||||||
@@ -257,10 +263,10 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
|||||||
|
|
||||||
println!(
|
println!(
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
utils::with_avx(),
|
candle_core::utils::with_avx(),
|
||||||
utils::with_neon(),
|
candle_core::utils::with_neon(),
|
||||||
utils::with_simd128(),
|
candle_core::utils::with_simd128(),
|
||||||
utils::with_f16c()
|
candle_core::utils::with_f16c()
|
||||||
);
|
);
|
||||||
|
|
||||||
let device = device(cfg.cpu)?;
|
let device = device(cfg.cpu)?;
|
||||||
@@ -303,7 +309,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
|||||||
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
|
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
|
||||||
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
|
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
|
||||||
}
|
}
|
||||||
.to_string()
|
.to_string()
|
||||||
});
|
});
|
||||||
|
|
||||||
println!("Loading model: {}", &model_id);
|
println!("Loading model: {}", &model_id);
|
||||||
@@ -313,7 +319,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
|||||||
let config_filename = repo.get("config.json")?;
|
let config_filename = repo.get("config.json")?;
|
||||||
let filenames = match cfg.model {
|
let filenames = match cfg.model {
|
||||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
||||||
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
};
|
};
|
||||||
println!("Retrieved files in {:?}", start.elapsed());
|
println!("Retrieved files in {:?}", start.elapsed());
|
||||||
|
|
||||||
@@ -337,7 +343,10 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
|||||||
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
|
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
|
||||||
Model::V1(model)
|
Model::V1(model)
|
||||||
}
|
}
|
||||||
WhichModel::BaseV2_2B | WhichModel::InstructV2_2B | WhichModel::BaseV2_9B | WhichModel::InstructV2_9B => {
|
WhichModel::BaseV2_2B
|
||||||
|
| WhichModel::InstructV2_2B
|
||||||
|
| WhichModel::BaseV2_9B
|
||||||
|
| WhichModel::InstructV2_9B => {
|
||||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||||
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
||||||
Model::V2(model)
|
Model::V2(model)
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
use std::io::Write;
|
|
||||||
use clap::Parser;
|
|
||||||
use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
|
use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
|
||||||
|
use clap::Parser;
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)]
|
#[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)]
|
||||||
@@ -94,4 +94,4 @@ pub fn run_cli() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@@ -2,8 +2,8 @@
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
mod gemma_cli;
|
|
||||||
mod gemma_api;
|
mod gemma_api;
|
||||||
|
mod gemma_cli;
|
||||||
|
|
||||||
use anyhow::Error;
|
use anyhow::Error;
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
@@ -14,4 +14,4 @@ use std::io::Write;
|
|||||||
/// just a placeholder, not used for anything
|
/// just a placeholder, not used for anything
|
||||||
fn main() -> std::result::Result<(), Error> {
|
fn main() -> std::result::Result<(), Error> {
|
||||||
run_cli()
|
run_cli()
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "helm-chart-tool"
|
name = "helm-chart-tool"
|
||||||
version = "0.1.0"
|
version.workspace = true
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
|
@@ -137,7 +137,7 @@ Parsing workspace at: ..
|
|||||||
Output directory: ../generated-helm-chart
|
Output directory: ../generated-helm-chart
|
||||||
Chart name: predict-otron-9000
|
Chart name: predict-otron-9000
|
||||||
Found 4 services:
|
Found 4 services:
|
||||||
- leptos-app: ghcr.io/geoffsee/leptos-app:latest (port 8788)
|
- chat-ui: ghcr.io/geoffsee/chat-ui:latest (port 8788)
|
||||||
- inference-engine: ghcr.io/geoffsee/inference-service:latest (port 8080)
|
- inference-engine: ghcr.io/geoffsee/inference-service:latest (port 8080)
|
||||||
- embeddings-engine: ghcr.io/geoffsee/embeddings-service:latest (port 8080)
|
- embeddings-engine: ghcr.io/geoffsee/embeddings-service:latest (port 8080)
|
||||||
- predict-otron-9000: ghcr.io/geoffsee/predict-otron-9000:latest (port 8080)
|
- predict-otron-9000: ghcr.io/geoffsee/predict-otron-9000:latest (port 8080)
|
||||||
|
@@ -84,7 +84,10 @@ fn main() -> Result<()> {
|
|||||||
let services = discover_services(workspace_path)?;
|
let services = discover_services(workspace_path)?;
|
||||||
println!("Found {} services:", services.len());
|
println!("Found {} services:", services.len());
|
||||||
for service in &services {
|
for service in &services {
|
||||||
println!(" - {}: {} (port {})", service.name, service.image, service.port);
|
println!(
|
||||||
|
" - {}: {} (port {})",
|
||||||
|
service.name, service.image, service.port
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
generate_helm_chart(output_path, chart_name, &services)?;
|
generate_helm_chart(output_path, chart_name, &services)?;
|
||||||
@@ -115,17 +118,20 @@ fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
|
|||||||
fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> {
|
fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> {
|
||||||
let content = fs::read_to_string(path)
|
let content = fs::read_to_string(path)
|
||||||
.with_context(|| format!("Failed to read Cargo.toml at {:?}", path))?;
|
.with_context(|| format!("Failed to read Cargo.toml at {:?}", path))?;
|
||||||
|
|
||||||
let cargo_toml: CargoToml = toml::from_str(&content)
|
let cargo_toml: CargoToml = toml::from_str(&content)
|
||||||
.with_context(|| format!("Failed to parse Cargo.toml at {:?}", path))?;
|
.with_context(|| format!("Failed to parse Cargo.toml at {:?}", path))?;
|
||||||
|
|
||||||
let package = cargo_toml.package
|
let package = cargo_toml
|
||||||
|
.package
|
||||||
.ok_or_else(|| anyhow::anyhow!("No package section found in {:?}", path))?;
|
.ok_or_else(|| anyhow::anyhow!("No package section found in {:?}", path))?;
|
||||||
|
|
||||||
let metadata = package.metadata
|
let metadata = package
|
||||||
|
.metadata
|
||||||
.ok_or_else(|| anyhow::anyhow!("No metadata section found in {:?}", path))?;
|
.ok_or_else(|| anyhow::anyhow!("No metadata section found in {:?}", path))?;
|
||||||
|
|
||||||
let kube_metadata = metadata.kube
|
let kube_metadata = metadata
|
||||||
|
.kube
|
||||||
.ok_or_else(|| anyhow::anyhow!("No kube metadata found in {:?}", path))?;
|
.ok_or_else(|| anyhow::anyhow!("No kube metadata found in {:?}", path))?;
|
||||||
|
|
||||||
Ok(ServiceInfo {
|
Ok(ServiceInfo {
|
||||||
@@ -136,7 +142,11 @@ fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn generate_helm_chart(output_path: &str, chart_name: &str, services: &[ServiceInfo]) -> Result<()> {
|
fn generate_helm_chart(
|
||||||
|
output_path: &str,
|
||||||
|
chart_name: &str,
|
||||||
|
services: &[ServiceInfo],
|
||||||
|
) -> Result<()> {
|
||||||
let chart_dir = Path::new(output_path);
|
let chart_dir = Path::new(output_path);
|
||||||
let templates_dir = chart_dir.join("templates");
|
let templates_dir = chart_dir.join("templates");
|
||||||
|
|
||||||
@@ -512,4 +522,4 @@ fn generate_helmignore(chart_dir: &Path) -> Result<()> {
|
|||||||
|
|
||||||
fs::write(chart_dir.join(".helmignore"), helmignore_content)?;
|
fs::write(chart_dir.join(".helmignore"), helmignore_content)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@@ -1,41 +1,15 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "inference-engine"
|
name = "inference-engine"
|
||||||
version = "0.1.0"
|
version.workspace = true
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
|
||||||
[[bin]]
|
|
||||||
name="gemma_inference"
|
|
||||||
path = "src/gemma_inference.rs"
|
|
||||||
required-features = ["bin"]
|
|
||||||
|
|
||||||
[[bin]]
|
|
||||||
name="llama_inference"
|
|
||||||
path = "src/llama_inference.rs"
|
|
||||||
required-features = ["bin"]
|
|
||||||
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { version = "0.3.2", optional = true }
|
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||||
candle-datasets = { version = "=0.9.1", optional = true }
|
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||||
candle-nn = { version = "=0.9.1" }
|
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
||||||
candle-transformers = { version = "=0.9.1" }
|
|
||||||
candle-flash-attn = { version = "=0.9.1", optional = true }
|
candle-flash-attn = { version = "=0.9.1", optional = true }
|
||||||
candle-onnx = { version = "=0.9.1", optional = true }
|
candle-onnx = { version = "=0.9.1", optional = true }
|
||||||
|
|
||||||
csv = "1.3.0"
|
|
||||||
cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false, optional = true }
|
|
||||||
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"], optional = true }
|
|
||||||
hf-hub = { version = "0.4.1", features = ["tokio"] }
|
|
||||||
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
|
|
||||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"], optional = true }
|
|
||||||
num-traits = { version = "0.2.15" }
|
|
||||||
palette = { version = "0.7.6", optional = true }
|
|
||||||
enterpolation = { version = "0.2.1", optional = true}
|
|
||||||
pyo3 = { version = "0.22.0", features = ["auto-initialize", "abi3-py311"], optional = true }
|
|
||||||
rayon = "1.7.0"
|
|
||||||
rubato = { version = "0.15.0", optional = true }
|
|
||||||
safetensors = "0.4.1"
|
|
||||||
serde = { version = "1.0.171", features = ["derive"] }
|
serde = { version = "1.0.171", features = ["derive"] }
|
||||||
serde_json = "1.0.99"
|
serde_json = "1.0.99"
|
||||||
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
||||||
@@ -57,22 +31,14 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
|||||||
uuid = { version = "1.7.0", features = ["v4"] }
|
uuid = { version = "1.7.0", features = ["v4"] }
|
||||||
reborrow = "0.5.5"
|
reborrow = "0.5.5"
|
||||||
futures-util = "0.3.31"
|
futures-util = "0.3.31"
|
||||||
gemma-runner = { path = "../gemma-runner" }
|
gemma-runner = { path = "../gemma-runner", features = ["metal"] }
|
||||||
llama-runner = { path = "../llama-runner" }
|
llama-runner = { path = "../llama-runner", features = ["metal"]}
|
||||||
|
|
||||||
# --- Add this section for conditional compilation ---
|
|
||||||
[target.'cfg(target_os = "macos")'.dependencies]
|
[target.'cfg(target_os = "macos")'.dependencies]
|
||||||
# Use CPU backend for macOS to avoid Metal rotary-emb implementation issues
|
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
candle-core = { version = "=0.9.1", features = ["metal"], optional = false }
|
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
|
||||||
[target.'cfg(not(target_os = "macos"))'.dependencies]
|
|
||||||
# For Linux or other non-macOS systems, you likely want the CPU backend or CUDA
|
|
||||||
# If you're building on Linux with a CUDA-enabled GPU:
|
|
||||||
candle-core = { version = "=0.9.1", features = ["cuda"], default-features = false } # Or just "cuda" if not using default features
|
|
||||||
|
|
||||||
# If you're building on Linux with only CPU:
|
|
||||||
# candle-core = { version = "=0.9.1", default-features = false } # CPU is often the default, but good to be explicit
|
|
||||||
# --- End of conditional compilation section ---
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
|
@@ -30,4 +30,4 @@ pub trait ModelInference {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Factory function type for creating model inference implementations
|
/// Factory function type for creating model inference implementations
|
||||||
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;
|
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;
|
||||||
|
@@ -1,19 +1,14 @@
|
|||||||
// Expose modules for testing and library usage
|
// Expose modules for testing and library usage
|
||||||
pub mod token_output_stream;
|
|
||||||
pub mod model;
|
pub mod model;
|
||||||
pub mod text_generation;
|
|
||||||
pub mod utilities_lib;
|
|
||||||
pub mod openai_types;
|
pub mod openai_types;
|
||||||
// pub mod cli;
|
// pub mod cli;
|
||||||
pub mod server;
|
|
||||||
pub mod inference;
|
pub mod inference;
|
||||||
|
pub mod server;
|
||||||
|
|
||||||
// Re-export key components for easier access
|
// Re-export key components for easier access
|
||||||
pub use model::{Model, Which};
|
|
||||||
pub use text_generation::TextGeneration;
|
|
||||||
pub use token_output_stream::TokenOutputStream;
|
|
||||||
pub use server::{AppState, create_router};
|
|
||||||
pub use inference::ModelInference;
|
pub use inference::ModelInference;
|
||||||
|
pub use model::{Model, Which};
|
||||||
|
pub use server::{create_router, AppState};
|
||||||
|
|
||||||
use std::env;
|
use std::env;
|
||||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||||
|
@@ -1,11 +1,53 @@
|
|||||||
// use candle_core::Tensor;
|
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
|
||||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||||
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum Model {
|
||||||
|
V1(Model1),
|
||||||
|
V2(Model2),
|
||||||
|
V3(Model3),
|
||||||
|
Llama(LlamaModel),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
input_ids: &candle_core::Tensor,
|
||||||
|
pos: usize,
|
||||||
|
) -> candle_core::Result<candle_core::Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::V1(m) => m.forward(input_ids, pos),
|
||||||
|
Self::V2(m) => m.forward(input_ids, pos),
|
||||||
|
Self::V3(m) => m.forward(input_ids, pos),
|
||||||
|
Self::Llama(m) => m.forward(input_ids, pos),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
|
pub enum Family {
|
||||||
|
GemmaV1,
|
||||||
|
GemmaV2,
|
||||||
|
GemmaV3,
|
||||||
|
Llama,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub struct ModelMeta {
|
||||||
|
pub id: &'static str,
|
||||||
|
pub family: Family,
|
||||||
|
pub instruct: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta {
|
||||||
|
ModelMeta { id, family, instruct }
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
pub enum Which {
|
pub enum Which {
|
||||||
|
// Gemma 1.x
|
||||||
#[value(name = "2b")]
|
#[value(name = "2b")]
|
||||||
Base2B,
|
Base2B,
|
||||||
#[value(name = "7b")]
|
#[value(name = "7b")]
|
||||||
@@ -18,6 +60,8 @@ pub enum Which {
|
|||||||
InstructV1_1_2B,
|
InstructV1_1_2B,
|
||||||
#[value(name = "1.1-7b-it")]
|
#[value(name = "1.1-7b-it")]
|
||||||
InstructV1_1_7B,
|
InstructV1_1_7B,
|
||||||
|
|
||||||
|
// CodeGemma
|
||||||
#[value(name = "code-2b")]
|
#[value(name = "code-2b")]
|
||||||
CodeBase2B,
|
CodeBase2B,
|
||||||
#[value(name = "code-7b")]
|
#[value(name = "code-7b")]
|
||||||
@@ -26,6 +70,8 @@ pub enum Which {
|
|||||||
CodeInstruct2B,
|
CodeInstruct2B,
|
||||||
#[value(name = "code-7b-it")]
|
#[value(name = "code-7b-it")]
|
||||||
CodeInstruct7B,
|
CodeInstruct7B,
|
||||||
|
|
||||||
|
// Gemma 2
|
||||||
#[value(name = "2-2b")]
|
#[value(name = "2-2b")]
|
||||||
BaseV2_2B,
|
BaseV2_2B,
|
||||||
#[value(name = "2-2b-it")]
|
#[value(name = "2-2b-it")]
|
||||||
@@ -34,70 +80,73 @@ pub enum Which {
|
|||||||
BaseV2_9B,
|
BaseV2_9B,
|
||||||
#[value(name = "2-9b-it")]
|
#[value(name = "2-9b-it")]
|
||||||
InstructV2_9B,
|
InstructV2_9B,
|
||||||
|
|
||||||
|
// Gemma 3
|
||||||
#[value(name = "3-1b")]
|
#[value(name = "3-1b")]
|
||||||
BaseV3_1B,
|
BaseV3_1B,
|
||||||
#[value(name = "3-1b-it")]
|
#[value(name = "3-1b-it")]
|
||||||
InstructV3_1B,
|
InstructV3_1B,
|
||||||
#[value(name = "llama-3.2-1b-it")]
|
|
||||||
LlamaInstruct3_2_1B,
|
|
||||||
#[value(name = "llama-3.2-3b-it")]
|
|
||||||
LlamaInstruct3_2_3B,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub enum Model {
|
// Llama 3.2 (use aliases instead of duplicate variants)
|
||||||
V1(Model1),
|
#[value(name = "llama-3.2-1b")]
|
||||||
V2(Model2),
|
Llama32_1B,
|
||||||
V3(Model3),
|
#[value(name = "llama-3.2-1b-it", alias = "llama-3.2-1b-instruct")]
|
||||||
Llama(LlamaModel),
|
Llama32_1BInstruct,
|
||||||
}
|
#[value(name = "llama-3.2-3b")]
|
||||||
|
Llama32_3B,
|
||||||
impl Model {
|
#[value(name = "llama-3.2-3b-it", alias = "llama-3.2-3b-instruct")]
|
||||||
pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
|
Llama32_3BInstruct,
|
||||||
match self {
|
|
||||||
Self::V1(m) => m.forward(input_ids, pos),
|
|
||||||
Self::V2(m) => m.forward(input_ids, pos),
|
|
||||||
Self::V3(m) => m.forward(input_ids, pos),
|
|
||||||
Self::Llama(m) => m.forward(input_ids, pos),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
pub fn to_model_id(&self) -> String {
|
pub const fn meta(&self) -> ModelMeta {
|
||||||
|
use Family::*;
|
||||||
match self {
|
match self {
|
||||||
Self::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
// Gemma 1.x
|
||||||
Self::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
Self::Base2B => m("google/gemma-2b", GemmaV1, false),
|
||||||
Self::Base2B => "google/gemma-2b".to_string(),
|
Self::Base7B => m("google/gemma-7b", GemmaV1, false),
|
||||||
Self::Base7B => "google/gemma-7b".to_string(),
|
Self::Instruct2B => m("google/gemma-2b-it", GemmaV1, true),
|
||||||
Self::Instruct2B => "google/gemma-2b-it".to_string(),
|
Self::Instruct7B => m("google/gemma-7b-it", GemmaV1, true),
|
||||||
Self::Instruct7B => "google/gemma-7b-it".to_string(),
|
Self::InstructV1_1_2B => m("google/gemma-1.1-2b-it", GemmaV1, true),
|
||||||
Self::CodeBase2B => "google/codegemma-2b".to_string(),
|
Self::InstructV1_1_7B => m("google/gemma-1.1-7b-it", GemmaV1, true),
|
||||||
Self::CodeBase7B => "google/codegemma-7b".to_string(),
|
|
||||||
Self::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
// CodeGemma
|
||||||
Self::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
Self::CodeBase2B => m("google/codegemma-2b", GemmaV1, false),
|
||||||
Self::BaseV2_2B => "google/gemma-2-2b".to_string(),
|
Self::CodeBase7B => m("google/codegemma-7b", GemmaV1, false),
|
||||||
Self::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
|
Self::CodeInstruct2B => m("google/codegemma-2b-it", GemmaV1, true),
|
||||||
Self::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
Self::CodeInstruct7B => m("google/codegemma-7b-it", GemmaV1, true),
|
||||||
Self::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
|
||||||
Self::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
// Gemma 2
|
||||||
Self::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
Self::BaseV2_2B => m("google/gemma-2-2b", GemmaV2, false),
|
||||||
Self::LlamaInstruct3_2_1B => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
|
Self::InstructV2_2B => m("google/gemma-2-2b-it", GemmaV2, true),
|
||||||
Self::LlamaInstruct3_2_3B => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
|
Self::BaseV2_9B => m("google/gemma-2-9b", GemmaV2, false),
|
||||||
|
Self::InstructV2_9B => m("google/gemma-2-9b-it", GemmaV2, true),
|
||||||
|
|
||||||
|
// Gemma 3
|
||||||
|
Self::BaseV3_1B => m("google/gemma-3-1b-pt", GemmaV3, false),
|
||||||
|
Self::InstructV3_1B => m("google/gemma-3-1b-it", GemmaV3, true),
|
||||||
|
|
||||||
|
// Llama 3.2
|
||||||
|
Self::Llama32_1B => m("meta-llama/Llama-3.2-1B", Llama, false),
|
||||||
|
Self::Llama32_1BInstruct => m("meta-llama/Llama-3.2-1B-Instruct", Llama, true),
|
||||||
|
Self::Llama32_3B => m("meta-llama/Llama-3.2-3B", Llama, false),
|
||||||
|
Self::Llama32_3BInstruct => m("meta-llama/Llama-3.2-3B-Instruct", Llama, true),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn to_model_id(&self) -> String {
|
||||||
|
self.meta().id.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn is_instruct_model(&self) -> bool {
|
pub fn is_instruct_model(&self) -> bool {
|
||||||
match self {
|
self.meta().instruct
|
||||||
Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false,
|
|
||||||
_ => true,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_v3_model(&self) -> bool {
|
pub fn is_v3_model(&self) -> bool {
|
||||||
matches!(self, Self::BaseV3_1B | Self::InstructV3_1B)
|
matches!(self.meta().family, Family::GemmaV3)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_llama_model(&self) -> bool {
|
pub fn is_llama_model(&self) -> bool {
|
||||||
matches!(self, Self::LlamaInstruct3_2_1B | Self::LlamaInstruct3_2_3B)
|
matches!(self.meta().family, Family::Llama)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
use either::Either;
|
use either::Either;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::json;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
@@ -10,7 +11,10 @@ pub struct MessageInnerContent(
|
|||||||
);
|
);
|
||||||
|
|
||||||
impl ToSchema<'_> for MessageInnerContent {
|
impl ToSchema<'_> for MessageInnerContent {
|
||||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
fn schema() -> (
|
||||||
|
&'static str,
|
||||||
|
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
|
||||||
|
) {
|
||||||
(
|
(
|
||||||
"MessageInnerContent",
|
"MessageInnerContent",
|
||||||
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
||||||
@@ -45,12 +49,18 @@ fn message_inner_content_schema() -> utoipa::openapi::Schema {
|
|||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct MessageContent(
|
pub struct MessageContent(
|
||||||
#[serde(with = "either::serde_untagged")]
|
#[serde(with = "either::serde_untagged")]
|
||||||
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
||||||
);
|
);
|
||||||
|
|
||||||
impl ToSchema<'_> for MessageContent {
|
impl ToSchema<'_> for MessageContent {
|
||||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
fn schema() -> (
|
||||||
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
|
&'static str,
|
||||||
|
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
|
||||||
|
) {
|
||||||
|
(
|
||||||
|
"MessageContent",
|
||||||
|
utoipa::openapi::RefOr::T(message_content_schema()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,4 +223,4 @@ pub struct ModelListResponse {
|
|||||||
pub object: String,
|
pub object: String,
|
||||||
/// Array of available models
|
/// Array of available models
|
||||||
pub data: Vec<Model>,
|
pub data: Vec<Model>,
|
||||||
}
|
}
|
||||||
|
@@ -6,19 +6,22 @@ use axum::{
|
|||||||
Json, Router,
|
Json, Router,
|
||||||
};
|
};
|
||||||
use futures_util::stream::{self, Stream};
|
use futures_util::stream::{self, Stream};
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{Mutex, mpsc};
|
use tokio::sync::{mpsc, Mutex};
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tower_http::cors::{Any, CorsLayer};
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
|
use crate::openai_types::{
|
||||||
|
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage,
|
||||||
|
};
|
||||||
use crate::Which;
|
use crate::Which;
|
||||||
use either::Either;
|
use either::Either;
|
||||||
use serde_json::Value;
|
|
||||||
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
|
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
|
||||||
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
|
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
|
||||||
|
use serde_json::Value;
|
||||||
// -------------------------
|
// -------------------------
|
||||||
// Shared app state
|
// Shared app state
|
||||||
// -------------------------
|
// -------------------------
|
||||||
@@ -39,13 +42,18 @@ pub struct AppState {
|
|||||||
|
|
||||||
impl Default for AppState {
|
impl Default for AppState {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
|
// Configure a default model to prevent 503 errors from the chat-ui
|
||||||
|
// This can be overridden by environment variables if needed
|
||||||
|
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
|
||||||
|
|
||||||
let gemma_config = GemmaInferenceConfig {
|
let gemma_config = GemmaInferenceConfig {
|
||||||
model: gemma_runner::WhichModel::InstructV3_1B,
|
model: gemma_runner::WhichModel::InstructV3_1B,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
model_type: ModelType::Gemma,
|
model_type: ModelType::Gemma,
|
||||||
model_id: "gemma-3-1b-it".to_string(),
|
model_id: default_model_id,
|
||||||
gemma_config: Some(gemma_config),
|
gemma_config: Some(gemma_config),
|
||||||
llama_config: None,
|
llama_config: None,
|
||||||
}
|
}
|
||||||
@@ -56,18 +64,49 @@ impl Default for AppState {
|
|||||||
// Helper functions
|
// Helper functions
|
||||||
// -------------------------
|
// -------------------------
|
||||||
|
|
||||||
|
fn model_id_to_which(model_id: &str) -> Option<Which> {
|
||||||
|
let normalized = normalize_model_id(model_id);
|
||||||
|
match normalized.as_str() {
|
||||||
|
"gemma-2b" => Some(Which::Base2B),
|
||||||
|
"gemma-7b" => Some(Which::Base7B),
|
||||||
|
"gemma-2b-it" => Some(Which::Instruct2B),
|
||||||
|
"gemma-7b-it" => Some(Which::Instruct7B),
|
||||||
|
"gemma-1.1-2b-it" => Some(Which::InstructV1_1_2B),
|
||||||
|
"gemma-1.1-7b-it" => Some(Which::InstructV1_1_7B),
|
||||||
|
"codegemma-2b" => Some(Which::CodeBase2B),
|
||||||
|
"codegemma-7b" => Some(Which::CodeBase7B),
|
||||||
|
"codegemma-2b-it" => Some(Which::CodeInstruct2B),
|
||||||
|
"codegemma-7b-it" => Some(Which::CodeInstruct7B),
|
||||||
|
"gemma-2-2b" => Some(Which::BaseV2_2B),
|
||||||
|
"gemma-2-2b-it" => Some(Which::InstructV2_2B),
|
||||||
|
"gemma-2-9b" => Some(Which::BaseV2_9B),
|
||||||
|
"gemma-2-9b-it" => Some(Which::InstructV2_9B),
|
||||||
|
"gemma-3-1b" => Some(Which::BaseV3_1B),
|
||||||
|
"gemma-3-1b-it" => Some(Which::InstructV3_1B),
|
||||||
|
"llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct),
|
||||||
|
"llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
fn normalize_model_id(model_id: &str) -> String {
|
fn normalize_model_id(model_id: &str) -> String {
|
||||||
model_id.to_lowercase().replace("_", "-")
|
model_id.to_lowercase().replace("_", "-")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_gemma_prompt(messages: &[Message]) -> String {
|
fn build_gemma_prompt(messages: &[Message]) -> String {
|
||||||
let mut prompt = String::new();
|
let mut prompt = String::new();
|
||||||
|
|
||||||
for message in messages {
|
for message in messages {
|
||||||
match message.role.as_str() {
|
match message.role.as_str() {
|
||||||
"system" => {
|
"system" => {
|
||||||
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
||||||
prompt.push_str(&format!("<start_of_turn>system\n{}<end_of_turn>\n", content));
|
prompt.push_str(&format!(
|
||||||
|
"<start_of_turn>system\n{}<end_of_turn>\n",
|
||||||
|
content
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"user" => {
|
"user" => {
|
||||||
@@ -83,7 +122,7 @@ fn build_gemma_prompt(messages: &[Message]) -> String {
|
|||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt.push_str("<start_of_turn>model\n");
|
prompt.push_str("<start_of_turn>model\n");
|
||||||
prompt
|
prompt
|
||||||
}
|
}
|
||||||
@@ -97,95 +136,88 @@ pub async fn chat_completions(
|
|||||||
Json(request): Json<ChatCompletionRequest>,
|
Json(request): Json<ChatCompletionRequest>,
|
||||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||||
if !request.stream.unwrap_or(false) {
|
if !request.stream.unwrap_or(false) {
|
||||||
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response());
|
return Ok(chat_completions_non_streaming_proxy(state, request)
|
||||||
|
.await
|
||||||
|
.into_response());
|
||||||
}
|
}
|
||||||
Ok(chat_completions_stream(state, request).await.into_response())
|
Ok(chat_completions_stream(state, request)
|
||||||
|
.await
|
||||||
|
.into_response())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn chat_completions_non_streaming_proxy(
|
pub async fn chat_completions_non_streaming_proxy(
|
||||||
state: AppState,
|
state: AppState,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||||
// Enforce model selection behavior: reject if a different model is requested
|
// Use the model specified in the request
|
||||||
let configured_model = state.model_id.clone();
|
let model_id = request.model.clone();
|
||||||
let requested_model = request.model.clone();
|
let which_model = model_id_to_which(&model_id);
|
||||||
if requested_model.to_lowercase() != "default" {
|
|
||||||
let normalized_requested = normalize_model_id(&requested_model);
|
// Validate that the requested model is supported
|
||||||
let normalized_configured = normalize_model_id(&configured_model);
|
let which_model = match which_model {
|
||||||
if normalized_requested != normalized_configured {
|
Some(model) => model,
|
||||||
|
None => {
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"error": {
|
"error": {
|
||||||
"message": format!(
|
"message": format!("Unsupported model: {}", model_id),
|
||||||
"Requested model '{}' is not available. This server is running '{}' only.",
|
"type": "model_not_supported"
|
||||||
requested_model, configured_model
|
|
||||||
),
|
|
||||||
"type": "model_mismatch"
|
|
||||||
}
|
}
|
||||||
})),
|
})),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
let model_id = state.model_id.clone();
|
|
||||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||||
|
|
||||||
// Build prompt based on model type
|
// Build prompt based on model type
|
||||||
let prompt = match state.model_type {
|
let prompt = if which_model.is_llama_model() {
|
||||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
// For Llama, just use the last user message for now
|
||||||
ModelType::Llama => {
|
request
|
||||||
// For Llama, just use the last user message for now
|
.messages
|
||||||
request.messages.last()
|
.last()
|
||||||
.and_then(|m| m.content.as_ref())
|
.and_then(|m| m.content.as_ref())
|
||||||
.and_then(|c| match c {
|
.and_then(|c| match c {
|
||||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
}
|
} else {
|
||||||
|
build_gemma_prompt(&request.messages)
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get streaming receiver based on model type
|
// Get streaming receiver based on model type
|
||||||
let rx = match state.model_type {
|
let rx = if which_model.is_llama_model() {
|
||||||
ModelType::Gemma => {
|
// Create Llama configuration dynamically
|
||||||
if let Some(mut config) = state.gemma_config {
|
let mut config = LlamaInferenceConfig::default();
|
||||||
config.prompt = prompt.clone();
|
config.prompt = prompt.clone();
|
||||||
config.max_tokens = max_tokens;
|
config.max_tokens = max_tokens;
|
||||||
run_gemma_api(config).map_err(|e| (
|
run_llama_inference(config).map_err(|e| (
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||||
}))
|
}))
|
||||||
))?
|
))?
|
||||||
} else {
|
} else {
|
||||||
return Err((
|
// Create Gemma configuration dynamically
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
let gemma_model = if which_model.is_v3_model() {
|
||||||
Json(serde_json::json!({
|
gemma_runner::WhichModel::InstructV3_1B
|
||||||
"error": { "message": "Gemma configuration not available" }
|
} else {
|
||||||
}))
|
gemma_runner::WhichModel::InstructV3_1B // Default fallback
|
||||||
));
|
};
|
||||||
}
|
|
||||||
}
|
let mut config = GemmaInferenceConfig {
|
||||||
ModelType::Llama => {
|
model: gemma_model,
|
||||||
if let Some(mut config) = state.llama_config {
|
..Default::default()
|
||||||
config.prompt = prompt.clone();
|
};
|
||||||
config.max_tokens = max_tokens;
|
config.prompt = prompt.clone();
|
||||||
run_llama_inference(config).map_err(|e| (
|
config.max_tokens = max_tokens;
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
run_gemma_api(config).map_err(|e| (
|
||||||
Json(serde_json::json!({
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
Json(serde_json::json!({
|
||||||
}))
|
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||||
))?
|
}))
|
||||||
} else {
|
))?
|
||||||
return Err((
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
Json(serde_json::json!({
|
|
||||||
"error": { "message": "Llama configuration not available" }
|
|
||||||
}))
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Collect all tokens from the stream
|
// Collect all tokens from the stream
|
||||||
@@ -245,27 +277,25 @@ async fn handle_streaming_request(
|
|||||||
state: AppState,
|
state: AppState,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
|
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
|
||||||
// Validate requested model vs configured model
|
// Use the model specified in the request
|
||||||
let configured_model = state.model_id.clone();
|
let model_id = request.model.clone();
|
||||||
let requested_model = request.model.clone();
|
let which_model = model_id_to_which(&model_id);
|
||||||
if requested_model.to_lowercase() != "default" {
|
|
||||||
let normalized_requested = normalize_model_id(&requested_model);
|
// Validate that the requested model is supported
|
||||||
let normalized_configured = normalize_model_id(&configured_model);
|
let which_model = match which_model {
|
||||||
if normalized_requested != normalized_configured {
|
Some(model) => model,
|
||||||
|
None => {
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"error": {
|
"error": {
|
||||||
"message": format!(
|
"message": format!("Unsupported model: {}", model_id),
|
||||||
"Requested model '{}' is not available. This server is running '{}' only.",
|
"type": "model_not_supported"
|
||||||
requested_model, configured_model
|
|
||||||
),
|
|
||||||
"type": "model_mismatch"
|
|
||||||
}
|
}
|
||||||
})),
|
})),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
// Generate a unique ID and metadata
|
// Generate a unique ID and metadata
|
||||||
let response_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', ""));
|
let response_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', ""));
|
||||||
@@ -273,22 +303,22 @@ async fn handle_streaming_request(
|
|||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
.as_secs();
|
.as_secs();
|
||||||
let model_id = state.model_id.clone();
|
|
||||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||||
|
|
||||||
// Build prompt based on model type
|
// Build prompt based on model type
|
||||||
let prompt = match state.model_type {
|
let prompt = if which_model.is_llama_model() {
|
||||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
// For Llama, just use the last user message for now
|
||||||
ModelType::Llama => {
|
request
|
||||||
// For Llama, just use the last user message for now
|
.messages
|
||||||
request.messages.last()
|
.last()
|
||||||
.and_then(|m| m.content.as_ref())
|
.and_then(|m| m.content.as_ref())
|
||||||
.and_then(|c| match c {
|
.and_then(|c| match c {
|
||||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
}
|
} else {
|
||||||
|
build_gemma_prompt(&request.messages)
|
||||||
};
|
};
|
||||||
tracing::debug!("Formatted prompt: {}", prompt);
|
tracing::debug!("Formatted prompt: {}", prompt);
|
||||||
|
|
||||||
@@ -303,7 +333,10 @@ async fn handle_streaming_request(
|
|||||||
model: model_id.clone(),
|
model: model_id.clone(),
|
||||||
choices: vec![ChatCompletionChunkChoice {
|
choices: vec![ChatCompletionChunkChoice {
|
||||||
index: 0,
|
index: 0,
|
||||||
delta: Delta { role: Some("assistant".to_string()), content: None },
|
delta: Delta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: None,
|
||||||
|
},
|
||||||
finish_reason: None,
|
finish_reason: None,
|
||||||
}],
|
}],
|
||||||
};
|
};
|
||||||
@@ -312,52 +345,44 @@ async fn handle_streaming_request(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get streaming receiver based on model type
|
// Get streaming receiver based on model type
|
||||||
let model_rx = match state.model_type {
|
let model_rx = if which_model.is_llama_model() {
|
||||||
ModelType::Gemma => {
|
// Create Llama configuration dynamically
|
||||||
if let Some(mut config) = state.gemma_config {
|
let mut config = LlamaInferenceConfig::default();
|
||||||
config.prompt = prompt.clone();
|
config.prompt = prompt.clone();
|
||||||
config.max_tokens = max_tokens;
|
config.max_tokens = max_tokens;
|
||||||
match run_gemma_api(config) {
|
match run_llama_inference(config) {
|
||||||
Ok(rx) => rx,
|
Ok(rx) => rx,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return Err((
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
Json(serde_json::json!({
|
|
||||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
|
||||||
}))
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"error": { "message": "Gemma configuration not available" }
|
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||||
}))
|
})),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ModelType::Llama => {
|
} else {
|
||||||
if let Some(mut config) = state.llama_config {
|
// Create Gemma configuration dynamically
|
||||||
config.prompt = prompt.clone();
|
let gemma_model = if which_model.is_v3_model() {
|
||||||
config.max_tokens = max_tokens;
|
gemma_runner::WhichModel::InstructV3_1B
|
||||||
match run_llama_inference(config) {
|
} else {
|
||||||
Ok(rx) => rx,
|
gemma_runner::WhichModel::InstructV3_1B // Default fallback
|
||||||
Err(e) => {
|
};
|
||||||
return Err((
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
let mut config = GemmaInferenceConfig {
|
||||||
Json(serde_json::json!({
|
model: gemma_model,
|
||||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
..Default::default()
|
||||||
}))
|
};
|
||||||
));
|
config.prompt = prompt.clone();
|
||||||
}
|
config.max_tokens = max_tokens;
|
||||||
}
|
match run_gemma_api(config) {
|
||||||
} else {
|
Ok(rx) => rx,
|
||||||
|
Err(e) => {
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"error": { "message": "Llama configuration not available" }
|
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||||
}))
|
})),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -386,16 +411,20 @@ async fn handle_streaming_request(
|
|||||||
if recent_tokens.len() > REPETITION_WINDOW {
|
if recent_tokens.len() > REPETITION_WINDOW {
|
||||||
recent_tokens.remove(0);
|
recent_tokens.remove(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for repetitive patterns
|
// Check for repetitive patterns
|
||||||
if recent_tokens.len() >= 4 {
|
if recent_tokens.len() >= 4 {
|
||||||
let last_token = &recent_tokens[recent_tokens.len() - 1];
|
let last_token = &recent_tokens[recent_tokens.len() - 1];
|
||||||
let second_last = &recent_tokens[recent_tokens.len() - 2];
|
let second_last = &recent_tokens[recent_tokens.len() - 2];
|
||||||
|
|
||||||
if last_token == second_last {
|
if last_token == second_last {
|
||||||
repetition_count += 1;
|
repetition_count += 1;
|
||||||
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
|
tracing::warn!(
|
||||||
|
"Detected repetition pattern: '{}' (count: {})",
|
||||||
|
last_token,
|
||||||
|
repetition_count
|
||||||
|
);
|
||||||
|
|
||||||
if repetition_count >= MAX_REPETITION_COUNT {
|
if repetition_count >= MAX_REPETITION_COUNT {
|
||||||
tracing::info!("Stopping generation due to excessive repetition");
|
tracing::info!("Stopping generation due to excessive repetition");
|
||||||
break;
|
break;
|
||||||
@@ -412,11 +441,14 @@ async fn handle_streaming_request(
|
|||||||
model: model_id_clone.clone(),
|
model: model_id_clone.clone(),
|
||||||
choices: vec![ChatCompletionChunkChoice {
|
choices: vec![ChatCompletionChunkChoice {
|
||||||
index: 0,
|
index: 0,
|
||||||
delta: Delta { role: None, content: Some(token) },
|
delta: Delta {
|
||||||
|
role: None,
|
||||||
|
content: Some(token),
|
||||||
|
},
|
||||||
finish_reason: None,
|
finish_reason: None,
|
||||||
}],
|
}],
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Ok(json) = serde_json::to_string(&chunk) {
|
if let Ok(json) = serde_json::to_string(&chunk) {
|
||||||
let _ = tx.send(Ok(Event::default().data(json)));
|
let _ = tx.send(Ok(Event::default().data(json)));
|
||||||
}
|
}
|
||||||
@@ -436,7 +468,10 @@ async fn handle_streaming_request(
|
|||||||
model: model_id_clone.clone(),
|
model: model_id_clone.clone(),
|
||||||
choices: vec![ChatCompletionChunkChoice {
|
choices: vec![ChatCompletionChunkChoice {
|
||||||
index: 0,
|
index: 0,
|
||||||
delta: Delta { role: None, content: None },
|
delta: Delta {
|
||||||
|
role: None,
|
||||||
|
content: None,
|
||||||
|
},
|
||||||
finish_reason: Some("stop".to_string()),
|
finish_reason: Some("stop".to_string()),
|
||||||
}],
|
}],
|
||||||
};
|
};
|
||||||
@@ -451,8 +486,6 @@ async fn handle_streaming_request(
|
|||||||
Ok(Sse::new(stream))
|
Ok(Sse::new(stream))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// -------------------------
|
// -------------------------
|
||||||
// Router
|
// Router
|
||||||
// -------------------------
|
// -------------------------
|
||||||
@@ -474,172 +507,69 @@ pub fn create_router(app_state: AppState) -> Router {
|
|||||||
/// Handler for GET /v1/models - returns list of available models
|
/// Handler for GET /v1/models - returns list of available models
|
||||||
pub async fn list_models() -> Json<ModelListResponse> {
|
pub async fn list_models() -> Json<ModelListResponse> {
|
||||||
// Get all available model variants from the Which enum
|
// Get all available model variants from the Which enum
|
||||||
let models = vec![
|
let which_variants = vec![
|
||||||
// Gemma models
|
Which::Base2B,
|
||||||
|
Which::Base7B,
|
||||||
|
Which::Instruct2B,
|
||||||
|
Which::Instruct7B,
|
||||||
|
Which::InstructV1_1_2B,
|
||||||
|
Which::InstructV1_1_7B,
|
||||||
|
Which::CodeBase2B,
|
||||||
|
Which::CodeBase7B,
|
||||||
|
Which::CodeInstruct2B,
|
||||||
|
Which::CodeInstruct7B,
|
||||||
|
Which::BaseV2_2B,
|
||||||
|
Which::InstructV2_2B,
|
||||||
|
Which::BaseV2_9B,
|
||||||
|
Which::InstructV2_9B,
|
||||||
|
Which::BaseV3_1B,
|
||||||
|
Which::InstructV3_1B,
|
||||||
|
Which::Llama32_1B,
|
||||||
|
Which::Llama32_1BInstruct,
|
||||||
|
Which::Llama32_3B,
|
||||||
|
Which::Llama32_3BInstruct,
|
||||||
|
];
|
||||||
|
|
||||||
|
let models: Vec<Model> = which_variants.into_iter().map(|which| {
|
||||||
|
let meta = which.meta();
|
||||||
|
let model_id = match which {
|
||||||
|
Which::Base2B => "gemma-2b",
|
||||||
|
Which::Base7B => "gemma-7b",
|
||||||
|
Which::Instruct2B => "gemma-2b-it",
|
||||||
|
Which::Instruct7B => "gemma-7b-it",
|
||||||
|
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
|
||||||
|
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
|
||||||
|
Which::CodeBase2B => "codegemma-2b",
|
||||||
|
Which::CodeBase7B => "codegemma-7b",
|
||||||
|
Which::CodeInstruct2B => "codegemma-2b-it",
|
||||||
|
Which::CodeInstruct7B => "codegemma-7b-it",
|
||||||
|
Which::BaseV2_2B => "gemma-2-2b",
|
||||||
|
Which::InstructV2_2B => "gemma-2-2b-it",
|
||||||
|
Which::BaseV2_9B => "gemma-2-9b",
|
||||||
|
Which::InstructV2_9B => "gemma-2-9b-it",
|
||||||
|
Which::BaseV3_1B => "gemma-3-1b",
|
||||||
|
Which::InstructV3_1B => "gemma-3-1b-it",
|
||||||
|
Which::Llama32_1B => "llama-3.2-1b",
|
||||||
|
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
|
||||||
|
Which::Llama32_3B => "llama-3.2-3b",
|
||||||
|
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
|
||||||
|
};
|
||||||
|
|
||||||
|
let owned_by = if meta.id.starts_with("google/") {
|
||||||
|
"google"
|
||||||
|
} else if meta.id.starts_with("meta-llama/") {
|
||||||
|
"meta"
|
||||||
|
} else {
|
||||||
|
"unknown"
|
||||||
|
};
|
||||||
|
|
||||||
Model {
|
Model {
|
||||||
id: "gemma-2b".to_string(),
|
id: model_id.to_string(),
|
||||||
object: "model".to_string(),
|
object: "model".to_string(),
|
||||||
created: 1686935002, // Using same timestamp as OpenAI example
|
created: 1686935002, // Using same timestamp as OpenAI example
|
||||||
owned_by: "google".to_string(),
|
owned_by: owned_by.to_string(),
|
||||||
},
|
}
|
||||||
Model {
|
}).collect();
|
||||||
id: "gemma-7b".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "google".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "gemma-2b-it".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "google".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "gemma-7b-it".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "google".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "gemma-1.1-2b-it".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "google".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "gemma-1.1-7b-it".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "google".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "codegemma-2b".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "google".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "codegemma-7b".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "google".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "codegemma-2b-it".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "google".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "codegemma-7b-it".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "google".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "gemma-2-2b".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "google".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "gemma-2-2b-it".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "google".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "gemma-2-9b".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "google".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "gemma-2-9b-it".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "google".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "gemma-3-1b".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "google".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "gemma-3-1b-it".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "google".to_string(),
|
|
||||||
},
|
|
||||||
// Llama models
|
|
||||||
Model {
|
|
||||||
id: "llama-3.2-1b".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "meta".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "llama-3.2-1b-instruct".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "meta".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "llama-3.2-3b".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "meta".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "llama-3.2-3b-instruct".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "meta".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "smollm2-135m".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "huggingface".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "smollm2-135m-instruct".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "huggingface".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "smollm2-360m".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "huggingface".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "smollm2-360m-instruct".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "huggingface".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "smollm2-1.7b".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "huggingface".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "smollm2-1.7b-instruct".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "huggingface".to_string(),
|
|
||||||
},
|
|
||||||
Model {
|
|
||||||
id: "tinyllama-1.1b-chat".to_string(),
|
|
||||||
object: "model".to_string(),
|
|
||||||
created: 1686935002,
|
|
||||||
owned_by: "tinyllama".to_string(),
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
Json(ModelListResponse {
|
Json(ModelListResponse {
|
||||||
object: "list".to_string(),
|
object: "list".to_string(),
|
||||||
@@ -647,7 +577,6 @@ pub async fn list_models() -> Json<ModelListResponse> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -681,10 +610,7 @@ mod tests {
|
|||||||
|
|
||||||
let prompt = build_gemma_prompt(&messages);
|
let prompt = build_gemma_prompt(&messages);
|
||||||
|
|
||||||
let expected = "<start_of_turn>user\nSystem message\n\nKnock knock.<end_of_turn>\n\
|
let expected = "<start_of_turn>system\nSystem message<end_of_turn>\n<start_of_turn>user\nKnock knock.<end_of_turn>\n<start_of_turn>model\nWho's there?<end_of_turn>\n<start_of_turn>user\nGemma.<end_of_turn>\n<start_of_turn>model\n";
|
||||||
<start_of_turn>model\nWho's there?<end_of_turn>\n\
|
|
||||||
<start_of_turn>user\nGemma.<end_of_turn>\n\
|
|
||||||
<start_of_turn>model\n";
|
|
||||||
|
|
||||||
assert_eq!(prompt, expected);
|
assert_eq!(prompt, expected);
|
||||||
}
|
}
|
||||||
@@ -698,15 +624,13 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_missing_content() {
|
fn test_missing_content() {
|
||||||
let messages = vec![
|
let messages = vec![Message {
|
||||||
Message {
|
role: "user".to_string(),
|
||||||
role: "user".to_string(),
|
content: None,
|
||||||
content: None,
|
name: None,
|
||||||
name: None,
|
}];
|
||||||
}
|
|
||||||
];
|
|
||||||
|
|
||||||
let prompt = build_gemma_prompt(&messages);
|
let prompt = build_gemma_prompt(&messages);
|
||||||
assert_eq!(prompt, "<start_of_turn>user\n<end_of_turn>\n<start_of_turn>model\n");
|
assert_eq!(prompt, "<start_of_turn>model\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
File diff suppressed because it is too large
Load Diff
@@ -9,7 +9,10 @@ mod tests {
|
|||||||
// Test a few representative model variants
|
// Test a few representative model variants
|
||||||
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
|
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
|
||||||
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
|
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
|
||||||
assert_eq!(Which::InstructV1_1_2B.to_model_id(), "google/gemma-1.1-2b-it");
|
assert_eq!(
|
||||||
|
Which::InstructV1_1_2B.to_model_id(),
|
||||||
|
"google/gemma-1.1-2b-it"
|
||||||
|
);
|
||||||
assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b");
|
assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b");
|
||||||
assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b");
|
assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b");
|
||||||
assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it");
|
assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it");
|
||||||
@@ -64,4 +67,4 @@ mod tests {
|
|||||||
// Note: Testing the Model enum's forward method would require creating actual model instances,
|
// Note: Testing the Model enum's forward method would require creating actual model instances,
|
||||||
// which is complex and would require loading model weights. This is better suited for
|
// which is complex and would require loading model weights. This is better suited for
|
||||||
// integration tests or mocking the models.
|
// integration tests or mocking the models.
|
||||||
}
|
}
|
||||||
|
@@ -1,549 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use candle_core::{Device, Tensor};
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use inference_engine::model::Which;
|
|
||||||
use inference_engine::text_generation::TextGeneration;
|
|
||||||
use inference_engine::token_output_stream::TokenOutputStream;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
// Helper function to create a simple tokenizer for testing
|
|
||||||
fn create_test_tokenizer() -> Result<Tokenizer> {
|
|
||||||
// Create a simple tokenizer from the pretrained model
|
|
||||||
// This uses the tokenizer from the Hugging Face hub
|
|
||||||
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
|
|
||||||
Ok(tokenizer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test the Which enum's to_model_id method
|
|
||||||
#[test]
|
|
||||||
fn test_which_model_id() {
|
|
||||||
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
|
|
||||||
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test the Which enum's is_instruct_model method
|
|
||||||
#[test]
|
|
||||||
fn test_which_is_instruct() {
|
|
||||||
assert!(!Which::Base2B.is_instruct_model());
|
|
||||||
assert!(Which::Instruct7B.is_instruct_model());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test the Which enum's is_v3_model method
|
|
||||||
#[test]
|
|
||||||
fn test_which_is_v3() {
|
|
||||||
assert!(!Which::Base2B.is_v3_model());
|
|
||||||
assert!(Which::BaseV3_1B.is_v3_model());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test the TokenOutputStream functionality
|
|
||||||
#[test]
|
|
||||||
fn test_token_output_stream() -> Result<()> {
|
|
||||||
let tokenizer = create_test_tokenizer()?;
|
|
||||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
// Test encoding and decoding
|
|
||||||
let text = "Hello, world!";
|
|
||||||
let encoded = token_stream.tokenizer().encode(text, true).unwrap();
|
|
||||||
let token_ids = encoded.get_ids();
|
|
||||||
|
|
||||||
// Add tokens one by one
|
|
||||||
for &token_id in token_ids {
|
|
||||||
token_stream.next_token(token_id)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode all and check
|
|
||||||
let decoded = token_stream.decode_all()?;
|
|
||||||
assert_eq!(decoded.trim(), text);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test the LogitsProcessor
|
|
||||||
#[test]
|
|
||||||
fn test_logits_processor() -> Result<()> {
|
|
||||||
// Create a LogitsProcessor with default settings
|
|
||||||
let seed = 42;
|
|
||||||
let temp = Some(0.8);
|
|
||||||
let top_p = Some(0.9);
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
|
|
||||||
// Create a simple logits tensor
|
|
||||||
// In a real test, we would create a tensor with known values and verify
|
|
||||||
// that sampling produces expected results
|
|
||||||
|
|
||||||
// For now, we'll just verify that the LogitsProcessor can be created
|
|
||||||
assert!(true);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test the TextGeneration constructor
|
|
||||||
#[test]
|
|
||||||
fn test_text_generation_constructor() -> Result<()> {
|
|
||||||
// We can't easily create a Model instance for testing,
|
|
||||||
// but we can test that the constructor compiles and the types are correct
|
|
||||||
|
|
||||||
// In a real test with a mock Model, we would:
|
|
||||||
// 1. Create a mock model
|
|
||||||
// 2. Create a tokenizer
|
|
||||||
// 3. Call TextGeneration::new
|
|
||||||
// 4. Verify the properties of the created instance
|
|
||||||
|
|
||||||
// For now, we'll just verify that the code compiles
|
|
||||||
assert!(true);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test apply_cached_repeat_penalty method with no penalty
|
|
||||||
#[test]
|
|
||||||
fn test_apply_cached_repeat_penalty_no_penalty() -> Result<()> {
|
|
||||||
// Create a simple test setup
|
|
||||||
let device = Device::Cpu;
|
|
||||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
|
||||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
|
||||||
let tokens = vec![1u32, 2u32, 3u32];
|
|
||||||
|
|
||||||
// Create a mock TextGeneration instance
|
|
||||||
// Since we can't easily create a full TextGeneration instance without a model,
|
|
||||||
// we'll test the logic by creating a simple struct with the necessary fields
|
|
||||||
struct MockTextGeneration {
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
penalty_cache: HashMap<usize, f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MockTextGeneration {
|
|
||||||
fn apply_cached_repeat_penalty(
|
|
||||||
&mut self,
|
|
||||||
logits: Tensor,
|
|
||||||
tokens: &[u32],
|
|
||||||
) -> Result<(Tensor, std::time::Duration)> {
|
|
||||||
let repeat_start = std::time::Instant::now();
|
|
||||||
|
|
||||||
// If no penalty, return the original logits
|
|
||||||
if self.repeat_penalty == 1.0 {
|
|
||||||
return Ok((logits, repeat_start.elapsed()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the tokens to penalize (the last n tokens)
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
let penalty_tokens = &tokens[start_at..];
|
|
||||||
|
|
||||||
// Extract logits to a vector for modification
|
|
||||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
||||||
let cache_hits = std::cell::Cell::new(0);
|
|
||||||
|
|
||||||
// Apply penalties with caching
|
|
||||||
for &token_id in penalty_tokens {
|
|
||||||
let token_id = token_id as usize;
|
|
||||||
if token_id < logits_vec.len() {
|
|
||||||
// Check if we've already calculated this token's penalty
|
|
||||||
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
|
|
||||||
// Use cached value
|
|
||||||
logits_vec[token_id] = *penalized_score;
|
|
||||||
cache_hits.set(cache_hits.get() + 1);
|
|
||||||
} else {
|
|
||||||
// Calculate and cache new value
|
|
||||||
let score = logits_vec[token_id];
|
|
||||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
||||||
let penalized_score = sign * score / self.repeat_penalty;
|
|
||||||
logits_vec[token_id] = penalized_score;
|
|
||||||
self.penalty_cache.insert(token_id, penalized_score);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new tensor with the modified logits
|
|
||||||
let device = logits.device().clone();
|
|
||||||
let shape = logits.shape().clone();
|
|
||||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
|
||||||
let result = new_logits.reshape(shape)?;
|
|
||||||
|
|
||||||
let elapsed = repeat_start.elapsed();
|
|
||||||
Ok((result, elapsed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut mock_gen = MockTextGeneration {
|
|
||||||
repeat_penalty: 1.0, // No penalty
|
|
||||||
repeat_last_n: 3,
|
|
||||||
penalty_cache: HashMap::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
|
||||||
let result_data = result_logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
// With no penalty, logits should be unchanged
|
|
||||||
assert_eq!(result_data, logits_data);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test apply_cached_repeat_penalty method with penalty
|
|
||||||
#[test]
|
|
||||||
fn test_apply_cached_repeat_penalty_with_penalty() -> Result<()> {
|
|
||||||
let device = Device::Cpu;
|
|
||||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
|
||||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
|
||||||
let tokens = vec![1u32, 2u32, 3u32];
|
|
||||||
|
|
||||||
struct MockTextGeneration {
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
penalty_cache: HashMap<usize, f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MockTextGeneration {
|
|
||||||
fn apply_cached_repeat_penalty(
|
|
||||||
&mut self,
|
|
||||||
logits: Tensor,
|
|
||||||
tokens: &[u32],
|
|
||||||
) -> Result<(Tensor, std::time::Duration)> {
|
|
||||||
let repeat_start = std::time::Instant::now();
|
|
||||||
|
|
||||||
if self.repeat_penalty == 1.0 {
|
|
||||||
return Ok((logits, repeat_start.elapsed()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
let penalty_tokens = &tokens[start_at..];
|
|
||||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
||||||
let cache_hits = std::cell::Cell::new(0);
|
|
||||||
|
|
||||||
for &token_id in penalty_tokens {
|
|
||||||
let token_id = token_id as usize;
|
|
||||||
if token_id < logits_vec.len() {
|
|
||||||
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
|
|
||||||
logits_vec[token_id] = *penalized_score;
|
|
||||||
cache_hits.set(cache_hits.get() + 1);
|
|
||||||
} else {
|
|
||||||
let score = logits_vec[token_id];
|
|
||||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
||||||
let penalized_score = sign * score / self.repeat_penalty;
|
|
||||||
logits_vec[token_id] = penalized_score;
|
|
||||||
self.penalty_cache.insert(token_id, penalized_score);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let device = logits.device().clone();
|
|
||||||
let shape = logits.shape().clone();
|
|
||||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
|
||||||
let result = new_logits.reshape(shape)?;
|
|
||||||
|
|
||||||
let elapsed = repeat_start.elapsed();
|
|
||||||
Ok((result, elapsed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut mock_gen = MockTextGeneration {
|
|
||||||
repeat_penalty: 2.0, // Apply penalty
|
|
||||||
repeat_last_n: 3,
|
|
||||||
penalty_cache: HashMap::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
|
||||||
let result_data = result_logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
// Tokens 1, 2, 3 should be penalized (divided by 2.0)
|
|
||||||
let expected = vec![1.0f32, 1.0, 1.5, 2.0, 5.0]; // [1.0, 2.0/2.0, 3.0/2.0, 4.0/2.0, 5.0]
|
|
||||||
assert_eq!(result_data, expected);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test apply_cached_repeat_penalty caching behavior
|
|
||||||
#[test]
|
|
||||||
fn test_apply_cached_repeat_penalty_caching() -> Result<()> {
|
|
||||||
let device = Device::Cpu;
|
|
||||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
|
||||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
|
||||||
let tokens = vec![1u32, 1u32, 1u32]; // Repeated token should use cache
|
|
||||||
|
|
||||||
struct MockTextGeneration {
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
penalty_cache: HashMap<usize, f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MockTextGeneration {
|
|
||||||
fn apply_cached_repeat_penalty(
|
|
||||||
&mut self,
|
|
||||||
logits: Tensor,
|
|
||||||
tokens: &[u32],
|
|
||||||
) -> Result<(Tensor, std::time::Duration)> {
|
|
||||||
let repeat_start = std::time::Instant::now();
|
|
||||||
|
|
||||||
if self.repeat_penalty == 1.0 {
|
|
||||||
return Ok((logits, repeat_start.elapsed()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
let penalty_tokens = &tokens[start_at..];
|
|
||||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
for &token_id in penalty_tokens {
|
|
||||||
let token_id = token_id as usize;
|
|
||||||
if token_id < logits_vec.len() {
|
|
||||||
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
|
|
||||||
logits_vec[token_id] = *penalized_score;
|
|
||||||
} else {
|
|
||||||
let score = logits_vec[token_id];
|
|
||||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
||||||
let penalized_score = sign * score / self.repeat_penalty;
|
|
||||||
logits_vec[token_id] = penalized_score;
|
|
||||||
self.penalty_cache.insert(token_id, penalized_score);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let device = logits.device().clone();
|
|
||||||
let shape = logits.shape().clone();
|
|
||||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
|
||||||
let result = new_logits.reshape(shape)?;
|
|
||||||
|
|
||||||
let elapsed = repeat_start.elapsed();
|
|
||||||
Ok((result, elapsed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut mock_gen = MockTextGeneration {
|
|
||||||
repeat_penalty: 2.0,
|
|
||||||
repeat_last_n: 3,
|
|
||||||
penalty_cache: HashMap::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// First call should cache the penalty for token 1
|
|
||||||
let (_result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
|
||||||
|
|
||||||
// Cache should contain the penalized value for token 1
|
|
||||||
assert!(mock_gen.penalty_cache.contains_key(&1));
|
|
||||||
assert_eq!(mock_gen.penalty_cache.get(&1), Some(&1.0)); // 2.0 / 2.0 = 1.0
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test edge case: empty tokens array
|
|
||||||
#[test]
|
|
||||||
fn test_apply_cached_repeat_penalty_empty_tokens() -> Result<()> {
|
|
||||||
let device = Device::Cpu;
|
|
||||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
|
||||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
|
||||||
let tokens: Vec<u32> = vec![]; // Empty tokens
|
|
||||||
|
|
||||||
struct MockTextGeneration {
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
penalty_cache: HashMap<usize, f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MockTextGeneration {
|
|
||||||
fn apply_cached_repeat_penalty(
|
|
||||||
&mut self,
|
|
||||||
logits: Tensor,
|
|
||||||
tokens: &[u32],
|
|
||||||
) -> Result<(Tensor, std::time::Duration)> {
|
|
||||||
let repeat_start = std::time::Instant::now();
|
|
||||||
|
|
||||||
if self.repeat_penalty == 1.0 {
|
|
||||||
return Ok((logits, repeat_start.elapsed()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
let penalty_tokens = &tokens[start_at..];
|
|
||||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
for &token_id in penalty_tokens {
|
|
||||||
let token_id = token_id as usize;
|
|
||||||
if token_id < logits_vec.len() {
|
|
||||||
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
|
|
||||||
logits_vec[token_id] = *penalized_score;
|
|
||||||
} else {
|
|
||||||
let score = logits_vec[token_id];
|
|
||||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
||||||
let penalized_score = sign * score / self.repeat_penalty;
|
|
||||||
logits_vec[token_id] = penalized_score;
|
|
||||||
self.penalty_cache.insert(token_id, penalized_score);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let device = logits.device().clone();
|
|
||||||
let shape = logits.shape().clone();
|
|
||||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
|
||||||
let result = new_logits.reshape(shape)?;
|
|
||||||
|
|
||||||
let elapsed = repeat_start.elapsed();
|
|
||||||
Ok((result, elapsed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut mock_gen = MockTextGeneration {
|
|
||||||
repeat_penalty: 2.0,
|
|
||||||
repeat_last_n: 3,
|
|
||||||
penalty_cache: HashMap::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
|
||||||
let result_data = result_logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
// With empty tokens, logits should be unchanged
|
|
||||||
assert_eq!(result_data, logits_data);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test edge case: out-of-bounds token IDs
|
|
||||||
#[test]
|
|
||||||
fn test_apply_cached_repeat_penalty_out_of_bounds() -> Result<()> {
|
|
||||||
let device = Device::Cpu;
|
|
||||||
let logits_data = vec![1.0f32, 2.0, 3.0];
|
|
||||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
|
||||||
let tokens = vec![1u32, 5u32, 10u32]; // Token 5 and 10 are out of bounds
|
|
||||||
|
|
||||||
struct MockTextGeneration {
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
penalty_cache: HashMap<usize, f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MockTextGeneration {
|
|
||||||
fn apply_cached_repeat_penalty(
|
|
||||||
&mut self,
|
|
||||||
logits: Tensor,
|
|
||||||
tokens: &[u32],
|
|
||||||
) -> Result<(Tensor, std::time::Duration)> {
|
|
||||||
let repeat_start = std::time::Instant::now();
|
|
||||||
|
|
||||||
if self.repeat_penalty == 1.0 {
|
|
||||||
return Ok((logits, repeat_start.elapsed()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
let penalty_tokens = &tokens[start_at..];
|
|
||||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
for &token_id in penalty_tokens {
|
|
||||||
let token_id = token_id as usize;
|
|
||||||
if token_id < logits_vec.len() {
|
|
||||||
if let Some(penalized_score) = self.penalty_cache.get(&token_id) {
|
|
||||||
logits_vec[token_id] = *penalized_score;
|
|
||||||
} else {
|
|
||||||
let score = logits_vec[token_id];
|
|
||||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
||||||
let penalized_score = sign * score / self.repeat_penalty;
|
|
||||||
logits_vec[token_id] = penalized_score;
|
|
||||||
self.penalty_cache.insert(token_id, penalized_score);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let device = logits.device().clone();
|
|
||||||
let shape = logits.shape().clone();
|
|
||||||
let new_logits = Tensor::new(&logits_vec[..], &device)?;
|
|
||||||
let result = new_logits.reshape(shape)?;
|
|
||||||
|
|
||||||
let elapsed = repeat_start.elapsed();
|
|
||||||
Ok((result, elapsed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut mock_gen = MockTextGeneration {
|
|
||||||
repeat_penalty: 2.0,
|
|
||||||
repeat_last_n: 3,
|
|
||||||
penalty_cache: HashMap::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
|
||||||
let result_data = result_logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
// Only token 1 should be penalized, out-of-bounds tokens should be ignored
|
|
||||||
let expected = vec![1.0f32, 1.0, 3.0]; // [1.0, 2.0/2.0, 3.0]
|
|
||||||
assert_eq!(result_data, expected);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test the actual apply_cached_repeat_penalty method from TextGeneration
|
|
||||||
// This test creates a TextGeneration instance with minimal dependencies to test the real method
|
|
||||||
#[test]
|
|
||||||
fn test_actual_apply_cached_repeat_penalty_implementation() -> Result<()> {
|
|
||||||
// Since creating a real TextGeneration instance requires a Model which needs model weights,
|
|
||||||
// we'll create a test that demonstrates the method is now public and can be accessed.
|
|
||||||
// The comprehensive functionality testing is already covered by the mock tests above.
|
|
||||||
|
|
||||||
// Test data setup
|
|
||||||
let device = Device::Cpu;
|
|
||||||
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
|
|
||||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
|
||||||
let tokens = vec![1u32, 2u32, 3u32];
|
|
||||||
|
|
||||||
// Test that we can create the necessary components
|
|
||||||
let tokenizer = create_test_tokenizer()?;
|
|
||||||
|
|
||||||
// The method is now public as confirmed by making it pub fn apply_cached_repeat_penalty
|
|
||||||
// This test verifies the method signature and that it's accessible from external code
|
|
||||||
|
|
||||||
// We could create a TextGeneration instance if we had a way to mock the Model,
|
|
||||||
// but for now we confirm that the existing mock tests cover the functionality
|
|
||||||
// and the method is properly exposed as public
|
|
||||||
|
|
||||||
println!("apply_cached_repeat_penalty method is now public and accessible for testing");
|
|
||||||
assert!(true);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Integration test that demonstrates the method usage pattern
|
|
||||||
#[test]
|
|
||||||
fn test_apply_cached_repeat_penalty_usage_pattern() -> Result<()> {
|
|
||||||
// This test demonstrates how the apply_cached_repeat_penalty method would be used
|
|
||||||
// in practice, even though we can't create a full TextGeneration instance in unit tests
|
|
||||||
|
|
||||||
let device = Device::Cpu;
|
|
||||||
let logits_data = vec![1.5f32, 2.5, 3.5, 4.5, 5.5];
|
|
||||||
let logits = Tensor::new(&logits_data[..], &device)?;
|
|
||||||
let tokens = vec![1u32, 2u32, 1u32, 3u32]; // Repeated token 1 to test caching
|
|
||||||
|
|
||||||
// Test parameters that would be used with TextGeneration
|
|
||||||
let repeat_penalty = 1.2f32;
|
|
||||||
let repeat_last_n = 3usize;
|
|
||||||
let mut penalty_cache: HashMap<usize, f32> = HashMap::new();
|
|
||||||
|
|
||||||
// Simulate the method's logic to verify it works as expected
|
|
||||||
let start_time = std::time::Instant::now();
|
|
||||||
|
|
||||||
if repeat_penalty != 1.0 {
|
|
||||||
let start_at = tokens.len().saturating_sub(repeat_last_n);
|
|
||||||
let penalty_tokens = &tokens[start_at..];
|
|
||||||
let mut logits_vec = logits.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
for &token_id in penalty_tokens {
|
|
||||||
let token_id = token_id as usize;
|
|
||||||
if token_id < logits_vec.len() {
|
|
||||||
if let Some(_cached_score) = penalty_cache.get(&token_id) {
|
|
||||||
// Cache hit simulation
|
|
||||||
} else {
|
|
||||||
let score = logits_vec[token_id];
|
|
||||||
let sign = if score < 0.0 { -1.0 } else { 1.0 };
|
|
||||||
let penalized_score = sign * score / repeat_penalty;
|
|
||||||
penalty_cache.insert(token_id, penalized_score);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let _duration = start_time.elapsed();
|
|
||||||
|
|
||||||
// Verify that tokens were processed correctly
|
|
||||||
assert!(penalty_cache.contains_key(&1)); // Token 1 should be cached
|
|
||||||
assert!(penalty_cache.contains_key(&2)); // Token 2 should be cached
|
|
||||||
assert!(penalty_cache.contains_key(&3)); // Token 3 should be cached
|
|
||||||
|
|
||||||
println!("Successfully demonstrated apply_cached_repeat_penalty usage pattern");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Note: Testing the actual text generation functionality would require
|
|
||||||
// integration tests with real models, which is beyond the scope of these unit tests.
|
|
||||||
// The tests above focus on the components that can be tested in isolation.
|
|
||||||
}
|
|
@@ -1,129 +0,0 @@
|
|||||||
use inference_engine::token_output_stream::TokenOutputStream;
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
use std::path::PathBuf;
|
|
||||||
use anyhow::Result;
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
// Helper function to create a simple tokenizer for testing
|
|
||||||
fn create_test_tokenizer() -> Result<Tokenizer> {
|
|
||||||
// Create a simple tokenizer from the pretrained model
|
|
||||||
// This uses the tokenizer from the Hugging Face hub
|
|
||||||
let tokenizer = Tokenizer::from_pretrained("google/gemma-2b", None).unwrap();
|
|
||||||
Ok(tokenizer)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_new_token_output_stream() -> Result<()> {
|
|
||||||
let tokenizer = create_test_tokenizer()?;
|
|
||||||
let token_stream = TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
// Check that the token stream was created successfully
|
|
||||||
assert!(token_stream.tokenizer().get_vocab(true).len() > 0);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_clear() -> Result<()> {
|
|
||||||
let tokenizer = create_test_tokenizer()?;
|
|
||||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
// Add a token
|
|
||||||
let token_id = token_stream.get_token("<eos>").unwrap();
|
|
||||||
token_stream.next_token(token_id)?;
|
|
||||||
|
|
||||||
// Clear the stream
|
|
||||||
token_stream.clear();
|
|
||||||
|
|
||||||
// Check that the stream is empty by trying to decode all
|
|
||||||
let decoded = token_stream.decode_all()?;
|
|
||||||
assert_eq!(decoded, "");
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_get_token() -> Result<()> {
|
|
||||||
let tokenizer = create_test_tokenizer()?;
|
|
||||||
let token_stream = TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
// Get a token that should exist
|
|
||||||
let eos_token = token_stream.get_token("<eos>");
|
|
||||||
assert!(eos_token.is_some());
|
|
||||||
|
|
||||||
// Get a token that shouldn't exist
|
|
||||||
let nonexistent_token = token_stream.get_token("<this_token_does_not_exist>");
|
|
||||||
assert!(nonexistent_token.is_none());
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_next_token_and_decode() -> Result<()> {
|
|
||||||
let tokenizer = create_test_tokenizer()?;
|
|
||||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
// Get some tokens
|
|
||||||
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
|
||||||
let token_ids = hello_tokens.get_ids();
|
|
||||||
|
|
||||||
// Add tokens one by one
|
|
||||||
let mut output = String::new();
|
|
||||||
for &token_id in token_ids {
|
|
||||||
if let Some(text) = token_stream.next_token(token_id)? {
|
|
||||||
output.push_str(&text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get any remaining text
|
|
||||||
if let Some(rest) = token_stream.decode_rest()? {
|
|
||||||
output.push_str(&rest);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the output
|
|
||||||
assert!(!output.is_empty());
|
|
||||||
assert_eq!(output.trim(), "Hello world");
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_decode_all() -> Result<()> {
|
|
||||||
let tokenizer = create_test_tokenizer()?;
|
|
||||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
// Get some tokens
|
|
||||||
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
|
||||||
let token_ids = hello_tokens.get_ids();
|
|
||||||
|
|
||||||
// Add tokens one by one
|
|
||||||
for &token_id in token_ids {
|
|
||||||
token_stream.next_token(token_id)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode all
|
|
||||||
let decoded = token_stream.decode_all()?;
|
|
||||||
|
|
||||||
// Check the output
|
|
||||||
assert_eq!(decoded.trim(), "Hello world");
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_into_inner() -> Result<()> {
|
|
||||||
let tokenizer = create_test_tokenizer()?;
|
|
||||||
let token_stream = TokenOutputStream::new(tokenizer);
|
|
||||||
|
|
||||||
// Get the inner tokenizer
|
|
||||||
let inner_tokenizer = token_stream.into_inner();
|
|
||||||
|
|
||||||
// Check that the inner tokenizer works
|
|
||||||
let encoded = inner_tokenizer.encode("Test", true).unwrap();
|
|
||||||
assert!(encoded.get_ids().len() > 0);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,3 +0,0 @@
|
|||||||
# Ensure getrandom works on wasm32-unknown-unknown without needing manual RUSTFLAGS
|
|
||||||
[target.wasm32-unknown-unknown]
|
|
||||||
rustflags = ["--cfg", "getrandom_backend=\"wasm_js\""]
|
|
@@ -1,21 +0,0 @@
|
|||||||
# Build stage
|
|
||||||
FROM rust:1-alpine AS builder
|
|
||||||
|
|
||||||
# Install build dependencies
|
|
||||||
RUN apk add --no-cache npm nodejs musl-dev pkgconfig openssl-dev git curl bash
|
|
||||||
|
|
||||||
RUN curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# Copy manifest first (cache deps)
|
|
||||||
COPY . .
|
|
||||||
|
|
||||||
# Install cargo-leptos
|
|
||||||
RUN cargo binstall cargo-leptos
|
|
||||||
|
|
||||||
# Build release artifacts
|
|
||||||
RUN cargo leptos build --release
|
|
||||||
|
|
||||||
EXPOSE 8788
|
|
||||||
CMD ["cargo", "leptos", "serve", "--release"]
|
|
@@ -1,494 +0,0 @@
|
|||||||
use leptos::prelude::*;
|
|
||||||
use leptos_meta::{provide_meta_context, MetaTags, Stylesheet, Title};
|
|
||||||
use leptos_router::{
|
|
||||||
components::{Route, Router, Routes},
|
|
||||||
StaticSegment,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use std::collections::VecDeque;
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use uuid::Uuid;
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use js_sys::Date;
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use web_sys::{HtmlInputElement, KeyboardEvent, SubmitEvent};
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use futures_util::StreamExt;
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use async_openai_wasm::{
|
|
||||||
types::{
|
|
||||||
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
|
|
||||||
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, Model as OpenAIModel,
|
|
||||||
},
|
|
||||||
Client,
|
|
||||||
};
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use async_openai_wasm::config::OpenAIConfig;
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use async_openai_wasm::types::{Role, FinishReason};
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use leptos::task::spawn_local;
|
|
||||||
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct Message {
|
|
||||||
pub id: String,
|
|
||||||
pub role: String,
|
|
||||||
pub content: String,
|
|
||||||
pub timestamp: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct MessageContent(pub either::Either<String, Vec<std::collections::HashMap<String, MessageInnerContent>>>);
|
|
||||||
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct MessageInnerContent(pub either::Either<String, std::collections::HashMap<String, String>>);
|
|
||||||
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ChatMessage {
|
|
||||||
pub role: String,
|
|
||||||
pub content: Option<MessageContent>,
|
|
||||||
pub name: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
const DEFAULT_MODEL: &str = "default";
|
|
||||||
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
async fn fetch_available_models() -> Result<Vec<OpenAIModel>, String> {
|
|
||||||
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Starting model fetch from http://localhost:8080/v1");
|
|
||||||
|
|
||||||
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
|
|
||||||
let client = Client::with_config(config);
|
|
||||||
|
|
||||||
match client.models().list().await {
|
|
||||||
Ok(response) => {
|
|
||||||
let model_count = response.data.len();
|
|
||||||
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Successfully fetched {} models", model_count);
|
|
||||||
|
|
||||||
if model_count > 0 {
|
|
||||||
let model_names: Vec<String> = response.data.iter().map(|m| m.id.clone()).collect();
|
|
||||||
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Available models: {:?}", model_names);
|
|
||||||
} else {
|
|
||||||
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: No models returned by server");
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(response.data)
|
|
||||||
},
|
|
||||||
Err(e) => {
|
|
||||||
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Failed to fetch models: {:?}", e);
|
|
||||||
Err(format!("Failed to fetch models: {}", e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn shell(options: LeptosOptions) -> impl IntoView {
|
|
||||||
view! {
|
|
||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="utf-8"/>
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
|
||||||
<AutoReload options=options.clone() />
|
|
||||||
<HydrationScripts options/>
|
|
||||||
<MetaTags/>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<App/>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[component]
|
|
||||||
pub fn App() -> impl IntoView {
|
|
||||||
// Provides context that manages stylesheets, titles, meta tags, etc.
|
|
||||||
provide_meta_context();
|
|
||||||
|
|
||||||
view! {
|
|
||||||
// injects a stylesheet into the document <head>
|
|
||||||
// id=leptos means cargo-leptos will hot-reload this stylesheet
|
|
||||||
<Stylesheet id="leptos" href="/pkg/leptos-app.css"/>
|
|
||||||
|
|
||||||
// sets the document title
|
|
||||||
<Title text="Chat Interface"/>
|
|
||||||
|
|
||||||
// content for this chat interface
|
|
||||||
<Router>
|
|
||||||
<main>
|
|
||||||
<Routes fallback=|| "Page not found.".into_view()>
|
|
||||||
<Route path=StaticSegment("") view=ChatInterface/>
|
|
||||||
</Routes>
|
|
||||||
</main>
|
|
||||||
</Router>
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Renders the home page of your application.
|
|
||||||
#[component]
|
|
||||||
fn HomePage() -> impl IntoView {
|
|
||||||
// Creates a reactive value to update the button
|
|
||||||
let count = RwSignal::new(0);
|
|
||||||
let on_click = move |_| *count.write() += 1;
|
|
||||||
|
|
||||||
view! {
|
|
||||||
<h1>"Welcome to Leptos!"</h1>
|
|
||||||
<button on:click=on_click>"Click Me: " {count}</button>
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Renders the chat interface
|
|
||||||
#[component]
|
|
||||||
fn ChatInterface() -> impl IntoView {
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
{
|
|
||||||
ChatInterfaceImpl()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(feature = "hydrate"))]
|
|
||||||
{
|
|
||||||
view! {
|
|
||||||
<div class="chat-container">
|
|
||||||
<h1>"Chat Interface"</h1>
|
|
||||||
<p>"Loading chat interface..."</p>
|
|
||||||
</div>
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
#[component]
|
|
||||||
fn ChatInterfaceImpl() -> impl IntoView {
|
|
||||||
let (messages, set_messages) = RwSignal::new(VecDeque::<Message>::new()).split();
|
|
||||||
let (input_value, set_input_value) = RwSignal::new(String::new()).split();
|
|
||||||
let (is_loading, set_is_loading) = RwSignal::new(false).split();
|
|
||||||
let (available_models, set_available_models) = RwSignal::new(Vec::<OpenAIModel>::new()).split();
|
|
||||||
let (selected_model, set_selected_model) = RwSignal::new(DEFAULT_MODEL.to_string()).split();
|
|
||||||
let (models_loading, set_models_loading) = RwSignal::new(false).split();
|
|
||||||
|
|
||||||
// Fetch models on component initialization
|
|
||||||
Effect::new(move |_| {
|
|
||||||
spawn_local(async move {
|
|
||||||
set_models_loading.set(true);
|
|
||||||
match fetch_available_models().await {
|
|
||||||
Ok(models) => {
|
|
||||||
set_available_models.set(models);
|
|
||||||
set_models_loading.set(false);
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
leptos::logging::log!("Failed to fetch models: {}", e);
|
|
||||||
set_available_models.set(vec![]);
|
|
||||||
set_models_loading.set(false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
let send_message = Action::new_unsync(move |content: &String| {
|
|
||||||
let content = content.clone();
|
|
||||||
async move {
|
|
||||||
if content.trim().is_empty() {
|
|
||||||
leptos::logging::log!("[DEBUG_LOG] send_message: Empty content, skipping");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
leptos::logging::log!("[DEBUG_LOG] send_message: Starting message send process");
|
|
||||||
set_is_loading.set(true);
|
|
||||||
|
|
||||||
// Add user message to chat
|
|
||||||
let user_message = Message {
|
|
||||||
id: Uuid::new_v4().to_string(),
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: content.clone(),
|
|
||||||
timestamp: Date::now(),
|
|
||||||
};
|
|
||||||
|
|
||||||
set_messages.update(|msgs| msgs.push_back(user_message.clone()));
|
|
||||||
set_input_value.set(String::new());
|
|
||||||
|
|
||||||
let mut chat_messages = Vec::new();
|
|
||||||
|
|
||||||
// Add system message
|
|
||||||
let system_message = ChatCompletionRequestSystemMessageArgs::default()
|
|
||||||
.content("You are a helpful assistant.")
|
|
||||||
.build()
|
|
||||||
.expect("failed to build system message");
|
|
||||||
chat_messages.push(system_message.into());
|
|
||||||
|
|
||||||
// Add history messages
|
|
||||||
let history_count = messages.get_untracked().len();
|
|
||||||
for msg in messages.get_untracked().iter() {
|
|
||||||
match msg.role.as_str() {
|
|
||||||
"user" => {
|
|
||||||
let message = ChatCompletionRequestUserMessageArgs::default()
|
|
||||||
.content(msg.content.clone())
|
|
||||||
.build()
|
|
||||||
.expect("failed to build user message");
|
|
||||||
chat_messages.push(message.into());
|
|
||||||
}
|
|
||||||
"assistant" => {
|
|
||||||
let message = ChatCompletionRequestAssistantMessageArgs::default()
|
|
||||||
.content(msg.content.clone())
|
|
||||||
.build()
|
|
||||||
.expect("failed to build assistant message");
|
|
||||||
chat_messages.push(message.into());
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add current user message
|
|
||||||
let message = ChatCompletionRequestUserMessageArgs::default()
|
|
||||||
.content(user_message.content.clone())
|
|
||||||
.build()
|
|
||||||
.expect("failed to build user message");
|
|
||||||
chat_messages.push(message.into());
|
|
||||||
|
|
||||||
let current_model = selected_model.get_untracked();
|
|
||||||
let total_messages = chat_messages.len();
|
|
||||||
|
|
||||||
leptos::logging::log!("[DEBUG_LOG] send_message: Preparing request - model: '{}', history_count: {}, total_messages: {}",
|
|
||||||
current_model, history_count, total_messages);
|
|
||||||
|
|
||||||
let request = CreateChatCompletionRequestArgs::default()
|
|
||||||
.model(current_model.as_str())
|
|
||||||
.max_tokens(512u32)
|
|
||||||
.messages(chat_messages)
|
|
||||||
.stream(true)
|
|
||||||
.build()
|
|
||||||
.expect("failed to build request");
|
|
||||||
|
|
||||||
// Send request
|
|
||||||
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
|
|
||||||
let client = Client::with_config(config);
|
|
||||||
|
|
||||||
leptos::logging::log!("[DEBUG_LOG] send_message: Sending request to http://localhost:8080/v1 with model: '{}'", current_model);
|
|
||||||
|
|
||||||
match client.chat().create_stream(request).await {
|
|
||||||
Ok(mut stream) => {
|
|
||||||
leptos::logging::log!("[DEBUG_LOG] send_message: Successfully created stream");
|
|
||||||
|
|
||||||
let mut assistant_created = false;
|
|
||||||
let mut content_appended = false;
|
|
||||||
let mut chunks_received = 0;
|
|
||||||
|
|
||||||
while let Some(next) = stream.next().await {
|
|
||||||
match next {
|
|
||||||
Ok(chunk) => {
|
|
||||||
chunks_received += 1;
|
|
||||||
if let Some(choice) = chunk.choices.get(0) {
|
|
||||||
if !assistant_created {
|
|
||||||
if let Some(role) = &choice.delta.role {
|
|
||||||
if role == &Role::Assistant {
|
|
||||||
assistant_created = true;
|
|
||||||
let assistant_id = Uuid::new_v4().to_string();
|
|
||||||
set_messages.update(|msgs| {
|
|
||||||
msgs.push_back(Message {
|
|
||||||
id: assistant_id,
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: String::new(),
|
|
||||||
timestamp: Date::now(),
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(content) = &choice.delta.content {
|
|
||||||
if !content.is_empty() {
|
|
||||||
if !assistant_created {
|
|
||||||
assistant_created = true;
|
|
||||||
let assistant_id = Uuid::new_v4().to_string();
|
|
||||||
set_messages.update(|msgs| {
|
|
||||||
msgs.push_back(Message {
|
|
||||||
id: assistant_id,
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: String::new(),
|
|
||||||
timestamp: Date::now(),
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
content_appended = true;
|
|
||||||
set_messages.update(|msgs| {
|
|
||||||
if let Some(last) = msgs.back_mut() {
|
|
||||||
if last.role == "assistant" {
|
|
||||||
last.content.push_str(content);
|
|
||||||
last.timestamp = Date::now();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(reason) = &choice.finish_reason {
|
|
||||||
if reason == &FinishReason::Stop {
|
|
||||||
leptos::logging::log!("[DEBUG_LOG] send_message: Received finish_reason=stop after {} chunks", chunks_received);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
leptos::logging::log!("[DEBUG_LOG] send_message: Stream error after {} chunks: {:?}", chunks_received, e);
|
|
||||||
set_messages.update(|msgs| {
|
|
||||||
msgs.push_back(Message {
|
|
||||||
id: Uuid::new_v4().to_string(),
|
|
||||||
role: "system".to_string(),
|
|
||||||
content: format!("Stream error: {}", e),
|
|
||||||
timestamp: Date::now(),
|
|
||||||
});
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if assistant_created && !content_appended {
|
|
||||||
set_messages.update(|msgs| {
|
|
||||||
let should_pop = msgs
|
|
||||||
.back()
|
|
||||||
.map(|m| m.role == "assistant" && m.content.is_empty())
|
|
||||||
.unwrap_or(false);
|
|
||||||
if should_pop {
|
|
||||||
msgs.pop_back();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
leptos::logging::log!("[DEBUG_LOG] send_message: Stream completed successfully, received {} chunks", chunks_received);
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
leptos::logging::log!("[DEBUG_LOG] send_message: Request failed with error: {:?}", e);
|
|
||||||
let error_message = Message {
|
|
||||||
id: Uuid::new_v4().to_string(),
|
|
||||||
role: "system".to_string(),
|
|
||||||
content: format!("Error: Request failed - {}", e),
|
|
||||||
timestamp: Date::now(),
|
|
||||||
};
|
|
||||||
set_messages.update(|msgs| msgs.push_back(error_message));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
set_is_loading.set(false);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let on_input = move |ev| {
|
|
||||||
let input = event_target::<HtmlInputElement>(&ev);
|
|
||||||
set_input_value.set(input.value());
|
|
||||||
};
|
|
||||||
|
|
||||||
let on_submit = move |ev: SubmitEvent| {
|
|
||||||
ev.prevent_default();
|
|
||||||
let content = input_value.get();
|
|
||||||
send_message.dispatch(content);
|
|
||||||
};
|
|
||||||
|
|
||||||
let on_keypress = move |ev: KeyboardEvent| {
|
|
||||||
if ev.key() == "Enter" && !ev.shift_key() {
|
|
||||||
ev.prevent_default();
|
|
||||||
let content = input_value.get();
|
|
||||||
send_message.dispatch(content);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let on_model_change = move |ev| {
|
|
||||||
let select = event_target::<web_sys::HtmlSelectElement>(&ev);
|
|
||||||
set_selected_model.set(select.value());
|
|
||||||
};
|
|
||||||
|
|
||||||
let messages_list = move || {
|
|
||||||
messages.get()
|
|
||||||
.into_iter()
|
|
||||||
.map(|message| {
|
|
||||||
let role_class = match message.role.as_str() {
|
|
||||||
"user" => "user-message",
|
|
||||||
"assistant" => "assistant-message",
|
|
||||||
_ => "system-message",
|
|
||||||
};
|
|
||||||
|
|
||||||
view! {
|
|
||||||
<div class=format!("message {}", role_class)>
|
|
||||||
<div class="message-role">{message.role}</div>
|
|
||||||
<div class="message-content">{message.content}</div>
|
|
||||||
</div>
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
};
|
|
||||||
|
|
||||||
let loading_indicator = move || {
|
|
||||||
is_loading.get().then(|| {
|
|
||||||
view! {
|
|
||||||
<div class="message assistant-message">
|
|
||||||
<div class="message-role">"assistant"</div>
|
|
||||||
<div class="message-content">"Thinking..."</div>
|
|
||||||
</div>
|
|
||||||
}
|
|
||||||
})
|
|
||||||
};
|
|
||||||
|
|
||||||
view! {
|
|
||||||
<div class="chat-container">
|
|
||||||
<h1>"Chat Interface"</h1>
|
|
||||||
<div class="model-selector">
|
|
||||||
<label for="model-select">"Model: "</label>
|
|
||||||
<select
|
|
||||||
id="model-select"
|
|
||||||
on:change=on_model_change
|
|
||||||
prop:value=selected_model
|
|
||||||
prop:disabled=models_loading
|
|
||||||
>
|
|
||||||
{move || {
|
|
||||||
if models_loading.get() {
|
|
||||||
vec![view! {
|
|
||||||
<option value={String::from("")} selected=false>{String::from("Loading models...")}</option>
|
|
||||||
}]
|
|
||||||
} else {
|
|
||||||
let models = available_models.get();
|
|
||||||
if models.is_empty() {
|
|
||||||
vec![view! {
|
|
||||||
<option value={String::from("default")} selected=true>{String::from("default")}</option>
|
|
||||||
}]
|
|
||||||
} else {
|
|
||||||
models.into_iter().map(|model| {
|
|
||||||
view! {
|
|
||||||
<option value=model.id.clone() selected={model.id == DEFAULT_MODEL}>{model.id.clone()}</option>
|
|
||||||
}
|
|
||||||
}).collect::<Vec<_>>()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
</select>
|
|
||||||
</div>
|
|
||||||
<div class="messages-container">
|
|
||||||
{messages_list}
|
|
||||||
{loading_indicator}
|
|
||||||
</div>
|
|
||||||
<form class="input-form" on:submit=on_submit>
|
|
||||||
<input
|
|
||||||
type="text"
|
|
||||||
class="message-input"
|
|
||||||
placeholder="Type your message here..."
|
|
||||||
prop:value=input_value
|
|
||||||
on:input=on_input
|
|
||||||
on:keypress=on_keypress
|
|
||||||
prop:disabled=is_loading
|
|
||||||
/>
|
|
||||||
<button
|
|
||||||
type="submit"
|
|
||||||
class="send-button"
|
|
||||||
prop:disabled=move || is_loading.get() || input_value.get().trim().is_empty()
|
|
||||||
>
|
|
||||||
"Send"
|
|
||||||
</button>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
}
|
|
||||||
}
|
|
@@ -1,30 +0,0 @@
|
|||||||
pub mod app;
|
|
||||||
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
#[wasm_bindgen::prelude::wasm_bindgen]
|
|
||||||
pub fn hydrate() {
|
|
||||||
use crate::app::*;
|
|
||||||
console_error_panic_hook::set_once();
|
|
||||||
leptos::mount::hydrate_body(App);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "ssr")]
|
|
||||||
pub fn create_leptos_router() -> axum::Router {
|
|
||||||
use axum::Router;
|
|
||||||
use leptos::prelude::*;
|
|
||||||
use leptos_axum::{generate_route_list, LeptosRoutes};
|
|
||||||
use crate::app::*;
|
|
||||||
|
|
||||||
let conf = get_configuration(None).unwrap();
|
|
||||||
let leptos_options = conf.leptos_options;
|
|
||||||
// Generate the list of routes in your Leptos App
|
|
||||||
let routes = generate_route_list(App);
|
|
||||||
|
|
||||||
Router::new()
|
|
||||||
.leptos_routes(&leptos_options, routes, {
|
|
||||||
let leptos_options = leptos_options.clone();
|
|
||||||
move || shell(leptos_options.clone())
|
|
||||||
})
|
|
||||||
.fallback(leptos_axum::file_and_error_handler(shell))
|
|
||||||
.with_state(leptos_options)
|
|
||||||
}
|
|
@@ -1,39 +0,0 @@
|
|||||||
|
|
||||||
#[cfg(feature = "ssr")]
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() {
|
|
||||||
use axum::Router;
|
|
||||||
use leptos::logging::log;
|
|
||||||
use leptos::prelude::*;
|
|
||||||
use leptos_axum::{generate_route_list, LeptosRoutes};
|
|
||||||
use leptos_app::app::*;
|
|
||||||
|
|
||||||
let conf = get_configuration(None).unwrap();
|
|
||||||
let addr = conf.leptos_options.site_addr;
|
|
||||||
let leptos_options = conf.leptos_options;
|
|
||||||
// Generate the list of routes in your Leptos App
|
|
||||||
let routes = generate_route_list(App);
|
|
||||||
|
|
||||||
let app = Router::new()
|
|
||||||
.leptos_routes(&leptos_options, routes, {
|
|
||||||
let leptos_options = leptos_options.clone();
|
|
||||||
move || shell(leptos_options.clone())
|
|
||||||
})
|
|
||||||
.fallback(leptos_axum::file_and_error_handler(shell))
|
|
||||||
.with_state(leptos_options);
|
|
||||||
|
|
||||||
// run our app with hyper
|
|
||||||
// `axum::Server` is a re-export of `hyper::Server`
|
|
||||||
log!("listening on http://{}", &addr);
|
|
||||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
|
||||||
axum::serve(listener, app.into_make_service())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(feature = "ssr"))]
|
|
||||||
pub fn main() {
|
|
||||||
// no client-side main function
|
|
||||||
// unless we want this to work with e.g., Trunk for pure client-side testing
|
|
||||||
// see lib.rs for hydration function instead
|
|
||||||
}
|
|
@@ -1,4 +0,0 @@
|
|||||||
body {
|
|
||||||
font-family: sans-serif;
|
|
||||||
text-align: center;
|
|
||||||
}
|
|
@@ -1,12 +1,12 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "llama-runner"
|
name = "llama-runner"
|
||||||
version = "0.1.0"
|
version.workspace = true
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||||
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||||
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
candle-transformers = { git = "https://github.com/huggingface/candle.git"}
|
||||||
hf-hub = "0.3"
|
hf-hub = "0.3"
|
||||||
tokenizers = "0.20"
|
tokenizers = "0.20"
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
|
@@ -5,4 +5,3 @@ pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
|||||||
|
|
||||||
// Re-export constants and types that might be needed
|
// Re-export constants and types that might be needed
|
||||||
pub const EOS_TOKEN: &str = "</s>";
|
pub const EOS_TOKEN: &str = "</s>";
|
||||||
|
|
||||||
|
@@ -1,14 +1,14 @@
|
|||||||
|
use crate::EOS_TOKEN;
|
||||||
use anyhow::{bail, Error as E};
|
use anyhow::{bail, Error as E};
|
||||||
use candle_core::{utils, DType, Device, Tensor};
|
use candle_core::{utils, DType, Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
use candle_transformers::models::llama::{Llama, LlamaConfig};
|
|
||||||
use candle_transformers::models::llama as model;
|
use candle_transformers::models::llama as model;
|
||||||
|
use candle_transformers::models::llama::{Llama, LlamaConfig};
|
||||||
|
use clap::ValueEnum;
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
use hf_hub::{Repo, RepoType};
|
use hf_hub::{Repo, RepoType};
|
||||||
use std::sync::mpsc::{self, Receiver};
|
use std::sync::mpsc::{self, Receiver};
|
||||||
use clap::ValueEnum;
|
|
||||||
use crate::{EOS_TOKEN};
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
|
||||||
pub enum WhichModel {
|
pub enum WhichModel {
|
||||||
@@ -81,8 +81,8 @@ impl Default for LlamaInferenceConfig {
|
|||||||
max_tokens: 512,
|
max_tokens: 512,
|
||||||
|
|
||||||
// Performance flags
|
// Performance flags
|
||||||
no_kv_cache: false, // keep cache ON for speed
|
no_kv_cache: false, // keep cache ON for speed
|
||||||
use_flash_attn: true, // great speed boost if supported
|
use_flash_attn: false, // great speed boost if supported
|
||||||
|
|
||||||
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
|
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
|
||||||
dtype: Some("bf16".to_string()),
|
dtype: Some("bf16".to_string()),
|
||||||
@@ -98,8 +98,6 @@ impl Default for LlamaInferenceConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
fn device(cpu: bool) -> anyhow::Result<Device> {
|
fn device(cpu: bool) -> anyhow::Result<Device> {
|
||||||
if cpu {
|
if cpu {
|
||||||
Ok(Device::Cpu)
|
Ok(Device::Cpu)
|
||||||
@@ -112,7 +110,6 @@ fn device(cpu: bool) -> anyhow::Result<Device> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fn hub_load_safetensors(
|
fn hub_load_safetensors(
|
||||||
api: &hf_hub::api::sync::ApiRepo,
|
api: &hf_hub::api::sync::ApiRepo,
|
||||||
json_file: &str,
|
json_file: &str,
|
||||||
@@ -171,7 +168,7 @@ pub fn run_llama_inference(
|
|||||||
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||||
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
}
|
}
|
||||||
.to_string()
|
.to_string()
|
||||||
});
|
});
|
||||||
println!("Loading model: {}", model_id);
|
println!("Loading model: {}", model_id);
|
||||||
let revision = cfg.revision.clone().unwrap_or("main".to_string());
|
let revision = cfg.revision.clone().unwrap_or("main".to_string());
|
||||||
@@ -334,4 +331,3 @@ pub fn run_llama_inference(
|
|||||||
|
|
||||||
Ok(rx)
|
Ok(rx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -88,7 +88,6 @@ impl Into<LlamaInferenceConfig> for Args {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub fn run_cli() -> anyhow::Result<()> {
|
pub fn run_cli() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let cfg = args.into();
|
let cfg = args.into();
|
||||||
@@ -106,4 +105,4 @@ pub fn run_cli() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@@ -2,8 +2,8 @@
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
mod llama_cli;
|
|
||||||
mod llama_api;
|
mod llama_api;
|
||||||
|
mod llama_cli;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
@@ -14,7 +14,6 @@ use crate::llama_cli::run_cli;
|
|||||||
|
|
||||||
const EOS_TOKEN: &str = "</s>";
|
const EOS_TOKEN: &str = "</s>";
|
||||||
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
run_cli()
|
run_cli()
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "predict-otron-9000"
|
name = "predict-otron-9000"
|
||||||
version = "0.1.0"
|
version.workspace = true
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
@@ -19,7 +19,7 @@ tracing = "0.1"
|
|||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
uuid = { version = "1.7.0", features = ["v4"] }
|
uuid = { version = "1.7.0", features = ["v4"] }
|
||||||
reqwest = { version = "0.12", features = ["json"] }
|
reqwest = { version = "0.12", features = ["json"] }
|
||||||
rust-embed = { version = "8.7.2", features = ["include-exclude"] }
|
rust-embed = { version = "8.7.2", features = ["include-exclude", "axum"] }
|
||||||
|
|
||||||
# Dependencies for embeddings functionality
|
# Dependencies for embeddings functionality
|
||||||
embeddings-engine = { path = "../embeddings-engine" }
|
embeddings-engine = { path = "../embeddings-engine" }
|
||||||
@@ -28,9 +28,11 @@ embeddings-engine = { path = "../embeddings-engine" }
|
|||||||
inference-engine = { path = "../inference-engine" }
|
inference-engine = { path = "../inference-engine" }
|
||||||
|
|
||||||
# Dependencies for leptos web app
|
# Dependencies for leptos web app
|
||||||
leptos-app = { path = "../leptos-app", features = ["ssr"] }
|
#leptos-app = { path = "../leptos-app", features = ["ssr"] }
|
||||||
|
chat-ui = { path = "../chat-ui", features = ["ssr", "hydrate"], optional = false }
|
||||||
|
|
||||||
mime_guess = "2.0.5"
|
mime_guess = "2.0.5"
|
||||||
|
log = "0.4.27"
|
||||||
|
|
||||||
|
|
||||||
[package.metadata.compose]
|
[package.metadata.compose]
|
||||||
@@ -44,4 +46,8 @@ port = 8080
|
|||||||
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||||
replicas = 1
|
replicas = 1
|
||||||
port = 8080
|
port = 8080
|
||||||
env = { SERVER_CONFIG = "" }
|
# SERVER_CONFIG Example: {\"serverMode\":\"HighAvailability\",\"services\":{\"inference_url\":\"http://custom-inference:9000\",\"embeddings_url\":\"http://custom-embeddings:9001\"}}
|
||||||
|
# you can generate this via node to avoid toil
|
||||||
|
# const server_config = {serverMode: "HighAvailability", services: {inference_url: "http://custom-inference:9000", embeddings_url: "http://custom-embeddings:9001"} };
|
||||||
|
# console.log(JSON.stringify(server_config).replace(/"/g, '\\"'));
|
||||||
|
env = { SERVER_CONFIG = "<your-json-value-here>" }
|
||||||
|
@@ -1,7 +1,12 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::env;
|
use std::env;
|
||||||
|
use tracing::info;
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
use tracing::log::error;
|
||||||
|
/// # Generating `SERVER_CONFIG` with Node
|
||||||
|
// # const server_config = {serverMode: "HighAvailability", services: {inference_url: "http://custom-inference:9000", embeddings_url: "http://custom-embeddings:9001"} };
|
||||||
|
// # console.log(JSON.stringify(server_config).replace(/"/g, '\\"'));
|
||||||
|
///
|
||||||
|
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ServerConfig {
|
pub struct ServerConfig {
|
||||||
#[serde(default = "default_server_host")]
|
#[serde(default = "default_server_host")]
|
||||||
@@ -10,14 +15,16 @@ pub struct ServerConfig {
|
|||||||
pub server_port: u16,
|
pub server_port: u16,
|
||||||
pub server_mode: ServerMode,
|
pub server_mode: ServerMode,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub services: Services,
|
pub services: Option<Services>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_server_host() -> String {
|
fn default_server_host() -> String {
|
||||||
"127.0.0.1".to_string()
|
"127.0.0.1".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_server_port() -> u16 { 8080 }
|
fn default_server_port() -> u16 {
|
||||||
|
8080
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||||
#[serde(rename_all = "PascalCase")]
|
#[serde(rename_all = "PascalCase")]
|
||||||
@@ -34,17 +41,15 @@ impl Default for ServerMode {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct Services {
|
pub struct Services {
|
||||||
#[serde(default = "inference_service_url")]
|
pub inference_url: Option<String>,
|
||||||
pub inference_url: String,
|
pub embeddings_url: Option<String>,
|
||||||
#[serde(default = "embeddings_service_url")]
|
|
||||||
pub embeddings_url: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Services {
|
impl Default for Services {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
inference_url: inference_service_url(),
|
inference_url: None,
|
||||||
embeddings_url: embeddings_service_url(),
|
embeddings_url: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -63,7 +68,7 @@ impl Default for ServerConfig {
|
|||||||
server_host: "127.0.0.1".to_string(),
|
server_host: "127.0.0.1".to_string(),
|
||||||
server_port: 8080,
|
server_port: 8080,
|
||||||
server_mode: ServerMode::Standalone,
|
server_mode: ServerMode::Standalone,
|
||||||
services: Services::default(),
|
services: Some(Services::default()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -73,21 +78,19 @@ impl ServerConfig {
|
|||||||
/// Falls back to default (Local mode) if not set or invalid
|
/// Falls back to default (Local mode) if not set or invalid
|
||||||
pub fn from_env() -> Self {
|
pub fn from_env() -> Self {
|
||||||
match env::var("SERVER_CONFIG") {
|
match env::var("SERVER_CONFIG") {
|
||||||
Ok(config_str) => {
|
Ok(config_str) => match serde_json::from_str::<ServerConfig>(&config_str) {
|
||||||
match serde_json::from_str::<ServerConfig>(&config_str) {
|
Ok(config) => {
|
||||||
Ok(config) => {
|
tracing::info!("Loaded server configuration: {:?}", config);
|
||||||
tracing::info!("Loaded server configuration: {:?}", config);
|
config
|
||||||
config
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::warn!(
|
|
||||||
"Failed to parse SERVER_CONFIG environment variable: {}. Using default configuration.",
|
|
||||||
e
|
|
||||||
);
|
|
||||||
ServerConfig::default()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to parse SERVER_CONFIG environment variable: {}. Using default configuration.",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
ServerConfig::default()
|
||||||
|
}
|
||||||
|
},
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
tracing::info!("SERVER_CONFIG not set, Standalone mode active");
|
tracing::info!("SERVER_CONFIG not set, Standalone mode active");
|
||||||
ServerConfig::default()
|
ServerConfig::default()
|
||||||
@@ -96,18 +99,52 @@ impl ServerConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Check if the server should run in high availability mode
|
/// Check if the server should run in high availability mode
|
||||||
pub fn is_high_availability(&self) -> bool {
|
pub fn is_high_availability(&self) -> Result<bool, std::io::Error> {
|
||||||
self.server_mode == ServerMode::HighAvailability
|
if self.server_mode == ServerMode::HighAvailability {
|
||||||
|
let services_well_defined: bool = self.clone().services.is_some();
|
||||||
|
|
||||||
|
let inference_url_well_defined: bool =
|
||||||
|
services_well_defined && self.clone().services.unwrap().inference_url.is_some();
|
||||||
|
|
||||||
|
let embeddings_well_defined: bool =
|
||||||
|
services_well_defined && self.clone().services.unwrap().embeddings_url.is_some();
|
||||||
|
|
||||||
|
let is_well_defined_for_ha =
|
||||||
|
services_well_defined && inference_url_well_defined && embeddings_well_defined;
|
||||||
|
|
||||||
|
if !is_well_defined_for_ha {
|
||||||
|
let config_string = serde_json::to_string_pretty(&self).unwrap();
|
||||||
|
error!(
|
||||||
|
"HighAvailability mode configured but services not well defined! \n## Config Used:\n {}",
|
||||||
|
config_string
|
||||||
|
);
|
||||||
|
let err = std::io::Error::new(
|
||||||
|
std::io::ErrorKind::Other,
|
||||||
|
"HighAvailability mode configured but services not well defined!",
|
||||||
|
);
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(self.server_mode == ServerMode::HighAvailability)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the inference service URL for proxying
|
/// Get the inference service URL for proxying
|
||||||
pub fn inference_url(&self) -> &str {
|
pub fn inference_url(&self) -> Option<String> {
|
||||||
&self.services.inference_url
|
if self.services.is_some() {
|
||||||
|
self.services.clone()?.inference_url
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the embeddings service URL for proxying
|
/// Get the embeddings service URL for proxying
|
||||||
pub fn embeddings_url(&self) -> &str {
|
pub fn embeddings_url(&self) -> Option<String> {
|
||||||
&self.services.embeddings_url
|
if self.services.is_some() {
|
||||||
|
self.services.clone()?.embeddings_url
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,7 +156,7 @@ mod tests {
|
|||||||
fn test_default_config() {
|
fn test_default_config() {
|
||||||
let config = ServerConfig::default();
|
let config = ServerConfig::default();
|
||||||
assert_eq!(config.server_mode, ServerMode::Standalone);
|
assert_eq!(config.server_mode, ServerMode::Standalone);
|
||||||
assert!(!config.is_high_availability());
|
assert!(!config.is_high_availability().unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -134,23 +171,26 @@ mod tests {
|
|||||||
|
|
||||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||||
assert_eq!(config.server_mode, ServerMode::HighAvailability);
|
assert_eq!(config.server_mode, ServerMode::HighAvailability);
|
||||||
assert!(config.is_high_availability());
|
assert!(config.is_high_availability().unwrap());
|
||||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
assert_eq!(
|
||||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
config.inference_url().unwrap(),
|
||||||
|
"http://inference-service:8080"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.embeddings_url().unwrap(),
|
||||||
|
"http://embeddings-service:8080"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_local_mode_config() {
|
fn test_local_mode_config() {
|
||||||
let config_json = r#"{
|
let config_json = r#"{
|
||||||
"serverMode": "Local"
|
"serverMode": "Standalone"
|
||||||
}"#;
|
}"#;
|
||||||
|
|
||||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||||
assert_eq!(config.server_mode, ServerMode::Standalone);
|
assert_eq!(config.server_mode, ServerMode::Standalone);
|
||||||
assert!(!config.is_high_availability());
|
assert!(!config.is_high_availability().unwrap());
|
||||||
// Should use default URLs
|
|
||||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
|
||||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -164,17 +204,26 @@ mod tests {
|
|||||||
}"#;
|
}"#;
|
||||||
|
|
||||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||||
assert_eq!(config.inference_url(), "http://custom-inference:9000");
|
assert_eq!(
|
||||||
assert_eq!(config.embeddings_url(), "http://custom-embeddings:9001");
|
config.inference_url().unwrap(),
|
||||||
|
"http://custom-inference:9000"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.embeddings_url().unwrap(),
|
||||||
|
"http://custom-embeddings:9001"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_minimal_high_availability_config() {
|
fn test_minimal_high_availability_config_error() {
|
||||||
let config_json = r#"{"serverMode": "HighAvailability"}"#;
|
let config_json = r#"{"serverMode": "HighAvailability"}"#;
|
||||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||||
assert!(config.is_high_availability());
|
|
||||||
// Should use default URLs
|
let is_high_availability = config.is_high_availability();
|
||||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
|
||||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
assert!(is_high_availability.is_err());
|
||||||
|
// // Should use default URLs
|
||||||
|
// assert_eq!(config.inference_url().unwrap(), "http://inference-service:8080");
|
||||||
|
// assert_eq!(config.embeddings_url().unwrap(), "http://embeddings-service:8080");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,10 +1,10 @@
|
|||||||
use axum::{
|
use axum::{
|
||||||
|
Router,
|
||||||
body::Body,
|
body::Body,
|
||||||
extract::{Request, State},
|
extract::{Request, State},
|
||||||
http::{HeaderMap, Method, StatusCode, Uri},
|
http::{HeaderMap, Method, StatusCode, Uri},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
Router,
|
|
||||||
};
|
};
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
@@ -12,6 +12,120 @@ use std::time::Duration;
|
|||||||
|
|
||||||
use crate::config::ServerConfig;
|
use crate::config::ServerConfig;
|
||||||
|
|
||||||
|
/// # Generating `SERVER_CONFIG` for TOML using Node.js
|
||||||
|
///
|
||||||
|
/// You can still use the Node.js REPL to build the JSON, but when pasting into
|
||||||
|
/// a `.toml` file you must follow TOML's string rules. Below are the safest patterns.
|
||||||
|
///
|
||||||
|
/// ## 1) Generate the JSON in Node
|
||||||
|
/// ```bash
|
||||||
|
/// node
|
||||||
|
/// ```
|
||||||
|
/// ```javascript
|
||||||
|
/// const myobject = {
|
||||||
|
/// serverMode: "HighAvailability",
|
||||||
|
/// services: {
|
||||||
|
/// inference_url: "http://custom-inference:9000",
|
||||||
|
/// embeddings_url: "http://custom-embeddings:9001"
|
||||||
|
/// }
|
||||||
|
/// };
|
||||||
|
/// const json = JSON.stringify(myobject);
|
||||||
|
/// json
|
||||||
|
/// // -> '{"serverMode":"HighAvailability","services":{"inference_url":"http://custom-inference:9000","embeddings_url":"http://custom-embeddings:9001"}}'
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// ## 2) Put it into `.toml`
|
||||||
|
///
|
||||||
|
/// ### Option A (recommended): single-quoted TOML *literal* string
|
||||||
|
/// Single quotes in TOML mean "no escaping", so your inner double quotes are safe.
|
||||||
|
/// ```toml
|
||||||
|
/// SERVER_CONFIG = '{"serverMode":"HighAvailability","services":{"inference_url":"http://custom-inference:9000","embeddings_url":"http://custom-embeddings:9001"}}'
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// ### Option B: double-quoted TOML string (must escape inner quotes)
|
||||||
|
/// If you *must* use double quotes in TOML, escape all `"` inside the JSON.
|
||||||
|
/// You can have Node do this for you:
|
||||||
|
/// ```javascript
|
||||||
|
/// // In Node:
|
||||||
|
/// const jsonForToml = JSON.stringify(myobject).replace(/"/g, '\\"');
|
||||||
|
/// jsonForToml
|
||||||
|
/// // -> \"{\\\"serverMode\\\":\\\"HighAvailability\\\",...}\"
|
||||||
|
/// ```
|
||||||
|
/// Then paste into TOML:
|
||||||
|
/// ```toml
|
||||||
|
/// SERVER_CONFIG = "{\"serverMode\":\"HighAvailability\",\"services\":{\"inference_url\":\"http://custom-inference:9000\",\"embeddings_url\":\"http://custom-embeddings:9001\"}}"
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// ### Option C: multi-line literal (for pretty JSON)
|
||||||
|
/// If you want pretty-printed JSON in the file, use TOML's triple single quotes:
|
||||||
|
/// ```javascript
|
||||||
|
/// // In Node (pretty with 2 spaces):
|
||||||
|
/// const pretty = JSON.stringify(myobject, null, 2);
|
||||||
|
/// ```
|
||||||
|
/// ```toml
|
||||||
|
/// SERVER_CONFIG = '''{
|
||||||
|
/// "serverMode": "HighAvailability",
|
||||||
|
/// "services": {
|
||||||
|
/// "inference_url": "http://custom-inference:9000",
|
||||||
|
/// "embeddings_url": "http://custom-embeddings:9001"
|
||||||
|
/// }
|
||||||
|
/// }'''
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// ## 3) Reading it in Rust
|
||||||
|
///
|
||||||
|
/// If `SERVER_CONFIG` is stored as a **string** in TOML (Options A/B/C):
|
||||||
|
/// ```rust
|
||||||
|
/// use serde_json::Value;
|
||||||
|
///
|
||||||
|
/// // Suppose you've already loaded your .toml into a struct or a toml::Value:
|
||||||
|
/// // e.g., struct FileCfg { pub SERVER_CONFIG: String }
|
||||||
|
/// fn parse_server_config(raw: &str) -> anyhow::Result<Value> {
|
||||||
|
/// let v: Value = serde_json::from_str(raw)?;
|
||||||
|
/// Ok(v)
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// ### Alternative: store it as TOML tables and serialize to JSON at runtime
|
||||||
|
/// Instead of a JSON string, you can make the TOML first-class tables:
|
||||||
|
/// ```toml
|
||||||
|
/// [SERVER_CONFIG]
|
||||||
|
/// serverMode = "HighAvailability"
|
||||||
|
///
|
||||||
|
/// [SERVER_CONFIG.services]
|
||||||
|
/// inference_url = "http://custom-inference:9000"
|
||||||
|
/// embeddings_url = "http://custom-embeddings:9001"
|
||||||
|
/// ```
|
||||||
|
/// ```rust
|
||||||
|
/// use serde::{Deserialize, Serialize};
|
||||||
|
/// use serde_json::Value;
|
||||||
|
///
|
||||||
|
/// #[derive(Debug, Serialize, Deserialize)]
|
||||||
|
/// struct Services {
|
||||||
|
/// inference_url: String,
|
||||||
|
/// embeddings_url: String,
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// #[derive(Debug, Serialize, Deserialize)]
|
||||||
|
/// struct ServerConfig {
|
||||||
|
/// serverMode: String,
|
||||||
|
/// services: Services,
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// // After loading the .toml (e.g., via `toml::from_str`):
|
||||||
|
/// // let cfg: ServerConfig = toml::from_str(toml_str)?;
|
||||||
|
/// // Convert to JSON if needed:
|
||||||
|
/// fn to_json(cfg: &ServerConfig) -> serde_json::Result<Value> {
|
||||||
|
/// Ok(serde_json::to_value(cfg)?)
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// ## Gotchas
|
||||||
|
/// - Prefer **single-quoted** TOML strings for raw JSON to avoid escaping.
|
||||||
|
/// - If you use **double-quoted** TOML strings, escape every inner `"` in the JSON.
|
||||||
|
/// - Pretty JSON is fine in TOML using `''' ... '''`, but remember the newlines are part of the string.
|
||||||
|
/// - If you control the consumer, TOML tables (the alternative above) are more ergonomic than embedding JSON.
|
||||||
|
|
||||||
/// HTTP client configured for proxying requests
|
/// HTTP client configured for proxying requests
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ProxyClient {
|
pub struct ProxyClient {
|
||||||
@@ -31,7 +145,7 @@ impl ProxyClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a router that proxies requests to external services in HighAvailability mode
|
/// Create a router that proxies requests to external services in HighAvailability mode
|
||||||
pub fn create_proxy_router(config: ServerConfig) -> Router {
|
pub fn create_ha_router(config: ServerConfig) -> Router {
|
||||||
let proxy_client = ProxyClient::new(config.clone());
|
let proxy_client = ProxyClient::new(config.clone());
|
||||||
|
|
||||||
Router::new()
|
Router::new()
|
||||||
@@ -47,10 +161,16 @@ async fn proxy_chat_completions(
|
|||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
body: Body,
|
body: Body,
|
||||||
) -> Result<Response, StatusCode> {
|
) -> Result<Response, StatusCode> {
|
||||||
let target_url = format!("{}/v1/chat/completions", proxy_client.config.inference_url());
|
let target_url = format!(
|
||||||
|
"{}/v1/chat/completions",
|
||||||
|
proxy_client
|
||||||
|
.config
|
||||||
|
.inference_url()
|
||||||
|
.expect("Invalid Configuration")
|
||||||
|
);
|
||||||
|
|
||||||
tracing::info!("Proxying chat completions request to: {}", target_url);
|
tracing::info!("Proxying chat completions request to: {}", target_url);
|
||||||
|
|
||||||
// Extract body as bytes
|
// Extract body as bytes
|
||||||
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||||
Ok(bytes) => bytes,
|
Ok(bytes) => bytes,
|
||||||
@@ -63,7 +183,9 @@ async fn proxy_chat_completions(
|
|||||||
// Check if this is a streaming request
|
// Check if this is a streaming request
|
||||||
let is_streaming = if let Ok(body_str) = String::from_utf8(body_bytes.to_vec()) {
|
let is_streaming = if let Ok(body_str) = String::from_utf8(body_bytes.to_vec()) {
|
||||||
if let Ok(json) = serde_json::from_str::<Value>(&body_str) {
|
if let Ok(json) = serde_json::from_str::<Value>(&body_str) {
|
||||||
json.get("stream").and_then(|v| v.as_bool()).unwrap_or(false)
|
json.get("stream")
|
||||||
|
.and_then(|v| v.as_bool())
|
||||||
|
.unwrap_or(false)
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
@@ -72,7 +194,8 @@ async fn proxy_chat_completions(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Forward the request
|
// Forward the request
|
||||||
let mut req_builder = proxy_client.client
|
let mut req_builder = proxy_client
|
||||||
|
.client
|
||||||
.post(&target_url)
|
.post(&target_url)
|
||||||
.body(body_bytes.to_vec());
|
.body(body_bytes.to_vec());
|
||||||
|
|
||||||
@@ -85,8 +208,7 @@ async fn proxy_chat_completions(
|
|||||||
|
|
||||||
match req_builder.send().await {
|
match req_builder.send().await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
let mut resp_builder = Response::builder()
|
let mut resp_builder = Response::builder().status(response.status());
|
||||||
.status(response.status());
|
|
||||||
|
|
||||||
// Forward response headers
|
// Forward response headers
|
||||||
for (name, value) in response.headers().iter() {
|
for (name, value) in response.headers().iter() {
|
||||||
@@ -99,14 +221,12 @@ async fn proxy_chat_completions(
|
|||||||
if is_streaming {
|
if is_streaming {
|
||||||
// For streaming, we need to forward the response as-is
|
// For streaming, we need to forward the response as-is
|
||||||
match response.bytes().await {
|
match response.bytes().await {
|
||||||
Ok(body) => {
|
Ok(body) => resp_builder
|
||||||
resp_builder
|
.header("content-type", "text/plain; charset=utf-8")
|
||||||
.header("content-type", "text/plain; charset=utf-8")
|
.header("cache-control", "no-cache")
|
||||||
.header("cache-control", "no-cache")
|
.header("connection", "keep-alive")
|
||||||
.header("connection", "keep-alive")
|
.body(Body::from(body))
|
||||||
.body(Body::from(body))
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("Failed to read streaming response body: {}", e);
|
tracing::error!("Failed to read streaming response body: {}", e);
|
||||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
@@ -115,11 +235,9 @@ async fn proxy_chat_completions(
|
|||||||
} else {
|
} else {
|
||||||
// For non-streaming, forward the JSON response
|
// For non-streaming, forward the JSON response
|
||||||
match response.bytes().await {
|
match response.bytes().await {
|
||||||
Ok(body) => {
|
Ok(body) => resp_builder
|
||||||
resp_builder
|
.body(Body::from(body))
|
||||||
.body(Body::from(body))
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("Failed to read response body: {}", e);
|
tracing::error!("Failed to read response body: {}", e);
|
||||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
@@ -139,10 +257,16 @@ async fn proxy_models(
|
|||||||
State(proxy_client): State<ProxyClient>,
|
State(proxy_client): State<ProxyClient>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
) -> Result<Response, StatusCode> {
|
) -> Result<Response, StatusCode> {
|
||||||
let target_url = format!("{}/v1/models", proxy_client.config.inference_url());
|
let target_url = format!(
|
||||||
|
"{}/v1/models",
|
||||||
|
proxy_client
|
||||||
|
.config
|
||||||
|
.inference_url()
|
||||||
|
.expect("Invalid Configuration Detected")
|
||||||
|
);
|
||||||
|
|
||||||
tracing::info!("Proxying models request to: {}", target_url);
|
tracing::info!("Proxying models request to: {}", target_url);
|
||||||
|
|
||||||
let mut req_builder = proxy_client.client.get(&target_url);
|
let mut req_builder = proxy_client.client.get(&target_url);
|
||||||
|
|
||||||
// Forward relevant headers
|
// Forward relevant headers
|
||||||
@@ -154,8 +278,7 @@ async fn proxy_models(
|
|||||||
|
|
||||||
match req_builder.send().await {
|
match req_builder.send().await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
let mut resp_builder = Response::builder()
|
let mut resp_builder = Response::builder().status(response.status());
|
||||||
.status(response.status());
|
|
||||||
|
|
||||||
// Forward response headers
|
// Forward response headers
|
||||||
for (name, value) in response.headers().iter() {
|
for (name, value) in response.headers().iter() {
|
||||||
@@ -165,11 +288,9 @@ async fn proxy_models(
|
|||||||
}
|
}
|
||||||
|
|
||||||
match response.bytes().await {
|
match response.bytes().await {
|
||||||
Ok(body) => {
|
Ok(body) => resp_builder
|
||||||
resp_builder
|
.body(Body::from(body))
|
||||||
.body(Body::from(body))
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("Failed to read models response body: {}", e);
|
tracing::error!("Failed to read models response body: {}", e);
|
||||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
@@ -189,10 +310,16 @@ async fn proxy_embeddings(
|
|||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
body: Body,
|
body: Body,
|
||||||
) -> Result<Response, StatusCode> {
|
) -> Result<Response, StatusCode> {
|
||||||
let target_url = format!("{}/v1/embeddings", proxy_client.config.embeddings_url());
|
let target_url = format!(
|
||||||
|
"{}/v1/embeddings",
|
||||||
|
proxy_client
|
||||||
|
.config
|
||||||
|
.embeddings_url()
|
||||||
|
.expect("Invalid Configuration Detected")
|
||||||
|
);
|
||||||
|
|
||||||
tracing::info!("Proxying embeddings request to: {}", target_url);
|
tracing::info!("Proxying embeddings request to: {}", target_url);
|
||||||
|
|
||||||
// Extract body as bytes
|
// Extract body as bytes
|
||||||
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||||
Ok(bytes) => bytes,
|
Ok(bytes) => bytes,
|
||||||
@@ -203,7 +330,8 @@ async fn proxy_embeddings(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Forward the request
|
// Forward the request
|
||||||
let mut req_builder = proxy_client.client
|
let mut req_builder = proxy_client
|
||||||
|
.client
|
||||||
.post(&target_url)
|
.post(&target_url)
|
||||||
.body(body_bytes.to_vec());
|
.body(body_bytes.to_vec());
|
||||||
|
|
||||||
@@ -216,8 +344,7 @@ async fn proxy_embeddings(
|
|||||||
|
|
||||||
match req_builder.send().await {
|
match req_builder.send().await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
let mut resp_builder = Response::builder()
|
let mut resp_builder = Response::builder().status(response.status());
|
||||||
.status(response.status());
|
|
||||||
|
|
||||||
// Forward response headers
|
// Forward response headers
|
||||||
for (name, value) in response.headers().iter() {
|
for (name, value) in response.headers().iter() {
|
||||||
@@ -227,11 +354,9 @@ async fn proxy_embeddings(
|
|||||||
}
|
}
|
||||||
|
|
||||||
match response.bytes().await {
|
match response.bytes().await {
|
||||||
Ok(body) => {
|
Ok(body) => resp_builder
|
||||||
resp_builder
|
.body(Body::from(body))
|
||||||
.body(Body::from(body))
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("Failed to read embeddings response body: {}", e);
|
tracing::error!("Failed to read embeddings response body: {}", e);
|
||||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
@@ -250,7 +375,7 @@ fn should_forward_header(header_name: &str) -> bool {
|
|||||||
match header_name.to_lowercase().as_str() {
|
match header_name.to_lowercase().as_str() {
|
||||||
"content-type" | "content-length" | "authorization" | "user-agent" | "accept" => true,
|
"content-type" | "content-length" | "authorization" | "user-agent" | "accept" => true,
|
||||||
"host" | "connection" | "upgrade" => false, // Don't forward connection-specific headers
|
"host" | "connection" | "upgrade" => false, // Don't forward connection-specific headers
|
||||||
_ => true, // Forward other headers by default
|
_ => true, // Forward other headers by default
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,7 +384,7 @@ fn should_forward_response_header(header_name: &str) -> bool {
|
|||||||
match header_name.to_lowercase().as_str() {
|
match header_name.to_lowercase().as_str() {
|
||||||
"content-type" | "content-length" | "cache-control" | "connection" => true,
|
"content-type" | "content-length" | "cache-control" | "connection" => true,
|
||||||
"server" | "date" => false, // Don't forward server-specific headers
|
"server" | "date" => false, // Don't forward server-specific headers
|
||||||
_ => true, // Forward other headers by default
|
_ => true, // Forward other headers by default
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -290,14 +415,20 @@ mod tests {
|
|||||||
server_host: "127.0.0.1".to_string(),
|
server_host: "127.0.0.1".to_string(),
|
||||||
server_port: 8080,
|
server_port: 8080,
|
||||||
server_mode: ServerMode::HighAvailability,
|
server_mode: ServerMode::HighAvailability,
|
||||||
services: Services {
|
services: Some(Services {
|
||||||
inference_url: "http://test-inference:8080".to_string(),
|
inference_url: Some("http://test-inference:8080".to_string()),
|
||||||
embeddings_url: "http://test-embeddings:8080".to_string(),
|
embeddings_url: Some("http://test-embeddings:8080".to_string()),
|
||||||
},
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
let proxy_client = ProxyClient::new(config);
|
let proxy_client = ProxyClient::new(config);
|
||||||
assert_eq!(proxy_client.config.inference_url(), "http://test-inference:8080");
|
assert_eq!(
|
||||||
assert_eq!(proxy_client.config.embeddings_url(), "http://test-embeddings:8080");
|
proxy_client.config.inference_url().unwrap().as_str(),
|
||||||
|
"http://test-inference:8080"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
proxy_client.config.embeddings_url().unwrap().as_str(),
|
||||||
|
"http://test-embeddings:8080"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
@@ -1,22 +1,59 @@
|
|||||||
mod config;
|
mod config;
|
||||||
|
mod ha_mode;
|
||||||
mod middleware;
|
mod middleware;
|
||||||
mod proxy;
|
mod standalone_mode;
|
||||||
|
|
||||||
|
use crate::standalone_mode::create_standalone_router;
|
||||||
|
use axum::handler::Handler;
|
||||||
|
use axum::http::StatusCode as AxumStatusCode;
|
||||||
|
use axum::http::header;
|
||||||
use axum::response::IntoResponse;
|
use axum::response::IntoResponse;
|
||||||
use axum::routing::get;
|
use axum::routing::get;
|
||||||
use axum::{Router, http::Uri, response::Html, serve};
|
use axum::{Router, ServiceExt, http::Uri, response::Html, serve};
|
||||||
use config::ServerConfig;
|
use config::ServerConfig;
|
||||||
|
use ha_mode::create_ha_router;
|
||||||
use inference_engine::AppState;
|
use inference_engine::AppState;
|
||||||
|
use log::info;
|
||||||
use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
|
use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
|
||||||
use proxy::create_proxy_router;
|
use mime_guess::from_path;
|
||||||
use rust_embed::Embed;
|
use rust_embed::Embed;
|
||||||
use std::env;
|
use std::env;
|
||||||
|
use std::path::Component::ParentDir;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
use tower::MakeService;
|
||||||
use tower_http::classify::ServerErrorsFailureClass::StatusCode;
|
use tower_http::classify::ServerErrorsFailureClass::StatusCode;
|
||||||
use tower_http::cors::{Any, CorsLayer};
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||||
|
|
||||||
|
#[derive(Embed)]
|
||||||
|
#[folder = "../../target/site"]
|
||||||
|
#[include = "*.js"]
|
||||||
|
#[include = "*.wasm"]
|
||||||
|
#[include = "*.css"]
|
||||||
|
#[include = "*.ico"]
|
||||||
|
struct Asset;
|
||||||
|
|
||||||
|
async fn static_handler(uri: Uri) -> axum::response::Response {
|
||||||
|
// Strip the leading `/`
|
||||||
|
let path = uri.path().trim_start_matches('/');
|
||||||
|
|
||||||
|
tracing::info!("Static file: {}", &path);
|
||||||
|
|
||||||
|
// If root is requested, serve index.html
|
||||||
|
let path = if path.is_empty() { "index.html" } else { path };
|
||||||
|
|
||||||
|
match Asset::get(path) {
|
||||||
|
Some(content) => {
|
||||||
|
let body = content.data.into_owned();
|
||||||
|
let mime = from_path(path).first_or_octet_stream();
|
||||||
|
|
||||||
|
([(header::CONTENT_TYPE, mime.as_ref())], body).into_response()
|
||||||
|
}
|
||||||
|
None => (AxumStatusCode::NOT_FOUND, "404 Not Found").into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
// Initialize tracing
|
// Initialize tracing
|
||||||
@@ -49,33 +86,19 @@ async fn main() {
|
|||||||
let default_host = server_config.server_host.clone();
|
let default_host = server_config.server_host.clone();
|
||||||
let default_port = server_config.server_port;
|
let default_port = server_config.server_port;
|
||||||
|
|
||||||
// Create router based on server mode
|
let service_router = match server_config.clone().is_high_availability() {
|
||||||
let service_router = if server_config.clone().is_high_availability() {
|
Ok(is_ha) => {
|
||||||
tracing::info!("Running in HighAvailability mode - proxying to external services");
|
if is_ha {
|
||||||
tracing::info!(" Inference service URL: {}", server_config.inference_url());
|
log_config(server_config.clone());
|
||||||
tracing::info!(
|
create_ha_router(server_config.clone())
|
||||||
" Embeddings service URL: {}",
|
} else {
|
||||||
server_config.embeddings_url()
|
log_config(server_config.clone());
|
||||||
);
|
create_standalone_router(server_config)
|
||||||
|
}
|
||||||
// Use proxy router that forwards requests to external services
|
}
|
||||||
create_proxy_router(server_config.clone())
|
Err(error) => {
|
||||||
} else {
|
panic!("{}", error);
|
||||||
tracing::info!("Running in Standalone mode - using embedded services");
|
}
|
||||||
|
|
||||||
// Create unified router by merging embeddings and inference routers (existing behavior)
|
|
||||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
|
||||||
|
|
||||||
// Create AppState with correct model configuration
|
|
||||||
let app_state = AppState::default();
|
|
||||||
|
|
||||||
// Get the inference router directly from the inference engine
|
|
||||||
let inference_router = inference_engine::create_router(app_state);
|
|
||||||
|
|
||||||
// Merge the local routers
|
|
||||||
Router::new()
|
|
||||||
.merge(embeddings_router)
|
|
||||||
.merge(inference_router)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create CORS layer
|
// Create CORS layer
|
||||||
@@ -88,14 +111,17 @@ async fn main() {
|
|||||||
// Create metrics layer
|
// Create metrics layer
|
||||||
let metrics_layer = MetricsLayer::new(metrics_store);
|
let metrics_layer = MetricsLayer::new(metrics_store);
|
||||||
|
|
||||||
|
let leptos_config = chat_ui::app::AppConfig::default();
|
||||||
|
|
||||||
// Create the leptos router for the web frontend
|
// Create the leptos router for the web frontend
|
||||||
let leptos_router = leptos_app::create_leptos_router();
|
let leptos_router = chat_ui::app::create_router(leptos_config.config.leptos_options);
|
||||||
|
|
||||||
// Merge the service router with base routes and add middleware layers
|
// Merge the service router with base routes and add middleware layers
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
|
.route("/pkg/{*path}", get(static_handler))
|
||||||
.route("/health", get(|| async { "ok" }))
|
.route("/health", get(|| async { "ok" }))
|
||||||
.merge(service_router)
|
.merge(service_router)
|
||||||
.merge(leptos_router) // Add leptos web frontend routes
|
.merge(leptos_router)
|
||||||
.layer(metrics_layer) // Add metrics tracking
|
.layer(metrics_layer) // Add metrics tracking
|
||||||
.layer(cors)
|
.layer(cors)
|
||||||
.layer(TraceLayer::new_for_http());
|
.layer(TraceLayer::new_for_http());
|
||||||
@@ -121,7 +147,27 @@ async fn main() {
|
|||||||
tracing::info!(" POST /v1/embeddings - Text embeddings API");
|
tracing::info!(" POST /v1/embeddings - Text embeddings API");
|
||||||
tracing::info!(" POST /v1/chat/completions - Chat completions API");
|
tracing::info!(" POST /v1/chat/completions - Chat completions API");
|
||||||
|
|
||||||
serve(listener, app).await.unwrap();
|
serve(listener, app.into_make_service()).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn log_config(config: ServerConfig) {
|
||||||
|
match config.is_high_availability() {
|
||||||
|
Ok(is_high) => {
|
||||||
|
if is_high {
|
||||||
|
tracing::info!("Running in HighAvailability mode - proxying to external services");
|
||||||
|
tracing::info!("Inference service URL: {}", config.inference_url().unwrap());
|
||||||
|
tracing::info!(
|
||||||
|
"Embeddings service URL: {}",
|
||||||
|
config.embeddings_url().unwrap()
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
tracing::info!("Running in Standalone mode");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
panic!("{}", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Chat completions handler that properly uses the inference server crate's error handling
|
// Chat completions handler that properly uses the inference server crate's error handling
|
||||||
|
@@ -2,6 +2,8 @@ use axum::{
|
|||||||
extract::MatchedPath,
|
extract::MatchedPath,
|
||||||
http::{Request, Response},
|
http::{Request, Response},
|
||||||
};
|
};
|
||||||
|
use std::fmt;
|
||||||
|
use std::task::ready;
|
||||||
use std::{
|
use std::{
|
||||||
future::Future,
|
future::Future,
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
@@ -12,8 +14,6 @@ use std::{
|
|||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tower::{Layer, Service};
|
use tower::{Layer, Service};
|
||||||
use tracing::{debug, info};
|
use tracing::{debug, info};
|
||||||
use std::task::ready;
|
|
||||||
use std::fmt;
|
|
||||||
|
|
||||||
/// Performance metrics for a specific endpoint
|
/// Performance metrics for a specific endpoint
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
@@ -33,16 +33,16 @@ impl EndpointMetrics {
|
|||||||
pub fn add_response_time(&mut self, time_ms: u64) {
|
pub fn add_response_time(&mut self, time_ms: u64) {
|
||||||
self.count += 1;
|
self.count += 1;
|
||||||
self.total_time_ms += time_ms;
|
self.total_time_ms += time_ms;
|
||||||
|
|
||||||
if self.min_time_ms == 0 || time_ms < self.min_time_ms {
|
if self.min_time_ms == 0 || time_ms < self.min_time_ms {
|
||||||
self.min_time_ms = time_ms;
|
self.min_time_ms = time_ms;
|
||||||
}
|
}
|
||||||
|
|
||||||
if time_ms > self.max_time_ms {
|
if time_ms > self.max_time_ms {
|
||||||
self.max_time_ms = time_ms;
|
self.max_time_ms = time_ms;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the average response time in milliseconds
|
/// Get the average response time in milliseconds
|
||||||
pub fn avg_time_ms(&self) -> f64 {
|
pub fn avg_time_ms(&self) -> f64 {
|
||||||
if self.count == 0 {
|
if self.count == 0 {
|
||||||
@@ -51,12 +51,15 @@ impl EndpointMetrics {
|
|||||||
self.total_time_ms as f64 / self.count as f64
|
self.total_time_ms as f64 / self.count as f64
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a human-readable summary of the metrics
|
/// Get a human-readable summary of the metrics
|
||||||
pub fn summary(&self) -> String {
|
pub fn summary(&self) -> String {
|
||||||
format!(
|
format!(
|
||||||
"requests: {}, avg: {:.2}ms, min: {}ms, max: {}ms",
|
"requests: {}, avg: {:.2}ms, min: {}ms, max: {}ms",
|
||||||
self.count, self.avg_time_ms(), self.min_time_ms, self.max_time_ms
|
self.count,
|
||||||
|
self.avg_time_ms(),
|
||||||
|
self.min_time_ms,
|
||||||
|
self.max_time_ms
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -75,14 +78,16 @@ impl MetricsStore {
|
|||||||
endpoints: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
endpoints: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Record a request's timing information
|
/// Record a request's timing information
|
||||||
pub async fn record(&self, path: String, time_ms: u64) {
|
pub async fn record(&self, path: String, time_ms: u64) {
|
||||||
let mut endpoints = self.endpoints.lock().await;
|
let mut endpoints = self.endpoints.lock().await;
|
||||||
let metrics = endpoints.entry(path).or_insert_with(EndpointMetrics::default);
|
let metrics = endpoints
|
||||||
|
.entry(path)
|
||||||
|
.or_insert_with(EndpointMetrics::default);
|
||||||
metrics.add_response_time(time_ms);
|
metrics.add_response_time(time_ms);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get metrics for all endpoints
|
/// Get metrics for all endpoints
|
||||||
pub async fn get_all(&self) -> Vec<(String, EndpointMetrics)> {
|
pub async fn get_all(&self) -> Vec<(String, EndpointMetrics)> {
|
||||||
let endpoints = self.endpoints.lock().await;
|
let endpoints = self.endpoints.lock().await;
|
||||||
@@ -91,12 +96,12 @@ impl MetricsStore {
|
|||||||
.map(|(k, v)| (k.clone(), v.clone()))
|
.map(|(k, v)| (k.clone(), v.clone()))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Log a summary of all metrics
|
/// Log a summary of all metrics
|
||||||
pub async fn log_summary(&self) {
|
pub async fn log_summary(&self) {
|
||||||
let metrics = self.get_all().await;
|
let metrics = self.get_all().await;
|
||||||
info!("Performance metrics summary:");
|
info!("Performance metrics summary:");
|
||||||
|
|
||||||
for (path, metric) in metrics {
|
for (path, metric) in metrics {
|
||||||
info!(" {}: {}", path, metric.summary());
|
info!(" {}: {}", path, metric.summary());
|
||||||
}
|
}
|
||||||
@@ -163,26 +168,28 @@ where
|
|||||||
} else {
|
} else {
|
||||||
req.uri().path().to_string()
|
req.uri().path().to_string()
|
||||||
};
|
};
|
||||||
|
|
||||||
let method = req.method().clone();
|
let method = req.method().clone();
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let metrics_store = self.metrics_store.clone();
|
let metrics_store = self.metrics_store.clone();
|
||||||
|
|
||||||
let future = self.inner.call(req);
|
let future = self.inner.call(req);
|
||||||
|
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let response = future.await?;
|
let response = future.await?;
|
||||||
|
|
||||||
let time = start.elapsed();
|
let time = start.elapsed();
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let time_ms = time.as_millis() as u64;
|
let time_ms = time.as_millis() as u64;
|
||||||
|
|
||||||
// Record the timing in our metrics store
|
// Record the timing in our metrics store
|
||||||
metrics_store.record(format!("{} {}", method, path), time_ms).await;
|
metrics_store
|
||||||
|
.record(format!("{} {}", method, path), time_ms)
|
||||||
|
.await;
|
||||||
|
|
||||||
// Log the request timing
|
// Log the request timing
|
||||||
debug!("{} {} {} - {} ms", method, path, status, time_ms);
|
debug!("{} {} {} - {} ms", method, path, status, time_ms);
|
||||||
|
|
||||||
Ok(response)
|
Ok(response)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -214,7 +221,7 @@ impl Future for MetricsLoggerFuture {
|
|||||||
metrics_store.log_summary().await;
|
metrics_store.log_summary().await;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,7 +1,3 @@
|
|||||||
pub mod metrics;
|
pub mod metrics;
|
||||||
|
|
||||||
pub use metrics::{
|
pub use metrics::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
|
||||||
MetricsStore,
|
|
||||||
MetricsLoggerFuture,
|
|
||||||
MetricsLayer,
|
|
||||||
};
|
|
||||||
|
20
crates/predict-otron-9000/src/standalone_mode.rs
Normal file
20
crates/predict-otron-9000/src/standalone_mode.rs
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
use crate::config::ServerConfig;
|
||||||
|
use axum::Router;
|
||||||
|
use inference_engine::AppState;
|
||||||
|
|
||||||
|
pub fn create_standalone_router(server_config: ServerConfig) -> Router {
|
||||||
|
// Create unified router by merging embeddings and inference routers (existing behavior)
|
||||||
|
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||||
|
|
||||||
|
// Create AppState - no default model, must be configured explicitly
|
||||||
|
// This removes the hardcoded gemma-3-1b-it default behavior
|
||||||
|
let app_state = AppState::default();
|
||||||
|
|
||||||
|
// Get the inference router directly from the inference engine
|
||||||
|
let inference_router = inference_engine::create_router(app_state);
|
||||||
|
|
||||||
|
// Merge the local routers
|
||||||
|
Router::new()
|
||||||
|
.merge(embeddings_router)
|
||||||
|
.merge(inference_router)
|
||||||
|
}
|
88
crates/utils/Cargo.toml
Normal file
88
crates/utils/Cargo.toml
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
[package]
|
||||||
|
name = "utils"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/lib.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
accelerate-src = {version = "0.3.2", optional = true }
|
||||||
|
candle-nn = {version = "0.9.1" }
|
||||||
|
candle-transformers = {version = "0.9.1" }
|
||||||
|
|
||||||
|
candle-flash-attn = {version = "0.9.1", optional = true }
|
||||||
|
candle-onnx = {version = "0.9.1", optional = true }
|
||||||
|
candle-core="0.9.1"
|
||||||
|
csv = "1.3.0"
|
||||||
|
anyhow = "1.0.99"
|
||||||
|
cudarc = {version = "0.17.3", optional = true }
|
||||||
|
half = {version = "2.6.0", optional = true }
|
||||||
|
hf-hub = {version = "0.4.3", features = ["tokio"] }
|
||||||
|
image = {version = "0.25.6" }
|
||||||
|
intel-mkl-src = {version = "0.8.1", optional = true }
|
||||||
|
num-traits = {version = "0.2.19" }
|
||||||
|
palette = { version = "0.7.6", optional = true }
|
||||||
|
enterpolation = { version = "0.2.1", optional = true }
|
||||||
|
pyo3 = { version = "0.22.0", features = [
|
||||||
|
"auto-initialize",
|
||||||
|
"abi3-py311",
|
||||||
|
], optional = true }
|
||||||
|
rayon = {version = "1.11.0" }
|
||||||
|
rubato = { version = "0.15.0", optional = true }
|
||||||
|
safetensors = {version = "0.6.2" }
|
||||||
|
serde = {version = "1.0.219" }
|
||||||
|
serde_json = {version = "1.0.143" }
|
||||||
|
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
||||||
|
tokenizers = {version = "0.22.0", features = ["onig"] }
|
||||||
|
cpal = { version = "0.15.2", optional = true }
|
||||||
|
pdf2image = { version = "0.1.2", optional = true }
|
||||||
|
tekken-rs = { version = "0.1.1", optional = true }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
anyhow = {version = "1.0.99" }
|
||||||
|
byteorder = {version = "1.5.0" }
|
||||||
|
clap = {version = "4.5.46" }
|
||||||
|
imageproc = {version = "0.25.0" }
|
||||||
|
memmap2 = {version = "0.9.8" }
|
||||||
|
rand = {version = "0.9.2" }
|
||||||
|
ab_glyph = {version = "0.2.31" }
|
||||||
|
tracing = {version = "0.1.41" }
|
||||||
|
tracing-chrome = {version = "0.7.2" }
|
||||||
|
tracing-subscriber = {version = "0.3.20" }
|
||||||
|
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||||
|
tokio = "1.43.0"
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
anyhow = {version = "1.0.99" }
|
||||||
|
bindgen_cuda = { version = "0.1.1", optional = true }
|
||||||
|
#
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
accelerate = [
|
||||||
|
"dep:accelerate-src",
|
||||||
|
"candle-core/accelerate",
|
||||||
|
"candle-nn/accelerate",
|
||||||
|
"candle-transformers/accelerate",
|
||||||
|
]
|
||||||
|
cuda = [
|
||||||
|
"candle-core/cuda",
|
||||||
|
"candle-nn/cuda",
|
||||||
|
"candle-transformers/cuda",
|
||||||
|
"dep:bindgen_cuda",
|
||||||
|
]
|
||||||
|
cudnn = ["candle-core/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
|
||||||
|
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||||
|
mkl = [
|
||||||
|
"dep:intel-mkl-src",
|
||||||
|
"candle-core/mkl",
|
||||||
|
"candle-nn/mkl",
|
||||||
|
"candle-transformers/mkl",
|
||||||
|
]
|
||||||
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
|
onnx = ["candle-onnx"]
|
||||||
|
metal = ["candle-core/metal", "candle-nn/metal"]
|
||||||
|
microphone = ["cpal", "rubato"]
|
||||||
|
encodec = ["cpal", "symphonia", "rubato"]
|
||||||
|
mimi = ["cpal", "symphonia", "rubato"]
|
||||||
|
snac = ["cpal", "symphonia", "rubato"]
|
||||||
|
depth_anything_v2 = ["palette", "enterpolation"]
|
||||||
|
tekken = ["tekken-rs"]
|
138
crates/utils/src/audio.rs
Normal file
138
crates/utils/src/audio.rs
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
use candle_core::{Result, Tensor};
|
||||||
|
|
||||||
|
// https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/data/audio_utils.py#L57
|
||||||
|
pub fn normalize_loudness(
|
||||||
|
wav: &Tensor,
|
||||||
|
sample_rate: u32,
|
||||||
|
loudness_compressor: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::<f32>()?;
|
||||||
|
if energy < 2e-3 {
|
||||||
|
return Ok(wav.clone());
|
||||||
|
}
|
||||||
|
let wav_array = wav.to_vec1::<f32>()?;
|
||||||
|
let mut meter = crate::bs1770::ChannelLoudnessMeter::new(sample_rate);
|
||||||
|
meter.push(wav_array.into_iter());
|
||||||
|
let power = meter.as_100ms_windows();
|
||||||
|
let loudness = match crate::bs1770::gated_mean(power) {
|
||||||
|
None => return Ok(wav.clone()),
|
||||||
|
Some(gp) => gp.loudness_lkfs() as f64,
|
||||||
|
};
|
||||||
|
let delta_loudness = -14. - loudness;
|
||||||
|
let gain = 10f64.powf(delta_loudness / 20.);
|
||||||
|
let wav = (wav * gain)?;
|
||||||
|
if loudness_compressor {
|
||||||
|
wav.tanh()
|
||||||
|
} else {
|
||||||
|
Ok(wav)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "symphonia")]
|
||||||
|
pub fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
|
||||||
|
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||||
|
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
|
||||||
|
use symphonia::core::conv::FromSample;
|
||||||
|
|
||||||
|
fn conv<T>(
|
||||||
|
samples: &mut Vec<f32>,
|
||||||
|
data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>,
|
||||||
|
) where
|
||||||
|
T: symphonia::core::sample::Sample,
|
||||||
|
f32: symphonia::core::conv::FromSample<T>,
|
||||||
|
{
|
||||||
|
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open the media source.
|
||||||
|
let src = std::fs::File::open(path).map_err(candle::Error::wrap)?;
|
||||||
|
|
||||||
|
// Create the media source stream.
|
||||||
|
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||||
|
|
||||||
|
// Create a probe hint using the file's extension. [Optional]
|
||||||
|
let hint = symphonia::core::probe::Hint::new();
|
||||||
|
|
||||||
|
// Use the default options for metadata and format readers.
|
||||||
|
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||||
|
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||||
|
|
||||||
|
// Probe the media source.
|
||||||
|
let probed = symphonia::default::get_probe()
|
||||||
|
.format(&hint, mss, &fmt_opts, &meta_opts)
|
||||||
|
.map_err(candle::Error::wrap)?;
|
||||||
|
// Get the instantiated format reader.
|
||||||
|
let mut format = probed.format;
|
||||||
|
|
||||||
|
// Find the first audio track with a known (decodeable) codec.
|
||||||
|
let track = format
|
||||||
|
.tracks()
|
||||||
|
.iter()
|
||||||
|
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
|
||||||
|
.ok_or_else(|| candle::Error::Msg("no supported audio tracks".to_string()))?;
|
||||||
|
|
||||||
|
// Use the default options for the decoder.
|
||||||
|
let dec_opts: DecoderOptions = Default::default();
|
||||||
|
|
||||||
|
// Create a decoder for the track.
|
||||||
|
let mut decoder = symphonia::default::get_codecs()
|
||||||
|
.make(&track.codec_params, &dec_opts)
|
||||||
|
.map_err(|_| candle::Error::Msg("unsupported codec".to_string()))?;
|
||||||
|
let track_id = track.id;
|
||||||
|
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||||
|
let mut pcm_data = Vec::new();
|
||||||
|
// The decode loop.
|
||||||
|
while let Ok(packet) = format.next_packet() {
|
||||||
|
// Consume any new metadata that has been read since the last packet.
|
||||||
|
while !format.metadata().is_latest() {
|
||||||
|
format.metadata().pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the packet does not belong to the selected track, skip over it.
|
||||||
|
if packet.track_id() != track_id {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match decoder.decode(&packet).map_err(candle::Error::wrap)? {
|
||||||
|
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||||
|
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((pcm_data, sample_rate))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "rubato")]
|
||||||
|
pub fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
|
||||||
|
use rubato::Resampler;
|
||||||
|
|
||||||
|
let mut pcm_out =
|
||||||
|
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
|
||||||
|
|
||||||
|
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)
|
||||||
|
.map_err(candle::Error::wrap)?;
|
||||||
|
let mut output_buffer = resampler.output_buffer_allocate(true);
|
||||||
|
let mut pos_in = 0;
|
||||||
|
while pos_in + resampler.input_frames_next() < pcm_in.len() {
|
||||||
|
let (in_len, out_len) = resampler
|
||||||
|
.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)
|
||||||
|
.map_err(candle::Error::wrap)?;
|
||||||
|
pos_in += in_len;
|
||||||
|
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if pos_in < pcm_in.len() {
|
||||||
|
let (_in_len, out_len) = resampler
|
||||||
|
.process_partial_into_buffer(Some(&[&pcm_in[pos_in..]]), &mut output_buffer, None)
|
||||||
|
.map_err(candle::Error::wrap)?;
|
||||||
|
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(pcm_out)
|
||||||
|
}
|
506
crates/utils/src/bs1770.rs
Normal file
506
crates/utils/src/bs1770.rs
Normal file
@@ -0,0 +1,506 @@
|
|||||||
|
// Copied from https://github.com/ruuda/bs1770/blob/master/src/lib.rs
|
||||||
|
// BS1770 -- Loudness analysis library conforming to ITU-R BS.1770
|
||||||
|
// Copyright 2020 Ruud van Asseldonk
|
||||||
|
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// A copy of the License has been included in the root of the repository.
|
||||||
|
|
||||||
|
//! Loudness analysis conforming to [ITU-R BS.1770-4][bs17704].
|
||||||
|
//!
|
||||||
|
//! This library offers the building blocks to perform BS.1770 loudness
|
||||||
|
//! measurements, but you need to put the pieces together yourself.
|
||||||
|
//!
|
||||||
|
//! [bs17704]: https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
|
||||||
|
//!
|
||||||
|
//! # Stereo integrated loudness example
|
||||||
|
//!
|
||||||
|
//! ```ignore
|
||||||
|
//! # fn load_stereo_audio() -> [Vec<i16>; 2] {
|
||||||
|
//! # [vec![0; 48_000], vec![0; 48_000]]
|
||||||
|
//! # }
|
||||||
|
//! #
|
||||||
|
//! let sample_rate_hz = 44_100;
|
||||||
|
//! let bits_per_sample = 16;
|
||||||
|
//! let channel_samples: [Vec<i16>; 2] = load_stereo_audio();
|
||||||
|
//!
|
||||||
|
//! // When converting integer samples to float, note that the maximum amplitude
|
||||||
|
//! // is `1 << (bits_per_sample - 1)`, one bit is the sign bit.
|
||||||
|
//! let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
|
||||||
|
//!
|
||||||
|
//! let channel_power: Vec<_> = channel_samples.iter().map(|samples| {
|
||||||
|
//! let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate_hz);
|
||||||
|
//! meter.push(samples.iter().map(|&s| s as f32 * normalizer));
|
||||||
|
//! meter.into_100ms_windows()
|
||||||
|
//! }).collect();
|
||||||
|
//!
|
||||||
|
//! let stereo_power = bs1770::reduce_stereo(
|
||||||
|
//! channel_power[0].as_ref(),
|
||||||
|
//! channel_power[1].as_ref(),
|
||||||
|
//! );
|
||||||
|
//!
|
||||||
|
//! let gated_power = bs1770::gated_mean(
|
||||||
|
//! stereo_power.as_ref()
|
||||||
|
//! ).unwrap_or(bs1770::Power(0.0));
|
||||||
|
//! println!("Integrated loudness: {:.1} LUFS", gated_power.loudness_lkfs());
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use std::f32;
|
||||||
|
|
||||||
|
/// Coefficients for a 2nd-degree infinite impulse response filter.
|
||||||
|
///
|
||||||
|
/// Coefficient a0 is implicitly 1.0.
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct Filter {
|
||||||
|
a1: f32,
|
||||||
|
a2: f32,
|
||||||
|
b0: f32,
|
||||||
|
b1: f32,
|
||||||
|
b2: f32,
|
||||||
|
|
||||||
|
// The past two input and output samples.
|
||||||
|
x1: f32,
|
||||||
|
x2: f32,
|
||||||
|
y1: f32,
|
||||||
|
y2: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Filter {
|
||||||
|
/// Stage 1 of th BS.1770-4 pre-filter.
|
||||||
|
pub fn high_shelf(sample_rate_hz: f32) -> Filter {
|
||||||
|
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||||
|
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
|
||||||
|
let gain_db = 3.999_843_8;
|
||||||
|
let q = 0.707_175_25;
|
||||||
|
let center_hz = 1_681.974_5;
|
||||||
|
|
||||||
|
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||||
|
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L134-L143.
|
||||||
|
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
|
||||||
|
let vh = 10.0_f32.powf(gain_db / 20.0);
|
||||||
|
let vb = vh.powf(0.499_666_78);
|
||||||
|
let a0 = 1.0 + k / q + k * k;
|
||||||
|
Filter {
|
||||||
|
b0: (vh + vb * k / q + k * k) / a0,
|
||||||
|
b1: 2.0 * (k * k - vh) / a0,
|
||||||
|
b2: (vh - vb * k / q + k * k) / a0,
|
||||||
|
a1: 2.0 * (k * k - 1.0) / a0,
|
||||||
|
a2: (1.0 - k / q + k * k) / a0,
|
||||||
|
|
||||||
|
x1: 0.0,
|
||||||
|
x2: 0.0,
|
||||||
|
y1: 0.0,
|
||||||
|
y2: 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stage 2 of th BS.1770-4 pre-filter.
|
||||||
|
pub fn high_pass(sample_rate_hz: f32) -> Filter {
|
||||||
|
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||||
|
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
|
||||||
|
let q = 0.500_327_05;
|
||||||
|
let center_hz = 38.135_47;
|
||||||
|
|
||||||
|
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||||
|
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L145-L151
|
||||||
|
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
|
||||||
|
Filter {
|
||||||
|
a1: 2.0 * (k * k - 1.0) / (1.0 + k / q + k * k),
|
||||||
|
a2: (1.0 - k / q + k * k) / (1.0 + k / q + k * k),
|
||||||
|
b0: 1.0,
|
||||||
|
b1: -2.0,
|
||||||
|
b2: 1.0,
|
||||||
|
|
||||||
|
x1: 0.0,
|
||||||
|
x2: 0.0,
|
||||||
|
y1: 0.0,
|
||||||
|
y2: 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feed the next input sample, get the next output sample.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn apply(&mut self, x0: f32) -> f32 {
|
||||||
|
let y0 = 0.0 + self.b0 * x0 + self.b1 * self.x1 + self.b2 * self.x2
|
||||||
|
- self.a1 * self.y1
|
||||||
|
- self.a2 * self.y2;
|
||||||
|
|
||||||
|
self.x2 = self.x1;
|
||||||
|
self.x1 = x0;
|
||||||
|
self.y2 = self.y1;
|
||||||
|
self.y1 = y0;
|
||||||
|
|
||||||
|
y0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compensated sum, for summing many values of different orders of magnitude
|
||||||
|
/// accurately.
|
||||||
|
#[derive(Copy, Clone, PartialEq)]
|
||||||
|
struct Sum {
|
||||||
|
sum: f32,
|
||||||
|
residue: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Sum {
|
||||||
|
#[inline(always)]
|
||||||
|
fn zero() -> Sum {
|
||||||
|
Sum {
|
||||||
|
sum: 0.0,
|
||||||
|
residue: 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn add(&mut self, x: f32) {
|
||||||
|
let sum = self.sum + (self.residue + x);
|
||||||
|
self.residue = (self.residue + x) - (sum - self.sum);
|
||||||
|
self.sum = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The mean of the squares of the K-weighted samples in a window of time.
|
||||||
|
///
|
||||||
|
/// K-weighted power is equivalent to K-weighted loudness, the only difference
|
||||||
|
/// is one of scale: power is quadratic in sample amplitudes, whereas loudness
|
||||||
|
/// units are logarithmic. `loudness_lkfs` and `from_lkfs` convert between power,
|
||||||
|
/// and K-weighted Loudness Units relative to nominal Full Scale (LKFS).
|
||||||
|
///
|
||||||
|
/// The term “LKFS” (Loudness Units, K-Weighted, relative to nominal Full Scale)
|
||||||
|
/// is used in BS.1770-4 to emphasize K-weighting, but the term is otherwise
|
||||||
|
/// interchangeable with the more widespread term “LUFS” (Loudness Units,
|
||||||
|
/// relative to Full Scale). Loudness units are related to decibels in the
|
||||||
|
/// following sense: boosting a signal that has a loudness of
|
||||||
|
/// -<var>L<sub>K</sub></var> LUFS by <var>L<sub>K</sub></var> dB (by
|
||||||
|
/// multiplying the amplitude by 10<sup><var>L<sub>K</sub></var>/20</sup>) will
|
||||||
|
/// bring the loudness to 0 LUFS.
|
||||||
|
///
|
||||||
|
/// K-weighting refers to a high-shelf and high-pass filter that model the
|
||||||
|
/// effect that humans perceive a certain amount of power in low frequencies to
|
||||||
|
/// be less loud than the same amount of power in higher frequencies. In this
|
||||||
|
/// library the `Power` type is used exclusively to refer to power after applying K-weighting.
|
||||||
|
///
|
||||||
|
/// The nominal “full scale” is the range [-1.0, 1.0]. Because the power is the
|
||||||
|
/// mean square of the samples, if no input samples exceeded the full scale, the
|
||||||
|
/// power will be in the range [0.0, 1.0]. However, the power delivered by
|
||||||
|
/// multiple channels, which is a weighted sum over individual channel powers,
|
||||||
|
/// can exceed this range, because the weighted sum is not normalized.
|
||||||
|
#[derive(Copy, Clone, PartialEq, PartialOrd)]
|
||||||
|
pub struct Power(pub f32);
|
||||||
|
|
||||||
|
impl Power {
|
||||||
|
/// Convert Loudness Units relative to Full Scale into a squared sample amplitude.
|
||||||
|
///
|
||||||
|
/// This is the inverse of `loudness_lkfs`.
|
||||||
|
pub fn from_lkfs(lkfs: f32) -> Power {
|
||||||
|
// The inverse of the formula below.
|
||||||
|
Power(10.0_f32.powf((lkfs + 0.691) * 0.1))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the loudness of this window in Loudness Units, K-weighted, relative to Full Scale.
|
||||||
|
///
|
||||||
|
/// This is the inverse of `from_lkfs`.
|
||||||
|
pub fn loudness_lkfs(&self) -> f32 {
|
||||||
|
// Equation 2 (p.5) of BS.1770-4.
|
||||||
|
-0.691 + 10.0 * self.0.log10()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A `T` value for non-overlapping windows of audio, 100ms in length.
|
||||||
|
///
|
||||||
|
/// The `ChannelLoudnessMeter` applies K-weighting and then produces the power
|
||||||
|
/// for non-overlapping windows of 100ms duration.
|
||||||
|
///
|
||||||
|
/// These non-overlapping 100ms windows can later be combined into overlapping
|
||||||
|
/// windows of 400ms, spaced 100ms apart, to compute instantaneous loudness or
|
||||||
|
/// to perform a gated measurement, or they can be combined into even larger
|
||||||
|
/// windows for a momentary loudness measurement.
|
||||||
|
#[derive(Copy, Clone, Debug)]
|
||||||
|
pub struct Windows100ms<T> {
|
||||||
|
pub inner: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Windows100ms<T> {
|
||||||
|
/// Wrap a new empty vector.
|
||||||
|
pub fn new() -> Windows100ms<Vec<T>> {
|
||||||
|
Windows100ms { inner: Vec::new() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply `as_ref` to the inner value.
|
||||||
|
pub fn as_ref(&self) -> Windows100ms<&[Power]>
|
||||||
|
where
|
||||||
|
T: AsRef<[Power]>,
|
||||||
|
{
|
||||||
|
Windows100ms {
|
||||||
|
inner: self.inner.as_ref(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply `as_mut` to the inner value.
|
||||||
|
pub fn as_mut(&mut self) -> Windows100ms<&mut [Power]>
|
||||||
|
where
|
||||||
|
T: AsMut<[Power]>,
|
||||||
|
{
|
||||||
|
Windows100ms {
|
||||||
|
inner: self.inner.as_mut(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::len_without_is_empty)]
|
||||||
|
/// Apply `len` to the inner value.
|
||||||
|
pub fn len(&self) -> usize
|
||||||
|
where
|
||||||
|
T: AsRef<[Power]>,
|
||||||
|
{
|
||||||
|
self.inner.as_ref().len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Measures K-weighted power of non-overlapping 100ms windows of a single channel of audio.
|
||||||
|
///
|
||||||
|
/// # Output
|
||||||
|
///
|
||||||
|
/// The output of the meter is an intermediate result in the form of power for
|
||||||
|
/// 100ms non-overlapping windows. The windows need to be processed further to
|
||||||
|
/// get one of the instantaneous, momentary, and integrated loudness
|
||||||
|
/// measurements defined in BS.1770.
|
||||||
|
///
|
||||||
|
/// The windows can also be inspected directly; the data is meaningful
|
||||||
|
/// on its own (the K-weighted power delivered in that window of time), but it
|
||||||
|
/// is not something that BS.1770 defines a term for.
|
||||||
|
///
|
||||||
|
/// # Multichannel audio
|
||||||
|
///
|
||||||
|
/// To perform a loudness measurement of multichannel audio, construct a
|
||||||
|
/// `ChannelLoudnessMeter` per channel, and later combine the measured power
|
||||||
|
/// with e.g. `reduce_stereo`.
|
||||||
|
///
|
||||||
|
/// # Instantaneous loudness
|
||||||
|
///
|
||||||
|
/// The instantaneous loudness is the power over a 400ms window, so you can
|
||||||
|
/// average four 100ms windows. No special functionality is implemented to help
|
||||||
|
/// with that at this time. ([Pull requests would be accepted.][contribute])
|
||||||
|
///
|
||||||
|
/// # Momentary loudness
|
||||||
|
///
|
||||||
|
/// The momentary loudness is the power over a 3-second window, so you can
|
||||||
|
/// average thirty 100ms windows. No special functionality is implemented to
|
||||||
|
/// help with that at this time. ([Pull requests would be accepted.][contribute])
|
||||||
|
///
|
||||||
|
/// # Integrated loudness
|
||||||
|
///
|
||||||
|
/// Use `gated_mean` to perform an integrated loudness measurement:
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// # use std::iter;
|
||||||
|
/// # use bs1770::{ChannelLoudnessMeter, gated_mean};
|
||||||
|
/// # let sample_rate_hz = 44_100;
|
||||||
|
/// # let samples_per_100ms = sample_rate_hz / 10;
|
||||||
|
/// # let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
|
||||||
|
/// # meter.push((0..44_100).map(|i| (i as f32 * 0.01).sin()));
|
||||||
|
/// let integrated_loudness_lkfs = gated_mean(meter.as_100ms_windows())
|
||||||
|
/// .unwrap_or(bs1770::Power(0.0))
|
||||||
|
/// .loudness_lkfs();
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// [contribute]: https://github.com/ruuda/bs1770/blob/master/CONTRIBUTING.md
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ChannelLoudnessMeter {
|
||||||
|
/// The number of samples that fit in 100ms of audio.
|
||||||
|
samples_per_100ms: u32,
|
||||||
|
|
||||||
|
/// Stage 1 filter (head effects, high shelf).
|
||||||
|
filter_stage1: Filter,
|
||||||
|
|
||||||
|
/// Stage 2 filter (high-pass).
|
||||||
|
filter_stage2: Filter,
|
||||||
|
|
||||||
|
/// Sum of the squares over non-overlapping windows of 100ms.
|
||||||
|
windows: Windows100ms<Vec<Power>>,
|
||||||
|
|
||||||
|
/// The number of samples in the current unfinished window.
|
||||||
|
count: u32,
|
||||||
|
|
||||||
|
/// The sum of the squares of the samples in the current unfinished window.
|
||||||
|
square_sum: Sum,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChannelLoudnessMeter {
|
||||||
|
/// Construct a new loudness meter for the given sample rate.
|
||||||
|
pub fn new(sample_rate_hz: u32) -> ChannelLoudnessMeter {
|
||||||
|
ChannelLoudnessMeter {
|
||||||
|
samples_per_100ms: sample_rate_hz / 10,
|
||||||
|
filter_stage1: Filter::high_shelf(sample_rate_hz as f32),
|
||||||
|
filter_stage2: Filter::high_pass(sample_rate_hz as f32),
|
||||||
|
windows: Windows100ms::new(),
|
||||||
|
count: 0,
|
||||||
|
square_sum: Sum::zero(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feed input samples for loudness analysis.
|
||||||
|
///
|
||||||
|
/// # Full scale
|
||||||
|
///
|
||||||
|
/// Full scale for the input samples is the interval [-1.0, 1.0]. If your
|
||||||
|
/// input consists of signed integer samples, you can convert as follows:
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// # let mut meter = bs1770::ChannelLoudnessMeter::new(44_100);
|
||||||
|
/// # let bits_per_sample = 16_usize;
|
||||||
|
/// # let samples = &[0_i16];
|
||||||
|
/// // Note that the maximum amplitude is `1 << (bits_per_sample - 1)`,
|
||||||
|
/// // one bit is the sign bit.
|
||||||
|
/// let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
|
||||||
|
/// meter.push(samples.iter().map(|&s| s as f32 * normalizer));
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// # Repeated calls
|
||||||
|
///
|
||||||
|
/// You can call `push` multiple times to feed multiple batches of samples.
|
||||||
|
/// This is equivalent to feeding a single chained iterator. The leftover of
|
||||||
|
/// samples that did not fill a full 100ms window is not discarded:
|
||||||
|
///
|
||||||
|
/// ```ignore
|
||||||
|
/// # use std::iter;
|
||||||
|
/// # use bs1770::ChannelLoudnessMeter;
|
||||||
|
/// let sample_rate_hz = 44_100;
|
||||||
|
/// let samples_per_100ms = sample_rate_hz / 10;
|
||||||
|
/// let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
|
||||||
|
///
|
||||||
|
/// meter.push(iter::repeat(0.0).take(samples_per_100ms as usize - 1));
|
||||||
|
/// assert_eq!(meter.as_100ms_windows().len(), 0);
|
||||||
|
///
|
||||||
|
/// meter.push(iter::once(0.0));
|
||||||
|
/// assert_eq!(meter.as_100ms_windows().len(), 1);
|
||||||
|
/// ```
|
||||||
|
pub fn push<I: Iterator<Item = f32>>(&mut self, samples: I) {
|
||||||
|
let normalizer = 1.0 / self.samples_per_100ms as f32;
|
||||||
|
|
||||||
|
// LLVM, if you could go ahead and inline those apply calls, and then
|
||||||
|
// unroll and vectorize the loop, that'd be terrific.
|
||||||
|
for x in samples {
|
||||||
|
let y = self.filter_stage1.apply(x);
|
||||||
|
let z = self.filter_stage2.apply(y);
|
||||||
|
|
||||||
|
self.square_sum.add(z * z);
|
||||||
|
self.count += 1;
|
||||||
|
|
||||||
|
// TODO: Should this branch be marked cold?
|
||||||
|
if self.count == self.samples_per_100ms {
|
||||||
|
let mean_squares = Power(self.square_sum.sum * normalizer);
|
||||||
|
self.windows.inner.push(mean_squares);
|
||||||
|
// We intentionally do not reset the residue. That way, leftover
|
||||||
|
// energy from this window is not lost, so for the file overall,
|
||||||
|
// the sum remains more accurate.
|
||||||
|
self.square_sum.sum = 0.0;
|
||||||
|
self.count = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a reference to the 100ms windows analyzed so far.
|
||||||
|
pub fn as_100ms_windows(&self) -> Windows100ms<&[Power]> {
|
||||||
|
self.windows.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return all 100ms windows analyzed so far.
|
||||||
|
pub fn into_100ms_windows(self) -> Windows100ms<Vec<Power>> {
|
||||||
|
self.windows
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Combine power for multiple channels by taking a weighted sum.
|
||||||
|
///
|
||||||
|
/// Note that BS.1770-4 defines power for a multi-channel signal as a weighted
|
||||||
|
/// sum over channels which is not normalized. This means that a stereo signal
|
||||||
|
/// is inherently louder than a mono signal. For a mono signal played back on
|
||||||
|
/// stereo speakers, you should therefore still apply `reduce_stereo`, passing
|
||||||
|
/// in the same signal for both channels.
|
||||||
|
pub fn reduce_stereo(
|
||||||
|
left: Windows100ms<&[Power]>,
|
||||||
|
right: Windows100ms<&[Power]>,
|
||||||
|
) -> Windows100ms<Vec<Power>> {
|
||||||
|
assert_eq!(
|
||||||
|
left.len(),
|
||||||
|
right.len(),
|
||||||
|
"Channels must have the same length."
|
||||||
|
);
|
||||||
|
let mut result = Vec::with_capacity(left.len());
|
||||||
|
for (l, r) in left.inner.iter().zip(right.inner) {
|
||||||
|
result.push(Power(l.0 + r.0));
|
||||||
|
}
|
||||||
|
Windows100ms { inner: result }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// In-place version of `reduce_stereo` that stores the result in the former left channel.
|
||||||
|
pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows100ms<&[Power]>) {
|
||||||
|
assert_eq!(
|
||||||
|
left.len(),
|
||||||
|
right.len(),
|
||||||
|
"Channels must have the same length."
|
||||||
|
);
|
||||||
|
for (l, r) in left.inner.iter_mut().zip(right.inner) {
|
||||||
|
l.0 += r.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
|
||||||
|
///
|
||||||
|
/// The integrated loudness measurement is not just the average power over the
|
||||||
|
/// entire signal. BS.1770-4 defines two stages of gating that exclude
|
||||||
|
/// parts of the signal, to ensure that silent parts do not contribute to the
|
||||||
|
/// loudness measurement. This function performs that gating, and returns the
|
||||||
|
/// average power over the windows that were not excluded.
|
||||||
|
///
|
||||||
|
/// The result of this function is the integrated loudness measurement.
|
||||||
|
///
|
||||||
|
/// When no signal remains after applying the gate, this function returns
|
||||||
|
/// `None`. In particular, this happens when all of the signal is softer than
|
||||||
|
/// -70 LKFS, including a signal that consists of pure silence.
|
||||||
|
pub fn gated_mean(windows_100ms: Windows100ms<&[Power]>) -> Option<Power> {
|
||||||
|
let mut gating_blocks = Vec::with_capacity(windows_100ms.len());
|
||||||
|
|
||||||
|
// Stage 1: an absolute threshold of -70 LKFS. (Equation 6, p.6.)
|
||||||
|
let absolute_threshold = Power::from_lkfs(-70.0);
|
||||||
|
|
||||||
|
// Iterate over all 400ms windows.
|
||||||
|
for window in windows_100ms.inner.windows(4) {
|
||||||
|
// Note that the sum over channels has already been performed at this point.
|
||||||
|
let gating_block_power = Power(0.25 * window.iter().map(|mean| mean.0).sum::<f32>());
|
||||||
|
|
||||||
|
if gating_block_power > absolute_threshold {
|
||||||
|
gating_blocks.push(gating_block_power);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gating_blocks.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the loudness after applying the absolute gate, in order to
|
||||||
|
// determine the threshold for the relative gate.
|
||||||
|
let mut sum_power = Sum::zero();
|
||||||
|
for &gating_block_power in &gating_blocks {
|
||||||
|
sum_power.add(gating_block_power.0);
|
||||||
|
}
|
||||||
|
let absolute_gated_power = Power(sum_power.sum / (gating_blocks.len() as f32));
|
||||||
|
|
||||||
|
// Stage 2: Apply the relative gate.
|
||||||
|
let relative_threshold = Power::from_lkfs(absolute_gated_power.loudness_lkfs() - 10.0);
|
||||||
|
let mut sum_power = Sum::zero();
|
||||||
|
let mut n_blocks = 0_usize;
|
||||||
|
for &gating_block_power in &gating_blocks {
|
||||||
|
if gating_block_power > relative_threshold {
|
||||||
|
sum_power.add(gating_block_power.0);
|
||||||
|
n_blocks += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if n_blocks == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let relative_gated_power = Power(sum_power.sum / n_blocks as f32);
|
||||||
|
Some(relative_gated_power)
|
||||||
|
}
|
82
crates/utils/src/coco_classes.rs
Normal file
82
crates/utils/src/coco_classes.rs
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
pub const NAMES: [&str; 80] = [
|
||||||
|
"person",
|
||||||
|
"bicycle",
|
||||||
|
"car",
|
||||||
|
"motorbike",
|
||||||
|
"aeroplane",
|
||||||
|
"bus",
|
||||||
|
"train",
|
||||||
|
"truck",
|
||||||
|
"boat",
|
||||||
|
"traffic light",
|
||||||
|
"fire hydrant",
|
||||||
|
"stop sign",
|
||||||
|
"parking meter",
|
||||||
|
"bench",
|
||||||
|
"bird",
|
||||||
|
"cat",
|
||||||
|
"dog",
|
||||||
|
"horse",
|
||||||
|
"sheep",
|
||||||
|
"cow",
|
||||||
|
"elephant",
|
||||||
|
"bear",
|
||||||
|
"zebra",
|
||||||
|
"giraffe",
|
||||||
|
"backpack",
|
||||||
|
"umbrella",
|
||||||
|
"handbag",
|
||||||
|
"tie",
|
||||||
|
"suitcase",
|
||||||
|
"frisbee",
|
||||||
|
"skis",
|
||||||
|
"snowboard",
|
||||||
|
"sports ball",
|
||||||
|
"kite",
|
||||||
|
"baseball bat",
|
||||||
|
"baseball glove",
|
||||||
|
"skateboard",
|
||||||
|
"surfboard",
|
||||||
|
"tennis racket",
|
||||||
|
"bottle",
|
||||||
|
"wine glass",
|
||||||
|
"cup",
|
||||||
|
"fork",
|
||||||
|
"knife",
|
||||||
|
"spoon",
|
||||||
|
"bowl",
|
||||||
|
"banana",
|
||||||
|
"apple",
|
||||||
|
"sandwich",
|
||||||
|
"orange",
|
||||||
|
"broccoli",
|
||||||
|
"carrot",
|
||||||
|
"hot dog",
|
||||||
|
"pizza",
|
||||||
|
"donut",
|
||||||
|
"cake",
|
||||||
|
"chair",
|
||||||
|
"sofa",
|
||||||
|
"pottedplant",
|
||||||
|
"bed",
|
||||||
|
"diningtable",
|
||||||
|
"toilet",
|
||||||
|
"tvmonitor",
|
||||||
|
"laptop",
|
||||||
|
"mouse",
|
||||||
|
"remote",
|
||||||
|
"keyboard",
|
||||||
|
"cell phone",
|
||||||
|
"microwave",
|
||||||
|
"oven",
|
||||||
|
"toaster",
|
||||||
|
"sink",
|
||||||
|
"refrigerator",
|
||||||
|
"book",
|
||||||
|
"clock",
|
||||||
|
"vase",
|
||||||
|
"scissors",
|
||||||
|
"teddy bear",
|
||||||
|
"hair drier",
|
||||||
|
"toothbrush",
|
||||||
|
];
|
1056
crates/utils/src/imagenet.rs
Normal file
1056
crates/utils/src/imagenet.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,17 @@
|
|||||||
use candle_core::utils::{cuda_is_available, metal_is_available};
|
extern crate candle_core;
|
||||||
use candle_core::{Device, Result, Tensor};
|
extern crate candle_transformers;
|
||||||
|
extern crate tokenizers;
|
||||||
|
|
||||||
pub fn device(cpu: bool) -> Result<Device> {
|
pub mod audio;
|
||||||
|
pub mod bs1770;
|
||||||
|
pub mod coco_classes;
|
||||||
|
pub mod imagenet;
|
||||||
|
pub mod token_output_stream;
|
||||||
|
pub mod wav;
|
||||||
|
use candle_core::{Device, Tensor, utils::{cuda_is_available, metal_is_available}};
|
||||||
|
|
||||||
|
|
||||||
|
pub fn device(cpu: bool) -> Result<Device, anyhow::Error> {
|
||||||
if cpu {
|
if cpu {
|
||||||
Ok(Device::Cpu)
|
Ok(Device::Cpu)
|
||||||
} else if cuda_is_available() {
|
} else if cuda_is_available() {
|
||||||
@@ -26,7 +36,7 @@ pub fn device(cpu: bool) -> Result<Device> {
|
|||||||
pub fn load_image<P: AsRef<std::path::Path>>(
|
pub fn load_image<P: AsRef<std::path::Path>>(
|
||||||
p: P,
|
p: P,
|
||||||
resize_longest: Option<usize>,
|
resize_longest: Option<usize>,
|
||||||
) -> Result<(Tensor, usize, usize)> {
|
) -> Result<(Tensor, usize, usize), anyhow::Error> {
|
||||||
let img = image::ImageReader::open(p)?
|
let img = image::ImageReader::open(p)?
|
||||||
.decode()
|
.decode()
|
||||||
.map_err(candle_core::Error::wrap)?;
|
.map_err(candle_core::Error::wrap)?;
|
||||||
@@ -57,7 +67,7 @@ pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
|||||||
p: P,
|
p: P,
|
||||||
width: usize,
|
width: usize,
|
||||||
height: usize,
|
height: usize,
|
||||||
) -> Result<Tensor> {
|
) -> candle_core::Result<Tensor> {
|
||||||
let img = image::ImageReader::open(p)?
|
let img = image::ImageReader::open(p)?
|
||||||
.decode()
|
.decode()
|
||||||
.map_err(candle_core::Error::wrap)?
|
.map_err(candle_core::Error::wrap)?
|
||||||
@@ -73,60 +83,36 @@ pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
|||||||
|
|
||||||
/// Saves an image to disk using the image crate, this expects an input with shape
|
/// Saves an image to disk using the image crate, this expects an input with shape
|
||||||
/// (c, height, width).
|
/// (c, height, width).
|
||||||
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<(), anyhow::Error> {
|
||||||
let p = p.as_ref();
|
let p = p.as_ref();
|
||||||
let (channel, height, width) = img.dims3()?;
|
let (channel, height, width) = img.dims3()?;
|
||||||
if channel != 3 {
|
if channel != 3 {
|
||||||
candle_core::bail!("save_image expects an input of shape (3, height, width)")
|
anyhow::bail!("save_image expects an input of shape (3, height, width)")
|
||||||
}
|
}
|
||||||
let img = img.permute((1, 2, 0))?.flatten_all()?;
|
let img = img.permute((1, 2, 0))?.flatten_all()?;
|
||||||
let pixels = img.to_vec1::<u8>()?;
|
let pixels = img.to_vec1::<u8>()?;
|
||||||
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||||
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
||||||
Some(image) => image,
|
Some(image) => image,
|
||||||
None => candle_core::bail!("error saving image {p:?}"),
|
None => anyhow::bail!("error saving image {p:?}"),
|
||||||
};
|
};
|
||||||
image.save(p).map_err(candle_core::Error::wrap)?;
|
image.save(p).map_err(candle_core::Error::wrap)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn save_image_resize<P: AsRef<std::path::Path>>(
|
|
||||||
img: &Tensor,
|
|
||||||
p: P,
|
|
||||||
h: usize,
|
|
||||||
w: usize,
|
|
||||||
) -> Result<()> {
|
|
||||||
let p = p.as_ref();
|
|
||||||
let (channel, height, width) = img.dims3()?;
|
|
||||||
if channel != 3 {
|
|
||||||
candle_core::bail!("save_image expects an input of shape (3, height, width)")
|
|
||||||
}
|
|
||||||
let img = img.permute((1, 2, 0))?.flatten_all()?;
|
|
||||||
let pixels = img.to_vec1::<u8>()?;
|
|
||||||
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
|
||||||
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
|
||||||
Some(image) => image,
|
|
||||||
None => candle_core::bail!("error saving image {p:?}"),
|
|
||||||
};
|
|
||||||
let image = image::DynamicImage::from(image);
|
|
||||||
let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
|
|
||||||
image.save(p).map_err(candle_core::Error::wrap)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Loads the safetensors files for a model from the hub based on a json index file.
|
/// Loads the safetensors files for a model from the hub based on a json index file.
|
||||||
pub fn hub_load_safetensors(
|
pub fn hub_load_safetensors(
|
||||||
repo: &hf_hub::api::sync::ApiRepo,
|
repo: &hf_hub::api::sync::ApiRepo,
|
||||||
json_file: &str,
|
json_file: &str,
|
||||||
) -> Result<Vec<std::path::PathBuf>> {
|
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
|
||||||
let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?;
|
let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?;
|
||||||
let json_file = std::fs::File::open(json_file)?;
|
let json_file = std::fs::File::open(json_file)?;
|
||||||
let json: serde_json::Value =
|
let json: serde_json::Value =
|
||||||
serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?;
|
serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?;
|
||||||
let weight_map = match json.get("weight_map") {
|
let weight_map = match json.get("weight_map") {
|
||||||
None => candle_core::bail!("no weight map in {json_file:?}"),
|
None => anyhow::bail!("no weight map in {json_file:?}"),
|
||||||
Some(serde_json::Value::Object(map)) => map,
|
Some(serde_json::Value::Object(map)) => map,
|
||||||
Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"),
|
Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"),
|
||||||
};
|
};
|
||||||
let mut safetensors_files = std::collections::HashSet::new();
|
let mut safetensors_files = std::collections::HashSet::new();
|
||||||
for value in weight_map.values() {
|
for value in weight_map.values() {
|
||||||
@@ -136,22 +122,25 @@ pub fn hub_load_safetensors(
|
|||||||
}
|
}
|
||||||
let safetensors_files = safetensors_files
|
let safetensors_files = safetensors_files
|
||||||
.iter()
|
.iter()
|
||||||
.map(|v| repo.get(v).map_err(candle_core::Error::wrap))
|
.map(|v| {
|
||||||
.collect::<Result<Vec<_>>>()?;
|
repo.get(v)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, std::io::Error, >>()?;
|
||||||
Ok(safetensors_files)
|
Ok(safetensors_files)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
||||||
path: P,
|
path: P,
|
||||||
json_file: &str,
|
json_file: &str,
|
||||||
) -> Result<Vec<std::path::PathBuf>> {
|
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
|
||||||
let path = path.as_ref();
|
let path = path.as_ref();
|
||||||
let jsfile = std::fs::File::open(path.join(json_file))?;
|
let jsfile = std::fs::File::open(path.join(json_file))?;
|
||||||
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||||
let weight_map = match json.get("weight_map") {
|
let weight_map = match json.get("weight_map") {
|
||||||
None => candle_core::bail!("no weight map in {json_file:?}"),
|
None => anyhow::bail!("no weight map in {json_file:?}"),
|
||||||
Some(serde_json::Value::Object(map)) => map,
|
Some(serde_json::Value::Object(map)) => map,
|
||||||
Some(_) => candle_core::bail!("weight map in {json_file:?} is not a map"),
|
Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"),
|
||||||
};
|
};
|
||||||
let mut safetensors_files = std::collections::HashSet::new();
|
let mut safetensors_files = std::collections::HashSet::new();
|
||||||
for value in weight_map.values() {
|
for value in weight_map.values() {
|
||||||
@@ -164,4 +153,4 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
|||||||
.map(|v| path.join(v))
|
.map(|v| path.join(v))
|
||||||
.collect();
|
.collect();
|
||||||
Ok(safetensors_files)
|
Ok(safetensors_files)
|
||||||
}
|
}
|
3
crates/utils/src/main.rs
Normal file
3
crates/utils/src/main.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
fn main() {
|
||||||
|
println!("Hello, world!");
|
||||||
|
}
|
@@ -1,7 +1,6 @@
|
|||||||
use candle_core::Result;
|
use candle_core::Result;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
|
|
||||||
/// streaming way rather than having to wait for the full decoding.
|
|
||||||
pub struct TokenOutputStream {
|
pub struct TokenOutputStream {
|
||||||
tokenizer: tokenizers::Tokenizer,
|
tokenizer: tokenizers::Tokenizer,
|
||||||
tokens: Vec<u32>,
|
tokens: Vec<u32>,
|
||||||
@@ -40,8 +39,7 @@ impl TokenOutputStream {
|
|||||||
};
|
};
|
||||||
self.tokens.push(token);
|
self.tokens.push(token);
|
||||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||||
if text.len() > prev_text.len() {
|
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
|
||||||
// Modified to include all tokens, not just alphanumeric ones
|
|
||||||
let text = text.split_at(prev_text.len());
|
let text = text.split_at(prev_text.len());
|
||||||
self.prev_index = self.current_index;
|
self.prev_index = self.current_index;
|
||||||
self.current_index = self.tokens.len();
|
self.current_index = self.tokens.len();
|
||||||
@@ -84,4 +82,4 @@ impl TokenOutputStream {
|
|||||||
self.prev_index = 0;
|
self.prev_index = 0;
|
||||||
self.current_index = 0;
|
self.current_index = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
56
crates/utils/src/wav.rs
Normal file
56
crates/utils/src/wav.rs
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
use std::io::prelude::*;
|
||||||
|
|
||||||
|
pub trait Sample {
|
||||||
|
fn to_i16(&self) -> i16;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Sample for f32 {
|
||||||
|
fn to_i16(&self) -> i16 {
|
||||||
|
(self.clamp(-1.0, 1.0) * 32767.0) as i16
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Sample for f64 {
|
||||||
|
fn to_i16(&self) -> i16 {
|
||||||
|
(self.clamp(-1.0, 1.0) * 32767.0) as i16
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Sample for i16 {
|
||||||
|
fn to_i16(&self) -> i16 {
|
||||||
|
*self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write_pcm_as_wav<W: Write, S: Sample>(
|
||||||
|
w: &mut W,
|
||||||
|
samples: &[S],
|
||||||
|
sample_rate: u32,
|
||||||
|
) -> std::io::Result<()> {
|
||||||
|
let len = 12u32; // header
|
||||||
|
let len = len + 24u32; // fmt
|
||||||
|
let len = len + samples.len() as u32 * 2 + 8; // data
|
||||||
|
let n_channels = 1u16;
|
||||||
|
let bytes_per_second = sample_rate * 2 * n_channels as u32;
|
||||||
|
w.write_all(b"RIFF")?;
|
||||||
|
w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
|
||||||
|
w.write_all(b"WAVE")?;
|
||||||
|
|
||||||
|
// Format block
|
||||||
|
w.write_all(b"fmt ")?;
|
||||||
|
w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes
|
||||||
|
w.write_all(&1u16.to_le_bytes())?; // PCM
|
||||||
|
w.write_all(&n_channels.to_le_bytes())?; // one channel
|
||||||
|
w.write_all(&sample_rate.to_le_bytes())?;
|
||||||
|
w.write_all(&bytes_per_second.to_le_bytes())?;
|
||||||
|
w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample
|
||||||
|
w.write_all(&16u16.to_le_bytes())?; // bits per sample
|
||||||
|
|
||||||
|
// Data block
|
||||||
|
w.write_all(b"data")?;
|
||||||
|
w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
|
||||||
|
for sample in samples.iter() {
|
||||||
|
w.write_all(&sample.to_i16().to_le_bytes())?
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@@ -52,7 +52,7 @@ graph TB
|
|||||||
|
|
||||||
## Workspace Structure
|
## Workspace Structure
|
||||||
|
|
||||||
The project uses a 7-crate Rust workspace with TypeScript tooling, designed for maximum flexibility in deployment configurations.
|
The project uses a 9-crate Rust workspace with TypeScript tooling, designed for maximum flexibility in deployment configurations.
|
||||||
|
|
||||||
```mermaid
|
```mermaid
|
||||||
graph TD
|
graph TD
|
||||||
@@ -69,18 +69,15 @@ graph TD
|
|||||||
end
|
end
|
||||||
|
|
||||||
subgraph "Frontend"
|
subgraph "Frontend"
|
||||||
D[leptos-app<br/>Edition: 2021<br/>Port: 3000/8788<br/>WASM/SSR]
|
D[chat-ui<br/>Edition: 2021<br/>Port: 8788<br/>WASM UI]
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph "Tooling"
|
subgraph "Tooling"
|
||||||
L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment]
|
L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment]
|
||||||
|
E[cli<br/>Edition: 2024<br/>TypeScript/Bun CLI]
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph "External Tooling"
|
|
||||||
E[scripts/cli.ts<br/>TypeScript/Bun<br/>OpenAI SDK]
|
|
||||||
end
|
|
||||||
|
|
||||||
subgraph "Dependencies"
|
subgraph "Dependencies"
|
||||||
A --> B
|
A --> B
|
||||||
A --> C
|
A --> C
|
||||||
@@ -193,7 +190,7 @@ graph TB
|
|||||||
end
|
end
|
||||||
|
|
||||||
subgraph "Frontend"
|
subgraph "Frontend"
|
||||||
D[leptos-app Pod<br/>:8788<br/>ClusterIP Service]
|
D[chat-ui Pod<br/>:8788<br/>ClusterIP Service]
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph "Ingress"
|
subgraph "Ingress"
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
{
|
{
|
||||||
"dependencies": {
|
"name": "predict-otron-9000",
|
||||||
"openai": "^5.16.0"
|
"workspaces": ["crates/cli/package"],
|
||||||
},
|
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"cli": "./scripts/cli.ts"
|
"# WORKSPACE ALIASES": "#",
|
||||||
|
"cli": "bun --filter crates/cli/package"
|
||||||
}
|
}
|
||||||
}
|
}
|
BIN
predict-otron-9000.png
Normal file
BIN
predict-otron-9000.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 248 KiB |
14
scripts/build_ui.sh
Executable file
14
scripts/build_ui.sh
Executable file
@@ -0,0 +1,14 @@
|
|||||||
|
#!/usr/bin/env sh
|
||||||
|
|
||||||
|
# Resolve the project root (script_dir/..)
|
||||||
|
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
||||||
|
|
||||||
|
# Move into the chat-ui crate
|
||||||
|
cd "$PROJECT_ROOT/crates/chat-ui" || exit 1
|
||||||
|
|
||||||
|
# Build with cargo leptos
|
||||||
|
cargo leptos build --release
|
||||||
|
|
||||||
|
# Move the wasm file, keeping paths relative to the project root
|
||||||
|
mv "$PROJECT_ROOT/target/site/pkg/chat-ui.wasm" \
|
||||||
|
"$PROJECT_ROOT/target/site/pkg/chat-ui_bg.wasm"
|
@@ -15,7 +15,7 @@ CONNECT_TIMEOUT=${CONNECT_TIMEOUT:-10}
|
|||||||
MAX_TIME=${MAX_TIME:-30}
|
MAX_TIME=${MAX_TIME:-30}
|
||||||
|
|
||||||
cat <<EOF
|
cat <<EOF
|
||||||
[info] POST $SERVER_URL/v1/chat/completions/stream (SSE)
|
[info] POST $SERVER_URL/v1/chat/completions (SSE)
|
||||||
[info] model=$MODEL_ID, max_tokens=$MAX_TOKENS
|
[info] model=$MODEL_ID, max_tokens=$MAX_TOKENS
|
||||||
[info] prompt=$PROMPT
|
[info] prompt=$PROMPT
|
||||||
[info] timeouts: connect=${CONNECT_TIMEOUT}s, max=${MAX_TIME}s
|
[info] timeouts: connect=${CONNECT_TIMEOUT}s, max=${MAX_TIME}s
|
||||||
@@ -35,7 +35,7 @@ curl -N -sS -X POST \
|
|||||||
--connect-timeout "$CONNECT_TIMEOUT" \
|
--connect-timeout "$CONNECT_TIMEOUT" \
|
||||||
--max-time "$MAX_TIME" \
|
--max-time "$MAX_TIME" \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
"$SERVER_URL/v1/chat/completions/stream" \
|
"$SERVER_URL/v1/chat/completions" \
|
||||||
-d @- <<JSON
|
-d @- <<JSON
|
||||||
{
|
{
|
||||||
"model": "${MODEL_ID}",
|
"model": "${MODEL_ID}",
|
||||||
|
17
scripts/run.sh
Executable file
17
scripts/run.sh
Executable file
@@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Resolve the project root (script_dir/..)
|
||||||
|
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
||||||
|
|
||||||
|
# todo, conditionally run this only when those files change
|
||||||
|
"$PROJECT_ROOT/scripts/build_ui.sh"
|
||||||
|
|
||||||
|
# build the frontend first
|
||||||
|
# Start the unified predict-otron-9000 server on port 8080
|
||||||
|
export SERVER_PORT=${SERVER_PORT:-8080}
|
||||||
|
export RUST_LOG=${RUST_LOG:-info}
|
||||||
|
|
||||||
|
cd "$PROJECT_ROOT" || exit 1
|
||||||
|
cargo run --bin predict-otron-9000 --release
|
@@ -1,30 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
PROMPT=${1:-"Say hello in one short sentence."}
|
|
||||||
MODEL=${2:-"meta-llama/Llama-3.2-1B-Instruct"}
|
|
||||||
MAX_NEW=${3:-64}
|
|
||||||
FORCE_CPU=${FORCE_CPU:-0}
|
|
||||||
|
|
||||||
# Optional: keep HF cache local to repo if not already set
|
|
||||||
export HF_HOME=${HF_HOME:-"$PWD/.hf-cache"}
|
|
||||||
|
|
||||||
BIN="$(dirname "$0")/../target/release/llama_infer"
|
|
||||||
|
|
||||||
if [[ ! -x "$BIN" ]]; then
|
|
||||||
echo "Building llama-runner (release)..."
|
|
||||||
cargo build -p llama-runner --release
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "Running llama inference..." >&2
|
|
||||||
ARGS=(
|
|
||||||
--model-id "$MODEL"
|
|
||||||
--prompt "$PROMPT"
|
|
||||||
--max-new-tokens "$MAX_NEW"
|
|
||||||
)
|
|
||||||
|
|
||||||
if [[ "$FORCE_CPU" == "1" || "$FORCE_CPU" == "true" ]]; then
|
|
||||||
ARGS+=( --force-cpu )
|
|
||||||
fi
|
|
||||||
|
|
||||||
"$BIN" "${ARGS[@]}"
|
|
@@ -1,7 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Start the unified predict-otron-9000 server on port 8080
|
|
||||||
export SERVER_PORT=${SERVER_PORT:-8080}
|
|
||||||
export RUST_LOG=${RUST_LOG:-info}
|
|
||||||
|
|
||||||
cargo run --bin predict-otron-9000 --release
|
|
Reference in New Issue
Block a user