mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
Compare commits
22 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4380ac69d3 | ||
![]() |
e6f3351ebb | ||
![]() |
3992532f15 | ||
![]() |
3ecdd9ffa0 | ||
![]() |
296d4dbe7e | ||
![]() |
fb5098eba6 | ||
![]() |
c1c583faab | ||
![]() |
1e02b12cda | ||
![]() |
ff55d882c7 | ||
![]() |
400c70f17d | ||
![]() |
bcbc6c4693 | ||
![]() |
21f20470de | ||
![]() |
2deecb5e51 | ||
![]() |
545e0c9831 | ||
![]() |
eca61c51ad | ||
![]() |
d1a7d5b28e | ||
![]() |
8d2b85b0b9 | ||
![]() |
4570780666 | ||
![]() |
44e4f9e5e1 | ||
![]() |
64daa77c6b | ||
![]() |
2b4a8a9df8 | ||
![]() |
38d51722f2 |
35
.dockerignore
Normal file
35
.dockerignore
Normal file
@@ -0,0 +1,35 @@
|
||||
# Git
|
||||
.git
|
||||
.gitignore
|
||||
|
||||
# Rust
|
||||
target/
|
||||
|
||||
# Documentation
|
||||
README.md
|
||||
*.md
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# Environment
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Dependencies
|
||||
node_modules/
|
||||
|
||||
# Build artifacts
|
||||
dist/
|
||||
build/
|
||||
.fastembed_cache
|
13
.github/workflows/ci.yml
vendored
13
.github/workflows/ci.yml
vendored
@@ -25,7 +25,16 @@ jobs:
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- name: Setup Rust
|
||||
run: rustup update stable && rustup default stable
|
||||
run: rustup update stable && rustup default stable && rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cargo install --locked cargo-leptos
|
||||
cd crates/chat-ui && cargo leptos build --release
|
||||
cargo build --release -p predict-otron-9000 -p cli
|
||||
|
||||
- name: Install clippy and rustfmt
|
||||
run: rustup component add clippy rustfmt
|
||||
@@ -35,7 +44,7 @@ jobs:
|
||||
|
||||
- name: Clippy
|
||||
shell: bash
|
||||
run: cargo clippy --all-targets
|
||||
run: cargo clippy --all
|
||||
|
||||
- name: Tests
|
||||
shell: bash
|
||||
|
46
.github/workflows/docker.yml
vendored
Normal file
46
.github/workflows/docker.yml
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
name: Build and Push Docker Image
|
||||
|
||||
on:
|
||||
tags:
|
||||
- 'v*'
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Log in to Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=sha
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
13
.github/workflows/release.yml
vendored
13
.github/workflows/release.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- name: Setup Rust
|
||||
run: rustup update stable && rustup default stable
|
||||
run: rustup update stable && rustup default stable && rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
|
||||
- name: Clippy
|
||||
shell: bash
|
||||
run: cargo clippy --all-targets
|
||||
run: cargo clippy --all
|
||||
|
||||
- name: Tests
|
||||
shell: bash
|
||||
@@ -129,12 +129,17 @@ jobs:
|
||||
key: ${{ runner.os }}-${{ matrix.target }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- name: Setup Rust
|
||||
run: rustup update stable && rustup default stable
|
||||
run: rustup update stable && rustup default stable && rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add target
|
||||
run: rustup target add ${{ matrix.target }}
|
||||
|
||||
- name: Build binary
|
||||
- name: Build UI
|
||||
run: cargo install --locked cargo-leptos && cd crates/chat-ui && cargo leptos build --release
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
- name: Build Binary
|
||||
run: cargo build --release --target ${{ matrix.target }} -p predict-otron-9000 -p cli
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
5
.gitignore
vendored
5
.gitignore
vendored
@@ -23,7 +23,6 @@ package-lock.json
|
||||
|
||||
# Web frontend build outputs
|
||||
dist/
|
||||
.trunk/
|
||||
|
||||
# ML model and embedding caches
|
||||
.fastembed_cache/
|
||||
@@ -75,7 +74,7 @@ venv/
|
||||
# Backup files
|
||||
*.bak
|
||||
*.backup
|
||||
*~
|
||||
/scripts/cli
|
||||
!/scripts/cli.ts
|
||||
/**/.*.bun-build
|
||||
/AGENTS.md
|
||||
.claude
|
||||
|
1001
Cargo.lock
generated
1001
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
19
Cargo.toml
19
Cargo.toml
@@ -3,17 +3,17 @@ members = [
|
||||
"crates/predict-otron-9000",
|
||||
"crates/inference-engine",
|
||||
"crates/embeddings-engine",
|
||||
"crates/leptos-app",
|
||||
"crates/helm-chart-tool",
|
||||
"crates/llama-runner",
|
||||
"crates/gemma-runner",
|
||||
"crates/cli"
|
||||
]
|
||||
"integration/helm-chart-tool",
|
||||
"integration/llama-runner",
|
||||
"integration/gemma-runner",
|
||||
"integration/cli",
|
||||
"crates/chat-ui"
|
||||
, "integration/utils"]
|
||||
default-members = ["crates/predict-otron-9000"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.2"
|
||||
version = "0.1.6"
|
||||
|
||||
# Compiler optimization profiles for the workspace
|
||||
[profile.release]
|
||||
@@ -42,8 +42,3 @@ overflow-checks = true
|
||||
opt-level = 3
|
||||
debug = true
|
||||
lto = "thin"
|
||||
|
||||
[[workspace.metadata.leptos]]
|
||||
# project name
|
||||
bin-package = "leptos-app"
|
||||
lib-package = "leptos-app"
|
||||
|
50
Dockerfile
Normal file
50
Dockerfile
Normal file
@@ -0,0 +1,50 @@
|
||||
# Multi-stage build for predict-otron-9000 workspace
|
||||
FROM rust:1 AS builder
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create app directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy workspace files
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
COPY crates/ ./crates/
|
||||
COPY integration/ ./integration/
|
||||
|
||||
# Build all 3 main server binaries in release mode
|
||||
RUN cargo build --release -p predict-otron-9000 --bin predict-otron-9000 --no-default-features -p embeddings-engine --bin embeddings-engine -p inference-engine --bin inference-engine
|
||||
|
||||
# Runtime stage
|
||||
FROM debian:bookworm-slim AS runtime
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
ca-certificates \
|
||||
libssl3 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create app user
|
||||
RUN useradd -r -s /bin/false -m -d /app appuser
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy binaries from builder stage
|
||||
COPY --from=builder /app/target/release/predict-otron-9000 ./bin/
|
||||
COPY --from=builder /app/target/release/embeddings-engine ./bin/
|
||||
COPY --from=builder /app/target/release/inference-engine ./bin/
|
||||
# Make binaries executable and change ownership
|
||||
RUN chmod +x ./bin/* && chown -R appuser:appuser /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Expose ports (adjust as needed based on your services)
|
||||
EXPOSE 8080 8081 8082
|
||||
|
||||
# Default command (can be overridden)
|
||||
CMD ["./bin/predict-otron-9000"]
|
65
README.md
65
README.md
@@ -2,19 +2,26 @@
|
||||
predict-otron-9000
|
||||
</h1>
|
||||
<p align="center">
|
||||
Powerful local AI inference with OpenAI-compatible APIs
|
||||
AI inference Server with OpenAI-compatible API (Limited Features)
|
||||
</p>
|
||||
<p align="center">
|
||||
<img src="https://github.com/geoffsee/predict-otron-9001/blob/master/predict-otron-9000.png?raw=true" width="90%" />
|
||||
</p>
|
||||
|
||||
<br/>
|
||||
|
||||
> This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks.
|
||||
|
||||
> By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems.
|
||||
Stability is currently best effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
|
||||
Stability is currently best-effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
|
||||
|
||||
A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces.
|
||||
|
||||
|
||||
~~~shell
|
||||
./scripts/run.sh
|
||||
~~~
|
||||
|
||||
|
||||
## Project Overview
|
||||
|
||||
The predict-otron-9000 is a flexible AI platform that provides:
|
||||
@@ -40,26 +47,30 @@ The system supports both CPU and GPU acceleration (CUDA/Metal), with intelligent
|
||||
|
||||
### Workspace Structure
|
||||
|
||||
The project uses a 7-crate Rust workspace plus TypeScript components:
|
||||
The project uses a 9-crate Rust workspace plus TypeScript components:
|
||||
|
||||
```
|
||||
crates/
|
||||
├── predict-otron-9000/ # Main orchestration server (Rust 2024)
|
||||
├── inference-engine/ # Multi-model inference orchestrator (Rust 2021)
|
||||
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
|
||||
└── chat-ui/ # WASM web frontend (Rust 2021)
|
||||
|
||||
integration/
|
||||
├── cli/ # CLI client crate (Rust 2024)
|
||||
│ └── package/
|
||||
│ └── cli.ts # TypeScript/Bun CLI client
|
||||
├── gemma-runner/ # Gemma model inference via Candle (Rust 2021)
|
||||
├── llama-runner/ # Llama model inference via Candle (Rust 2021)
|
||||
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
|
||||
├── leptos-app/ # WASM web frontend (Rust 2021)
|
||||
├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024)
|
||||
└── scripts/
|
||||
└── cli.ts # TypeScript/Bun CLI client
|
||||
└── utils/ # Shared utilities (Rust 2021)
|
||||
```
|
||||
|
||||
### Service Architecture
|
||||
|
||||
- **Main Server** (port 8080): Orchestrates inference and embeddings services
|
||||
- **Embeddings Service** (port 8080): Standalone FastEmbed service with OpenAI API compatibility
|
||||
- **Web Frontend** (port 8788): cargo leptos SSR app
|
||||
- **Web Frontend** (port 8788): chat-ui WASM app
|
||||
- **CLI Client**: TypeScript/Bun client for testing and automation
|
||||
|
||||
### Deployment Modes
|
||||
@@ -85,11 +96,6 @@ The architecture supports multiple deployment patterns:
|
||||
- **Bun**: Required for TypeScript CLI client: `curl -fsSL https://bun.sh/install | bash`
|
||||
- **Node.js**: Alternative to Bun, supports OpenAI SDK v5.16.0+
|
||||
|
||||
#### WASM Frontend Toolchain
|
||||
- **Trunk**: Required for Leptos frontend builds: `cargo install trunk`
|
||||
- **wasm-pack**: `cargo install wasm-pack`
|
||||
- **WASM target**: `rustup target add wasm32-unknown-unknown`
|
||||
|
||||
#### ML Framework Dependencies
|
||||
- **Candle**: Version 0.9.1 with conditional compilation:
|
||||
- macOS: Metal support with CPU fallback for stability
|
||||
@@ -134,11 +140,6 @@ cargo build --bin cli --package inference-engine --release
|
||||
cargo build --bin embeddings-engine --release
|
||||
```
|
||||
|
||||
**Web Frontend:**
|
||||
```bash
|
||||
cd crates/leptos-app
|
||||
trunk build --release
|
||||
```
|
||||
|
||||
### Running Services
|
||||
|
||||
@@ -152,26 +153,26 @@ trunk build --release
|
||||
|
||||
#### Web Frontend (Port 8788)
|
||||
```bash
|
||||
cd crates/leptos-app
|
||||
cd crates/chat-ui
|
||||
./run.sh
|
||||
```
|
||||
- Serves Leptos WASM frontend on port 8788
|
||||
- Serves chat-ui WASM frontend on port 8788
|
||||
- Sets required RUSTFLAGS for WebAssembly getrandom support
|
||||
- Auto-reloads during development
|
||||
|
||||
#### TypeScript CLI Client
|
||||
```bash
|
||||
# List available models
|
||||
bun run scripts/cli.ts --list-models
|
||||
cd integration/cli/package && bun run cli.ts --list-models
|
||||
|
||||
# Chat completion
|
||||
bun run scripts/cli.ts "What is the capital of France?"
|
||||
cd integration/cli/package && bun run cli.ts "What is the capital of France?"
|
||||
|
||||
# With specific model
|
||||
bun run scripts/cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||
cd integration/cli/package && bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||
|
||||
# Show help
|
||||
bun run scripts/cli.ts --help
|
||||
cd integration/cli/package && bun run cli.ts --help
|
||||
```
|
||||
|
||||
## API Usage
|
||||
@@ -287,7 +288,7 @@ cargo test --workspace
|
||||
|
||||
**End-to-end test script:**
|
||||
```bash
|
||||
./smoke_test.sh
|
||||
./scripts/smoke_test.sh
|
||||
```
|
||||
|
||||
This script:
|
||||
@@ -376,7 +377,7 @@ All services include Docker metadata in `Cargo.toml`:
|
||||
- Port: 8080
|
||||
|
||||
**Web Frontend:**
|
||||
- Image: `ghcr.io/geoffsee/leptos-app:latest`
|
||||
- Image: `ghcr.io/geoffsee/chat-ui:latest`
|
||||
- Port: 8788
|
||||
|
||||
**Docker Compose:**
|
||||
@@ -435,8 +436,7 @@ For Kubernetes deployment details, see the [ARCHITECTURE.md](docs/ARCHITECTURE.m
|
||||
**Symptom:** WASM compilation failures
|
||||
**Solution:**
|
||||
1. Install required targets: `rustup target add wasm32-unknown-unknown`
|
||||
2. Install trunk: `cargo install trunk`
|
||||
3. Check RUSTFLAGS in leptos-app/run.sh
|
||||
2. Check RUSTFLAGS in chat-ui/run.sh
|
||||
|
||||
### Network/Timeout Issues
|
||||
**Symptom:** First-time model downloads timing out
|
||||
@@ -467,24 +467,23 @@ curl -s http://localhost:8080/v1/models | jq
|
||||
|
||||
**CLI client test:**
|
||||
```bash
|
||||
bun run scripts/cli.ts "What is 2+2?"
|
||||
cd integration/cli/package && bun run cli.ts "What is 2+2?"
|
||||
```
|
||||
|
||||
**Web frontend:**
|
||||
```bash
|
||||
cd crates/leptos-app && ./run.sh &
|
||||
cd crates/chat-ui && ./run.sh &
|
||||
# Navigate to http://localhost:8788
|
||||
```
|
||||
|
||||
**Integration test:**
|
||||
```bash
|
||||
./smoke_test.sh
|
||||
./scripts/smoke_test.sh
|
||||
```
|
||||
|
||||
**Cleanup:**
|
||||
```bash
|
||||
pkill -f "predict-otron-9000"
|
||||
pkill -f "trunk"
|
||||
```
|
||||
|
||||
For networked tests and full functionality, ensure Hugging Face authentication is configured as described above.
|
||||
|
4
bun.lock
4
bun.lock
@@ -4,7 +4,7 @@
|
||||
"": {
|
||||
"name": "predict-otron-9000",
|
||||
},
|
||||
"crates/cli/package": {
|
||||
"integration/cli/package": {
|
||||
"name": "cli",
|
||||
"dependencies": {
|
||||
"install": "^0.13.0",
|
||||
@@ -13,7 +13,7 @@
|
||||
},
|
||||
},
|
||||
"packages": {
|
||||
"cli": ["cli@workspace:crates/cli/package"],
|
||||
"cli": ["cli@workspace:integration/cli/package"],
|
||||
|
||||
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
|
||||
|
||||
|
@@ -1,8 +1,9 @@
|
||||
[package]
|
||||
name = "leptos-app"
|
||||
version.workspace = true
|
||||
name = "chat-ui"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
@@ -15,45 +16,33 @@ leptos_axum = { version = "0.8.0", optional = true }
|
||||
leptos_meta = { version = "0.8.0" }
|
||||
tokio = { version = "1", features = ["rt-multi-thread"], optional = true }
|
||||
wasm-bindgen = { version = "=0.2.100", optional = true }
|
||||
|
||||
# Chat interface dependencies
|
||||
wasm-bindgen-futures = "0.4"
|
||||
js-sys = "0.3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
async-openai-wasm = { version = "0.29", default-features = false }
|
||||
futures-util = "0.3"
|
||||
js-sys = { version = "0.3", optional = true }
|
||||
either = { version = "1.9", features = ["serde"] }
|
||||
|
||||
web-sys = { version = "0.3", optional = true, features = [
|
||||
"console",
|
||||
"Window",
|
||||
"Document",
|
||||
"Element",
|
||||
"HtmlElement",
|
||||
"HtmlInputElement",
|
||||
"HtmlSelectElement",
|
||||
"HtmlTextAreaElement",
|
||||
"Event",
|
||||
"EventTarget",
|
||||
"KeyboardEvent",
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
web-sys = { version = "0.3", features = [
|
||||
"console",
|
||||
"EventSource",
|
||||
"MessageEvent",
|
||||
"Window",
|
||||
"Request",
|
||||
"RequestInit",
|
||||
"Response",
|
||||
"Headers",
|
||||
"ReadableStream",
|
||||
"ReadableStreamDefaultReader",
|
||||
"TextDecoder",
|
||||
"TextDecoderOptions",
|
||||
"HtmlInputElement"
|
||||
] }
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1.0"
|
||||
features = [
|
||||
"v4",
|
||||
"fast-rng",
|
||||
"macro-diagnostics",
|
||||
"js",
|
||||
]
|
||||
gloo-net = { version = "0.6", features = ["http"] }
|
||||
|
||||
[features]
|
||||
hydrate = [
|
||||
"leptos/hydrate",
|
||||
"dep:console_error_panic_hook",
|
||||
"dep:wasm-bindgen",
|
||||
"dep:js-sys",
|
||||
"dep:web-sys",
|
||||
]
|
||||
ssr = [
|
||||
"dep:axum",
|
||||
@@ -73,8 +62,9 @@ codegen-units = 1
|
||||
panic = "abort"
|
||||
|
||||
[package.metadata.leptos]
|
||||
name = "chat-ui"
|
||||
# The name used by wasm-bindgen/cargo-leptos for the JS/WASM bundle. Defaults to the crate name
|
||||
output-name = "leptos-app"
|
||||
output-name = "chat-ui"
|
||||
|
||||
# The site root folder is where cargo-leptos generate all output. WARNING: all content of this folder will be erased on a rebuild. Use it in your server setup.
|
||||
site-root = "target/site"
|
||||
@@ -84,7 +74,7 @@ site-root = "target/site"
|
||||
site-pkg-dir = "pkg"
|
||||
|
||||
# [Optional] The source CSS file. If it ends with .sass or .scss then it will be compiled by dart-sass into CSS. The CSS is optimized by Lightning CSS before being written to <site-root>/<site-pkg>/app.css
|
||||
style-file = "style/main.scss"
|
||||
style-file = "./style/main.scss"
|
||||
# Assets source dir. All files found here will be copied and synchronized to site-root.
|
||||
# The assets-dir cannot have a sub directory with the same name/path as site-pkg-dir.
|
||||
#
|
||||
@@ -132,4 +122,8 @@ lib-default-features = false
|
||||
# The profile to use for the lib target when compiling for release
|
||||
#
|
||||
# Optional. Defaults to "release".
|
||||
lib-profile-release = "wasm-release"
|
||||
lib-profile-release = "release"
|
||||
|
||||
[[bin]]
|
||||
name = "chat-ui"
|
||||
path = "src/main.rs"
|
41
crates/chat-ui/README.md
Normal file
41
crates/chat-ui/README.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# chat-ui
|
||||
|
||||
A WASM-based web chat interface for the predict-otron-9000 AI platform.
|
||||
|
||||
## Overview
|
||||
|
||||
The chat-ui provides a real-time web interface for interacting with language models through the predict-otron-9000 server. Built with Leptos and compiled to WebAssembly, it offers a modern chat experience with streaming response support.
|
||||
|
||||
## Features
|
||||
|
||||
- Real-time chat interface with the inference server
|
||||
- Streaming response support
|
||||
- Conversation history
|
||||
- Responsive web design
|
||||
- WebAssembly-powered for optimal performance
|
||||
|
||||
## Building and Running
|
||||
|
||||
### Prerequisites
|
||||
- Rust toolchain with WASM target: `rustup target add wasm32-unknown-unknown`
|
||||
- The predict-otron-9000 server must be running on port 8080
|
||||
|
||||
### Development Server
|
||||
```bash
|
||||
cd crates/chat-ui
|
||||
./run.sh
|
||||
```
|
||||
|
||||
This starts the development server on port 8788 with auto-reload capabilities.
|
||||
|
||||
### Usage
|
||||
1. Start the predict-otron-9000 server: `./scripts/run.sh`
|
||||
2. Start the chat-ui: `cd crates/chat-ui && ./run.sh`
|
||||
3. Navigate to `http://localhost:8788`
|
||||
4. Start chatting with your AI models!
|
||||
|
||||
## Technical Details
|
||||
- Built with Leptos framework
|
||||
- Compiled to WebAssembly for browser execution
|
||||
- Communicates with predict-otron-9000 API via HTTP
|
||||
- Sets required RUSTFLAGS for WebAssembly getrandom support
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
617
crates/chat-ui/src/app.rs
Normal file
617
crates/chat-ui/src/app.rs
Normal file
@@ -0,0 +1,617 @@
|
||||
#[cfg(feature = "ssr")]
|
||||
use axum::Router;
|
||||
#[cfg(feature = "ssr")]
|
||||
use leptos::prelude::LeptosOptions;
|
||||
#[cfg(feature = "ssr")]
|
||||
use leptos_axum::{generate_route_list, LeptosRoutes};
|
||||
|
||||
pub struct AppConfig {
|
||||
pub config: ConfFile,
|
||||
pub address: String,
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
let conf = get_configuration(Some(concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml")))
|
||||
.expect("failed to read config");
|
||||
|
||||
let addr = conf.leptos_options.site_addr;
|
||||
|
||||
AppConfig {
|
||||
config: conf, // or whichever field/string representation you need
|
||||
address: addr.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the Axum router for this app, including routes, fallback, and state.
|
||||
/// Call this from another crate (or your bin) when running the server.
|
||||
#[cfg(feature = "ssr")]
|
||||
pub fn create_router(leptos_options: LeptosOptions) -> Router {
|
||||
// Generate the list of routes in your Leptos App
|
||||
let routes = generate_route_list(App);
|
||||
|
||||
Router::new()
|
||||
.leptos_routes(&leptos_options, routes, {
|
||||
let leptos_options = leptos_options.clone();
|
||||
move || shell(leptos_options.clone())
|
||||
})
|
||||
.fallback(leptos_axum::file_and_error_handler(shell))
|
||||
.with_state(leptos_options)
|
||||
}
|
||||
|
||||
use gloo_net::http::Request;
|
||||
use leptos::prelude::*;
|
||||
use leptos_meta::{provide_meta_context, MetaTags, Stylesheet, Title};
|
||||
use leptos_router::{
|
||||
components::{Route, Router, Routes},
|
||||
StaticSegment,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use web_sys::console;
|
||||
|
||||
// Data structures for OpenAI-compatible API
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub stream: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ChatChoice {
|
||||
pub message: ChatMessage,
|
||||
pub index: u32,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
// Streaming response structures
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StreamDelta {
|
||||
pub role: Option<String>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StreamChoice {
|
||||
pub index: u32,
|
||||
pub delta: StreamDelta,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StreamChatResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<StreamChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ChatResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoice>,
|
||||
}
|
||||
|
||||
// Data structures for models API
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub owned_by: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ModelsResponse {
|
||||
pub object: String,
|
||||
pub data: Vec<ModelInfo>,
|
||||
}
|
||||
|
||||
// API client function to fetch available models
|
||||
pub async fn fetch_models() -> Result<Vec<ModelInfo>, String> {
|
||||
let response = Request::get("/v1/models")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to fetch models: {:?}", e))?;
|
||||
|
||||
if response.ok() {
|
||||
let models_response: ModelsResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse models response: {:?}", e))?;
|
||||
Ok(models_response.data)
|
||||
} else {
|
||||
let status = response.status();
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
Err(format!("Failed to fetch models {}: {}", status, error_text))
|
||||
}
|
||||
}
|
||||
|
||||
// API client function to send chat completion requests
|
||||
pub async fn send_chat_completion(
|
||||
messages: Vec<ChatMessage>,
|
||||
model: String,
|
||||
) -> Result<String, String> {
|
||||
let request = ChatRequest {
|
||||
model,
|
||||
messages,
|
||||
max_tokens: Some(1024),
|
||||
stream: Some(false),
|
||||
};
|
||||
|
||||
let response = Request::post("/v1/chat/completions")
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.map_err(|e| format!("Failed to create request: {:?}", e))?
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to send request: {:?}", e))?;
|
||||
|
||||
if response.ok() {
|
||||
let chat_response: ChatResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse response: {:?}", e))?;
|
||||
|
||||
if let Some(choice) = chat_response.choices.first() {
|
||||
Ok(choice.message.content.clone())
|
||||
} else {
|
||||
Err("No response choices available".to_string())
|
||||
}
|
||||
} else {
|
||||
let status = response.status();
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
Err(format!("Server error {}: {}", status, error_text))
|
||||
}
|
||||
}
|
||||
|
||||
// Streaming chat completion using EventSource
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub fn send_chat_completion_stream(
|
||||
messages: Vec<ChatMessage>,
|
||||
model: String,
|
||||
on_chunk: impl Fn(String) + 'static,
|
||||
on_complete: impl Fn() + 'static,
|
||||
on_error: impl Fn(String) + 'static,
|
||||
) {
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen::JsCast;
|
||||
|
||||
let request = ChatRequest {
|
||||
model,
|
||||
messages,
|
||||
max_tokens: Some(1024),
|
||||
stream: Some(true),
|
||||
};
|
||||
|
||||
// We need to send a POST request but EventSource only supports GET
|
||||
// So we'll use fetch with a readable stream instead
|
||||
let window = web_sys::window().unwrap();
|
||||
let request_json = serde_json::to_string(&request).unwrap();
|
||||
|
||||
let opts = web_sys::RequestInit::new();
|
||||
opts.set_method("POST");
|
||||
opts.set_body(&JsValue::from_str(&request_json));
|
||||
|
||||
let headers = web_sys::Headers::new().unwrap();
|
||||
headers.set("Content-Type", "application/json").unwrap();
|
||||
headers.set("Accept", "text/event-stream").unwrap();
|
||||
opts.set_headers(&headers);
|
||||
|
||||
let request = web_sys::Request::new_with_str_and_init("/v1/chat/completions", &opts).unwrap();
|
||||
|
||||
let promise = window.fetch_with_request(&request);
|
||||
|
||||
wasm_bindgen_futures::spawn_local(async move {
|
||||
match wasm_bindgen_futures::JsFuture::from(promise).await {
|
||||
Ok(resp_value) => {
|
||||
let resp: web_sys::Response = resp_value.dyn_into().unwrap();
|
||||
|
||||
if !resp.ok() {
|
||||
on_error(format!("Server error: {}", resp.status()));
|
||||
return;
|
||||
}
|
||||
|
||||
let body = resp.body();
|
||||
if body.is_none() {
|
||||
on_error("No response body".to_string());
|
||||
return;
|
||||
}
|
||||
|
||||
let reader = body
|
||||
.unwrap()
|
||||
.get_reader()
|
||||
.dyn_into::<web_sys::ReadableStreamDefaultReader>()
|
||||
.unwrap();
|
||||
|
||||
let decoder = web_sys::TextDecoder::new().unwrap();
|
||||
let mut buffer = String::new();
|
||||
|
||||
loop {
|
||||
match wasm_bindgen_futures::JsFuture::from(reader.read()).await {
|
||||
Ok(result) => {
|
||||
let done = js_sys::Reflect::get(&result, &JsValue::from_str("done"))
|
||||
.unwrap()
|
||||
.as_bool()
|
||||
.unwrap_or(false);
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
|
||||
let value =
|
||||
js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap();
|
||||
let array = js_sys::Uint8Array::new(&value);
|
||||
let mut bytes = vec![0; array.length() as usize];
|
||||
array.copy_to(&mut bytes);
|
||||
let text = decoder.decode_with_u8_array(&bytes).unwrap();
|
||||
|
||||
buffer.push_str(&text);
|
||||
|
||||
// Process complete SSE events from buffer
|
||||
while let Some(event_end) = buffer.find("\n\n") {
|
||||
let event = buffer[..event_end].to_string();
|
||||
buffer = buffer[event_end + 2..].to_string();
|
||||
|
||||
// Parse SSE event
|
||||
for line in event.lines() {
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if data == "[DONE]" {
|
||||
on_complete();
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse JSON chunk
|
||||
if let Ok(chunk) =
|
||||
serde_json::from_str::<StreamChatResponse>(data)
|
||||
{
|
||||
if let Some(choice) = chunk.choices.first() {
|
||||
if let Some(content) = &choice.delta.content {
|
||||
on_chunk(content.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
on_error(format!("Read error: {:?}", e));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
on_complete();
|
||||
}
|
||||
Err(e) => {
|
||||
on_error(format!("Fetch error: {:?}", e));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn shell(options: LeptosOptions) -> impl IntoView {
|
||||
view! {
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<AutoReload options=options.clone() />
|
||||
<HydrationScripts options/>
|
||||
<MetaTags/>
|
||||
</head>
|
||||
<body>
|
||||
<App/>
|
||||
</body>
|
||||
</html>
|
||||
}
|
||||
}
|
||||
|
||||
#[component]
|
||||
pub fn App() -> impl IntoView {
|
||||
// Provides context that manages stylesheets, titles, meta tags, etc.
|
||||
provide_meta_context();
|
||||
|
||||
view! {
|
||||
// injects a stylesheet into the document <head>
|
||||
// id=leptos means cargo-leptos will hot-reload this stylesheet
|
||||
<Stylesheet id="leptos" href="/pkg/chat-ui.css"/>
|
||||
|
||||
// sets the document title
|
||||
<Title text="Predict-Otron-9000 Chat"/>
|
||||
|
||||
// content for this welcome page
|
||||
<Router>
|
||||
<main>
|
||||
<Routes fallback=|| "Page not found.".into_view()>
|
||||
<Route path=StaticSegment("") view=ChatPage/>
|
||||
</Routes>
|
||||
</main>
|
||||
</Router>
|
||||
}
|
||||
}
|
||||
|
||||
/// Renders the chat interface page
|
||||
#[component]
|
||||
fn ChatPage() -> impl IntoView {
|
||||
// State for conversation messages
|
||||
let messages = RwSignal::new(Vec::<ChatMessage>::new());
|
||||
|
||||
// State for current user input
|
||||
let input_text = RwSignal::new(String::new());
|
||||
|
||||
// State for loading indicator
|
||||
let is_loading = RwSignal::new(false);
|
||||
|
||||
// State for error messages
|
||||
let error_message = RwSignal::new(Option::<String>::None);
|
||||
|
||||
// State for available models and selected model
|
||||
let available_models = RwSignal::new(Vec::<ModelInfo>::new());
|
||||
let selected_model = RwSignal::new(String::from("")); // Default model
|
||||
|
||||
// State for streaming response
|
||||
let streaming_content = RwSignal::new(String::new());
|
||||
let is_streaming = RwSignal::new(false);
|
||||
|
||||
// State for streaming mode toggle
|
||||
let use_streaming = RwSignal::new(true); // Default to streaming
|
||||
|
||||
// Client-side only: Fetch models on component mount
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
use leptos::task::spawn_local;
|
||||
spawn_local(async move {
|
||||
match fetch_models().await {
|
||||
Ok(models) => {
|
||||
available_models.set(models);
|
||||
selected_model.set(String::from("gemma-3-1b-it"));
|
||||
}
|
||||
Err(error) => {
|
||||
console::log_1(&format!("Failed to fetch models: {}", error).into());
|
||||
error_message.set(Some(format!("Failed to load models: {}", error)));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Shared logic for sending a message
|
||||
let send_message_logic = move || {
|
||||
let user_input = input_text.get();
|
||||
if user_input.trim().is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Add user message to conversation
|
||||
let user_message = ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: user_input.clone(),
|
||||
};
|
||||
|
||||
messages.update(|msgs| msgs.push(user_message.clone()));
|
||||
input_text.set(String::new());
|
||||
is_loading.set(true);
|
||||
error_message.set(None);
|
||||
|
||||
// Client-side only: Send chat completion request
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
use leptos::task::spawn_local;
|
||||
|
||||
// Prepare messages for API call
|
||||
let current_messages = messages.get();
|
||||
let current_model = selected_model.get();
|
||||
let should_stream = use_streaming.get();
|
||||
|
||||
if should_stream {
|
||||
// Clear streaming content and set streaming flag
|
||||
streaming_content.set(String::new());
|
||||
is_streaming.set(true);
|
||||
|
||||
// Use streaming API
|
||||
send_chat_completion_stream(
|
||||
current_messages,
|
||||
current_model,
|
||||
move |chunk| {
|
||||
// Append chunk to streaming content
|
||||
streaming_content.update(|content| content.push_str(&chunk));
|
||||
},
|
||||
move || {
|
||||
// On complete, move streaming content to messages
|
||||
let final_content = streaming_content.get();
|
||||
if !final_content.is_empty() {
|
||||
let assistant_message = ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: final_content,
|
||||
};
|
||||
messages.update(|msgs| msgs.push(assistant_message));
|
||||
}
|
||||
streaming_content.set(String::new());
|
||||
is_streaming.set(false);
|
||||
is_loading.set(false);
|
||||
},
|
||||
move |error| {
|
||||
console::log_1(&format!("Streaming Error: {}", error).into());
|
||||
error_message.set(Some(error));
|
||||
is_streaming.set(false);
|
||||
is_loading.set(false);
|
||||
streaming_content.set(String::new());
|
||||
},
|
||||
);
|
||||
} else {
|
||||
// Use non-streaming API
|
||||
spawn_local(async move {
|
||||
match send_chat_completion(current_messages, current_model).await {
|
||||
Ok(response_content) => {
|
||||
let assistant_message = ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: response_content,
|
||||
};
|
||||
messages.update(|msgs| msgs.push(assistant_message));
|
||||
is_loading.set(false);
|
||||
}
|
||||
Err(error) => {
|
||||
console::log_1(&format!("API Error: {}", error).into());
|
||||
error_message.set(Some(error));
|
||||
is_loading.set(false);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Button click handler
|
||||
let on_button_click = {
|
||||
let send_logic = send_message_logic.clone();
|
||||
move |_: web_sys::MouseEvent| {
|
||||
send_logic();
|
||||
}
|
||||
};
|
||||
|
||||
// Handle enter key press in input field
|
||||
let on_key_down = move |ev: web_sys::KeyboardEvent| {
|
||||
if ev.key() == "Enter" && !ev.shift_key() {
|
||||
ev.prevent_default();
|
||||
send_message_logic();
|
||||
}
|
||||
};
|
||||
|
||||
view! {
|
||||
<div class="chat-container">
|
||||
<div class="chat-header">
|
||||
<h1>"Predict-Otron-9000 Chat"</h1>
|
||||
<div class="model-selector">
|
||||
<label for="model-select">"Model:"</label>
|
||||
<select
|
||||
id="model-select"
|
||||
prop:value=move || selected_model.get()
|
||||
on:change=move |ev| {
|
||||
let new_model = event_target_value(&ev);
|
||||
selected_model.set(new_model);
|
||||
}
|
||||
>
|
||||
<For
|
||||
each=move || available_models.get().into_iter()
|
||||
key=|model| model.id.clone()
|
||||
children=move |model| {
|
||||
view! {
|
||||
<option value=model.id.clone()>
|
||||
{format!("{} ({})", model.id, model.owned_by)}
|
||||
</option>
|
||||
}
|
||||
}
|
||||
/>
|
||||
</select>
|
||||
<div class="streaming-toggle">
|
||||
<label>
|
||||
<input
|
||||
type="checkbox"
|
||||
prop:checked=move || use_streaming.get()
|
||||
on:change=move |ev| {
|
||||
let target = event_target::<web_sys::HtmlInputElement>(&ev);
|
||||
use_streaming.set(target.checked());
|
||||
}
|
||||
/>
|
||||
" Use streaming"
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="chat-messages">
|
||||
<For
|
||||
each=move || messages.get().into_iter().enumerate()
|
||||
key=|(i, _)| *i
|
||||
children=move |(_, message)| {
|
||||
let role_class = if message.role == "user" { "user-message" } else { "assistant-message" };
|
||||
view! {
|
||||
<div class=format!("message {}", role_class)>
|
||||
<div class="message-role">{message.role.clone()}</div>
|
||||
<div class="message-content">{message.content.clone()}</div>
|
||||
</div>
|
||||
}
|
||||
}
|
||||
/>
|
||||
|
||||
{move || {
|
||||
if is_streaming.get() {
|
||||
let content = streaming_content.get();
|
||||
if !content.is_empty() {
|
||||
view! {
|
||||
<div class="message assistant-message streaming">
|
||||
<div class="message-role">"assistant"</div>
|
||||
<div class="message-content">{content}<span class="cursor">"▊"</span></div>
|
||||
</div>
|
||||
}.into_any()
|
||||
} else {
|
||||
view! {
|
||||
<div class="message assistant-message loading">
|
||||
<div class="message-role">"assistant"</div>
|
||||
<div class="message-content">"Thinking..."</div>
|
||||
</div>
|
||||
}.into_any()
|
||||
}
|
||||
} else if is_loading.get() && !use_streaming.get() {
|
||||
view! {
|
||||
<div class="message assistant-message loading">
|
||||
<div class="message-role">"assistant"</div>
|
||||
<div class="message-content">"Thinking..."</div>
|
||||
</div>
|
||||
}.into_any()
|
||||
} else {
|
||||
view! {}.into_any()
|
||||
}
|
||||
}}
|
||||
</div>
|
||||
|
||||
{move || {
|
||||
if let Some(error) = error_message.get() {
|
||||
view! {
|
||||
<div class="error-message">
|
||||
"Error: " {error}
|
||||
</div>
|
||||
}.into_any()
|
||||
} else {
|
||||
view! {}.into_any()
|
||||
}
|
||||
}}
|
||||
|
||||
<div class="chat-input">
|
||||
<textarea
|
||||
placeholder="Type your message here... (Press Enter to send, Shift+Enter for new line)"
|
||||
prop:value=move || input_text.get()
|
||||
on:input=move |ev| input_text.set(event_target_value(&ev))
|
||||
on:keydown=on_key_down
|
||||
class:disabled=move || is_loading.get()
|
||||
/>
|
||||
<button
|
||||
on:click=on_button_click
|
||||
class:disabled=move || is_loading.get() || input_text.get().trim().is_empty()
|
||||
>
|
||||
"Send"
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
}
|
9
crates/chat-ui/src/lib.rs
Normal file
9
crates/chat-ui/src/lib.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
pub mod app;
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
#[wasm_bindgen::prelude::wasm_bindgen]
|
||||
pub fn hydrate() {
|
||||
use crate::app::*;
|
||||
console_error_panic_hook::set_once();
|
||||
leptos::mount::hydrate_body(App);
|
||||
}
|
26
crates/chat-ui/src/main.rs
Normal file
26
crates/chat-ui/src/main.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
#[cfg(feature = "ssr")]
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
use axum::Router;
|
||||
use chat_ui::app::*;
|
||||
use leptos::logging::log;
|
||||
use leptos::prelude::*;
|
||||
use leptos_axum::{generate_route_list, LeptosRoutes};
|
||||
|
||||
let conf = get_configuration(None).expect("failed to read config");
|
||||
let addr = conf.leptos_options.site_addr;
|
||||
|
||||
// Build the app router with your extracted function
|
||||
let app: Router = create_router(conf.leptos_options);
|
||||
|
||||
log!("listening on http://{}", &addr);
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||
axum::serve(listener, app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "ssr"))]
|
||||
pub fn main() {
|
||||
// no client-side main function
|
||||
}
|
265
crates/chat-ui/style/main.scss
Normal file
265
crates/chat-ui/style/main.scss
Normal file
@@ -0,0 +1,265 @@
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
background-color: #f5f5f5;
|
||||
height: 100vh;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.chat-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
height: 100vh;
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
background-color: white;
|
||||
box-shadow: 0 0 20px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.chat-header {
|
||||
background-color: #000000;
|
||||
color: white;
|
||||
padding: 1rem;
|
||||
text-align: center;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
|
||||
h1 {
|
||||
margin: 0;
|
||||
font-size: 1.5rem;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.model-selector {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
|
||||
label {
|
||||
font-weight: 500;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
|
||||
select {
|
||||
background-color: white;
|
||||
color: #374151;
|
||||
border: 1px solid #d1d5db;
|
||||
border-radius: 6px;
|
||||
padding: 0.5rem 0.75rem;
|
||||
font-size: 0.9rem;
|
||||
font-family: inherit;
|
||||
cursor: pointer;
|
||||
min-width: 200px;
|
||||
|
||||
&:focus {
|
||||
outline: none;
|
||||
border-color: #663c99;
|
||||
box-shadow: 0 0 0 2px rgba(29, 78, 216, 0.2);
|
||||
}
|
||||
|
||||
option {
|
||||
padding: 0.5rem;
|
||||
}
|
||||
}
|
||||
|
||||
.streaming-toggle {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-left: 1rem;
|
||||
|
||||
label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.9rem;
|
||||
|
||||
input[type="checkbox"] {
|
||||
cursor: pointer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.chat-messages {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 1rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
background-color: #fafafa;
|
||||
}
|
||||
|
||||
.message {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
padding: 1rem;
|
||||
border-radius: 12px;
|
||||
max-width: 80%;
|
||||
word-wrap: break-word;
|
||||
|
||||
&.user-message {
|
||||
align-self: flex-end;
|
||||
background-color: #2563eb;
|
||||
color: white;
|
||||
|
||||
.message-role {
|
||||
font-weight: 600;
|
||||
font-size: 0.8rem;
|
||||
opacity: 0.8;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
|
||||
.message-content {
|
||||
line-height: 1.5;
|
||||
}
|
||||
}
|
||||
|
||||
&.assistant-message {
|
||||
align-self: flex-start;
|
||||
background-color: #646873;
|
||||
border: 1px solid #e5e7eb;
|
||||
color: #f3f3f3;
|
||||
|
||||
.message-role {
|
||||
font-weight: 600;
|
||||
font-size: 0.8rem;
|
||||
color: #c4c5cd;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
|
||||
.message-content {
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
&.loading {
|
||||
background-color: #f3f4f6;
|
||||
border-color: #d1d5db;
|
||||
|
||||
.message-content {
|
||||
font-style: italic;
|
||||
color: #6b7280;
|
||||
}
|
||||
}
|
||||
|
||||
&.streaming {
|
||||
.message-content {
|
||||
.cursor {
|
||||
display: inline-block;
|
||||
animation: blink 1s infinite;
|
||||
color: #9ca3af;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.error-message {
|
||||
background-color: #fef2f2;
|
||||
border: 1px solid #fca5a5;
|
||||
color: #dc2626;
|
||||
padding: 1rem;
|
||||
margin: 0 1rem;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.chat-input {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
padding: 1rem;
|
||||
background-color: white;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
|
||||
textarea {
|
||||
flex: 1;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid #d1d5db;
|
||||
border-radius: 8px;
|
||||
resize: none;
|
||||
min-height: 60px;
|
||||
max-height: 120px;
|
||||
font-family: inherit;
|
||||
font-size: 1rem;
|
||||
line-height: 1.5;
|
||||
|
||||
&:focus {
|
||||
outline: none;
|
||||
border-color: #663c99;
|
||||
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
|
||||
}
|
||||
|
||||
&.disabled {
|
||||
background-color: #f9fafb;
|
||||
color: #6b7280;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
}
|
||||
|
||||
button {
|
||||
padding: 0.75rem 1.5rem;
|
||||
background-color: #663c99;
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.2s ease;
|
||||
align-self: flex-end;
|
||||
|
||||
&:hover:not(.disabled) {
|
||||
background-color: #663c99;
|
||||
}
|
||||
|
||||
&.disabled {
|
||||
background-color: #9ca3af;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
&:focus {
|
||||
outline: none;
|
||||
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Scrollbar styling for webkit browsers */
|
||||
.chat-messages::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
}
|
||||
|
||||
.chat-messages::-webkit-scrollbar-track {
|
||||
background: #f1f1f1;
|
||||
}
|
||||
|
||||
.chat-messages::-webkit-scrollbar-thumb {
|
||||
background: #c1c1c1;
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.chat-messages::-webkit-scrollbar-thumb:hover {
|
||||
background: #a8a8a8;
|
||||
}
|
||||
|
||||
/* Cursor blink animation */
|
||||
@keyframes blink {
|
||||
0%, 50% {
|
||||
opacity: 1;
|
||||
}
|
||||
51%, 100% {
|
||||
opacity: 0;
|
||||
}
|
||||
}
|
@@ -25,15 +25,9 @@ rand = "0.8.5"
|
||||
async-openai = "0.28.3"
|
||||
once_cell = "1.19.0"
|
||||
|
||||
|
||||
|
||||
[package.metadata.compose]
|
||||
image = "ghcr.io/geoffsee/embeddings-service:latest"
|
||||
port = 8080
|
||||
|
||||
|
||||
# generates kubernetes manifests
|
||||
[package.metadata.kube]
|
||||
image = "ghcr.io/geoffsee/embeddings-service:latest"
|
||||
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||
cmd = ["./bin/embeddings-engine"]
|
||||
replicas = 1
|
||||
port = 8080
|
@@ -1,42 +0,0 @@
|
||||
# ---- Build stage ----
|
||||
FROM rust:1-slim-bullseye AS builder
|
||||
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Cache deps first
|
||||
COPY . ./
|
||||
RUN rm -rf src
|
||||
RUN mkdir src && echo "fn main() {}" > src/main.rs && echo "// lib" > src/lib.rs && cargo build --release
|
||||
RUN rm -rf src
|
||||
|
||||
# Copy real sources and build
|
||||
COPY . .
|
||||
RUN cargo build --release
|
||||
|
||||
# ---- Runtime stage ----
|
||||
FROM debian:bullseye-slim
|
||||
|
||||
# Install only what the compiled binary needs
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libssl1.1 \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy binary from builder
|
||||
COPY --from=builder /usr/src/app/target/release/embeddings-engine /usr/local/bin/
|
||||
|
||||
# Run as non-root user for safety
|
||||
RUN useradd -m appuser
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8080
|
||||
CMD ["embeddings-engine"]
|
@@ -1,4 +1,100 @@
|
||||
# Embeddings Engine
|
||||
|
||||
A high-performance text embeddings service that generates vector representations of text using state-of-the-art models.
|
||||
This crate wraps the fastembed crate to provide embeddings and partially adapts the openai specification.
|
||||
A high-performance text embeddings service that generates vector representations of text using state-of-the-art models. This crate wraps the FastEmbed library to provide embeddings with OpenAI-compatible API endpoints.
|
||||
|
||||
## Overview
|
||||
|
||||
The embeddings-engine provides a standalone service for generating text embeddings that can be used for semantic search, similarity comparisons, and other NLP tasks. It's designed to be compatible with OpenAI's embeddings API format.
|
||||
|
||||
## Features
|
||||
|
||||
- **OpenAI-Compatible API**: `/v1/embeddings` endpoint matching OpenAI's specification
|
||||
- **FastEmbed Integration**: Powered by the FastEmbed library for high-quality embeddings
|
||||
- **Multiple Model Support**: Support for various embedding models
|
||||
- **High Performance**: Optimized for fast embedding generation
|
||||
- **Standalone Service**: Can run independently or as part of the predict-otron-9000 platform
|
||||
|
||||
## Building and Running
|
||||
|
||||
### Prerequisites
|
||||
- Rust toolchain
|
||||
- Internet connection for initial model downloads
|
||||
|
||||
### Standalone Server
|
||||
```bash
|
||||
cargo run --bin embeddings-engine --release
|
||||
```
|
||||
|
||||
The service will start on port 8080 by default.
|
||||
|
||||
## API Usage
|
||||
|
||||
### Generate Embeddings
|
||||
|
||||
**Endpoint**: `POST /v1/embeddings`
|
||||
|
||||
**Request Body**:
|
||||
```json
|
||||
{
|
||||
"input": "Your text to embed",
|
||||
"model": "nomic-embed-text-v1.5"
|
||||
}
|
||||
```
|
||||
|
||||
**Response**:
|
||||
```json
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [0.1, 0.2, 0.3, ...]
|
||||
}
|
||||
],
|
||||
"model": "nomic-embed-text-v1.5",
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example Usage
|
||||
|
||||
**Using cURL**:
|
||||
```bash
|
||||
curl -s http://localhost:8080/v1/embeddings \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"input": "The quick brown fox jumps over the lazy dog",
|
||||
"model": "nomic-embed-text-v1.5"
|
||||
}' | jq
|
||||
```
|
||||
|
||||
**Using Python OpenAI Client**:
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1",
|
||||
api_key="dummy" # Not validated but required by client
|
||||
)
|
||||
|
||||
response = client.embeddings.create(
|
||||
input="Your text here",
|
||||
model="nomic-embed-text-v1.5"
|
||||
)
|
||||
|
||||
print(response.data[0].embedding)
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The service can be configured through environment variables:
|
||||
- `SERVER_PORT`: Port to run on (default: 8080)
|
||||
- `RUST_LOG`: Logging level (default: info)
|
||||
|
||||
## Integration
|
||||
|
||||
This service is designed to work seamlessly with the predict-otron-9000 main server, but can also be deployed independently for dedicated embeddings workloads.
|
@@ -1,43 +1,225 @@
|
||||
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
||||
use axum::{Json, Router, response::Json as ResponseJson, routing::post};
|
||||
use axum::{Json, Router, http::StatusCode, response::Json as ResponseJson, routing::post};
|
||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing;
|
||||
|
||||
// Persistent model instance (singleton pattern)
|
||||
static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
|
||||
tracing::info!("Initializing persistent embedding model (singleton)");
|
||||
// Cache for multiple embedding models
|
||||
static MODEL_CACHE: Lazy<RwLock<HashMap<EmbeddingModel, Arc<TextEmbedding>>>> =
|
||||
Lazy::new(|| RwLock::new(HashMap::new()));
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ModelInfo {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub owned_by: String,
|
||||
pub description: String,
|
||||
pub dimensions: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ModelsResponse {
|
||||
pub object: String,
|
||||
pub data: Vec<ModelInfo>,
|
||||
}
|
||||
|
||||
// Function to convert model name strings to EmbeddingModel enum variants
|
||||
fn parse_embedding_model(model_name: &str) -> Result<EmbeddingModel, String> {
|
||||
match model_name {
|
||||
// Sentence Transformers models
|
||||
"sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => {
|
||||
Ok(EmbeddingModel::AllMiniLML6V2)
|
||||
}
|
||||
"sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => {
|
||||
Ok(EmbeddingModel::AllMiniLML6V2Q)
|
||||
}
|
||||
"sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => {
|
||||
Ok(EmbeddingModel::AllMiniLML12V2)
|
||||
}
|
||||
"sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => {
|
||||
Ok(EmbeddingModel::AllMiniLML12V2Q)
|
||||
}
|
||||
|
||||
// BGE models
|
||||
"BAAI/bge-base-en-v1.5" | "bge-base-en-v1.5" => Ok(EmbeddingModel::BGEBaseENV15),
|
||||
"BAAI/bge-base-en-v1.5-q" | "bge-base-en-v1.5-q" => Ok(EmbeddingModel::BGEBaseENV15Q),
|
||||
"BAAI/bge-large-en-v1.5" | "bge-large-en-v1.5" => Ok(EmbeddingModel::BGELargeENV15),
|
||||
"BAAI/bge-large-en-v1.5-q" | "bge-large-en-v1.5-q" => Ok(EmbeddingModel::BGELargeENV15Q),
|
||||
"BAAI/bge-small-en-v1.5" | "bge-small-en-v1.5" => Ok(EmbeddingModel::BGESmallENV15),
|
||||
"BAAI/bge-small-en-v1.5-q" | "bge-small-en-v1.5-q" => Ok(EmbeddingModel::BGESmallENV15Q),
|
||||
"BAAI/bge-small-zh-v1.5" | "bge-small-zh-v1.5" => Ok(EmbeddingModel::BGESmallZHV15),
|
||||
"BAAI/bge-large-zh-v1.5" | "bge-large-zh-v1.5" => Ok(EmbeddingModel::BGELargeZHV15),
|
||||
|
||||
// Nomic models
|
||||
"nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => {
|
||||
Ok(EmbeddingModel::NomicEmbedTextV1)
|
||||
}
|
||||
"nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => {
|
||||
Ok(EmbeddingModel::NomicEmbedTextV15)
|
||||
}
|
||||
"nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => {
|
||||
Ok(EmbeddingModel::NomicEmbedTextV15Q)
|
||||
}
|
||||
|
||||
// Paraphrase models
|
||||
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
| "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2),
|
||||
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q"
|
||||
| "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q),
|
||||
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
| "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2),
|
||||
|
||||
// ModernBert
|
||||
"lightonai/modernbert-embed-large" | "modernbert-embed-large" => {
|
||||
Ok(EmbeddingModel::ModernBertEmbedLarge)
|
||||
}
|
||||
|
||||
// Multilingual E5 models
|
||||
"intfloat/multilingual-e5-small" | "multilingual-e5-small" => {
|
||||
Ok(EmbeddingModel::MultilingualE5Small)
|
||||
}
|
||||
"intfloat/multilingual-e5-base" | "multilingual-e5-base" => {
|
||||
Ok(EmbeddingModel::MultilingualE5Base)
|
||||
}
|
||||
"intfloat/multilingual-e5-large" | "multilingual-e5-large" => {
|
||||
Ok(EmbeddingModel::MultilingualE5Large)
|
||||
}
|
||||
|
||||
// Mixedbread models
|
||||
"mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => {
|
||||
Ok(EmbeddingModel::MxbaiEmbedLargeV1)
|
||||
}
|
||||
"mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => {
|
||||
Ok(EmbeddingModel::MxbaiEmbedLargeV1Q)
|
||||
}
|
||||
|
||||
// GTE models
|
||||
"Alibaba-NLP/gte-base-en-v1.5" | "gte-base-en-v1.5" => Ok(EmbeddingModel::GTEBaseENV15),
|
||||
"Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => {
|
||||
Ok(EmbeddingModel::GTEBaseENV15Q)
|
||||
}
|
||||
"Alibaba-NLP/gte-large-en-v1.5" | "gte-large-en-v1.5" => Ok(EmbeddingModel::GTELargeENV15),
|
||||
"Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => {
|
||||
Ok(EmbeddingModel::GTELargeENV15Q)
|
||||
}
|
||||
|
||||
// CLIP model
|
||||
"Qdrant/clip-ViT-B-32-text" | "clip-vit-b-32" => Ok(EmbeddingModel::ClipVitB32),
|
||||
|
||||
// Jina model
|
||||
"jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => {
|
||||
Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode)
|
||||
}
|
||||
|
||||
_ => Err(format!("Unsupported embedding model: {}", model_name)),
|
||||
}
|
||||
}
|
||||
|
||||
// Function to get model dimensions
|
||||
fn get_model_dimensions(model: &EmbeddingModel) -> usize {
|
||||
match model {
|
||||
EmbeddingModel::AllMiniLML6V2 | EmbeddingModel::AllMiniLML6V2Q => 384,
|
||||
EmbeddingModel::AllMiniLML12V2 | EmbeddingModel::AllMiniLML12V2Q => 384,
|
||||
EmbeddingModel::BGEBaseENV15 | EmbeddingModel::BGEBaseENV15Q => 768,
|
||||
EmbeddingModel::BGELargeENV15 | EmbeddingModel::BGELargeENV15Q => 1024,
|
||||
EmbeddingModel::BGESmallENV15 | EmbeddingModel::BGESmallENV15Q => 384,
|
||||
EmbeddingModel::BGESmallZHV15 => 512,
|
||||
EmbeddingModel::BGELargeZHV15 => 1024,
|
||||
EmbeddingModel::NomicEmbedTextV1
|
||||
| EmbeddingModel::NomicEmbedTextV15
|
||||
| EmbeddingModel::NomicEmbedTextV15Q => 768,
|
||||
EmbeddingModel::ParaphraseMLMiniLML12V2 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => 384,
|
||||
EmbeddingModel::ParaphraseMLMpnetBaseV2 => 768,
|
||||
EmbeddingModel::ModernBertEmbedLarge => 1024,
|
||||
EmbeddingModel::MultilingualE5Small => 384,
|
||||
EmbeddingModel::MultilingualE5Base => 768,
|
||||
EmbeddingModel::MultilingualE5Large => 1024,
|
||||
EmbeddingModel::MxbaiEmbedLargeV1 | EmbeddingModel::MxbaiEmbedLargeV1Q => 1024,
|
||||
EmbeddingModel::GTEBaseENV15 | EmbeddingModel::GTEBaseENV15Q => 768,
|
||||
EmbeddingModel::GTELargeENV15 | EmbeddingModel::GTELargeENV15Q => 1024,
|
||||
EmbeddingModel::ClipVitB32 => 512,
|
||||
EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
|
||||
}
|
||||
}
|
||||
|
||||
// Function to get or create a model from cache
|
||||
fn get_or_create_model(embedding_model: EmbeddingModel) -> Result<Arc<TextEmbedding>, String> {
|
||||
// First try to get from cache (read lock)
|
||||
{
|
||||
let cache = MODEL_CACHE
|
||||
.read()
|
||||
.map_err(|e| format!("Failed to acquire read lock: {}", e))?;
|
||||
if let Some(model) = cache.get(&embedding_model) {
|
||||
tracing::debug!("Using cached model: {:?}", embedding_model);
|
||||
return Ok(Arc::clone(model));
|
||||
}
|
||||
}
|
||||
|
||||
// Model not in cache, create it (write lock)
|
||||
let mut cache = MODEL_CACHE
|
||||
.write()
|
||||
.map_err(|e| format!("Failed to acquire write lock: {}", e))?;
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if let Some(model) = cache.get(&embedding_model) {
|
||||
tracing::debug!("Using cached model (double-check): {:?}", embedding_model);
|
||||
return Ok(Arc::clone(model));
|
||||
}
|
||||
|
||||
tracing::info!("Initializing new embedding model: {:?}", embedding_model);
|
||||
let model_start_time = std::time::Instant::now();
|
||||
|
||||
let model = TextEmbedding::try_new(
|
||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
|
||||
InitOptions::new(embedding_model.clone()).with_show_download_progress(true),
|
||||
)
|
||||
.expect("Failed to initialize persistent embedding model");
|
||||
.map_err(|e| format!("Failed to initialize model {:?}: {}", embedding_model, e))?;
|
||||
|
||||
let model_init_time = model_start_time.elapsed();
|
||||
tracing::info!(
|
||||
"Persistent embedding model initialized in {:.2?}",
|
||||
"Embedding model {:?} initialized in {:.2?}",
|
||||
embedding_model,
|
||||
model_init_time
|
||||
);
|
||||
|
||||
model
|
||||
});
|
||||
let model_arc = Arc::new(model);
|
||||
cache.insert(embedding_model.clone(), Arc::clone(&model_arc));
|
||||
Ok(model_arc)
|
||||
}
|
||||
|
||||
pub async fn embeddings_create(
|
||||
Json(payload): Json<CreateEmbeddingRequest>,
|
||||
) -> ResponseJson<serde_json::Value> {
|
||||
) -> Result<ResponseJson<serde_json::Value>, (StatusCode, String)> {
|
||||
// Start timing the entire process
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
// Phase 1: Access persistent model instance
|
||||
// Phase 1: Parse and get the embedding model
|
||||
let model_start_time = std::time::Instant::now();
|
||||
|
||||
// Access the lazy-initialized persistent model instance
|
||||
// This will only initialize the model on the first request
|
||||
let embedding_model = match parse_embedding_model(&payload.model) {
|
||||
Ok(model) => model,
|
||||
Err(e) => {
|
||||
tracing::error!("Invalid model requested: {}", e);
|
||||
return Err((StatusCode::BAD_REQUEST, format!("Invalid model: {}", e)));
|
||||
}
|
||||
};
|
||||
|
||||
let model = match get_or_create_model(embedding_model.clone()) {
|
||||
Ok(model) => model,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to get/create model: {}", e);
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Model initialization failed: {}", e),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let model_access_time = model_start_time.elapsed();
|
||||
tracing::debug!(
|
||||
"Persistent model access completed in {:.2?}",
|
||||
"Model access/creation completed in {:.2?}",
|
||||
model_access_time
|
||||
);
|
||||
|
||||
@@ -65,9 +247,13 @@ pub async fn embeddings_create(
|
||||
// Phase 3: Generate embeddings
|
||||
let embedding_start_time = std::time::Instant::now();
|
||||
|
||||
let embeddings = EMBEDDING_MODEL
|
||||
.embed(texts_from_embedding_input, None)
|
||||
.expect("failed to embed document");
|
||||
let embeddings = model.embed(texts_from_embedding_input, None).map_err(|e| {
|
||||
tracing::error!("Failed to generate embeddings: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Embedding generation failed: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
let embedding_generation_time = embedding_start_time.elapsed();
|
||||
tracing::info!(
|
||||
@@ -117,8 +303,9 @@ pub async fn embeddings_create(
|
||||
// Generate a random non-zero embedding
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut random_embedding = Vec::with_capacity(768);
|
||||
for _ in 0..768 {
|
||||
let expected_dimensions = get_model_dimensions(&embedding_model);
|
||||
let mut random_embedding = Vec::with_capacity(expected_dimensions);
|
||||
for _ in 0..expected_dimensions {
|
||||
// Generate random values between -1.0 and 1.0, excluding 0
|
||||
let mut val = 0.0;
|
||||
while val == 0.0 {
|
||||
@@ -138,18 +325,19 @@ pub async fn embeddings_create(
|
||||
random_embedding
|
||||
} else {
|
||||
// Check if dimensions parameter is provided and pad the embeddings if necessary
|
||||
let mut padded_embedding = embeddings[0].clone();
|
||||
let padded_embedding = embeddings[0].clone();
|
||||
|
||||
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
|
||||
let target_dimension = 768;
|
||||
if padded_embedding.len() < target_dimension {
|
||||
let padding_needed = target_dimension - padded_embedding.len();
|
||||
tracing::trace!(
|
||||
"Padding embedding with {} zeros to reach {} dimensions",
|
||||
padding_needed,
|
||||
target_dimension
|
||||
// Use the actual model dimensions instead of hardcoded 768
|
||||
let actual_dimensions = padded_embedding.len();
|
||||
let expected_dimensions = get_model_dimensions(&embedding_model);
|
||||
|
||||
if actual_dimensions != expected_dimensions {
|
||||
tracing::warn!(
|
||||
"Model {:?} produced {} dimensions but expected {}",
|
||||
embedding_model,
|
||||
actual_dimensions,
|
||||
expected_dimensions
|
||||
);
|
||||
padded_embedding.extend(vec![0.0; padding_needed]);
|
||||
}
|
||||
|
||||
padded_embedding
|
||||
@@ -203,11 +391,234 @@ pub async fn embeddings_create(
|
||||
postprocessing_time
|
||||
);
|
||||
|
||||
ResponseJson(response)
|
||||
Ok(ResponseJson(response))
|
||||
}
|
||||
|
||||
pub async fn models_list() -> ResponseJson<ModelsResponse> {
|
||||
let models = vec![
|
||||
ModelInfo {
|
||||
id: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "sentence-transformers".to_string(),
|
||||
description: "Sentence Transformer model, MiniLM-L6-v2".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "sentence-transformers/all-MiniLM-L6-v2-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "sentence-transformers".to_string(),
|
||||
description: "Quantized Sentence Transformer model, MiniLM-L6-v2".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "sentence-transformers/all-MiniLM-L12-v2".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "sentence-transformers".to_string(),
|
||||
description: "Sentence Transformer model, MiniLM-L12-v2".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "sentence-transformers/all-MiniLM-L12-v2-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "sentence-transformers".to_string(),
|
||||
description: "Quantized Sentence Transformer model, MiniLM-L12-v2".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-base-en-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "v1.5 release of the base English model".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-base-en-v1.5-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "Quantized v1.5 release of the base English model".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-large-en-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "v1.5 release of the large English model".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-large-en-v1.5-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "Quantized v1.5 release of the large English model".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-small-en-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "v1.5 release of the fast and default English model".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-small-en-v1.5-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "Quantized v1.5 release of the fast and default English model".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-small-zh-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "v1.5 release of the small Chinese model".to_string(),
|
||||
dimensions: 512,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-large-zh-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "v1.5 release of the large Chinese model".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "nomic-ai/nomic-embed-text-v1".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "nomic-ai".to_string(),
|
||||
description: "8192 context length english model".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "nomic-ai/nomic-embed-text-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "nomic-ai".to_string(),
|
||||
description: "v1.5 release of the 8192 context length english model".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "nomic-ai/nomic-embed-text-v1.5-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "nomic-ai".to_string(),
|
||||
description: "Quantized v1.5 release of the 8192 context length english model"
|
||||
.to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "sentence-transformers".to_string(),
|
||||
description: "Multi-lingual model".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "sentence-transformers".to_string(),
|
||||
description: "Quantized Multi-lingual model".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "sentence-transformers".to_string(),
|
||||
description: "Sentence-transformers model for tasks like clustering or semantic search"
|
||||
.to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "lightonai/modernbert-embed-large".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "lightonai".to_string(),
|
||||
description: "Large model of ModernBert Text Embeddings".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "intfloat/multilingual-e5-small".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "intfloat".to_string(),
|
||||
description: "Small model of multilingual E5 Text Embeddings".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "intfloat/multilingual-e5-base".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "intfloat".to_string(),
|
||||
description: "Base model of multilingual E5 Text Embeddings".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "intfloat/multilingual-e5-large".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "intfloat".to_string(),
|
||||
description: "Large model of multilingual E5 Text Embeddings".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "mixedbread-ai/mxbai-embed-large-v1".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "mixedbread-ai".to_string(),
|
||||
description: "Large English embedding model from MixedBreed.ai".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "mixedbread-ai/mxbai-embed-large-v1-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "mixedbread-ai".to_string(),
|
||||
description: "Quantized Large English embedding model from MixedBreed.ai".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "Alibaba-NLP/gte-base-en-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "Alibaba-NLP".to_string(),
|
||||
description: "Base multilingual embedding model from Alibaba".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "Alibaba-NLP/gte-base-en-v1.5-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "Alibaba-NLP".to_string(),
|
||||
description: "Quantized Base multilingual embedding model from Alibaba".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "Alibaba-NLP/gte-large-en-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "Alibaba-NLP".to_string(),
|
||||
description: "Large multilingual embedding model from Alibaba".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "Alibaba-NLP/gte-large-en-v1.5-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "Alibaba-NLP".to_string(),
|
||||
description: "Quantized Large multilingual embedding model from Alibaba".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "Qdrant/clip-ViT-B-32-text".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "Qdrant".to_string(),
|
||||
description: "CLIP text encoder based on ViT-B/32".to_string(),
|
||||
dimensions: 512,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "jinaai/jina-embeddings-v2-base-code".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "jinaai".to_string(),
|
||||
description: "Jina embeddings v2 base code".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
];
|
||||
|
||||
ResponseJson(ModelsResponse {
|
||||
object: "list".to_string(),
|
||||
data: models,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn create_embeddings_router() -> Router {
|
||||
Router::new()
|
||||
.route("/v1/embeddings", post(embeddings_create))
|
||||
// .route("/v1/models", get(models_list))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
}
|
||||
|
@@ -4,8 +4,6 @@ use axum::{
|
||||
response::Json as ResponseJson,
|
||||
routing::{get, post},
|
||||
};
|
||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing;
|
||||
@@ -13,127 +11,28 @@ use tracing;
|
||||
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
|
||||
const DEFAULT_SERVER_PORT: &str = "8080";
|
||||
|
||||
use embeddings_engine;
|
||||
|
||||
async fn embeddings_create(
|
||||
Json(payload): Json<CreateEmbeddingRequest>,
|
||||
) -> ResponseJson<serde_json::Value> {
|
||||
let model = TextEmbedding::try_new(
|
||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
|
||||
)
|
||||
.expect("Failed to initialize model");
|
||||
) -> Result<ResponseJson<serde_json::Value>, axum::response::Response> {
|
||||
match embeddings_engine::embeddings_create(Json(payload)).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err((status_code, message)) => Err(axum::response::Response::builder()
|
||||
.status(status_code)
|
||||
.body(axum::body::Body::from(message))
|
||||
.unwrap()),
|
||||
}
|
||||
}
|
||||
|
||||
let embedding_input = payload.input;
|
||||
|
||||
let texts_from_embedding_input = match embedding_input {
|
||||
EmbeddingInput::String(text) => vec![text],
|
||||
EmbeddingInput::StringArray(texts) => texts,
|
||||
EmbeddingInput::IntegerArray(_) => {
|
||||
panic!("Integer array input not supported for text embeddings");
|
||||
}
|
||||
EmbeddingInput::ArrayOfIntegerArray(_) => {
|
||||
panic!("Array of integer arrays not supported for text embeddings");
|
||||
}
|
||||
};
|
||||
|
||||
let embeddings = model
|
||||
.embed(texts_from_embedding_input, None)
|
||||
.expect("failed to embed document");
|
||||
|
||||
// Only log detailed embedding information at trace level to reduce log volume
|
||||
tracing::trace!("Embeddings length: {}", embeddings.len());
|
||||
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
||||
|
||||
// Log the first 10 values of the original embedding at trace level
|
||||
tracing::trace!(
|
||||
"Original embedding preview: {:?}",
|
||||
&embeddings[0][..10.min(embeddings[0].len())]
|
||||
);
|
||||
|
||||
// Check if there are any NaN or zero values in the original embedding
|
||||
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
||||
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
||||
tracing::trace!(
|
||||
"Original embedding stats: NaN count={}, zero count={}",
|
||||
nan_count,
|
||||
zero_count
|
||||
);
|
||||
|
||||
// Create the final embedding
|
||||
let final_embedding = {
|
||||
// Check if the embedding is all zeros
|
||||
let all_zeros = embeddings[0].iter().all(|&x| x == 0.0);
|
||||
if all_zeros {
|
||||
tracing::warn!("Embedding is all zeros. Generating random non-zero embedding.");
|
||||
|
||||
// Generate a random non-zero embedding
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut random_embedding = Vec::with_capacity(768);
|
||||
for _ in 0..768 {
|
||||
// Generate random values between -1.0 and 1.0, excluding 0
|
||||
let mut val = 0.0;
|
||||
while val == 0.0 {
|
||||
val = rng.gen_range(-1.0..1.0);
|
||||
}
|
||||
random_embedding.push(val);
|
||||
}
|
||||
|
||||
// Normalize the random embedding
|
||||
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
for i in 0..random_embedding.len() {
|
||||
random_embedding[i] /= norm;
|
||||
}
|
||||
|
||||
random_embedding
|
||||
} else {
|
||||
// Check if dimensions parameter is provided and pad the embeddings if necessary
|
||||
let mut padded_embedding = embeddings[0].clone();
|
||||
|
||||
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
|
||||
let target_dimension = 768;
|
||||
if padded_embedding.len() < target_dimension {
|
||||
let padding_needed = target_dimension - padded_embedding.len();
|
||||
tracing::trace!(
|
||||
"Padding embedding with {} zeros to reach {} dimensions",
|
||||
padding_needed,
|
||||
target_dimension
|
||||
);
|
||||
padded_embedding.extend(vec![0.0; padding_needed]);
|
||||
}
|
||||
|
||||
padded_embedding
|
||||
}
|
||||
};
|
||||
|
||||
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
||||
|
||||
// Log the first 10 values of the final embedding at trace level
|
||||
tracing::trace!(
|
||||
"Final embedding preview: {:?}",
|
||||
&final_embedding[..10.min(final_embedding.len())]
|
||||
);
|
||||
|
||||
// Return a response that matches the OpenAI API format
|
||||
let response = serde_json::json!({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": final_embedding
|
||||
}
|
||||
],
|
||||
"model": payload.model,
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
});
|
||||
ResponseJson(response)
|
||||
async fn models_list() -> ResponseJson<embeddings_engine::ModelsResponse> {
|
||||
embeddings_engine::models_list().await
|
||||
}
|
||||
|
||||
fn create_app() -> Router {
|
||||
Router::new()
|
||||
.route("/v1/embeddings", post(embeddings_create))
|
||||
.route("/v1/models", get(models_list))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
}
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "inference-engine"
|
||||
version.workspace = true
|
||||
edition = "2021"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||
@@ -31,13 +31,21 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
reborrow = "0.5.5"
|
||||
futures-util = "0.3.31"
|
||||
gemma-runner = { path = "../gemma-runner" }
|
||||
llama-runner = { path = "../llama-runner" }
|
||||
gemma-runner = { path = "../../integration/gemma-runner" }
|
||||
llama-runner = { path = "../../integration/llama-runner" }
|
||||
embeddings-engine = { path = "../embeddings-engine" }
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", default-features = false }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", default-features = false }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", default-features = false }
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
gemma-runner = { path = "../../integration/gemma-runner", features = ["metal"] }
|
||||
llama-runner = { path = "../../integration/llama-runner", features = ["metal"] }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
@@ -61,15 +69,13 @@ bindgen_cuda = { version = "0.1.1", optional = true }
|
||||
[features]
|
||||
bin = []
|
||||
|
||||
|
||||
|
||||
[package.metadata.compose]
|
||||
image = "ghcr.io/geoffsee/inference-engine:latest"
|
||||
port = 8080
|
||||
|
||||
[[bin]]
|
||||
name = "inference-engine"
|
||||
path = "src/main.rs"
|
||||
|
||||
# generates kubernetes manifests
|
||||
[package.metadata.kube]
|
||||
image = "ghcr.io/geoffsee/inference-service:latest"
|
||||
replicas = 1
|
||||
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||
cmd = ["./bin/inference-engine"]
|
||||
port = 8080
|
||||
replicas = 1
|
||||
|
@@ -1,86 +0,0 @@
|
||||
# ---- Build stage ----
|
||||
FROM rust:1-slim-bullseye AS builder
|
||||
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
# Install build dependencies including CUDA toolkit for GPU support
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
build-essential \
|
||||
wget \
|
||||
gnupg2 \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CUDA toolkit (optional, for GPU support)
|
||||
# This is a minimal CUDA installation for building
|
||||
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb && \
|
||||
dpkg -i cuda-keyring_1.0-1_all.deb && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-minimal-build-11-8 \
|
||||
libcublas-dev-11-8 \
|
||||
libcurand-dev-11-8 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& rm cuda-keyring_1.0-1_all.deb
|
||||
|
||||
# Set CUDA environment variables
|
||||
ENV CUDA_HOME=/usr/local/cuda
|
||||
ENV PATH=${CUDA_HOME}/bin:${PATH}
|
||||
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
||||
|
||||
# Copy the entire workspace to get access to all crates
|
||||
COPY . ./
|
||||
|
||||
# Cache dependencies first - create dummy source files
|
||||
RUN rm -rf crates/inference-engine/src
|
||||
RUN mkdir -p crates/inference-engine/src && \
|
||||
echo "fn main() {}" > crates/inference-engine/src/main.rs && \
|
||||
echo "fn main() {}" > crates/inference-engine/src/cli_main.rs && \
|
||||
echo "// lib" > crates/inference-engine/src/lib.rs && \
|
||||
cargo build --release --bin cli --package inference-engine
|
||||
|
||||
# Remove dummy source and copy real sources
|
||||
RUN rm -rf crates/inference-engine/src
|
||||
COPY . .
|
||||
|
||||
# Build the actual CLI binary
|
||||
RUN cargo build --release --bin cli --package inference-engine
|
||||
|
||||
# ---- Runtime stage ----
|
||||
FROM debian:bullseye-slim
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libssl1.1 \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CUDA runtime libraries (optional, for GPU support at runtime)
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
wget \
|
||||
gnupg2 \
|
||||
&& wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb \
|
||||
&& dpkg -i cuda-keyring_1.0-1_all.deb \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
cuda-cudart-11-8 \
|
||||
libcublas11 \
|
||||
libcurand10 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& rm cuda-keyring_1.0-1_all.deb \
|
||||
&& apt-get purge -y wget gnupg2
|
||||
|
||||
# Copy binary from builder
|
||||
COPY --from=builder /usr/src/app/target/release/cli /usr/local/bin/inference-cli
|
||||
|
||||
# Run as non-root user for safety
|
||||
RUN useradd -m appuser
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8080
|
||||
CMD ["inference-cli"]
|
@@ -8,7 +8,7 @@ pub mod server;
|
||||
// Re-export key components for easier access
|
||||
pub use inference::ModelInference;
|
||||
pub use model::{Model, Which};
|
||||
pub use server::{create_router, AppState};
|
||||
pub use server::{AppState, create_router};
|
||||
|
||||
use std::env;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
26
crates/inference-engine/src/main.rs
Normal file
26
crates/inference-engine/src/main.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use inference_engine::{AppState, create_router, get_server_config, init_tracing};
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::info;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
init_tracing();
|
||||
|
||||
let app_state = AppState::default();
|
||||
let app = create_router(app_state);
|
||||
|
||||
let (server_host, server_port, server_address) = get_server_config();
|
||||
let listener = TcpListener::bind(&server_address).await?;
|
||||
|
||||
info!(
|
||||
"Inference Engine server starting on http://{}",
|
||||
server_address
|
||||
);
|
||||
info!("Available endpoints:");
|
||||
info!(" POST /v1/chat/completions - OpenAI-compatible chat completions");
|
||||
info!(" GET /v1/models - List available models");
|
||||
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
@@ -1,49 +1,9 @@
|
||||
// use candle_core::Tensor;
|
||||
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
pub enum Which {
|
||||
#[value(name = "2b")]
|
||||
Base2B,
|
||||
#[value(name = "7b")]
|
||||
Base7B,
|
||||
#[value(name = "2b-it")]
|
||||
Instruct2B,
|
||||
#[value(name = "7b-it")]
|
||||
Instruct7B,
|
||||
#[value(name = "1.1-2b-it")]
|
||||
InstructV1_1_2B,
|
||||
#[value(name = "1.1-7b-it")]
|
||||
InstructV1_1_7B,
|
||||
#[value(name = "code-2b")]
|
||||
CodeBase2B,
|
||||
#[value(name = "code-7b")]
|
||||
CodeBase7B,
|
||||
#[value(name = "code-2b-it")]
|
||||
CodeInstruct2B,
|
||||
#[value(name = "code-7b-it")]
|
||||
CodeInstruct7B,
|
||||
#[value(name = "2-2b")]
|
||||
BaseV2_2B,
|
||||
#[value(name = "2-2b-it")]
|
||||
InstructV2_2B,
|
||||
#[value(name = "2-9b")]
|
||||
BaseV2_9B,
|
||||
#[value(name = "2-9b-it")]
|
||||
InstructV2_9B,
|
||||
#[value(name = "3-1b")]
|
||||
BaseV3_1B,
|
||||
#[value(name = "3-1b-it")]
|
||||
InstructV3_1B,
|
||||
#[value(name = "llama-3.2-1b-it")]
|
||||
LlamaInstruct3_2_1B,
|
||||
#[value(name = "llama-3.2-3b-it")]
|
||||
LlamaInstruct3_2_3B,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum Model {
|
||||
V1(Model1),
|
||||
V2(Model2),
|
||||
@@ -66,48 +26,131 @@ impl Model {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Family {
|
||||
GemmaV1,
|
||||
GemmaV2,
|
||||
GemmaV3,
|
||||
Llama,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct ModelMeta {
|
||||
pub id: &'static str,
|
||||
pub family: Family,
|
||||
pub instruct: bool,
|
||||
}
|
||||
|
||||
const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta {
|
||||
ModelMeta {
|
||||
id,
|
||||
family,
|
||||
instruct,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
pub enum Which {
|
||||
// Gemma 1.x
|
||||
#[value(name = "2b")]
|
||||
Base2B,
|
||||
#[value(name = "7b")]
|
||||
Base7B,
|
||||
#[value(name = "2b-it")]
|
||||
Instruct2B,
|
||||
#[value(name = "7b-it")]
|
||||
Instruct7B,
|
||||
#[value(name = "1.1-2b-it")]
|
||||
InstructV1_1_2B,
|
||||
#[value(name = "1.1-7b-it")]
|
||||
InstructV1_1_7B,
|
||||
|
||||
// CodeGemma
|
||||
#[value(name = "code-2b")]
|
||||
CodeBase2B,
|
||||
#[value(name = "code-7b")]
|
||||
CodeBase7B,
|
||||
#[value(name = "code-2b-it")]
|
||||
CodeInstruct2B,
|
||||
#[value(name = "code-7b-it")]
|
||||
CodeInstruct7B,
|
||||
|
||||
// Gemma 2
|
||||
#[value(name = "2-2b")]
|
||||
BaseV2_2B,
|
||||
#[value(name = "2-2b-it")]
|
||||
InstructV2_2B,
|
||||
#[value(name = "2-9b")]
|
||||
BaseV2_9B,
|
||||
#[value(name = "2-9b-it")]
|
||||
InstructV2_9B,
|
||||
|
||||
// Gemma 3
|
||||
#[value(name = "3-1b")]
|
||||
BaseV3_1B,
|
||||
#[value(name = "3-1b-it")]
|
||||
InstructV3_1B,
|
||||
|
||||
// Llama 3.2 (use aliases instead of duplicate variants)
|
||||
#[value(name = "llama-3.2-1b")]
|
||||
Llama32_1B,
|
||||
#[value(name = "llama-3.2-1b-it", alias = "llama-3.2-1b-instruct")]
|
||||
Llama32_1BInstruct,
|
||||
#[value(name = "llama-3.2-3b")]
|
||||
Llama32_3B,
|
||||
#[value(name = "llama-3.2-3b-it", alias = "llama-3.2-3b-instruct")]
|
||||
Llama32_3BInstruct,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
pub fn to_model_id(&self) -> String {
|
||||
pub const fn meta(&self) -> ModelMeta {
|
||||
use Family::*;
|
||||
match self {
|
||||
Self::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
||||
Self::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
||||
Self::Base2B => "google/gemma-2b".to_string(),
|
||||
Self::Base7B => "google/gemma-7b".to_string(),
|
||||
Self::Instruct2B => "google/gemma-2b-it".to_string(),
|
||||
Self::Instruct7B => "google/gemma-7b-it".to_string(),
|
||||
Self::CodeBase2B => "google/codegemma-2b".to_string(),
|
||||
Self::CodeBase7B => "google/codegemma-7b".to_string(),
|
||||
Self::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
||||
Self::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
||||
Self::BaseV2_2B => "google/gemma-2-2b".to_string(),
|
||||
Self::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
|
||||
Self::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
||||
Self::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
||||
Self::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
||||
Self::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
||||
Self::LlamaInstruct3_2_1B => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
|
||||
Self::LlamaInstruct3_2_3B => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
|
||||
// Gemma 1.x
|
||||
Self::Base2B => m("google/gemma-2b", GemmaV1, false),
|
||||
Self::Base7B => m("google/gemma-7b", GemmaV1, false),
|
||||
Self::Instruct2B => m("google/gemma-2b-it", GemmaV1, true),
|
||||
Self::Instruct7B => m("google/gemma-7b-it", GemmaV1, true),
|
||||
Self::InstructV1_1_2B => m("google/gemma-1.1-2b-it", GemmaV1, true),
|
||||
Self::InstructV1_1_7B => m("google/gemma-1.1-7b-it", GemmaV1, true),
|
||||
|
||||
// CodeGemma
|
||||
Self::CodeBase2B => m("google/codegemma-2b", GemmaV1, false),
|
||||
Self::CodeBase7B => m("google/codegemma-7b", GemmaV1, false),
|
||||
Self::CodeInstruct2B => m("google/codegemma-2b-it", GemmaV1, true),
|
||||
Self::CodeInstruct7B => m("google/codegemma-7b-it", GemmaV1, true),
|
||||
|
||||
// Gemma 2
|
||||
Self::BaseV2_2B => m("google/gemma-2-2b", GemmaV2, false),
|
||||
Self::InstructV2_2B => m("google/gemma-2-2b-it", GemmaV2, true),
|
||||
Self::BaseV2_9B => m("google/gemma-2-9b", GemmaV2, false),
|
||||
Self::InstructV2_9B => m("google/gemma-2-9b-it", GemmaV2, true),
|
||||
|
||||
// Gemma 3
|
||||
Self::BaseV3_1B => m("google/gemma-3-1b-pt", GemmaV3, false),
|
||||
Self::InstructV3_1B => m("google/gemma-3-1b-it", GemmaV3, true),
|
||||
|
||||
// Llama 3.2
|
||||
Self::Llama32_1B => m("meta-llama/Llama-3.2-1B", Llama, false),
|
||||
Self::Llama32_1BInstruct => m("meta-llama/Llama-3.2-1B-Instruct", Llama, true),
|
||||
Self::Llama32_3B => m("meta-llama/Llama-3.2-3B", Llama, false),
|
||||
Self::Llama32_3BInstruct => m("meta-llama/Llama-3.2-3B-Instruct", Llama, true),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_model_id(&self) -> String {
|
||||
self.meta().id.to_string()
|
||||
}
|
||||
|
||||
pub fn is_instruct_model(&self) -> bool {
|
||||
match self {
|
||||
Self::Base2B
|
||||
| Self::Base7B
|
||||
| Self::CodeBase2B
|
||||
| Self::CodeBase7B
|
||||
| Self::BaseV2_2B
|
||||
| Self::BaseV2_9B
|
||||
| Self::BaseV3_1B => false,
|
||||
_ => true,
|
||||
}
|
||||
self.meta().instruct
|
||||
}
|
||||
|
||||
pub fn is_v3_model(&self) -> bool {
|
||||
matches!(self, Self::BaseV3_1B | Self::InstructV3_1B)
|
||||
matches!(self.meta().family, Family::GemmaV3)
|
||||
}
|
||||
|
||||
pub fn is_llama_model(&self) -> bool {
|
||||
matches!(self, Self::LlamaInstruct3_2_1B | Self::LlamaInstruct3_2_3B)
|
||||
matches!(self.meta().family, Family::Llama)
|
||||
}
|
||||
}
|
||||
|
@@ -1,26 +1,28 @@
|
||||
use axum::{
|
||||
Json, Router,
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{sse::Event, sse::Sse, IntoResponse},
|
||||
response::{IntoResponse, sse::Event, sse::Sse},
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use futures_util::stream::{self, Stream};
|
||||
use std::convert::Infallible;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::Which;
|
||||
use crate::openai_types::{
|
||||
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest,
|
||||
ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage,
|
||||
};
|
||||
use crate::Which;
|
||||
use either::Either;
|
||||
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
|
||||
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
|
||||
use embeddings_engine::models_list;
|
||||
use gemma_runner::{GemmaInferenceConfig, WhichModel, run_gemma_api};
|
||||
use llama_runner::{LlamaInferenceConfig, run_llama_inference};
|
||||
use serde_json::Value;
|
||||
// -------------------------
|
||||
// Shared app state
|
||||
@@ -34,7 +36,7 @@ pub enum ModelType {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub model_type: ModelType,
|
||||
pub model_type: Option<ModelType>,
|
||||
pub model_id: String,
|
||||
pub gemma_config: Option<GemmaInferenceConfig>,
|
||||
pub llama_config: Option<LlamaInferenceConfig>,
|
||||
@@ -42,13 +44,19 @@ pub struct AppState {
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
// Configure a default model to prevent 503 errors from the chat-ui
|
||||
// This can be overridden by environment variables if needed
|
||||
let default_model_id =
|
||||
std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
|
||||
|
||||
let gemma_config = GemmaInferenceConfig {
|
||||
model: gemma_runner::WhichModel::InstructV3_1B,
|
||||
model: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
Self {
|
||||
model_type: ModelType::Gemma,
|
||||
model_id: "gemma-3-1b-it".to_string(),
|
||||
model_type: None,
|
||||
model_id: default_model_id,
|
||||
gemma_config: Some(gemma_config),
|
||||
llama_config: None,
|
||||
}
|
||||
@@ -59,6 +67,33 @@ impl Default for AppState {
|
||||
// Helper functions
|
||||
// -------------------------
|
||||
|
||||
fn model_id_to_which(model_id: &str) -> Option<Which> {
|
||||
let normalized = normalize_model_id(model_id);
|
||||
match normalized.as_str() {
|
||||
"gemma-2b" => Some(Which::Base2B),
|
||||
"gemma-7b" => Some(Which::Base7B),
|
||||
"gemma-2b-it" => Some(Which::Instruct2B),
|
||||
"gemma-7b-it" => Some(Which::Instruct7B),
|
||||
"gemma-1.1-2b-it" => Some(Which::InstructV1_1_2B),
|
||||
"gemma-1.1-7b-it" => Some(Which::InstructV1_1_7B),
|
||||
"codegemma-2b" => Some(Which::CodeBase2B),
|
||||
"codegemma-7b" => Some(Which::CodeBase7B),
|
||||
"codegemma-2b-it" => Some(Which::CodeInstruct2B),
|
||||
"codegemma-7b-it" => Some(Which::CodeInstruct7B),
|
||||
"gemma-2-2b" => Some(Which::BaseV2_2B),
|
||||
"gemma-2-2b-it" => Some(Which::InstructV2_2B),
|
||||
"gemma-2-9b" => Some(Which::BaseV2_9B),
|
||||
"gemma-2-9b-it" => Some(Which::InstructV2_9B),
|
||||
"gemma-3-1b" => Some(Which::BaseV3_1B),
|
||||
"gemma-3-1b-it" => Some(Which::InstructV3_1B),
|
||||
"llama-3.2-1b" => Some(Which::Llama32_1B),
|
||||
"llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct),
|
||||
"llama-3.2-3b" => Some(Which::Llama32_3B),
|
||||
"llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_model_id(model_id: &str) -> String {
|
||||
model_id.to_lowercase().replace("_", "-")
|
||||
}
|
||||
@@ -116,90 +151,115 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
state: AppState,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
|
||||
// Enforce model selection behavior: reject if a different model is requested
|
||||
let configured_model = state.model_id.clone();
|
||||
let requested_model = request.model.clone();
|
||||
if requested_model.to_lowercase() != "default" {
|
||||
let normalized_requested = normalize_model_id(&requested_model);
|
||||
let normalized_configured = normalize_model_id(&configured_model);
|
||||
if normalized_requested != normalized_configured {
|
||||
// Use the model specified in the request
|
||||
let model_id = request.model.clone();
|
||||
let which_model = model_id_to_which(&model_id);
|
||||
|
||||
// Validate that the requested model is supported
|
||||
let which_model = match which_model {
|
||||
Some(model) => model,
|
||||
None => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!(
|
||||
"Requested model '{}' is not available. This server is running '{}' only.",
|
||||
requested_model, configured_model
|
||||
),
|
||||
"type": "model_mismatch"
|
||||
"message": format!("Unsupported model: {}", model_id),
|
||||
"type": "model_not_supported"
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let model_id = state.model_id.clone();
|
||||
};
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
|
||||
// Build prompt based on model type
|
||||
let prompt = match state.model_type {
|
||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||
ModelType::Llama => {
|
||||
// For Llama, just use the last user message for now
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
let prompt = if which_model.is_llama_model() {
|
||||
// For Llama, just use the last user message for now
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
build_gemma_prompt(&request.messages)
|
||||
};
|
||||
|
||||
// Get streaming receiver based on model type
|
||||
let rx =
|
||||
match state.model_type {
|
||||
ModelType::Gemma => {
|
||||
if let Some(mut config) = state.gemma_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_gemma_api(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
let rx = if which_model.is_llama_model() {
|
||||
// Create Llama configuration dynamically
|
||||
let llama_model = match which_model {
|
||||
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
|
||||
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
|
||||
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
|
||||
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
ModelType::Llama => {
|
||||
if let Some(mut config) = state.llama_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
} else {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
})),
|
||||
));
|
||||
}
|
||||
"error": { "message": format!("Model {} is not a Llama model", model_id) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
let mut config = LlamaInferenceConfig::new(llama_model);
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
})),
|
||||
)
|
||||
})?
|
||||
} else {
|
||||
// Create Gemma configuration dynamically
|
||||
let gemma_model = match which_model {
|
||||
Which::Base2B => gemma_runner::WhichModel::Base2B,
|
||||
Which::Base7B => gemma_runner::WhichModel::Base7B,
|
||||
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
|
||||
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
|
||||
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
|
||||
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
|
||||
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
|
||||
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
|
||||
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
|
||||
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
|
||||
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
|
||||
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
|
||||
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
|
||||
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
|
||||
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
|
||||
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut config = GemmaInferenceConfig {
|
||||
model: Some(gemma_model),
|
||||
..Default::default()
|
||||
};
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_gemma_api(config).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
})),
|
||||
)
|
||||
})?
|
||||
};
|
||||
|
||||
// Collect all tokens from the stream
|
||||
let mut completion = String::new();
|
||||
@@ -258,27 +318,25 @@ async fn handle_streaming_request(
|
||||
state: AppState,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
|
||||
// Validate requested model vs configured model
|
||||
let configured_model = state.model_id.clone();
|
||||
let requested_model = request.model.clone();
|
||||
if requested_model.to_lowercase() != "default" {
|
||||
let normalized_requested = normalize_model_id(&requested_model);
|
||||
let normalized_configured = normalize_model_id(&configured_model);
|
||||
if normalized_requested != normalized_configured {
|
||||
// Use the model specified in the request
|
||||
let model_id = request.model.clone();
|
||||
let which_model = model_id_to_which(&model_id);
|
||||
|
||||
// Validate that the requested model is supported
|
||||
let which_model = match which_model {
|
||||
Some(model) => model,
|
||||
None => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!(
|
||||
"Requested model '{}' is not available. This server is running '{}' only.",
|
||||
requested_model, configured_model
|
||||
),
|
||||
"type": "model_mismatch"
|
||||
"message": format!("Unsupported model: {}", model_id),
|
||||
"type": "model_not_supported"
|
||||
}
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Generate a unique ID and metadata
|
||||
let response_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', ""));
|
||||
@@ -286,24 +344,22 @@ async fn handle_streaming_request(
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
let model_id = state.model_id.clone();
|
||||
let max_tokens = request.max_tokens.unwrap_or(1000);
|
||||
|
||||
// Build prompt based on model type
|
||||
let prompt = match state.model_type {
|
||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||
ModelType::Llama => {
|
||||
// For Llama, just use the last user message for now
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
let prompt = if which_model.is_llama_model() {
|
||||
// For Llama, just use the last user message for now
|
||||
request
|
||||
.messages
|
||||
.last()
|
||||
.and_then(|m| m.content.as_ref())
|
||||
.and_then(|c| match c {
|
||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
build_gemma_prompt(&request.messages)
|
||||
};
|
||||
tracing::debug!("Formatted prompt: {}", prompt);
|
||||
|
||||
@@ -330,51 +386,78 @@ async fn handle_streaming_request(
|
||||
}
|
||||
|
||||
// Get streaming receiver based on model type
|
||||
let model_rx = match state.model_type {
|
||||
ModelType::Gemma => {
|
||||
if let Some(mut config) = state.gemma_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
match run_gemma_api(config) {
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let model_rx = if which_model.is_llama_model() {
|
||||
// Create Llama configuration dynamically
|
||||
let llama_model = match which_model {
|
||||
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
|
||||
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
|
||||
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
|
||||
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Gemma configuration not available" }
|
||||
"error": { "message": format!("Model {} is not a Llama model", model_id) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
let mut config = LlamaInferenceConfig::new(llama_model);
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
match run_llama_inference(config) {
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
ModelType::Llama => {
|
||||
if let Some(mut config) = state.llama_config {
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
match run_llama_inference(config) {
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
} else {
|
||||
// Create Gemma configuration dynamically
|
||||
let gemma_model = match which_model {
|
||||
Which::Base2B => gemma_runner::WhichModel::Base2B,
|
||||
Which::Base7B => gemma_runner::WhichModel::Base7B,
|
||||
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
|
||||
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
|
||||
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
|
||||
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
|
||||
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
|
||||
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
|
||||
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
|
||||
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
|
||||
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
|
||||
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
|
||||
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
|
||||
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
|
||||
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
|
||||
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "Llama configuration not available" }
|
||||
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut config = GemmaInferenceConfig {
|
||||
model: Some(gemma_model),
|
||||
..Default::default()
|
||||
};
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
match run_gemma_api(config) {
|
||||
Ok(rx) => rx,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
@@ -500,173 +583,93 @@ pub fn create_router(app_state: AppState) -> Router {
|
||||
/// Handler for GET /v1/models - returns list of available models
|
||||
pub async fn list_models() -> Json<ModelListResponse> {
|
||||
// Get all available model variants from the Which enum
|
||||
let models = vec![
|
||||
// Gemma models
|
||||
Model {
|
||||
id: "gemma-2b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002, // Using same timestamp as OpenAI example
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-7b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-2b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-7b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-1.1-2b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-1.1-7b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "codegemma-2b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "codegemma-7b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "codegemma-2b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "codegemma-7b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-2-2b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-2-2b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-2-9b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-2-9b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-3-1b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "gemma-3-1b-it".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "google".to_string(),
|
||||
},
|
||||
// Llama models
|
||||
Model {
|
||||
id: "llama-3.2-1b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "llama-3.2-1b-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "llama-3.2-3b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "llama-3.2-3b-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "meta".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-135m".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-135m-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-360m".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-360m-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-1.7b".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "smollm2-1.7b-instruct".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "huggingface".to_string(),
|
||||
},
|
||||
Model {
|
||||
id: "tinyllama-1.1b-chat".to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: "tinyllama".to_string(),
|
||||
},
|
||||
let which_variants = vec![
|
||||
Which::Base2B,
|
||||
Which::Base7B,
|
||||
Which::Instruct2B,
|
||||
Which::Instruct7B,
|
||||
Which::InstructV1_1_2B,
|
||||
Which::InstructV1_1_7B,
|
||||
Which::CodeBase2B,
|
||||
Which::CodeBase7B,
|
||||
Which::CodeInstruct2B,
|
||||
Which::CodeInstruct7B,
|
||||
Which::BaseV2_2B,
|
||||
Which::InstructV2_2B,
|
||||
Which::BaseV2_9B,
|
||||
Which::InstructV2_9B,
|
||||
Which::BaseV3_1B,
|
||||
Which::InstructV3_1B,
|
||||
Which::Llama32_1B,
|
||||
Which::Llama32_1BInstruct,
|
||||
Which::Llama32_3B,
|
||||
Which::Llama32_3BInstruct,
|
||||
];
|
||||
|
||||
let mut models: Vec<Model> = which_variants
|
||||
.into_iter()
|
||||
.map(|which| {
|
||||
let meta = which.meta();
|
||||
let model_id = match which {
|
||||
Which::Base2B => "gemma-2b",
|
||||
Which::Base7B => "gemma-7b",
|
||||
Which::Instruct2B => "gemma-2b-it",
|
||||
Which::Instruct7B => "gemma-7b-it",
|
||||
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
|
||||
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
|
||||
Which::CodeBase2B => "codegemma-2b",
|
||||
Which::CodeBase7B => "codegemma-7b",
|
||||
Which::CodeInstruct2B => "codegemma-2b-it",
|
||||
Which::CodeInstruct7B => "codegemma-7b-it",
|
||||
Which::BaseV2_2B => "gemma-2-2b",
|
||||
Which::InstructV2_2B => "gemma-2-2b-it",
|
||||
Which::BaseV2_9B => "gemma-2-9b",
|
||||
Which::InstructV2_9B => "gemma-2-9b-it",
|
||||
Which::BaseV3_1B => "gemma-3-1b",
|
||||
Which::InstructV3_1B => "gemma-3-1b-it",
|
||||
Which::Llama32_1B => "llama-3.2-1b",
|
||||
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
|
||||
Which::Llama32_3B => "llama-3.2-3b",
|
||||
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
|
||||
};
|
||||
|
||||
let owned_by = if meta.id.starts_with("google/") {
|
||||
"google"
|
||||
} else if meta.id.starts_with("meta-llama/") {
|
||||
"meta"
|
||||
} else {
|
||||
"unknown"
|
||||
};
|
||||
|
||||
Model {
|
||||
id: model_id.to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002,
|
||||
owned_by: owned_by.to_string(),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Get embeddings models and convert them to inference Model format
|
||||
let embeddings_response = models_list().await;
|
||||
let embeddings_models: Vec<Model> = embeddings_response
|
||||
.0
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|embedding_model| Model {
|
||||
id: embedding_model.id,
|
||||
object: embedding_model.object,
|
||||
created: 1686935002,
|
||||
owned_by: format!(
|
||||
"{} - {}",
|
||||
embedding_model.owned_by, embedding_model.description
|
||||
),
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Add embeddings models to the main models list
|
||||
models.extend(embeddings_models);
|
||||
|
||||
Json(ModelListResponse {
|
||||
object: "list".to_string(),
|
||||
data: models,
|
||||
|
@@ -1,3 +0,0 @@
|
||||
# Ensure getrandom works on wasm32-unknown-unknown without needing manual RUSTFLAGS
|
||||
[target.wasm32-unknown-unknown]
|
||||
rustflags = ["--cfg", "getrandom_backend=\"wasm_js\""]
|
@@ -1,21 +0,0 @@
|
||||
# Build stage
|
||||
FROM rust:1-alpine AS builder
|
||||
|
||||
# Install build dependencies
|
||||
RUN apk add --no-cache npm nodejs musl-dev pkgconfig openssl-dev git curl bash
|
||||
|
||||
RUN curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy manifest first (cache deps)
|
||||
COPY . .
|
||||
|
||||
# Install cargo-leptos
|
||||
RUN cargo binstall cargo-leptos
|
||||
|
||||
# Build release artifacts
|
||||
RUN cargo leptos build --release
|
||||
|
||||
EXPOSE 8788
|
||||
CMD ["cargo", "leptos", "serve", "--release"]
|
@@ -1,520 +0,0 @@
|
||||
use leptos::prelude::*;
|
||||
use leptos_meta::{provide_meta_context, MetaTags, Stylesheet, Title};
|
||||
use leptos_router::{
|
||||
components::{Route, Router, Routes},
|
||||
StaticSegment,
|
||||
};
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
use async_openai_wasm::config::OpenAIConfig;
|
||||
#[cfg(feature = "hydrate")]
|
||||
use async_openai_wasm::types::{FinishReason, Role};
|
||||
#[cfg(feature = "hydrate")]
|
||||
use async_openai_wasm::{
|
||||
types::{
|
||||
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
|
||||
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs,
|
||||
Model as OpenAIModel,
|
||||
},
|
||||
Client,
|
||||
};
|
||||
#[cfg(feature = "hydrate")]
|
||||
use futures_util::StreamExt;
|
||||
#[cfg(feature = "hydrate")]
|
||||
use js_sys::Date;
|
||||
#[cfg(feature = "hydrate")]
|
||||
use leptos::task::spawn_local;
|
||||
#[cfg(feature = "hydrate")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[cfg(feature = "hydrate")]
|
||||
use std::collections::VecDeque;
|
||||
#[cfg(feature = "hydrate")]
|
||||
use uuid::Uuid;
|
||||
#[cfg(feature = "hydrate")]
|
||||
use web_sys::{HtmlInputElement, KeyboardEvent, SubmitEvent};
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub id: String,
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
pub timestamp: f64,
|
||||
}
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MessageContent(
|
||||
pub either::Either<String, Vec<std::collections::HashMap<String, MessageInnerContent>>>,
|
||||
);
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MessageInnerContent(
|
||||
pub either::Either<String, std::collections::HashMap<String, String>>,
|
||||
);
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: Option<MessageContent>,
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
const DEFAULT_MODEL: &str = "default";
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
async fn fetch_available_models() -> Result<Vec<OpenAIModel>, String> {
|
||||
leptos::logging::log!(
|
||||
"[DEBUG_LOG] fetch_available_models: Starting model fetch from http://localhost:8080/v1"
|
||||
);
|
||||
|
||||
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
|
||||
let client = Client::with_config(config);
|
||||
|
||||
match client.models().list().await {
|
||||
Ok(response) => {
|
||||
let model_count = response.data.len();
|
||||
leptos::logging::log!(
|
||||
"[DEBUG_LOG] fetch_available_models: Successfully fetched {} models",
|
||||
model_count
|
||||
);
|
||||
|
||||
if model_count > 0 {
|
||||
let model_names: Vec<String> = response.data.iter().map(|m| m.id.clone()).collect();
|
||||
leptos::logging::log!(
|
||||
"[DEBUG_LOG] fetch_available_models: Available models: {:?}",
|
||||
model_names
|
||||
);
|
||||
} else {
|
||||
leptos::logging::log!(
|
||||
"[DEBUG_LOG] fetch_available_models: No models returned by server"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(response.data)
|
||||
}
|
||||
Err(e) => {
|
||||
leptos::logging::log!(
|
||||
"[DEBUG_LOG] fetch_available_models: Failed to fetch models: {:?}",
|
||||
e
|
||||
);
|
||||
Err(format!("Failed to fetch models: {}", e))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shell(options: LeptosOptions) -> impl IntoView {
|
||||
view! {
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<AutoReload options=options.clone() />
|
||||
<HydrationScripts options/>
|
||||
<MetaTags/>
|
||||
</head>
|
||||
<body>
|
||||
<App/>
|
||||
</body>
|
||||
</html>
|
||||
}
|
||||
}
|
||||
|
||||
#[component]
|
||||
pub fn App() -> impl IntoView {
|
||||
// Provides context that manages stylesheets, titles, meta tags, etc.
|
||||
provide_meta_context();
|
||||
|
||||
view! {
|
||||
// injects a stylesheet into the document <head>
|
||||
// id=leptos means cargo-leptos will hot-reload this stylesheet
|
||||
<Stylesheet id="leptos" href="/pkg/leptos-app.css"/>
|
||||
|
||||
// sets the document title
|
||||
<Title text="Chat Interface"/>
|
||||
|
||||
// content for this chat interface
|
||||
<Router>
|
||||
<main>
|
||||
<Routes fallback=|| "Page not found.".into_view()>
|
||||
<Route path=StaticSegment("") view=ChatInterface/>
|
||||
</Routes>
|
||||
</main>
|
||||
</Router>
|
||||
}
|
||||
}
|
||||
|
||||
/// Renders the home page of your application.
|
||||
#[component]
|
||||
fn HomePage() -> impl IntoView {
|
||||
// Creates a reactive value to update the button
|
||||
let count = RwSignal::new(0);
|
||||
let on_click = move |_| *count.write() += 1;
|
||||
|
||||
view! {
|
||||
<h1>"Welcome to Leptos!"</h1>
|
||||
<button on:click=on_click>"Click Me: " {count}</button>
|
||||
}
|
||||
}
|
||||
|
||||
/// Renders the chat interface
|
||||
#[component]
|
||||
fn ChatInterface() -> impl IntoView {
|
||||
#[cfg(feature = "hydrate")]
|
||||
{
|
||||
ChatInterfaceImpl()
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "hydrate"))]
|
||||
{
|
||||
view! {
|
||||
<div class="chat-container">
|
||||
<h1>"Chat Interface"</h1>
|
||||
<p>"Loading chat interface..."</p>
|
||||
</div>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
#[component]
|
||||
fn ChatInterfaceImpl() -> impl IntoView {
|
||||
let (messages, set_messages) = RwSignal::new(VecDeque::<Message>::new()).split();
|
||||
let (input_value, set_input_value) = RwSignal::new(String::new()).split();
|
||||
let (is_loading, set_is_loading) = RwSignal::new(false).split();
|
||||
let (available_models, set_available_models) = RwSignal::new(Vec::<OpenAIModel>::new()).split();
|
||||
let (selected_model, set_selected_model) = RwSignal::new(DEFAULT_MODEL.to_string()).split();
|
||||
let (models_loading, set_models_loading) = RwSignal::new(false).split();
|
||||
|
||||
// Fetch models on component initialization
|
||||
Effect::new(move |_| {
|
||||
spawn_local(async move {
|
||||
set_models_loading.set(true);
|
||||
match fetch_available_models().await {
|
||||
Ok(models) => {
|
||||
set_available_models.set(models);
|
||||
set_models_loading.set(false);
|
||||
}
|
||||
Err(e) => {
|
||||
leptos::logging::log!("Failed to fetch models: {}", e);
|
||||
set_available_models.set(vec![]);
|
||||
set_models_loading.set(false);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
let send_message = Action::new_unsync(move |content: &String| {
|
||||
let content = content.clone();
|
||||
async move {
|
||||
if content.trim().is_empty() {
|
||||
leptos::logging::log!("[DEBUG_LOG] send_message: Empty content, skipping");
|
||||
return;
|
||||
}
|
||||
|
||||
leptos::logging::log!("[DEBUG_LOG] send_message: Starting message send process");
|
||||
set_is_loading.set(true);
|
||||
|
||||
// Add user message to chat
|
||||
let user_message = Message {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
role: "user".to_string(),
|
||||
content: content.clone(),
|
||||
timestamp: Date::now(),
|
||||
};
|
||||
|
||||
set_messages.update(|msgs| msgs.push_back(user_message.clone()));
|
||||
set_input_value.set(String::new());
|
||||
|
||||
let mut chat_messages = Vec::new();
|
||||
|
||||
// Add system message
|
||||
let system_message = ChatCompletionRequestSystemMessageArgs::default()
|
||||
.content("You are a helpful assistant.")
|
||||
.build()
|
||||
.expect("failed to build system message");
|
||||
chat_messages.push(system_message.into());
|
||||
|
||||
// Add history messages
|
||||
let history_count = messages.get_untracked().len();
|
||||
for msg in messages.get_untracked().iter() {
|
||||
match msg.role.as_str() {
|
||||
"user" => {
|
||||
let message = ChatCompletionRequestUserMessageArgs::default()
|
||||
.content(msg.content.clone())
|
||||
.build()
|
||||
.expect("failed to build user message");
|
||||
chat_messages.push(message.into());
|
||||
}
|
||||
"assistant" => {
|
||||
let message = ChatCompletionRequestAssistantMessageArgs::default()
|
||||
.content(msg.content.clone())
|
||||
.build()
|
||||
.expect("failed to build assistant message");
|
||||
chat_messages.push(message.into());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Add current user message
|
||||
let message = ChatCompletionRequestUserMessageArgs::default()
|
||||
.content(user_message.content.clone())
|
||||
.build()
|
||||
.expect("failed to build user message");
|
||||
chat_messages.push(message.into());
|
||||
|
||||
let current_model = selected_model.get_untracked();
|
||||
let total_messages = chat_messages.len();
|
||||
|
||||
leptos::logging::log!("[DEBUG_LOG] send_message: Preparing request - model: '{}', history_count: {}, total_messages: {}",
|
||||
current_model, history_count, total_messages);
|
||||
|
||||
let request = CreateChatCompletionRequestArgs::default()
|
||||
.model(current_model.as_str())
|
||||
.max_tokens(512u32)
|
||||
.messages(chat_messages)
|
||||
.stream(true)
|
||||
.build()
|
||||
.expect("failed to build request");
|
||||
|
||||
// Send request
|
||||
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
|
||||
let client = Client::with_config(config);
|
||||
|
||||
leptos::logging::log!("[DEBUG_LOG] send_message: Sending request to http://localhost:8080/v1 with model: '{}'", current_model);
|
||||
|
||||
match client.chat().create_stream(request).await {
|
||||
Ok(mut stream) => {
|
||||
leptos::logging::log!("[DEBUG_LOG] send_message: Successfully created stream");
|
||||
|
||||
let mut assistant_created = false;
|
||||
let mut content_appended = false;
|
||||
let mut chunks_received = 0;
|
||||
|
||||
while let Some(next) = stream.next().await {
|
||||
match next {
|
||||
Ok(chunk) => {
|
||||
chunks_received += 1;
|
||||
if let Some(choice) = chunk.choices.get(0) {
|
||||
if !assistant_created {
|
||||
if let Some(role) = &choice.delta.role {
|
||||
if role == &Role::Assistant {
|
||||
assistant_created = true;
|
||||
let assistant_id = Uuid::new_v4().to_string();
|
||||
set_messages.update(|msgs| {
|
||||
msgs.push_back(Message {
|
||||
id: assistant_id,
|
||||
role: "assistant".to_string(),
|
||||
content: String::new(),
|
||||
timestamp: Date::now(),
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(content) = &choice.delta.content {
|
||||
if !content.is_empty() {
|
||||
if !assistant_created {
|
||||
assistant_created = true;
|
||||
let assistant_id = Uuid::new_v4().to_string();
|
||||
set_messages.update(|msgs| {
|
||||
msgs.push_back(Message {
|
||||
id: assistant_id,
|
||||
role: "assistant".to_string(),
|
||||
content: String::new(),
|
||||
timestamp: Date::now(),
|
||||
});
|
||||
});
|
||||
}
|
||||
content_appended = true;
|
||||
set_messages.update(|msgs| {
|
||||
if let Some(last) = msgs.back_mut() {
|
||||
if last.role == "assistant" {
|
||||
last.content.push_str(content);
|
||||
last.timestamp = Date::now();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(reason) = &choice.finish_reason {
|
||||
if reason == &FinishReason::Stop {
|
||||
leptos::logging::log!("[DEBUG_LOG] send_message: Received finish_reason=stop after {} chunks", chunks_received);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
leptos::logging::log!(
|
||||
"[DEBUG_LOG] send_message: Stream error after {} chunks: {:?}",
|
||||
chunks_received,
|
||||
e
|
||||
);
|
||||
set_messages.update(|msgs| {
|
||||
msgs.push_back(Message {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
role: "system".to_string(),
|
||||
content: format!("Stream error: {}", e),
|
||||
timestamp: Date::now(),
|
||||
});
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if assistant_created && !content_appended {
|
||||
set_messages.update(|msgs| {
|
||||
let should_pop = msgs
|
||||
.back()
|
||||
.map(|m| m.role == "assistant" && m.content.is_empty())
|
||||
.unwrap_or(false);
|
||||
if should_pop {
|
||||
msgs.pop_back();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
leptos::logging::log!("[DEBUG_LOG] send_message: Stream completed successfully, received {} chunks", chunks_received);
|
||||
}
|
||||
Err(e) => {
|
||||
leptos::logging::log!(
|
||||
"[DEBUG_LOG] send_message: Request failed with error: {:?}",
|
||||
e
|
||||
);
|
||||
let error_message = Message {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
role: "system".to_string(),
|
||||
content: format!("Error: Request failed - {}", e),
|
||||
timestamp: Date::now(),
|
||||
};
|
||||
set_messages.update(|msgs| msgs.push_back(error_message));
|
||||
}
|
||||
}
|
||||
|
||||
set_is_loading.set(false);
|
||||
}
|
||||
});
|
||||
|
||||
let on_input = move |ev| {
|
||||
let input = event_target::<HtmlInputElement>(&ev);
|
||||
set_input_value.set(input.value());
|
||||
};
|
||||
|
||||
let on_submit = move |ev: SubmitEvent| {
|
||||
ev.prevent_default();
|
||||
let content = input_value.get();
|
||||
send_message.dispatch(content);
|
||||
};
|
||||
|
||||
let on_keypress = move |ev: KeyboardEvent| {
|
||||
if ev.key() == "Enter" && !ev.shift_key() {
|
||||
ev.prevent_default();
|
||||
let content = input_value.get();
|
||||
send_message.dispatch(content);
|
||||
}
|
||||
};
|
||||
|
||||
let on_model_change = move |ev| {
|
||||
let select = event_target::<web_sys::HtmlSelectElement>(&ev);
|
||||
set_selected_model.set(select.value());
|
||||
};
|
||||
|
||||
let messages_list = move || {
|
||||
messages
|
||||
.get()
|
||||
.into_iter()
|
||||
.map(|message| {
|
||||
let role_class = match message.role.as_str() {
|
||||
"user" => "user-message",
|
||||
"assistant" => "assistant-message",
|
||||
_ => "system-message",
|
||||
};
|
||||
|
||||
view! {
|
||||
<div class=format!("message {}", role_class)>
|
||||
<div class="message-role">{message.role}</div>
|
||||
<div class="message-content">{message.content}</div>
|
||||
</div>
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
let loading_indicator = move || {
|
||||
is_loading.get().then(|| {
|
||||
view! {
|
||||
<div class="message assistant-message">
|
||||
<div class="message-role">"assistant"</div>
|
||||
<div class="message-content">"Thinking..."</div>
|
||||
</div>
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
view! {
|
||||
<div class="chat-container">
|
||||
<h1>"Chat Interface"</h1>
|
||||
<div class="model-selector">
|
||||
<label for="model-select">"Model: "</label>
|
||||
<select
|
||||
id="model-select"
|
||||
on:change=on_model_change
|
||||
prop:value=selected_model
|
||||
prop:disabled=models_loading
|
||||
>
|
||||
{move || {
|
||||
if models_loading.get() {
|
||||
vec![view! {
|
||||
<option value={String::from("")} selected=false>{String::from("Loading models...")}</option>
|
||||
}]
|
||||
} else {
|
||||
let models = available_models.get();
|
||||
if models.is_empty() {
|
||||
vec![view! {
|
||||
<option value={String::from("default")} selected=true>{String::from("default")}</option>
|
||||
}]
|
||||
} else {
|
||||
models.into_iter().map(|model| {
|
||||
view! {
|
||||
<option value=model.id.clone() selected={model.id == DEFAULT_MODEL}>{model.id.clone()}</option>
|
||||
}
|
||||
}).collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
}}
|
||||
</select>
|
||||
</div>
|
||||
<div class="messages-container">
|
||||
{messages_list}
|
||||
{loading_indicator}
|
||||
</div>
|
||||
<form class="input-form" on:submit=on_submit>
|
||||
<input
|
||||
type="text"
|
||||
class="message-input"
|
||||
placeholder="Type your message here..."
|
||||
prop:value=input_value
|
||||
on:input=on_input
|
||||
on:keypress=on_keypress
|
||||
prop:disabled=is_loading
|
||||
/>
|
||||
<button
|
||||
type="submit"
|
||||
class="send-button"
|
||||
prop:disabled=move || is_loading.get() || input_value.get().trim().is_empty()
|
||||
>
|
||||
"Send"
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
}
|
||||
}
|
@@ -1,30 +0,0 @@
|
||||
pub mod app;
|
||||
|
||||
#[cfg(feature = "hydrate")]
|
||||
#[wasm_bindgen::prelude::wasm_bindgen]
|
||||
pub fn hydrate() {
|
||||
use crate::app::*;
|
||||
console_error_panic_hook::set_once();
|
||||
leptos::mount::hydrate_body(App);
|
||||
}
|
||||
|
||||
#[cfg(feature = "ssr")]
|
||||
pub fn create_leptos_router() -> axum::Router {
|
||||
use crate::app::*;
|
||||
use axum::Router;
|
||||
use leptos::prelude::*;
|
||||
use leptos_axum::{generate_route_list, LeptosRoutes};
|
||||
|
||||
let conf = get_configuration(None).unwrap();
|
||||
let leptos_options = conf.leptos_options;
|
||||
// Generate the list of routes in your Leptos App
|
||||
let routes = generate_route_list(App);
|
||||
|
||||
Router::new()
|
||||
.leptos_routes(&leptos_options, routes, {
|
||||
let leptos_options = leptos_options.clone();
|
||||
move || shell(leptos_options.clone())
|
||||
})
|
||||
.fallback(leptos_axum::file_and_error_handler(shell))
|
||||
.with_state(leptos_options)
|
||||
}
|
@@ -1,38 +0,0 @@
|
||||
#[cfg(feature = "ssr")]
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
use axum::Router;
|
||||
use leptos::logging::log;
|
||||
use leptos::prelude::*;
|
||||
use leptos_app::app::*;
|
||||
use leptos_axum::{generate_route_list, LeptosRoutes};
|
||||
|
||||
let conf = get_configuration(None).unwrap();
|
||||
let addr = conf.leptos_options.site_addr;
|
||||
let leptos_options = conf.leptos_options;
|
||||
// Generate the list of routes in your Leptos App
|
||||
let routes = generate_route_list(App);
|
||||
|
||||
let app = Router::new()
|
||||
.leptos_routes(&leptos_options, routes, {
|
||||
let leptos_options = leptos_options.clone();
|
||||
move || shell(leptos_options.clone())
|
||||
})
|
||||
.fallback(leptos_axum::file_and_error_handler(shell))
|
||||
.with_state(leptos_options);
|
||||
|
||||
// run our app with hyper
|
||||
// `axum::Server` is a re-export of `hyper::Server`
|
||||
log!("listening on http://{}", &addr);
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||
axum::serve(listener, app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "ssr"))]
|
||||
pub fn main() {
|
||||
// no client-side main function
|
||||
// unless we want this to work with e.g., Trunk for pure client-side testing
|
||||
// see lib.rs for hydration function instead
|
||||
}
|
@@ -1,4 +0,0 @@
|
||||
body {
|
||||
font-family: sans-serif;
|
||||
text-align: center;
|
||||
}
|
@@ -19,7 +19,7 @@ tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
rust-embed = { version = "8.7.2", features = ["include-exclude"] }
|
||||
rust-embed = { version = "8.7.2", features = ["include-exclude", "axum"] }
|
||||
|
||||
# Dependencies for embeddings functionality
|
||||
embeddings-engine = { path = "../embeddings-engine" }
|
||||
@@ -28,24 +28,24 @@ embeddings-engine = { path = "../embeddings-engine" }
|
||||
inference-engine = { path = "../inference-engine" }
|
||||
|
||||
# Dependencies for leptos web app
|
||||
leptos-app = { path = "../leptos-app", features = ["ssr"] }
|
||||
#leptos-app = { path = "../leptos-app", features = ["ssr"] }
|
||||
chat-ui = { path = "../chat-ui", features = ["ssr", "hydrate"], optional = true }
|
||||
|
||||
mime_guess = "2.0.5"
|
||||
|
||||
|
||||
[package.metadata.compose]
|
||||
name = "predict-otron-9000"
|
||||
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||
port = 8080
|
||||
|
||||
log = "0.4.27"
|
||||
|
||||
# generates kubernetes manifests
|
||||
[package.metadata.kube]
|
||||
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||
replicas = 1
|
||||
port = 8080
|
||||
cmd = ["./bin/predict-otron-9000"]
|
||||
# SERVER_CONFIG Example: {\"serverMode\":\"HighAvailability\",\"services\":{\"inference_url\":\"http://custom-inference:9000\",\"embeddings_url\":\"http://custom-embeddings:9001\"}}
|
||||
# you can generate this via node to avoid toil
|
||||
# const server_config = {serverMode: "HighAvailability", services: {inference_url: "http://custom-inference:9000", embeddings_url: "http://custom-embeddings:9001"} };
|
||||
# console.log(JSON.stringify(server_config).replace(/"/g, '\\"'));
|
||||
env = { SERVER_CONFIG = "<your-json-value-here>" }
|
||||
|
||||
[features]
|
||||
default = ["ui"]
|
||||
ui = ["dep:chat-ui"]
|
||||
|
@@ -1,89 +0,0 @@
|
||||
# ---- Build stage ----
|
||||
FROM rust:1-slim-bullseye AS builder
|
||||
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
# Install build dependencies including CUDA toolkit for GPU support (needed for inference-engine dependency)
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
build-essential \
|
||||
wget \
|
||||
gnupg2 \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CUDA toolkit (required for inference-engine dependency)
|
||||
# This is a minimal CUDA installation for building
|
||||
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb && \
|
||||
dpkg -i cuda-keyring_1.0-1_all.deb && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-minimal-build-11-8 \
|
||||
libcublas-dev-11-8 \
|
||||
libcurand-dev-11-8 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& rm cuda-keyring_1.0-1_all.deb
|
||||
|
||||
# Set CUDA environment variables
|
||||
ENV CUDA_HOME=/usr/local/cuda
|
||||
ENV PATH=${CUDA_HOME}/bin:${PATH}
|
||||
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
||||
|
||||
# Copy the entire workspace to get access to all crates (needed for local dependencies)
|
||||
COPY . ./
|
||||
|
||||
# Cache dependencies first - create dummy source files for all crates
|
||||
RUN rm -rf crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src
|
||||
RUN mkdir -p crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src && \
|
||||
echo "fn main() {}" > crates/predict-otron-9000/src/main.rs && \
|
||||
echo "fn main() {}" > crates/inference-engine/src/main.rs && \
|
||||
echo "fn main() {}" > crates/inference-engine/src/cli_main.rs && \
|
||||
echo "// lib" > crates/inference-engine/src/lib.rs && \
|
||||
echo "fn main() {}" > crates/embeddings-engine/src/main.rs && \
|
||||
echo "// lib" > crates/embeddings-engine/src/lib.rs && \
|
||||
cargo build --release --bin predict-otron-9000 --package predict-otron-9000
|
||||
|
||||
# Remove dummy sources and copy real sources
|
||||
RUN rm -rf crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src
|
||||
COPY . .
|
||||
|
||||
# Build the actual binary
|
||||
RUN cargo build --release --bin predict-otron-9000 --package predict-otron-9000
|
||||
|
||||
# ---- Runtime stage ----
|
||||
FROM debian:bullseye-slim
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libssl1.1 \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CUDA runtime libraries (required for inference-engine dependency)
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
wget \
|
||||
gnupg2 \
|
||||
&& wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb \
|
||||
&& dpkg -i cuda-keyring_1.0-1_all.deb \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
cuda-cudart-11-8 \
|
||||
libcublas11 \
|
||||
libcurand10 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& rm cuda-keyring_1.0-1_all.deb \
|
||||
&& apt-get purge -y wget gnupg2
|
||||
|
||||
# Copy binary from builder
|
||||
COPY --from=builder /usr/src/app/target/release/predict-otron-9000 /usr/local/bin/
|
||||
|
||||
# Run as non-root user for safety
|
||||
RUN useradd -m appuser
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8080
|
||||
CMD ["predict-otron-9000"]
|
@@ -39,29 +39,12 @@ impl Default for ServerMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
||||
pub struct Services {
|
||||
pub inference_url: Option<String>,
|
||||
pub embeddings_url: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for Services {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
inference_url: None,
|
||||
embeddings_url: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn inference_service_url() -> String {
|
||||
"http://inference-service:8080".to_string()
|
||||
}
|
||||
|
||||
fn embeddings_service_url() -> String {
|
||||
"http://embeddings-service:8080".to_string()
|
||||
}
|
||||
|
||||
impl Default for ServerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -118,8 +101,7 @@ impl ServerConfig {
|
||||
"HighAvailability mode configured but services not well defined! \n## Config Used:\n {}",
|
||||
config_string
|
||||
);
|
||||
let err = std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
let err = std::io::Error::other(
|
||||
"HighAvailability mode configured but services not well defined!",
|
||||
);
|
||||
return Err(err);
|
||||
|
@@ -126,7 +126,7 @@ use crate::config::ServerConfig;
|
||||
/// - Pretty JSON is fine in TOML using `''' ... '''`, but remember the newlines are part of the string.
|
||||
/// - If you control the consumer, TOML tables (the alternative above) are more ergonomic than embedding JSON.
|
||||
|
||||
/// HTTP client configured for proxying requests
|
||||
/// HTTP client configured for proxying requests
|
||||
#[derive(Clone)]
|
||||
pub struct ProxyClient {
|
||||
client: Client,
|
||||
|
@@ -4,22 +4,60 @@ mod middleware;
|
||||
mod standalone_mode;
|
||||
|
||||
use crate::standalone_mode::create_standalone_router;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::get;
|
||||
use axum::{Router, http::Uri, response::Html, serve};
|
||||
use axum::{Router, serve};
|
||||
use config::ServerConfig;
|
||||
use ha_mode::create_ha_router;
|
||||
use inference_engine::AppState;
|
||||
use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
|
||||
use rust_embed::Embed;
|
||||
use std::env;
|
||||
use std::path::Component::ParentDir;
|
||||
|
||||
#[cfg(feature = "ui")]
|
||||
use axum::http::StatusCode as AxumStatusCode;
|
||||
#[cfg(feature = "ui")]
|
||||
use axum::http::Uri;
|
||||
#[cfg(feature = "ui")]
|
||||
use axum::http::header;
|
||||
#[cfg(feature = "ui")]
|
||||
use axum::response::IntoResponse;
|
||||
#[cfg(feature = "ui")]
|
||||
use mime_guess::from_path;
|
||||
#[cfg(feature = "ui")]
|
||||
use rust_embed::Embed;
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::classify::ServerErrorsFailureClass::StatusCode;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
#[cfg(feature = "ui")]
|
||||
#[derive(Embed)]
|
||||
#[folder = "../../target/site"]
|
||||
#[include = "*.js"]
|
||||
#[include = "*.wasm"]
|
||||
#[include = "*.css"]
|
||||
#[include = "*.ico"]
|
||||
struct Asset;
|
||||
|
||||
#[cfg(feature = "ui")]
|
||||
async fn static_handler(uri: Uri) -> axum::response::Response {
|
||||
// Strip the leading `/`
|
||||
let path = uri.path().trim_start_matches('/');
|
||||
|
||||
tracing::info!("Static file: {}", &path);
|
||||
|
||||
// If root is requested, serve index.html
|
||||
let path = if path.is_empty() { "index.html" } else { path };
|
||||
|
||||
match Asset::get(path) {
|
||||
Some(content) => {
|
||||
let body = content.data.into_owned();
|
||||
let mime = from_path(path).first_or_octet_stream();
|
||||
|
||||
([(header::CONTENT_TYPE, mime.as_ref())], body).into_response()
|
||||
}
|
||||
None => (AxumStatusCode::NOT_FOUND, "404 Not Found").into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// Initialize tracing
|
||||
@@ -77,20 +115,28 @@ async fn main() {
|
||||
// Create metrics layer
|
||||
let metrics_layer = MetricsLayer::new(metrics_store);
|
||||
|
||||
// Create the leptos router for the web frontend
|
||||
let leptos_router = leptos_app::create_leptos_router();
|
||||
|
||||
// Merge the service router with base routes and add middleware layers
|
||||
let app = Router::new()
|
||||
let mut app = Router::new()
|
||||
.route("/health", get(|| async { "ok" }))
|
||||
.merge(service_router)
|
||||
.merge(leptos_router) // Add leptos web frontend routes
|
||||
.merge(service_router);
|
||||
|
||||
// Add UI routes if the UI feature is enabled
|
||||
#[cfg(feature = "ui")]
|
||||
{
|
||||
let leptos_config = chat_ui::app::AppConfig::default();
|
||||
let leptos_router = chat_ui::app::create_router(leptos_config.config.leptos_options);
|
||||
app = app
|
||||
.route("/pkg/{*path}", get(static_handler))
|
||||
.merge(leptos_router);
|
||||
}
|
||||
|
||||
let app = app
|
||||
.layer(metrics_layer) // Add metrics tracking
|
||||
.layer(cors)
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
// Server configuration
|
||||
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| String::from(default_host));
|
||||
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| default_host.to_string());
|
||||
|
||||
let server_port = env::var("SERVER_PORT")
|
||||
.map(|v| v.parse::<u16>().unwrap_or(default_port))
|
||||
@@ -105,12 +151,14 @@ async fn main() {
|
||||
);
|
||||
tracing::info!("Performance metrics tracking enabled - summary logs every 60 seconds");
|
||||
tracing::info!("Available endpoints:");
|
||||
#[cfg(feature = "ui")]
|
||||
tracing::info!(" GET / - Leptos chat web application");
|
||||
tracing::info!(" GET /health - Health check");
|
||||
tracing::info!(" POST /v1/models - List Models");
|
||||
tracing::info!(" POST /v1/embeddings - Text embeddings API");
|
||||
tracing::info!(" POST /v1/chat/completions - Chat completions API");
|
||||
|
||||
serve(listener, app).await.unwrap();
|
||||
serve(listener, app.into_make_service()).await.unwrap();
|
||||
}
|
||||
|
||||
fn log_config(config: ServerConfig) {
|
||||
|
@@ -2,11 +2,12 @@ use crate::config::ServerConfig;
|
||||
use axum::Router;
|
||||
use inference_engine::AppState;
|
||||
|
||||
pub fn create_standalone_router(server_config: ServerConfig) -> Router {
|
||||
pub fn create_standalone_router(_server_config: ServerConfig) -> Router {
|
||||
// Create unified router by merging embeddings and inference routers (existing behavior)
|
||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||
|
||||
// Create AppState with correct model configuration
|
||||
// Create AppState - no default model, must be configured explicitly
|
||||
// This removes the hardcoded gemma-3-1b-it default behavior
|
||||
let app_state = AppState::default();
|
||||
|
||||
// Get the inference router directly from the inference engine
|
||||
|
@@ -52,7 +52,7 @@ graph TB
|
||||
|
||||
## Workspace Structure
|
||||
|
||||
The project uses a 7-crate Rust workspace with TypeScript tooling, designed for maximum flexibility in deployment configurations.
|
||||
The project uses a 9-crate Rust workspace with TypeScript tooling, designed for maximum flexibility in deployment configurations.
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
@@ -61,34 +61,33 @@ graph TD
|
||||
A[predict-otron-9000<br/>Edition: 2024<br/>Port: 8080]
|
||||
end
|
||||
|
||||
subgraph "AI Services"
|
||||
subgraph "AI Services (crates/)"
|
||||
B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Multi-model orchestrator]
|
||||
J[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
|
||||
K[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
|
||||
C[embeddings-engine<br/>Edition: 2024<br/>Port: 8080<br/>FastEmbed]
|
||||
end
|
||||
|
||||
subgraph "Frontend"
|
||||
D[leptos-app<br/>Edition: 2021<br/>Port: 3000/8788<br/>WASM/SSR]
|
||||
subgraph "Frontend (crates/)"
|
||||
D[chat-ui<br/>Edition: 2021<br/>Port: 8788<br/>WASM UI]
|
||||
end
|
||||
|
||||
subgraph "Tooling"
|
||||
|
||||
subgraph "Integration Tools (integration/)"
|
||||
L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment]
|
||||
E[cli<br/>Edition: 2024<br/>TypeScript/Bun CLI]
|
||||
M[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
|
||||
N[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
|
||||
O[utils<br/>Edition: 2021<br/>Shared utilities]
|
||||
end
|
||||
end
|
||||
|
||||
subgraph "External Tooling"
|
||||
E[scripts/cli.ts<br/>TypeScript/Bun<br/>OpenAI SDK]
|
||||
end
|
||||
|
||||
subgraph "Dependencies"
|
||||
A --> B
|
||||
A --> C
|
||||
A --> D
|
||||
B --> J
|
||||
B --> K
|
||||
J -.-> F[Candle 0.9.1]
|
||||
K -.-> F
|
||||
B --> M
|
||||
B --> N
|
||||
M -.-> F[Candle 0.9.1]
|
||||
N -.-> F
|
||||
C -.-> G[FastEmbed 4.x]
|
||||
D -.-> H[Leptos 0.8.0]
|
||||
E -.-> I[OpenAI SDK 5.16+]
|
||||
@@ -96,12 +95,13 @@ graph TD
|
||||
|
||||
style A fill:#e1f5fe
|
||||
style B fill:#f3e5f5
|
||||
style J fill:#f3e5f5
|
||||
style K fill:#f3e5f5
|
||||
style C fill:#e8f5e8
|
||||
style D fill:#fff3e0
|
||||
style E fill:#fce4ec
|
||||
style L fill:#fff9c4
|
||||
style M fill:#f3e5f5
|
||||
style N fill:#f3e5f5
|
||||
style O fill:#fff9c4
|
||||
```
|
||||
|
||||
## Deployment Configurations
|
||||
@@ -193,7 +193,7 @@ graph TB
|
||||
end
|
||||
|
||||
subgraph "Frontend"
|
||||
D[leptos-app Pod<br/>:8788<br/>ClusterIP Service]
|
||||
D[chat-ui Pod<br/>:8788<br/>ClusterIP Service]
|
||||
end
|
||||
|
||||
subgraph "Ingress"
|
||||
|
@@ -3,7 +3,7 @@
|
||||
A Rust/Typescript Hybrid
|
||||
|
||||
```console
|
||||
./cli [options] [prompt]
|
||||
bun run cli.ts [options] [prompt]
|
||||
|
||||
Simple CLI tool for testing the local OpenAI-compatible API server.
|
||||
|
||||
@@ -14,10 +14,11 @@ Options:
|
||||
--help Show this help message
|
||||
|
||||
Examples:
|
||||
./cli "What is the capital of France?"
|
||||
./cli --model gemma-3-1b-it --prompt "Hello, world!"
|
||||
./cli --prompt "Who was the 16th president of the United States?"
|
||||
./cli --list-models
|
||||
cd integration/cli/package
|
||||
bun run cli.ts "What is the capital of France?"
|
||||
bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||
bun run cli.ts --prompt "Who was the 16th president of the United States?"
|
||||
bun run cli.ts --list-models
|
||||
|
||||
The server must be running at http://localhost:8080
|
||||
```
|
@@ -24,8 +24,7 @@ fn run_build() -> io::Result<()> {
|
||||
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set by Cargo"));
|
||||
let output_path = out_dir.join("client-cli");
|
||||
|
||||
let bun_tgt = BunTarget::from_cargo_env()
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
|
||||
let bun_tgt = BunTarget::from_cargo_env().map_err(|e| io::Error::other(e.to_string()))?;
|
||||
|
||||
// Optional: warn if using a Bun target that’s marked unsupported in your chart
|
||||
if matches!(bun_tgt, BunTarget::WindowsArm64) {
|
||||
@@ -54,13 +53,12 @@ fn run_build() -> io::Result<()> {
|
||||
|
||||
if !install_status.success() {
|
||||
let code = install_status.code().unwrap_or(1);
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("bun install failed with status {code}"),
|
||||
));
|
||||
return Err(io::Error::other(format!(
|
||||
"bun install failed with status {code}"
|
||||
)));
|
||||
}
|
||||
|
||||
let target = env::var("TARGET").unwrap();
|
||||
let _target = env::var("TARGET").unwrap();
|
||||
|
||||
// --- bun build (in ./package), emit to OUT_DIR, keep temps inside OUT_DIR ---
|
||||
let mut build = Command::new("bun")
|
||||
@@ -87,7 +85,7 @@ fn run_build() -> io::Result<()> {
|
||||
} else {
|
||||
let code = status.code().unwrap_or(1);
|
||||
warn(&format!("bun build failed with status: {code}"));
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "bun build failed"));
|
||||
return Err(io::Error::other("bun build failed"));
|
||||
}
|
||||
|
||||
// Ensure the output is executable (after it exists)
|
17
integration/cli/package/bun.lock
Normal file
17
integration/cli/package/bun.lock
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"lockfileVersion": 1,
|
||||
"workspaces": {
|
||||
"": {
|
||||
"name": "cli",
|
||||
"dependencies": {
|
||||
"install": "^0.13.0",
|
||||
"openai": "^5.16.0",
|
||||
},
|
||||
},
|
||||
},
|
||||
"packages": {
|
||||
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
|
||||
|
||||
"openai": ["openai@5.19.1", "", { "peerDependencies": { "ws": "^8.18.0", "zod": "^3.23.8" }, "optionalPeers": ["ws", "zod"], "bin": { "openai": "bin/cli" } }, "sha512-zSqnUF7oR9ksmpusKkpUgkNrj8Sl57U+OyzO8jzc7LUjTMg4DRfR3uCm+EIMA6iw06sRPNp4t7ojp3sCpEUZRQ=="],
|
||||
}
|
||||
}
|
@@ -25,7 +25,7 @@ fn main() -> io::Result<()> {
|
||||
// Run it
|
||||
let status = Command::new(&tmp).arg("--version").status()?;
|
||||
if !status.success() {
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "client-cli failed"));
|
||||
return Err(io::Error::other("client-cli failed"));
|
||||
}
|
||||
|
||||
Ok(())
|
@@ -10,15 +10,15 @@ edition = "2021"
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-examples = { git = "https://github.com/huggingface/candle.git" }
|
||||
hf-hub = "0.4"
|
||||
tokenizers = "0.21"
|
||||
tokenizers = "0.22.0"
|
||||
anyhow = "1.0"
|
||||
clap = { version = "4.0", features = ["derive", "string"] }
|
||||
serde_json = "1.0"
|
||||
tracing = "0.1"
|
||||
tracing-chrome = "0.7"
|
||||
tracing-subscriber = "0.3"
|
||||
utils = {path = "../utils" }
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
@@ -1,27 +1,24 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
use clap::ValueEnum;
|
||||
|
||||
// Removed gemma_cli import as it's not needed for the API
|
||||
use candle_core::{utils, DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
use std::sync::mpsc::{self, Receiver, Sender};
|
||||
use std::thread;
|
||||
use tokenizers::Tokenizer;
|
||||
use utils::hub_load_safetensors;
|
||||
use utils::token_output_stream::TokenOutputStream;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
pub enum WhichModel {
|
||||
#[value(name = "gemma-2b")]
|
||||
Base2B,
|
||||
@@ -57,6 +54,56 @@ pub enum WhichModel {
|
||||
InstructV3_1B,
|
||||
}
|
||||
|
||||
impl FromStr for WhichModel {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"gemma-2b" => Ok(Self::Base2B),
|
||||
"gemma-7b" => Ok(Self::Base7B),
|
||||
"gemma-2b-it" => Ok(Self::Instruct2B),
|
||||
"gemma-7b-it" => Ok(Self::Instruct7B),
|
||||
"gemma-1.1-2b-it" => Ok(Self::InstructV1_1_2B),
|
||||
"gemma-1.1-7b-it" => Ok(Self::InstructV1_1_7B),
|
||||
"codegemma-2b" => Ok(Self::CodeBase2B),
|
||||
"codegemma-7b" => Ok(Self::CodeBase7B),
|
||||
"codegemma-2b-it" => Ok(Self::CodeInstruct2B),
|
||||
"codegemma-7b-it" => Ok(Self::CodeInstruct7B),
|
||||
"gemma-2-2b" => Ok(Self::BaseV2_2B),
|
||||
"gemma-2-2b-it" => Ok(Self::InstructV2_2B),
|
||||
"gemma-2-9b" => Ok(Self::BaseV2_9B),
|
||||
"gemma-2-9b-it" => Ok(Self::InstructV2_9B),
|
||||
"gemma-3-1b" => Ok(Self::BaseV3_1B),
|
||||
"gemma-3-1b-it" => Ok(Self::InstructV3_1B),
|
||||
_ => Err(format!("Unknown model: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for WhichModel {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let name = match self {
|
||||
Self::Base2B => "gemma-2b",
|
||||
Self::Base7B => "gemma-7b",
|
||||
Self::Instruct2B => "gemma-2b-it",
|
||||
Self::Instruct7B => "gemma-7b-it",
|
||||
Self::InstructV1_1_2B => "gemma-1.1-2b-it",
|
||||
Self::InstructV1_1_7B => "gemma-1.1-7b-it",
|
||||
Self::CodeBase2B => "codegemma-2b",
|
||||
Self::CodeBase7B => "codegemma-7b",
|
||||
Self::CodeInstruct2B => "codegemma-2b-it",
|
||||
Self::CodeInstruct7B => "codegemma-7b-it",
|
||||
Self::BaseV2_2B => "gemma-2-2b",
|
||||
Self::InstructV2_2B => "gemma-2-2b-it",
|
||||
Self::BaseV2_9B => "gemma-2-9b",
|
||||
Self::InstructV2_9B => "gemma-2-9b-it",
|
||||
Self::BaseV3_1B => "gemma-3-1b",
|
||||
Self::InstructV3_1B => "gemma-3-1b-it",
|
||||
};
|
||||
write!(f, "{}", name)
|
||||
}
|
||||
}
|
||||
|
||||
enum Model {
|
||||
V1(Model1),
|
||||
V2(Model2),
|
||||
@@ -85,9 +132,9 @@ pub struct TextGeneration {
|
||||
fn device(cpu: bool) -> Result<Device> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
} else if utils::cuda_is_available() {
|
||||
} else if candle_core::utils::cuda_is_available() {
|
||||
Ok(Device::new_cuda(0)?)
|
||||
} else if utils::metal_is_available() {
|
||||
} else if candle_core::utils::metal_is_available() {
|
||||
Ok(Device::new_metal(0)?)
|
||||
} else {
|
||||
Ok(Device::Cpu)
|
||||
@@ -98,7 +145,7 @@ impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
@@ -144,8 +191,6 @@ impl TextGeneration {
|
||||
// Make sure stdout isn't holding anything (if caller also prints).
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
@@ -182,7 +227,6 @@ impl TextGeneration {
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
@@ -209,7 +253,7 @@ impl TextGeneration {
|
||||
pub struct GemmaInferenceConfig {
|
||||
pub tracing: bool,
|
||||
pub prompt: String,
|
||||
pub model: WhichModel,
|
||||
pub model: Option<WhichModel>,
|
||||
pub cpu: bool,
|
||||
pub dtype: Option<String>,
|
||||
pub model_id: Option<String>,
|
||||
@@ -228,7 +272,7 @@ impl Default for GemmaInferenceConfig {
|
||||
Self {
|
||||
tracing: false,
|
||||
prompt: "Hello".to_string(),
|
||||
model: WhichModel::InstructV2_2B,
|
||||
model: Some(WhichModel::InstructV2_2B),
|
||||
cpu: false,
|
||||
dtype: None,
|
||||
model_id: None,
|
||||
@@ -262,10 +306,10 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
utils::with_avx(),
|
||||
utils::with_neon(),
|
||||
utils::with_simd128(),
|
||||
utils::with_f16c()
|
||||
candle_core::utils::with_avx(),
|
||||
candle_core::utils::with_neon(),
|
||||
candle_core::utils::with_simd128(),
|
||||
candle_core::utils::with_f16c()
|
||||
);
|
||||
|
||||
let device = device(cfg.cpu)?;
|
||||
@@ -285,28 +329,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
}
|
||||
};
|
||||
println!("Using dtype: {:?}", dtype);
|
||||
println!("Raw model string: {:?}", cfg.model_id);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
|
||||
let model_id = cfg.model_id.unwrap_or_else(|| {
|
||||
match cfg.model {
|
||||
WhichModel::Base2B => "google/gemma-2b",
|
||||
WhichModel::Base7B => "google/gemma-7b",
|
||||
WhichModel::Instruct2B => "google/gemma-2b-it",
|
||||
WhichModel::Instruct7B => "google/gemma-7b-it",
|
||||
WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it",
|
||||
WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it",
|
||||
WhichModel::CodeBase2B => "google/codegemma-2b",
|
||||
WhichModel::CodeBase7B => "google/codegemma-7b",
|
||||
WhichModel::CodeInstruct2B => "google/codegemma-2b-it",
|
||||
WhichModel::CodeInstruct7B => "google/codegemma-7b-it",
|
||||
WhichModel::BaseV2_2B => "google/gemma-2-2b",
|
||||
WhichModel::InstructV2_2B => "google/gemma-2-2b-it",
|
||||
WhichModel::BaseV2_9B => "google/gemma-2-9b",
|
||||
WhichModel::InstructV2_9B => "google/gemma-2-9b-it",
|
||||
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
|
||||
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
|
||||
Some(WhichModel::Base2B) => "google/gemma-2b",
|
||||
Some(WhichModel::Base7B) => "google/gemma-7b",
|
||||
Some(WhichModel::Instruct2B) => "google/gemma-2b-it",
|
||||
Some(WhichModel::Instruct7B) => "google/gemma-7b-it",
|
||||
Some(WhichModel::InstructV1_1_2B) => "google/gemma-1.1-2b-it",
|
||||
Some(WhichModel::InstructV1_1_7B) => "google/gemma-1.1-7b-it",
|
||||
Some(WhichModel::CodeBase2B) => "google/codegemma-2b",
|
||||
Some(WhichModel::CodeBase7B) => "google/codegemma-7b",
|
||||
Some(WhichModel::CodeInstruct2B) => "google/codegemma-2b-it",
|
||||
Some(WhichModel::CodeInstruct7B) => "google/codegemma-7b-it",
|
||||
Some(WhichModel::BaseV2_2B) => "google/gemma-2-2b",
|
||||
Some(WhichModel::InstructV2_2B) => "google/gemma-2-2b-it",
|
||||
Some(WhichModel::BaseV2_9B) => "google/gemma-2-9b",
|
||||
Some(WhichModel::InstructV2_9B) => "google/gemma-2-9b-it",
|
||||
Some(WhichModel::BaseV3_1B) => "google/gemma-3-1b-pt",
|
||||
Some(WhichModel::InstructV3_1B) => "google/gemma-3-1b-it",
|
||||
None => "google/gemma-2-2b-it", // default fallback
|
||||
}
|
||||
.to_string()
|
||||
});
|
||||
@@ -317,8 +363,10 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let filenames = match cfg.model {
|
||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
||||
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("Retrieved files in {:?}", start.elapsed());
|
||||
|
||||
@@ -328,29 +376,31 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
|
||||
let model: Model = match cfg.model {
|
||||
WhichModel::Base2B
|
||||
| WhichModel::Base7B
|
||||
| WhichModel::Instruct2B
|
||||
| WhichModel::Instruct7B
|
||||
| WhichModel::InstructV1_1_2B
|
||||
| WhichModel::InstructV1_1_7B
|
||||
| WhichModel::CodeBase2B
|
||||
| WhichModel::CodeBase7B
|
||||
| WhichModel::CodeInstruct2B
|
||||
| WhichModel::CodeInstruct7B => {
|
||||
Some(WhichModel::Base2B)
|
||||
| Some(WhichModel::Base7B)
|
||||
| Some(WhichModel::Instruct2B)
|
||||
| Some(WhichModel::Instruct7B)
|
||||
| Some(WhichModel::InstructV1_1_2B)
|
||||
| Some(WhichModel::InstructV1_1_7B)
|
||||
| Some(WhichModel::CodeBase2B)
|
||||
| Some(WhichModel::CodeBase7B)
|
||||
| Some(WhichModel::CodeInstruct2B)
|
||||
| Some(WhichModel::CodeInstruct7B) => {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V1(model)
|
||||
}
|
||||
WhichModel::BaseV2_2B
|
||||
| WhichModel::InstructV2_2B
|
||||
| WhichModel::BaseV2_9B
|
||||
| WhichModel::InstructV2_9B => {
|
||||
Some(WhichModel::BaseV2_2B)
|
||||
| Some(WhichModel::InstructV2_2B)
|
||||
| Some(WhichModel::BaseV2_9B)
|
||||
| Some(WhichModel::InstructV2_9B)
|
||||
| None => {
|
||||
// default to V2 model
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V2(model)
|
||||
}
|
||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => {
|
||||
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
|
||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V3(model)
|
||||
@@ -370,7 +420,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
);
|
||||
|
||||
let prompt = match cfg.model {
|
||||
WhichModel::InstructV3_1B => {
|
||||
Some(WhichModel::InstructV3_1B) => {
|
||||
format!(
|
||||
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
|
||||
cfg.prompt
|
@@ -67,7 +67,7 @@ pub fn run_cli() -> anyhow::Result<()> {
|
||||
let cfg = GemmaInferenceConfig {
|
||||
tracing: args.tracing,
|
||||
prompt: args.prompt,
|
||||
model: args.model,
|
||||
model: Some(args.model),
|
||||
cpu: args.cpu,
|
||||
dtype: args.dtype,
|
||||
model_id: args.model_id,
|
@@ -6,10 +6,8 @@ mod gemma_api;
|
||||
mod gemma_cli;
|
||||
|
||||
use anyhow::Error;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use crate::gemma_cli::run_cli;
|
||||
use std::io::Write;
|
||||
|
||||
/// just a placeholder, not used for anything
|
||||
fn main() -> std::result::Result<(), Error> {
|
@@ -64,14 +64,9 @@ version = "0.1.0"
|
||||
|
||||
# Required: Kubernetes metadata
|
||||
[package.metadata.kube]
|
||||
image = "ghcr.io/myorg/my-service:latest"
|
||||
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||
replicas = 1
|
||||
port = 8080
|
||||
|
||||
# Optional: Docker Compose metadata (currently not used but parsed)
|
||||
[package.metadata.compose]
|
||||
image = "ghcr.io/myorg/my-service:latest"
|
||||
port = 8080
|
||||
```
|
||||
|
||||
### Required Fields
|
||||
@@ -137,7 +132,7 @@ Parsing workspace at: ..
|
||||
Output directory: ../generated-helm-chart
|
||||
Chart name: predict-otron-9000
|
||||
Found 4 services:
|
||||
- leptos-app: ghcr.io/geoffsee/leptos-app:latest (port 8788)
|
||||
- chat-ui: ghcr.io/geoffsee/chat-ui:latest (port 8788)
|
||||
- inference-engine: ghcr.io/geoffsee/inference-service:latest (port 8080)
|
||||
- embeddings-engine: ghcr.io/geoffsee/embeddings-service:latest (port 8080)
|
||||
- predict-otron-9000: ghcr.io/geoffsee/predict-otron-9000:latest (port 8080)
|
@@ -1,9 +1,8 @@
|
||||
use anyhow::{Context, Result};
|
||||
use clap::{Arg, Command};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use serde::Deserialize;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::path::Path;
|
||||
use walkdir::WalkDir;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -20,7 +19,6 @@ struct Package {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Metadata {
|
||||
kube: Option<KubeMetadata>,
|
||||
compose: Option<ComposeMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -30,12 +28,6 @@ struct KubeMetadata {
|
||||
port: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ComposeMetadata {
|
||||
image: Option<String>,
|
||||
port: Option<u16>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ServiceInfo {
|
||||
name: String,
|
||||
@@ -105,7 +97,9 @@ fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
|
||||
.into_iter()
|
||||
.filter_map(|e| e.ok())
|
||||
{
|
||||
if entry.file_name() == "Cargo.toml" && entry.path() != workspace_root.join("Cargo.toml") {
|
||||
if entry.file_name() == "Cargo.toml"
|
||||
&& entry.path() != workspace_root.join("../../../Cargo.toml")
|
||||
{
|
||||
if let Ok(service_info) = parse_cargo_toml(entry.path()) {
|
||||
services.push(service_info);
|
||||
}
|
||||
@@ -375,7 +369,7 @@ spec:
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_ingress_template(templates_dir: &Path, services: &[ServiceInfo]) -> Result<()> {
|
||||
fn generate_ingress_template(templates_dir: &Path, _services: &[ServiceInfo]) -> Result<()> {
|
||||
let ingress_template = r#"{{- if .Values.ingress.enabled -}}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
@@ -5,8 +5,8 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git"}
|
||||
hf-hub = "0.3"
|
||||
tokenizers = "0.20"
|
||||
anyhow = "1.0"
|
@@ -1,6 +1,5 @@
|
||||
pub mod llama_api;
|
||||
|
||||
use clap::ValueEnum;
|
||||
pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
||||
|
||||
// Re-export constants and types that might be needed
|
@@ -57,6 +57,27 @@ pub struct LlamaInferenceConfig {
|
||||
pub repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl LlamaInferenceConfig {
|
||||
pub fn new(model: WhichModel) -> Self {
|
||||
Self {
|
||||
prompt: String::new(),
|
||||
model,
|
||||
cpu: false,
|
||||
temperature: 1.0,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
seed: 42,
|
||||
max_tokens: 512,
|
||||
no_kv_cache: false,
|
||||
dtype: None,
|
||||
model_id: None,
|
||||
revision: None,
|
||||
use_flash_attn: true,
|
||||
repeat_penalty: 1.1,
|
||||
repeat_last_n: 64,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Default for LlamaInferenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -81,8 +102,8 @@ impl Default for LlamaInferenceConfig {
|
||||
max_tokens: 512,
|
||||
|
||||
// Performance flags
|
||||
no_kv_cache: false, // keep cache ON for speed
|
||||
use_flash_attn: true, // great speed boost if supported
|
||||
no_kv_cache: false, // keep cache ON for speed
|
||||
use_flash_attn: false, // great speed boost if supported
|
||||
|
||||
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
|
||||
dtype: Some("bf16".to_string()),
|
@@ -6,9 +6,6 @@ mod llama_api;
|
||||
mod llama_cli;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use std::io::Write;
|
||||
|
||||
use crate::llama_cli::run_cli;
|
||||
|
96
integration/utils/Cargo.toml
Normal file
96
integration/utils/Cargo.toml
Normal file
@@ -0,0 +1,96 @@
|
||||
[package]
|
||||
name = "utils"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = {version = "0.3.2", optional = true }
|
||||
candle-flash-attn = {version = "0.9.1", optional = true }
|
||||
candle-onnx = {version = "0.9.1", optional = true }
|
||||
csv = "1.3.0"
|
||||
anyhow = "1.0.99"
|
||||
cudarc = {version = "0.17.3", optional = true }
|
||||
half = {version = "2.6.0", optional = true }
|
||||
hf-hub = {version = "0.4.3", features = ["tokio"] }
|
||||
image = {version = "0.25.6" }
|
||||
intel-mkl-src = {version = "0.8.1", optional = true }
|
||||
num-traits = {version = "0.2.19" }
|
||||
palette = { version = "0.7.6", optional = true }
|
||||
enterpolation = { version = "0.2.1", optional = true }
|
||||
pyo3 = { version = "0.22.0", features = [
|
||||
"auto-initialize",
|
||||
"abi3-py311",
|
||||
], optional = true }
|
||||
rayon = {version = "1.11.0" }
|
||||
rubato = { version = "0.15.0", optional = true }
|
||||
safetensors = {version = "0.6.2" }
|
||||
serde = {version = "1.0.219" }
|
||||
serde_json = {version = "1.0.143" }
|
||||
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
||||
tokenizers = {version = "0.22.0", features = ["onig"] }
|
||||
cpal = { version = "0.15.2", optional = true }
|
||||
pdf2image = { version = "0.1.2", optional = true }
|
||||
tekken-rs = { version = "0.1.1", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = {version = "1.0.99" }
|
||||
byteorder = {version = "1.5.0" }
|
||||
clap = {version = "4.5.46" }
|
||||
imageproc = {version = "0.25.0" }
|
||||
memmap2 = {version = "0.9.8" }
|
||||
rand = {version = "0.9.2" }
|
||||
ab_glyph = {version = "0.2.31" }
|
||||
tracing = {version = "0.1.41" }
|
||||
tracing-chrome = {version = "0.7.2" }
|
||||
tracing-subscriber = {version = "0.3.20" }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.43.0"
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = {version = "1.0.99" }
|
||||
bindgen_cuda = { version = "0.1.1", optional = true }
|
||||
#
|
||||
[features]
|
||||
default = []
|
||||
accelerate = [
|
||||
"dep:accelerate-src",
|
||||
"candle-core/accelerate",
|
||||
"candle-nn/accelerate",
|
||||
"candle-transformers/accelerate",
|
||||
]
|
||||
cuda = [
|
||||
"candle-core/cuda",
|
||||
"candle-nn/cuda",
|
||||
"candle-transformers/cuda",
|
||||
"dep:bindgen_cuda",
|
||||
]
|
||||
cudnn = ["candle-core/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = [
|
||||
"dep:intel-mkl-src",
|
||||
"candle-core/mkl",
|
||||
"candle-nn/mkl",
|
||||
"candle-transformers/mkl",
|
||||
]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
onnx = ["candle-onnx"]
|
||||
metal = ["candle-core/metal", "candle-nn/metal"]
|
||||
microphone = ["cpal", "rubato"]
|
||||
encodec = ["cpal", "symphonia", "rubato"]
|
||||
mimi = ["cpal", "symphonia", "rubato"]
|
||||
snac = ["cpal", "symphonia", "rubato"]
|
||||
depth_anything_v2 = ["palette", "enterpolation"]
|
||||
tekken = ["tekken-rs"]
|
||||
|
||||
# Platform-specific candle dependencies
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
candle-nn = {version = "0.9.1", default-features = false }
|
||||
candle-transformers = {version = "0.9.1", default-features = false }
|
||||
candle-core = {version = "0.9.1", default-features = false }
|
||||
|
||||
[target.'cfg(not(target_os = "linux"))'.dependencies]
|
||||
candle-nn = {version = "0.9.1" }
|
||||
candle-transformers = {version = "0.9.1" }
|
||||
candle-core = {version = "0.9.1" }
|
138
integration/utils/src/audio.rs
Normal file
138
integration/utils/src/audio.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
use candle_core::{Result, Tensor};
|
||||
|
||||
// https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/data/audio_utils.py#L57
|
||||
pub fn normalize_loudness(
|
||||
wav: &Tensor,
|
||||
sample_rate: u32,
|
||||
loudness_compressor: bool,
|
||||
) -> Result<Tensor> {
|
||||
let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::<f32>()?;
|
||||
if energy < 2e-3 {
|
||||
return Ok(wav.clone());
|
||||
}
|
||||
let wav_array = wav.to_vec1::<f32>()?;
|
||||
let mut meter = crate::bs1770::ChannelLoudnessMeter::new(sample_rate);
|
||||
meter.push(wav_array.into_iter());
|
||||
let power = meter.as_100ms_windows();
|
||||
let loudness = match crate::bs1770::gated_mean(power) {
|
||||
None => return Ok(wav.clone()),
|
||||
Some(gp) => gp.loudness_lkfs() as f64,
|
||||
};
|
||||
let delta_loudness = -14. - loudness;
|
||||
let gain = 10f64.powf(delta_loudness / 20.);
|
||||
let wav = (wav * gain)?;
|
||||
if loudness_compressor {
|
||||
wav.tanh()
|
||||
} else {
|
||||
Ok(wav)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "symphonia")]
|
||||
pub fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
|
||||
use symphonia::core::conv::FromSample;
|
||||
|
||||
fn conv<T>(
|
||||
samples: &mut Vec<f32>,
|
||||
data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>,
|
||||
) where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
// Open the media source.
|
||||
let src = std::fs::File::open(path).map_err(candle::Error::wrap)?;
|
||||
|
||||
// Create the media source stream.
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
|
||||
// Create a probe hint using the file's extension. [Optional]
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
|
||||
// Use the default options for metadata and format readers.
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
|
||||
// Probe the media source.
|
||||
let probed = symphonia::default::get_probe()
|
||||
.format(&hint, mss, &fmt_opts, &meta_opts)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
// Get the instantiated format reader.
|
||||
let mut format = probed.format;
|
||||
|
||||
// Find the first audio track with a known (decodeable) codec.
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
|
||||
.ok_or_else(|| candle::Error::Msg("no supported audio tracks".to_string()))?;
|
||||
|
||||
// Use the default options for the decoder.
|
||||
let dec_opts: DecoderOptions = Default::default();
|
||||
|
||||
// Create a decoder for the track.
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &dec_opts)
|
||||
.map_err(|_| candle::Error::Msg("unsupported codec".to_string()))?;
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
// The decode loop.
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
// Consume any new metadata that has been read since the last packet.
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
|
||||
// If the packet does not belong to the selected track, skip over it.
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet).map_err(candle::Error::wrap)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
||||
|
||||
#[cfg(feature = "rubato")]
|
||||
pub fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let mut pcm_out =
|
||||
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
|
||||
|
||||
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let mut output_buffer = resampler.output_buffer_allocate(true);
|
||||
let mut pos_in = 0;
|
||||
while pos_in + resampler.input_frames_next() < pcm_in.len() {
|
||||
let (in_len, out_len) = resampler
|
||||
.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
pos_in += in_len;
|
||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||
}
|
||||
|
||||
if pos_in < pcm_in.len() {
|
||||
let (_in_len, out_len) = resampler
|
||||
.process_partial_into_buffer(Some(&[&pcm_in[pos_in..]]), &mut output_buffer, None)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||
}
|
||||
|
||||
Ok(pcm_out)
|
||||
}
|
506
integration/utils/src/bs1770.rs
Normal file
506
integration/utils/src/bs1770.rs
Normal file
@@ -0,0 +1,506 @@
|
||||
// Copied from https://github.com/ruuda/bs1770/blob/master/src/lib.rs
|
||||
// BS1770 -- Loudness analysis library conforming to ITU-R BS.1770
|
||||
// Copyright 2020 Ruud van Asseldonk
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// A copy of the License has been included in the root of the repository.
|
||||
|
||||
//! Loudness analysis conforming to [ITU-R BS.1770-4][bs17704].
|
||||
//!
|
||||
//! This library offers the building blocks to perform BS.1770 loudness
|
||||
//! measurements, but you need to put the pieces together yourself.
|
||||
//!
|
||||
//! [bs17704]: https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
|
||||
//!
|
||||
//! # Stereo integrated loudness example
|
||||
//!
|
||||
//! ```ignore
|
||||
//! # fn load_stereo_audio() -> [Vec<i16>; 2] {
|
||||
//! # [vec![0; 48_000], vec![0; 48_000]]
|
||||
//! # }
|
||||
//! #
|
||||
//! let sample_rate_hz = 44_100;
|
||||
//! let bits_per_sample = 16;
|
||||
//! let channel_samples: [Vec<i16>; 2] = load_stereo_audio();
|
||||
//!
|
||||
//! // When converting integer samples to float, note that the maximum amplitude
|
||||
//! // is `1 << (bits_per_sample - 1)`, one bit is the sign bit.
|
||||
//! let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
|
||||
//!
|
||||
//! let channel_power: Vec<_> = channel_samples.iter().map(|samples| {
|
||||
//! let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate_hz);
|
||||
//! meter.push(samples.iter().map(|&s| s as f32 * normalizer));
|
||||
//! meter.into_100ms_windows()
|
||||
//! }).collect();
|
||||
//!
|
||||
//! let stereo_power = bs1770::reduce_stereo(
|
||||
//! channel_power[0].as_ref(),
|
||||
//! channel_power[1].as_ref(),
|
||||
//! );
|
||||
//!
|
||||
//! let gated_power = bs1770::gated_mean(
|
||||
//! stereo_power.as_ref()
|
||||
//! ).unwrap_or(bs1770::Power(0.0));
|
||||
//! println!("Integrated loudness: {:.1} LUFS", gated_power.loudness_lkfs());
|
||||
//! ```
|
||||
|
||||
use std::f32;
|
||||
|
||||
/// Coefficients for a 2nd-degree infinite impulse response filter.
|
||||
///
|
||||
/// Coefficient a0 is implicitly 1.0.
|
||||
#[derive(Clone)]
|
||||
struct Filter {
|
||||
a1: f32,
|
||||
a2: f32,
|
||||
b0: f32,
|
||||
b1: f32,
|
||||
b2: f32,
|
||||
|
||||
// The past two input and output samples.
|
||||
x1: f32,
|
||||
x2: f32,
|
||||
y1: f32,
|
||||
y2: f32,
|
||||
}
|
||||
|
||||
impl Filter {
|
||||
/// Stage 1 of th BS.1770-4 pre-filter.
|
||||
pub fn high_shelf(sample_rate_hz: f32) -> Filter {
|
||||
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
|
||||
let gain_db = 3.999_843_8;
|
||||
let q = 0.707_175_25;
|
||||
let center_hz = 1_681.974_5;
|
||||
|
||||
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L134-L143.
|
||||
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
|
||||
let vh = 10.0_f32.powf(gain_db / 20.0);
|
||||
let vb = vh.powf(0.499_666_78);
|
||||
let a0 = 1.0 + k / q + k * k;
|
||||
Filter {
|
||||
b0: (vh + vb * k / q + k * k) / a0,
|
||||
b1: 2.0 * (k * k - vh) / a0,
|
||||
b2: (vh - vb * k / q + k * k) / a0,
|
||||
a1: 2.0 * (k * k - 1.0) / a0,
|
||||
a2: (1.0 - k / q + k * k) / a0,
|
||||
|
||||
x1: 0.0,
|
||||
x2: 0.0,
|
||||
y1: 0.0,
|
||||
y2: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Stage 2 of th BS.1770-4 pre-filter.
|
||||
pub fn high_pass(sample_rate_hz: f32) -> Filter {
|
||||
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
|
||||
let q = 0.500_327_05;
|
||||
let center_hz = 38.135_47;
|
||||
|
||||
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L145-L151
|
||||
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
|
||||
Filter {
|
||||
a1: 2.0 * (k * k - 1.0) / (1.0 + k / q + k * k),
|
||||
a2: (1.0 - k / q + k * k) / (1.0 + k / q + k * k),
|
||||
b0: 1.0,
|
||||
b1: -2.0,
|
||||
b2: 1.0,
|
||||
|
||||
x1: 0.0,
|
||||
x2: 0.0,
|
||||
y1: 0.0,
|
||||
y2: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed the next input sample, get the next output sample.
|
||||
#[inline(always)]
|
||||
pub fn apply(&mut self, x0: f32) -> f32 {
|
||||
let y0 = 0.0 + self.b0 * x0 + self.b1 * self.x1 + self.b2 * self.x2
|
||||
- self.a1 * self.y1
|
||||
- self.a2 * self.y2;
|
||||
|
||||
self.x2 = self.x1;
|
||||
self.x1 = x0;
|
||||
self.y2 = self.y1;
|
||||
self.y1 = y0;
|
||||
|
||||
y0
|
||||
}
|
||||
}
|
||||
|
||||
/// Compensated sum, for summing many values of different orders of magnitude
|
||||
/// accurately.
|
||||
#[derive(Copy, Clone, PartialEq)]
|
||||
struct Sum {
|
||||
sum: f32,
|
||||
residue: f32,
|
||||
}
|
||||
|
||||
impl Sum {
|
||||
#[inline(always)]
|
||||
fn zero() -> Sum {
|
||||
Sum {
|
||||
sum: 0.0,
|
||||
residue: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn add(&mut self, x: f32) {
|
||||
let sum = self.sum + (self.residue + x);
|
||||
self.residue = (self.residue + x) - (sum - self.sum);
|
||||
self.sum = sum;
|
||||
}
|
||||
}
|
||||
|
||||
/// The mean of the squares of the K-weighted samples in a window of time.
|
||||
///
|
||||
/// K-weighted power is equivalent to K-weighted loudness, the only difference
|
||||
/// is one of scale: power is quadratic in sample amplitudes, whereas loudness
|
||||
/// units are logarithmic. `loudness_lkfs` and `from_lkfs` convert between power,
|
||||
/// and K-weighted Loudness Units relative to nominal Full Scale (LKFS).
|
||||
///
|
||||
/// The term “LKFS” (Loudness Units, K-Weighted, relative to nominal Full Scale)
|
||||
/// is used in BS.1770-4 to emphasize K-weighting, but the term is otherwise
|
||||
/// interchangeable with the more widespread term “LUFS” (Loudness Units,
|
||||
/// relative to Full Scale). Loudness units are related to decibels in the
|
||||
/// following sense: boosting a signal that has a loudness of
|
||||
/// -<var>L<sub>K</sub></var> LUFS by <var>L<sub>K</sub></var> dB (by
|
||||
/// multiplying the amplitude by 10<sup><var>L<sub>K</sub></var>/20</sup>) will
|
||||
/// bring the loudness to 0 LUFS.
|
||||
///
|
||||
/// K-weighting refers to a high-shelf and high-pass filter that model the
|
||||
/// effect that humans perceive a certain amount of power in low frequencies to
|
||||
/// be less loud than the same amount of power in higher frequencies. In this
|
||||
/// library the `Power` type is used exclusively to refer to power after applying K-weighting.
|
||||
///
|
||||
/// The nominal “full scale” is the range [-1.0, 1.0]. Because the power is the
|
||||
/// mean square of the samples, if no input samples exceeded the full scale, the
|
||||
/// power will be in the range [0.0, 1.0]. However, the power delivered by
|
||||
/// multiple channels, which is a weighted sum over individual channel powers,
|
||||
/// can exceed this range, because the weighted sum is not normalized.
|
||||
#[derive(Copy, Clone, PartialEq, PartialOrd)]
|
||||
pub struct Power(pub f32);
|
||||
|
||||
impl Power {
|
||||
/// Convert Loudness Units relative to Full Scale into a squared sample amplitude.
|
||||
///
|
||||
/// This is the inverse of `loudness_lkfs`.
|
||||
pub fn from_lkfs(lkfs: f32) -> Power {
|
||||
// The inverse of the formula below.
|
||||
Power(10.0_f32.powf((lkfs + 0.691) * 0.1))
|
||||
}
|
||||
|
||||
/// Return the loudness of this window in Loudness Units, K-weighted, relative to Full Scale.
|
||||
///
|
||||
/// This is the inverse of `from_lkfs`.
|
||||
pub fn loudness_lkfs(&self) -> f32 {
|
||||
// Equation 2 (p.5) of BS.1770-4.
|
||||
-0.691 + 10.0 * self.0.log10()
|
||||
}
|
||||
}
|
||||
|
||||
/// A `T` value for non-overlapping windows of audio, 100ms in length.
|
||||
///
|
||||
/// The `ChannelLoudnessMeter` applies K-weighting and then produces the power
|
||||
/// for non-overlapping windows of 100ms duration.
|
||||
///
|
||||
/// These non-overlapping 100ms windows can later be combined into overlapping
|
||||
/// windows of 400ms, spaced 100ms apart, to compute instantaneous loudness or
|
||||
/// to perform a gated measurement, or they can be combined into even larger
|
||||
/// windows for a momentary loudness measurement.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Windows100ms<T> {
|
||||
pub inner: T,
|
||||
}
|
||||
|
||||
impl<T> Windows100ms<T> {
|
||||
/// Wrap a new empty vector.
|
||||
pub fn new() -> Windows100ms<Vec<T>> {
|
||||
Windows100ms { inner: Vec::new() }
|
||||
}
|
||||
|
||||
/// Apply `as_ref` to the inner value.
|
||||
pub fn as_ref(&self) -> Windows100ms<&[Power]>
|
||||
where
|
||||
T: AsRef<[Power]>,
|
||||
{
|
||||
Windows100ms {
|
||||
inner: self.inner.as_ref(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply `as_mut` to the inner value.
|
||||
pub fn as_mut(&mut self) -> Windows100ms<&mut [Power]>
|
||||
where
|
||||
T: AsMut<[Power]>,
|
||||
{
|
||||
Windows100ms {
|
||||
inner: self.inner.as_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
/// Apply `len` to the inner value.
|
||||
pub fn len(&self) -> usize
|
||||
where
|
||||
T: AsRef<[Power]>,
|
||||
{
|
||||
self.inner.as_ref().len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Measures K-weighted power of non-overlapping 100ms windows of a single channel of audio.
|
||||
///
|
||||
/// # Output
|
||||
///
|
||||
/// The output of the meter is an intermediate result in the form of power for
|
||||
/// 100ms non-overlapping windows. The windows need to be processed further to
|
||||
/// get one of the instantaneous, momentary, and integrated loudness
|
||||
/// measurements defined in BS.1770.
|
||||
///
|
||||
/// The windows can also be inspected directly; the data is meaningful
|
||||
/// on its own (the K-weighted power delivered in that window of time), but it
|
||||
/// is not something that BS.1770 defines a term for.
|
||||
///
|
||||
/// # Multichannel audio
|
||||
///
|
||||
/// To perform a loudness measurement of multichannel audio, construct a
|
||||
/// `ChannelLoudnessMeter` per channel, and later combine the measured power
|
||||
/// with e.g. `reduce_stereo`.
|
||||
///
|
||||
/// # Instantaneous loudness
|
||||
///
|
||||
/// The instantaneous loudness is the power over a 400ms window, so you can
|
||||
/// average four 100ms windows. No special functionality is implemented to help
|
||||
/// with that at this time. ([Pull requests would be accepted.][contribute])
|
||||
///
|
||||
/// # Momentary loudness
|
||||
///
|
||||
/// The momentary loudness is the power over a 3-second window, so you can
|
||||
/// average thirty 100ms windows. No special functionality is implemented to
|
||||
/// help with that at this time. ([Pull requests would be accepted.][contribute])
|
||||
///
|
||||
/// # Integrated loudness
|
||||
///
|
||||
/// Use `gated_mean` to perform an integrated loudness measurement:
|
||||
///
|
||||
/// ```ignore
|
||||
/// # use std::iter;
|
||||
/// # use bs1770::{ChannelLoudnessMeter, gated_mean};
|
||||
/// # let sample_rate_hz = 44_100;
|
||||
/// # let samples_per_100ms = sample_rate_hz / 10;
|
||||
/// # let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
|
||||
/// # meter.push((0..44_100).map(|i| (i as f32 * 0.01).sin()));
|
||||
/// let integrated_loudness_lkfs = gated_mean(meter.as_100ms_windows())
|
||||
/// .unwrap_or(bs1770::Power(0.0))
|
||||
/// .loudness_lkfs();
|
||||
/// ```
|
||||
///
|
||||
/// [contribute]: https://github.com/ruuda/bs1770/blob/master/CONTRIBUTING.md
|
||||
#[derive(Clone)]
|
||||
pub struct ChannelLoudnessMeter {
|
||||
/// The number of samples that fit in 100ms of audio.
|
||||
samples_per_100ms: u32,
|
||||
|
||||
/// Stage 1 filter (head effects, high shelf).
|
||||
filter_stage1: Filter,
|
||||
|
||||
/// Stage 2 filter (high-pass).
|
||||
filter_stage2: Filter,
|
||||
|
||||
/// Sum of the squares over non-overlapping windows of 100ms.
|
||||
windows: Windows100ms<Vec<Power>>,
|
||||
|
||||
/// The number of samples in the current unfinished window.
|
||||
count: u32,
|
||||
|
||||
/// The sum of the squares of the samples in the current unfinished window.
|
||||
square_sum: Sum,
|
||||
}
|
||||
|
||||
impl ChannelLoudnessMeter {
|
||||
/// Construct a new loudness meter for the given sample rate.
|
||||
pub fn new(sample_rate_hz: u32) -> ChannelLoudnessMeter {
|
||||
ChannelLoudnessMeter {
|
||||
samples_per_100ms: sample_rate_hz / 10,
|
||||
filter_stage1: Filter::high_shelf(sample_rate_hz as f32),
|
||||
filter_stage2: Filter::high_pass(sample_rate_hz as f32),
|
||||
windows: Windows100ms::new(),
|
||||
count: 0,
|
||||
square_sum: Sum::zero(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed input samples for loudness analysis.
|
||||
///
|
||||
/// # Full scale
|
||||
///
|
||||
/// Full scale for the input samples is the interval [-1.0, 1.0]. If your
|
||||
/// input consists of signed integer samples, you can convert as follows:
|
||||
///
|
||||
/// ```ignore
|
||||
/// # let mut meter = bs1770::ChannelLoudnessMeter::new(44_100);
|
||||
/// # let bits_per_sample = 16_usize;
|
||||
/// # let samples = &[0_i16];
|
||||
/// // Note that the maximum amplitude is `1 << (bits_per_sample - 1)`,
|
||||
/// // one bit is the sign bit.
|
||||
/// let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
|
||||
/// meter.push(samples.iter().map(|&s| s as f32 * normalizer));
|
||||
/// ```
|
||||
///
|
||||
/// # Repeated calls
|
||||
///
|
||||
/// You can call `push` multiple times to feed multiple batches of samples.
|
||||
/// This is equivalent to feeding a single chained iterator. The leftover of
|
||||
/// samples that did not fill a full 100ms window is not discarded:
|
||||
///
|
||||
/// ```ignore
|
||||
/// # use std::iter;
|
||||
/// # use bs1770::ChannelLoudnessMeter;
|
||||
/// let sample_rate_hz = 44_100;
|
||||
/// let samples_per_100ms = sample_rate_hz / 10;
|
||||
/// let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
|
||||
///
|
||||
/// meter.push(iter::repeat(0.0).take(samples_per_100ms as usize - 1));
|
||||
/// assert_eq!(meter.as_100ms_windows().len(), 0);
|
||||
///
|
||||
/// meter.push(iter::once(0.0));
|
||||
/// assert_eq!(meter.as_100ms_windows().len(), 1);
|
||||
/// ```
|
||||
pub fn push<I: Iterator<Item = f32>>(&mut self, samples: I) {
|
||||
let normalizer = 1.0 / self.samples_per_100ms as f32;
|
||||
|
||||
// LLVM, if you could go ahead and inline those apply calls, and then
|
||||
// unroll and vectorize the loop, that'd be terrific.
|
||||
for x in samples {
|
||||
let y = self.filter_stage1.apply(x);
|
||||
let z = self.filter_stage2.apply(y);
|
||||
|
||||
self.square_sum.add(z * z);
|
||||
self.count += 1;
|
||||
|
||||
// TODO: Should this branch be marked cold?
|
||||
if self.count == self.samples_per_100ms {
|
||||
let mean_squares = Power(self.square_sum.sum * normalizer);
|
||||
self.windows.inner.push(mean_squares);
|
||||
// We intentionally do not reset the residue. That way, leftover
|
||||
// energy from this window is not lost, so for the file overall,
|
||||
// the sum remains more accurate.
|
||||
self.square_sum.sum = 0.0;
|
||||
self.count = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a reference to the 100ms windows analyzed so far.
|
||||
pub fn as_100ms_windows(&self) -> Windows100ms<&[Power]> {
|
||||
self.windows.as_ref()
|
||||
}
|
||||
|
||||
/// Return all 100ms windows analyzed so far.
|
||||
pub fn into_100ms_windows(self) -> Windows100ms<Vec<Power>> {
|
||||
self.windows
|
||||
}
|
||||
}
|
||||
|
||||
/// Combine power for multiple channels by taking a weighted sum.
|
||||
///
|
||||
/// Note that BS.1770-4 defines power for a multi-channel signal as a weighted
|
||||
/// sum over channels which is not normalized. This means that a stereo signal
|
||||
/// is inherently louder than a mono signal. For a mono signal played back on
|
||||
/// stereo speakers, you should therefore still apply `reduce_stereo`, passing
|
||||
/// in the same signal for both channels.
|
||||
pub fn reduce_stereo(
|
||||
left: Windows100ms<&[Power]>,
|
||||
right: Windows100ms<&[Power]>,
|
||||
) -> Windows100ms<Vec<Power>> {
|
||||
assert_eq!(
|
||||
left.len(),
|
||||
right.len(),
|
||||
"Channels must have the same length."
|
||||
);
|
||||
let mut result = Vec::with_capacity(left.len());
|
||||
for (l, r) in left.inner.iter().zip(right.inner) {
|
||||
result.push(Power(l.0 + r.0));
|
||||
}
|
||||
Windows100ms { inner: result }
|
||||
}
|
||||
|
||||
/// In-place version of `reduce_stereo` that stores the result in the former left channel.
|
||||
pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows100ms<&[Power]>) {
|
||||
assert_eq!(
|
||||
left.len(),
|
||||
right.len(),
|
||||
"Channels must have the same length."
|
||||
);
|
||||
for (l, r) in left.inner.iter_mut().zip(right.inner) {
|
||||
l.0 += r.0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
|
||||
///
|
||||
/// The integrated loudness measurement is not just the average power over the
|
||||
/// entire signal. BS.1770-4 defines two stages of gating that exclude
|
||||
/// parts of the signal, to ensure that silent parts do not contribute to the
|
||||
/// loudness measurement. This function performs that gating, and returns the
|
||||
/// average power over the windows that were not excluded.
|
||||
///
|
||||
/// The result of this function is the integrated loudness measurement.
|
||||
///
|
||||
/// When no signal remains after applying the gate, this function returns
|
||||
/// `None`. In particular, this happens when all of the signal is softer than
|
||||
/// -70 LKFS, including a signal that consists of pure silence.
|
||||
pub fn gated_mean(windows_100ms: Windows100ms<&[Power]>) -> Option<Power> {
|
||||
let mut gating_blocks = Vec::with_capacity(windows_100ms.len());
|
||||
|
||||
// Stage 1: an absolute threshold of -70 LKFS. (Equation 6, p.6.)
|
||||
let absolute_threshold = Power::from_lkfs(-70.0);
|
||||
|
||||
// Iterate over all 400ms windows.
|
||||
for window in windows_100ms.inner.windows(4) {
|
||||
// Note that the sum over channels has already been performed at this point.
|
||||
let gating_block_power = Power(0.25 * window.iter().map(|mean| mean.0).sum::<f32>());
|
||||
|
||||
if gating_block_power > absolute_threshold {
|
||||
gating_blocks.push(gating_block_power);
|
||||
}
|
||||
}
|
||||
|
||||
if gating_blocks.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Compute the loudness after applying the absolute gate, in order to
|
||||
// determine the threshold for the relative gate.
|
||||
let mut sum_power = Sum::zero();
|
||||
for &gating_block_power in &gating_blocks {
|
||||
sum_power.add(gating_block_power.0);
|
||||
}
|
||||
let absolute_gated_power = Power(sum_power.sum / (gating_blocks.len() as f32));
|
||||
|
||||
// Stage 2: Apply the relative gate.
|
||||
let relative_threshold = Power::from_lkfs(absolute_gated_power.loudness_lkfs() - 10.0);
|
||||
let mut sum_power = Sum::zero();
|
||||
let mut n_blocks = 0_usize;
|
||||
for &gating_block_power in &gating_blocks {
|
||||
if gating_block_power > relative_threshold {
|
||||
sum_power.add(gating_block_power.0);
|
||||
n_blocks += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if n_blocks == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let relative_gated_power = Power(sum_power.sum / n_blocks as f32);
|
||||
Some(relative_gated_power)
|
||||
}
|
82
integration/utils/src/coco_classes.rs
Normal file
82
integration/utils/src/coco_classes.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
pub const NAMES: [&str; 80] = [
|
||||
"person",
|
||||
"bicycle",
|
||||
"car",
|
||||
"motorbike",
|
||||
"aeroplane",
|
||||
"bus",
|
||||
"train",
|
||||
"truck",
|
||||
"boat",
|
||||
"traffic light",
|
||||
"fire hydrant",
|
||||
"stop sign",
|
||||
"parking meter",
|
||||
"bench",
|
||||
"bird",
|
||||
"cat",
|
||||
"dog",
|
||||
"horse",
|
||||
"sheep",
|
||||
"cow",
|
||||
"elephant",
|
||||
"bear",
|
||||
"zebra",
|
||||
"giraffe",
|
||||
"backpack",
|
||||
"umbrella",
|
||||
"handbag",
|
||||
"tie",
|
||||
"suitcase",
|
||||
"frisbee",
|
||||
"skis",
|
||||
"snowboard",
|
||||
"sports ball",
|
||||
"kite",
|
||||
"baseball bat",
|
||||
"baseball glove",
|
||||
"skateboard",
|
||||
"surfboard",
|
||||
"tennis racket",
|
||||
"bottle",
|
||||
"wine glass",
|
||||
"cup",
|
||||
"fork",
|
||||
"knife",
|
||||
"spoon",
|
||||
"bowl",
|
||||
"banana",
|
||||
"apple",
|
||||
"sandwich",
|
||||
"orange",
|
||||
"broccoli",
|
||||
"carrot",
|
||||
"hot dog",
|
||||
"pizza",
|
||||
"donut",
|
||||
"cake",
|
||||
"chair",
|
||||
"sofa",
|
||||
"pottedplant",
|
||||
"bed",
|
||||
"diningtable",
|
||||
"toilet",
|
||||
"tvmonitor",
|
||||
"laptop",
|
||||
"mouse",
|
||||
"remote",
|
||||
"keyboard",
|
||||
"cell phone",
|
||||
"microwave",
|
||||
"oven",
|
||||
"toaster",
|
||||
"sink",
|
||||
"refrigerator",
|
||||
"book",
|
||||
"clock",
|
||||
"vase",
|
||||
"scissors",
|
||||
"teddy bear",
|
||||
"hair drier",
|
||||
"toothbrush",
|
||||
];
|
1056
integration/utils/src/imagenet.rs
Normal file
1056
integration/utils/src/imagenet.rs
Normal file
File diff suppressed because it is too large
Load Diff
156
integration/utils/src/lib.rs
Normal file
156
integration/utils/src/lib.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
extern crate candle_core;
|
||||
extern crate candle_transformers;
|
||||
extern crate tokenizers;
|
||||
|
||||
pub mod audio;
|
||||
pub mod bs1770;
|
||||
pub mod coco_classes;
|
||||
pub mod imagenet;
|
||||
pub mod token_output_stream;
|
||||
pub mod wav;
|
||||
use candle_core::{
|
||||
utils::{cuda_is_available, metal_is_available},
|
||||
Device, Tensor,
|
||||
};
|
||||
|
||||
pub fn device(cpu: bool) -> Result<Device, anyhow::Error> {
|
||||
if cpu {
|
||||
Ok(Device::Cpu)
|
||||
} else if cuda_is_available() {
|
||||
Ok(Device::new_cuda(0)?)
|
||||
} else if metal_is_available() {
|
||||
Ok(Device::new_metal(0)?)
|
||||
} else {
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
{
|
||||
println!(
|
||||
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
|
||||
);
|
||||
}
|
||||
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
|
||||
{
|
||||
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
||||
}
|
||||
Ok(Device::Cpu)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(
|
||||
p: P,
|
||||
resize_longest: Option<usize>,
|
||||
) -> Result<(Tensor, usize, usize), anyhow::Error> {
|
||||
let img = image::ImageReader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle_core::Error::wrap)?;
|
||||
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
|
||||
let img = match resize_longest {
|
||||
None => img,
|
||||
Some(resize_longest) => {
|
||||
let (height, width) = (img.height(), img.width());
|
||||
let resize_longest = resize_longest as u32;
|
||||
let (height, width) = if height < width {
|
||||
let h = (resize_longest * height) / width;
|
||||
(h, resize_longest)
|
||||
} else {
|
||||
let w = (resize_longest * width) / height;
|
||||
(resize_longest, w)
|
||||
};
|
||||
img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
|
||||
}
|
||||
};
|
||||
let (height, width) = (img.height() as usize, img.width() as usize);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
Ok((data, initial_h, initial_w))
|
||||
}
|
||||
|
||||
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
||||
p: P,
|
||||
width: usize,
|
||||
height: usize,
|
||||
) -> candle_core::Result<Tensor> {
|
||||
let img = image::ImageReader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle_core::Error::wrap)?
|
||||
.resize_to_fill(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::Triangle,
|
||||
);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1))
|
||||
}
|
||||
|
||||
/// Saves an image to disk using the image crate, this expects an input with shape
|
||||
/// (c, height, width).
|
||||
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<(), anyhow::Error> {
|
||||
let p = p.as_ref();
|
||||
let (channel, height, width) = img.dims3()?;
|
||||
if channel != 3 {
|
||||
anyhow::bail!("save_image expects an input of shape (3, height, width)")
|
||||
}
|
||||
let img = img.permute((1, 2, 0))?.flatten_all()?;
|
||||
let pixels = img.to_vec1::<u8>()?;
|
||||
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
||||
Some(image) => image,
|
||||
None => anyhow::bail!("error saving image {p:?}"),
|
||||
};
|
||||
image.save(p).map_err(candle_core::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Loads the safetensors files for a model from the hub based on a json index file.
|
||||
pub fn hub_load_safetensors(
|
||||
repo: &hf_hub::api::sync::ApiRepo,
|
||||
json_file: &str,
|
||||
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
|
||||
let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?;
|
||||
let json_file = std::fs::File::open(json_file)?;
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
None => anyhow::bail!("no weight map in {json_file:?}"),
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"),
|
||||
};
|
||||
let mut safetensors_files = std::collections::HashSet::new();
|
||||
for value in weight_map.values() {
|
||||
if let Some(file) = value.as_str() {
|
||||
safetensors_files.insert(file.to_string());
|
||||
}
|
||||
}
|
||||
let safetensors_files = safetensors_files
|
||||
.iter()
|
||||
.map(|v| repo.get(v).map_err(std::io::Error::other))
|
||||
.collect::<Result<Vec<_>, std::io::Error>>()?;
|
||||
Ok(safetensors_files)
|
||||
}
|
||||
|
||||
pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
||||
path: P,
|
||||
json_file: &str,
|
||||
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
|
||||
let path = path.as_ref();
|
||||
let jsfile = std::fs::File::open(path.join(json_file))?;
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
None => anyhow::bail!("no weight map in {json_file:?}"),
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"),
|
||||
};
|
||||
let mut safetensors_files = std::collections::HashSet::new();
|
||||
for value in weight_map.values() {
|
||||
if let Some(file) = value.as_str() {
|
||||
safetensors_files.insert(file);
|
||||
}
|
||||
}
|
||||
let safetensors_files: Vec<_> = safetensors_files
|
||||
.into_iter()
|
||||
.map(|v| path.join(v))
|
||||
.collect();
|
||||
Ok(safetensors_files)
|
||||
}
|
3
integration/utils/src/main.rs
Normal file
3
integration/utils/src/main.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
fn main() {
|
||||
println!("Hello, world!");
|
||||
}
|
85
integration/utils/src/token_output_stream.rs
Normal file
85
integration/utils/src/token_output_stream.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
use candle_core::Result;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
pub struct TokenOutputStream {
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
tokens: Vec<u32>,
|
||||
prev_index: usize,
|
||||
current_index: usize,
|
||||
}
|
||||
|
||||
impl TokenOutputStream {
|
||||
pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
|
||||
Self {
|
||||
tokenizer,
|
||||
tokens: Vec::new(),
|
||||
prev_index: 0,
|
||||
current_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> tokenizers::Tokenizer {
|
||||
self.tokenizer
|
||||
}
|
||||
|
||||
fn decode(&self, tokens: &[u32]) -> Result<String> {
|
||||
match self.tokenizer.decode(tokens, true) {
|
||||
Ok(str) => Ok(str),
|
||||
Err(err) => candle_core::bail!("cannot decode: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
|
||||
pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
|
||||
let prev_text = if self.tokens.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let tokens = &self.tokens[self.prev_index..self.current_index];
|
||||
self.decode(tokens)?
|
||||
};
|
||||
self.tokens.push(token);
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
|
||||
let text = text.split_at(prev_text.len());
|
||||
self.prev_index = self.current_index;
|
||||
self.current_index = self.tokens.len();
|
||||
Ok(Some(text.1.to_string()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_rest(&self) -> Result<Option<String>> {
|
||||
let prev_text = if self.tokens.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let tokens = &self.tokens[self.prev_index..self.current_index];
|
||||
self.decode(tokens)?
|
||||
};
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() {
|
||||
let text = text.split_at(prev_text.len());
|
||||
Ok(Some(text.1.to_string()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_all(&self) -> Result<String> {
|
||||
self.decode(&self.tokens)
|
||||
}
|
||||
|
||||
pub fn get_token(&self, token_s: &str) -> Option<u32> {
|
||||
self.tokenizer.get_vocab(true).get(token_s).copied()
|
||||
}
|
||||
|
||||
pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
|
||||
&self.tokenizer
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.tokens.clear();
|
||||
self.prev_index = 0;
|
||||
self.current_index = 0;
|
||||
}
|
||||
}
|
56
integration/utils/src/wav.rs
Normal file
56
integration/utils/src/wav.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use std::io::prelude::*;
|
||||
|
||||
pub trait Sample {
|
||||
fn to_i16(&self) -> i16;
|
||||
}
|
||||
|
||||
impl Sample for f32 {
|
||||
fn to_i16(&self) -> i16 {
|
||||
(self.clamp(-1.0, 1.0) * 32767.0) as i16
|
||||
}
|
||||
}
|
||||
|
||||
impl Sample for f64 {
|
||||
fn to_i16(&self) -> i16 {
|
||||
(self.clamp(-1.0, 1.0) * 32767.0) as i16
|
||||
}
|
||||
}
|
||||
|
||||
impl Sample for i16 {
|
||||
fn to_i16(&self) -> i16 {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_pcm_as_wav<W: Write, S: Sample>(
|
||||
w: &mut W,
|
||||
samples: &[S],
|
||||
sample_rate: u32,
|
||||
) -> std::io::Result<()> {
|
||||
let len = 12u32; // header
|
||||
let len = len + 24u32; // fmt
|
||||
let len = len + samples.len() as u32 * 2 + 8; // data
|
||||
let n_channels = 1u16;
|
||||
let bytes_per_second = sample_rate * 2 * n_channels as u32;
|
||||
w.write_all(b"RIFF")?;
|
||||
w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
|
||||
w.write_all(b"WAVE")?;
|
||||
|
||||
// Format block
|
||||
w.write_all(b"fmt ")?;
|
||||
w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes
|
||||
w.write_all(&1u16.to_le_bytes())?; // PCM
|
||||
w.write_all(&n_channels.to_le_bytes())?; // one channel
|
||||
w.write_all(&sample_rate.to_le_bytes())?;
|
||||
w.write_all(&bytes_per_second.to_le_bytes())?;
|
||||
w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample
|
||||
w.write_all(&16u16.to_le_bytes())?; // bits per sample
|
||||
|
||||
// Data block
|
||||
w.write_all(b"data")?;
|
||||
w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
|
||||
for sample in samples.iter() {
|
||||
w.write_all(&sample.to_i16().to_le_bytes())?
|
||||
}
|
||||
Ok(())
|
||||
}
|
@@ -1,8 +1,8 @@
|
||||
{
|
||||
"name": "predict-otron-9000",
|
||||
"workspaces": ["crates/cli/package"],
|
||||
"workspaces": ["integration/cli/package"],
|
||||
"scripts": {
|
||||
"# WORKSPACE ALIASES": "#",
|
||||
"cli": "bun --filter crates/cli/package"
|
||||
"cli": "bun --filter integration/cli/package"
|
||||
}
|
||||
}
|
BIN
predict-otron-9000.png
Normal file
BIN
predict-otron-9000.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 248 KiB |
14
scripts/build_ui.sh
Executable file
14
scripts/build_ui.sh
Executable file
@@ -0,0 +1,14 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
# Resolve the project root (script_dir/..)
|
||||
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
||||
|
||||
# Move into the chat-ui crate
|
||||
cd "$PROJECT_ROOT/crates/chat-ui" || exit 1
|
||||
|
||||
# Build with cargo leptos
|
||||
cargo leptos build --release
|
||||
|
||||
# Move the wasm file, keeping paths relative to the project root
|
||||
mv "$PROJECT_ROOT/target/site/pkg/chat-ui.wasm" \
|
||||
"$PROJECT_ROOT/target/site/pkg/chat-ui_bg.wasm"
|
@@ -15,7 +15,7 @@ CONNECT_TIMEOUT=${CONNECT_TIMEOUT:-10}
|
||||
MAX_TIME=${MAX_TIME:-30}
|
||||
|
||||
cat <<EOF
|
||||
[info] POST $SERVER_URL/v1/chat/completions/stream (SSE)
|
||||
[info] POST $SERVER_URL/v1/chat/completions (SSE)
|
||||
[info] model=$MODEL_ID, max_tokens=$MAX_TOKENS
|
||||
[info] prompt=$PROMPT
|
||||
[info] timeouts: connect=${CONNECT_TIMEOUT}s, max=${MAX_TIME}s
|
||||
@@ -35,7 +35,7 @@ curl -N -sS -X POST \
|
||||
--connect-timeout "$CONNECT_TIMEOUT" \
|
||||
--max-time "$MAX_TIME" \
|
||||
-H "Content-Type: application/json" \
|
||||
"$SERVER_URL/v1/chat/completions/stream" \
|
||||
"$SERVER_URL/v1/chat/completions" \
|
||||
-d @- <<JSON
|
||||
{
|
||||
"model": "${MODEL_ID}",
|
||||
|
17
scripts/run.sh
Executable file
17
scripts/run.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Resolve the project root (script_dir/..)
|
||||
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
||||
|
||||
# todo, conditionally run this only when those files change
|
||||
"$PROJECT_ROOT/scripts/build_ui.sh"
|
||||
|
||||
# build the frontend first
|
||||
# Start the unified predict-otron-9000 server on port 8080
|
||||
export SERVER_PORT=${SERVER_PORT:-8080}
|
||||
export RUST_LOG=${RUST_LOG:-info}
|
||||
|
||||
cd "$PROJECT_ROOT" || exit 1
|
||||
cargo run --bin predict-otron-9000 --release
|
@@ -1,7 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Start the unified predict-otron-9000 server on port 8080
|
||||
export SERVER_PORT=${SERVER_PORT:-8080}
|
||||
export RUST_LOG=${RUST_LOG:-info}
|
||||
|
||||
cargo run --bin predict-otron-9000 --release
|
Reference in New Issue
Block a user