11 Commits

Author SHA1 Message Date
geoffsee
21f20470de patch version 2025-09-01 22:55:59 -04:00
geoffsee
2deecb5e51 chat client only displays available models 2025-09-01 22:29:54 -04:00
geoffsee
545e0c9831 make wasm32 availble for all builds in ci 2025-08-31 20:22:12 -04:00
geoffsee
eca61c51ad add build step to ci 2025-08-31 20:08:54 -04:00
geoffsee
d1a7d5b28e fix format error 2025-08-31 19:59:09 -04:00
geoffsee
8d2b85b0b9 update docs 2025-08-31 19:27:15 -04:00
geoffsee
4570780666 release 0.1.3 2025-08-31 18:55:37 -04:00
geoffsee
44e4f9e5e1 put proof in the pudding 2025-08-31 18:54:20 -04:00
geoffsee
64daa77c6b leptos chat ui renders 2025-08-31 18:50:25 -04:00
geoffsee
2b4a8a9df8 chat-ui not functional yet but builds 2025-08-31 18:18:56 -04:00
geoffsee
38d51722f2 Update configuration loading with Cargo.toml path and clean up .gitignore
---

This commit message concisely communicates the key changes:

1. The code now builds an absolute path to the `Cargo.toml` file, enhancing clarity in configuration loading.
2. The addition of `PathBuf` usage improves type safety.
3. The removal of unnecessary entries from `.gitignore` helps maintain a clean project structure.

These updates reflect improvements in both functionality and project organization.
2025-08-31 14:06:44 -04:00
53 changed files with 4300 additions and 1277 deletions

View File

@@ -25,7 +25,16 @@ jobs:
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Setup Rust - name: Setup Rust
run: rustup update stable && rustup default stable run: rustup update stable && rustup default stable && rustup target add wasm32-unknown-unknown
- name: Setup Bun
uses: oven-sh/setup-bun@v2
- name: Build
run: |
cargo install --locked cargo-leptos
cd crates/chat-ui && cargo leptos build --release
cargo build --release -p predict-otron-9000 -p cli
- name: Install clippy and rustfmt - name: Install clippy and rustfmt
run: rustup component add clippy rustfmt run: rustup component add clippy rustfmt

View File

@@ -32,7 +32,7 @@ jobs:
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Setup Rust - name: Setup Rust
run: rustup update stable && rustup default stable run: rustup update stable && rustup default stable && rustup target add wasm32-unknown-unknown
- name: Setup Bun - name: Setup Bun
uses: oven-sh/setup-bun@v2 uses: oven-sh/setup-bun@v2
@@ -129,12 +129,17 @@ jobs:
key: ${{ runner.os }}-${{ matrix.target }}-cargo-${{ hashFiles('**/Cargo.lock') }} key: ${{ runner.os }}-${{ matrix.target }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Setup Rust - name: Setup Rust
run: rustup update stable && rustup default stable run: rustup update stable && rustup default stable && rustup target add wasm32-unknown-unknown
- name: Add target - name: Add target
run: rustup target add ${{ matrix.target }} run: rustup target add ${{ matrix.target }}
- name: Build binary - name: Build UI
run: cargo install --locked cargo-leptos && cd crates/chat-ui && cargo leptos build --release
env:
CARGO_TERM_COLOR: always
- name: Build Binary
run: cargo build --release --target ${{ matrix.target }} -p predict-otron-9000 -p cli run: cargo build --release --target ${{ matrix.target }} -p predict-otron-9000 -p cli
env: env:
CARGO_TERM_COLOR: always CARGO_TERM_COLOR: always

4
.gitignore vendored
View File

@@ -23,7 +23,6 @@ package-lock.json
# Web frontend build outputs # Web frontend build outputs
dist/ dist/
.trunk/
# ML model and embedding caches # ML model and embedding caches
.fastembed_cache/ .fastembed_cache/
@@ -75,7 +74,6 @@ venv/
# Backup files # Backup files
*.bak *.bak
*.backup *.backup
*~
/scripts/cli
!/scripts/cli.ts !/scripts/cli.ts
/**/.*.bun-build /**/.*.bun-build
/AGENTS.md

998
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -3,17 +3,17 @@ members = [
"crates/predict-otron-9000", "crates/predict-otron-9000",
"crates/inference-engine", "crates/inference-engine",
"crates/embeddings-engine", "crates/embeddings-engine",
"crates/leptos-app",
"crates/helm-chart-tool", "crates/helm-chart-tool",
"crates/llama-runner", "crates/llama-runner",
"crates/gemma-runner", "crates/gemma-runner",
"crates/cli" "crates/cli",
] "crates/chat-ui"
, "crates/utils"]
default-members = ["crates/predict-otron-9000"] default-members = ["crates/predict-otron-9000"]
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "0.1.2" version = "0.1.4"
# Compiler optimization profiles for the workspace # Compiler optimization profiles for the workspace
[profile.release] [profile.release]
@@ -42,8 +42,3 @@ overflow-checks = true
opt-level = 3 opt-level = 3
debug = true debug = true
lto = "thin" lto = "thin"
[[workspace.metadata.leptos]]
# project name
bin-package = "leptos-app"
lib-package = "leptos-app"

View File

@@ -2,11 +2,13 @@
predict-otron-9000 predict-otron-9000
</h1> </h1>
<p align="center"> <p align="center">
Powerful local AI inference with OpenAI-compatible APIs AI inference Server with OpenAI-compatible API (Limited Features)
</p>
<p align="center">
<img src="https://github.com/geoffsee/predict-otron-9001/blob/master/predict-otron-9000.png?raw=true" width="90%" />
</p> </p>
<br/> <br/>
> This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks. > This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks.
> By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems. > By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems.
@@ -15,6 +17,11 @@ Stability is currently best effort. Many models require unique configuration. Wh
A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces. A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces.
~~~shell
./scripts/run.sh
~~~
## Project Overview ## Project Overview
The predict-otron-9000 is a flexible AI platform that provides: The predict-otron-9000 is a flexible AI platform that provides:
@@ -40,7 +47,7 @@ The system supports both CPU and GPU acceleration (CUDA/Metal), with intelligent
### Workspace Structure ### Workspace Structure
The project uses a 7-crate Rust workspace plus TypeScript components: The project uses a 9-crate Rust workspace plus TypeScript components:
``` ```
crates/ crates/
@@ -49,17 +56,18 @@ crates/
├── gemma-runner/ # Gemma model inference via Candle (Rust 2021) ├── gemma-runner/ # Gemma model inference via Candle (Rust 2021)
├── llama-runner/ # Llama model inference via Candle (Rust 2021) ├── llama-runner/ # Llama model inference via Candle (Rust 2021)
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024) ├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
├── leptos-app/ # WASM web frontend (Rust 2021) ├── chat-ui/ # WASM web frontend (Rust 2021)
├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024) ├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024)
└── scripts/ └── cli/ # CLI client crate (Rust 2024)
└── cli.ts # TypeScript/Bun CLI client └── package/
└── cli.ts # TypeScript/Bun CLI client
``` ```
### Service Architecture ### Service Architecture
- **Main Server** (port 8080): Orchestrates inference and embeddings services - **Main Server** (port 8080): Orchestrates inference and embeddings services
- **Embeddings Service** (port 8080): Standalone FastEmbed service with OpenAI API compatibility - **Embeddings Service** (port 8080): Standalone FastEmbed service with OpenAI API compatibility
- **Web Frontend** (port 8788): cargo leptos SSR app - **Web Frontend** (port 8788): chat-ui WASM app
- **CLI Client**: TypeScript/Bun client for testing and automation - **CLI Client**: TypeScript/Bun client for testing and automation
### Deployment Modes ### Deployment Modes
@@ -85,11 +93,6 @@ The architecture supports multiple deployment patterns:
- **Bun**: Required for TypeScript CLI client: `curl -fsSL https://bun.sh/install | bash` - **Bun**: Required for TypeScript CLI client: `curl -fsSL https://bun.sh/install | bash`
- **Node.js**: Alternative to Bun, supports OpenAI SDK v5.16.0+ - **Node.js**: Alternative to Bun, supports OpenAI SDK v5.16.0+
#### WASM Frontend Toolchain
- **Trunk**: Required for Leptos frontend builds: `cargo install trunk`
- **wasm-pack**: `cargo install wasm-pack`
- **WASM target**: `rustup target add wasm32-unknown-unknown`
#### ML Framework Dependencies #### ML Framework Dependencies
- **Candle**: Version 0.9.1 with conditional compilation: - **Candle**: Version 0.9.1 with conditional compilation:
- macOS: Metal support with CPU fallback for stability - macOS: Metal support with CPU fallback for stability
@@ -134,11 +137,6 @@ cargo build --bin cli --package inference-engine --release
cargo build --bin embeddings-engine --release cargo build --bin embeddings-engine --release
``` ```
**Web Frontend:**
```bash
cd crates/leptos-app
trunk build --release
```
### Running Services ### Running Services
@@ -152,26 +150,26 @@ trunk build --release
#### Web Frontend (Port 8788) #### Web Frontend (Port 8788)
```bash ```bash
cd crates/leptos-app cd crates/chat-ui
./run.sh ./run.sh
``` ```
- Serves Leptos WASM frontend on port 8788 - Serves chat-ui WASM frontend on port 8788
- Sets required RUSTFLAGS for WebAssembly getrandom support - Sets required RUSTFLAGS for WebAssembly getrandom support
- Auto-reloads during development - Auto-reloads during development
#### TypeScript CLI Client #### TypeScript CLI Client
```bash ```bash
# List available models # List available models
bun run scripts/cli.ts --list-models cd crates/cli/package && bun run cli.ts --list-models
# Chat completion # Chat completion
bun run scripts/cli.ts "What is the capital of France?" cd crates/cli/package && bun run cli.ts "What is the capital of France?"
# With specific model # With specific model
bun run scripts/cli.ts --model gemma-3-1b-it --prompt "Hello, world!" cd crates/cli/package && bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
# Show help # Show help
bun run scripts/cli.ts --help cd crates/cli/package && bun run cli.ts --help
``` ```
## API Usage ## API Usage
@@ -287,7 +285,7 @@ cargo test --workspace
**End-to-end test script:** **End-to-end test script:**
```bash ```bash
./smoke_test.sh ./scripts/smoke_test.sh
``` ```
This script: This script:
@@ -376,7 +374,7 @@ All services include Docker metadata in `Cargo.toml`:
- Port: 8080 - Port: 8080
**Web Frontend:** **Web Frontend:**
- Image: `ghcr.io/geoffsee/leptos-app:latest` - Image: `ghcr.io/geoffsee/chat-ui:latest`
- Port: 8788 - Port: 8788
**Docker Compose:** **Docker Compose:**
@@ -435,8 +433,7 @@ For Kubernetes deployment details, see the [ARCHITECTURE.md](docs/ARCHITECTURE.m
**Symptom:** WASM compilation failures **Symptom:** WASM compilation failures
**Solution:** **Solution:**
1. Install required targets: `rustup target add wasm32-unknown-unknown` 1. Install required targets: `rustup target add wasm32-unknown-unknown`
2. Install trunk: `cargo install trunk` 2. Check RUSTFLAGS in chat-ui/run.sh
3. Check RUSTFLAGS in leptos-app/run.sh
### Network/Timeout Issues ### Network/Timeout Issues
**Symptom:** First-time model downloads timing out **Symptom:** First-time model downloads timing out
@@ -467,24 +464,23 @@ curl -s http://localhost:8080/v1/models | jq
**CLI client test:** **CLI client test:**
```bash ```bash
bun run scripts/cli.ts "What is 2+2?" cd crates/cli/package && bun run cli.ts "What is 2+2?"
``` ```
**Web frontend:** **Web frontend:**
```bash ```bash
cd crates/leptos-app && ./run.sh & cd crates/chat-ui && ./run.sh &
# Navigate to http://localhost:8788 # Navigate to http://localhost:8788
``` ```
**Integration test:** **Integration test:**
```bash ```bash
./smoke_test.sh ./scripts/smoke_test.sh
``` ```
**Cleanup:** **Cleanup:**
```bash ```bash
pkill -f "predict-otron-9000" pkill -f "predict-otron-9000"
pkill -f "trunk"
``` ```
For networked tests and full functionality, ensure Hugging Face authentication is configured as described above. For networked tests and full functionality, ensure Hugging Face authentication is configured as described above.

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "leptos-app" name = "chat-ui"
version.workspace = true version = "0.1.0"
edition = "2021" edition = "2021"
[lib] [lib]
@@ -15,45 +15,17 @@ leptos_axum = { version = "0.8.0", optional = true }
leptos_meta = { version = "0.8.0" } leptos_meta = { version = "0.8.0" }
tokio = { version = "1", features = ["rt-multi-thread"], optional = true } tokio = { version = "1", features = ["rt-multi-thread"], optional = true }
wasm-bindgen = { version = "=0.2.100", optional = true } wasm-bindgen = { version = "=0.2.100", optional = true }
# Chat interface dependencies
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
async-openai-wasm = { version = "0.29", default-features = false } reqwest = { version = "0.12", features = ["json"] }
futures-util = "0.3" web-sys = { version = "0.3", features = ["console"] }
js-sys = { version = "0.3", optional = true } gloo-net = { version = "0.6", features = ["http"] }
either = { version = "1.9", features = ["serde"] }
web-sys = { version = "0.3", optional = true, features = [
"console",
"Window",
"Document",
"Element",
"HtmlElement",
"HtmlInputElement",
"HtmlSelectElement",
"HtmlTextAreaElement",
"Event",
"EventTarget",
"KeyboardEvent",
] }
[dependencies.uuid]
version = "1.0"
features = [
"v4",
"fast-rng",
"macro-diagnostics",
"js",
]
[features] [features]
hydrate = [ hydrate = [
"leptos/hydrate", "leptos/hydrate",
"dep:console_error_panic_hook", "dep:console_error_panic_hook",
"dep:wasm-bindgen", "dep:wasm-bindgen",
"dep:js-sys",
"dep:web-sys",
] ]
ssr = [ ssr = [
"dep:axum", "dep:axum",
@@ -73,8 +45,9 @@ codegen-units = 1
panic = "abort" panic = "abort"
[package.metadata.leptos] [package.metadata.leptos]
name = "chat-ui"
# The name used by wasm-bindgen/cargo-leptos for the JS/WASM bundle. Defaults to the crate name # The name used by wasm-bindgen/cargo-leptos for the JS/WASM bundle. Defaults to the crate name
output-name = "leptos-app" output-name = "chat-ui"
# The site root folder is where cargo-leptos generate all output. WARNING: all content of this folder will be erased on a rebuild. Use it in your server setup. # The site root folder is where cargo-leptos generate all output. WARNING: all content of this folder will be erased on a rebuild. Use it in your server setup.
site-root = "target/site" site-root = "target/site"
@@ -84,7 +57,7 @@ site-root = "target/site"
site-pkg-dir = "pkg" site-pkg-dir = "pkg"
# [Optional] The source CSS file. If it ends with .sass or .scss then it will be compiled by dart-sass into CSS. The CSS is optimized by Lightning CSS before being written to <site-root>/<site-pkg>/app.css # [Optional] The source CSS file. If it ends with .sass or .scss then it will be compiled by dart-sass into CSS. The CSS is optimized by Lightning CSS before being written to <site-root>/<site-pkg>/app.css
style-file = "style/main.scss" style-file = "./style/main.scss"
# Assets source dir. All files found here will be copied and synchronized to site-root. # Assets source dir. All files found here will be copied and synchronized to site-root.
# The assets-dir cannot have a sub directory with the same name/path as site-pkg-dir. # The assets-dir cannot have a sub directory with the same name/path as site-pkg-dir.
# #
@@ -132,4 +105,4 @@ lib-default-features = false
# The profile to use for the lib target when compiling for release # The profile to use for the lib target when compiling for release
# #
# Optional. Defaults to "release". # Optional. Defaults to "release".
lib-profile-release = "wasm-release" lib-profile-release = "release"

24
crates/chat-ui/LICENSE Normal file
View File

@@ -0,0 +1,24 @@
This is free and unencumbered software released into the public domain.
Anyone is free to copy, modify, publish, use, compile, sell, or
distribute this software, either in source code form or as a compiled
binary, for any purpose, commercial or non-commercial, and by any
means.
In jurisdictions that recognize copyright laws, the author or authors
of this software dedicate any and all copyright interest in the
software to the public domain. We make this dedication for the benefit
of the public at large and to the detriment of our heirs and
successors. We intend this dedication to be an overt act of
relinquishment in perpetuity of all present and future rights to this
software under copyright law.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
For more information, please refer to <https://unlicense.org>

41
crates/chat-ui/README.md Normal file
View File

@@ -0,0 +1,41 @@
# chat-ui
A WASM-based web chat interface for the predict-otron-9000 AI platform.
## Overview
The chat-ui provides a real-time web interface for interacting with language models through the predict-otron-9000 server. Built with Leptos and compiled to WebAssembly, it offers a modern chat experience with streaming response support.
## Features
- Real-time chat interface with the inference server
- Streaming response support
- Conversation history
- Responsive web design
- WebAssembly-powered for optimal performance
## Building and Running
### Prerequisites
- Rust toolchain with WASM target: `rustup target add wasm32-unknown-unknown`
- The predict-otron-9000 server must be running on port 8080
### Development Server
```bash
cd crates/chat-ui
./run.sh
```
This starts the development server on port 8788 with auto-reload capabilities.
### Usage
1. Start the predict-otron-9000 server: `./scripts/run.sh`
2. Start the chat-ui: `cd crates/chat-ui && ./run.sh`
3. Navigate to `http://localhost:8788`
4. Start chatting with your AI models!
## Technical Details
- Built with Leptos framework
- Compiled to WebAssembly for browser execution
- Communicates with predict-otron-9000 API via HTTP
- Sets required RUSTFLAGS for WebAssembly getrandom support

View File

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 15 KiB

393
crates/chat-ui/src/app.rs Normal file
View File

@@ -0,0 +1,393 @@
#[cfg(feature = "ssr")]
use axum::Router;
#[cfg(feature = "ssr")]
use leptos::prelude::LeptosOptions;
#[cfg(feature = "ssr")]
use leptos_axum::{generate_route_list, LeptosRoutes};
pub struct AppConfig {
pub config: ConfFile,
pub address: String,
}
impl Default for AppConfig {
fn default() -> Self {
let conf = get_configuration(Some(concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml")))
.expect("failed to read config");
let addr = conf.leptos_options.site_addr;
AppConfig {
config: conf, // or whichever field/string representation you need
address: addr.to_string(),
}
}
}
/// Build the Axum router for this app, including routes, fallback, and state.
/// Call this from another crate (or your bin) when running the server.
#[cfg(feature = "ssr")]
pub fn create_router(leptos_options: LeptosOptions) -> Router {
// Generate the list of routes in your Leptos App
let routes = generate_route_list(App);
Router::new()
.leptos_routes(&leptos_options, routes, {
let leptos_options = leptos_options.clone();
move || shell(leptos_options.clone())
})
.fallback(leptos_axum::file_and_error_handler(shell))
.with_state(leptos_options)
}
use gloo_net::http::Request;
use leptos::prelude::*;
use leptos_meta::{provide_meta_context, MetaTags, Stylesheet, Title};
use leptos_router::{
components::{Route, Router, Routes},
StaticSegment,
};
use serde::{Deserialize, Serialize};
use web_sys::console;
// Remove spawn_local import as we'll use different approach
// Data structures for OpenAI-compatible API
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
#[derive(Debug, Serialize)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub max_tokens: Option<u32>,
pub stream: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct ChatChoice {
pub message: ChatMessage,
pub index: u32,
pub finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct ChatResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatChoice>,
}
// Data structures for models API
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub object: String,
pub created: u64,
pub owned_by: String,
}
#[derive(Debug, Deserialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelInfo>,
}
// API client function to fetch available models
pub async fn fetch_models() -> Result<Vec<ModelInfo>, String> {
let response = Request::get("/v1/models")
.send()
.await
.map_err(|e| format!("Failed to fetch models: {:?}", e))?;
if response.ok() {
let models_response: ModelsResponse = response
.json()
.await
.map_err(|e| format!("Failed to parse models response: {:?}", e))?;
Ok(models_response.data)
} else {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
Err(format!("Failed to fetch models {}: {}", status, error_text))
}
}
// API client function to send chat completion requests
pub async fn send_chat_completion(
messages: Vec<ChatMessage>,
model: String,
) -> Result<String, String> {
let request = ChatRequest {
model,
messages,
max_tokens: Some(1024),
stream: Some(false),
};
let response = Request::post("/v1/chat/completions")
.header("Content-Type", "application/json")
.json(&request)
.map_err(|e| format!("Failed to create request: {:?}", e))?
.send()
.await
.map_err(|e| format!("Failed to send request: {:?}", e))?;
if response.ok() {
let chat_response: ChatResponse = response
.json()
.await
.map_err(|e| format!("Failed to parse response: {:?}", e))?;
if let Some(choice) = chat_response.choices.first() {
Ok(choice.message.content.clone())
} else {
Err("No response choices available".to_string())
}
} else {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
Err(format!("Server error {}: {}", status, error_text))
}
}
pub fn shell(options: LeptosOptions) -> impl IntoView {
view! {
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<AutoReload options=options.clone() />
<HydrationScripts options/>
<MetaTags/>
</head>
<body>
<App/>
</body>
</html>
}
}
#[component]
pub fn App() -> impl IntoView {
// Provides context that manages stylesheets, titles, meta tags, etc.
provide_meta_context();
view! {
// injects a stylesheet into the document <head>
// id=leptos means cargo-leptos will hot-reload this stylesheet
<Stylesheet id="leptos" href="/pkg/chat-ui.css"/>
// sets the document title
<Title text="Predict-Otron-9000 Chat"/>
// content for this welcome page
<Router>
<main>
<Routes fallback=|| "Page not found.".into_view()>
<Route path=StaticSegment("") view=ChatPage/>
</Routes>
</main>
</Router>
}
}
/// Renders the chat interface page
#[component]
fn ChatPage() -> impl IntoView {
// State for conversation messages
let messages = RwSignal::new(Vec::<ChatMessage>::new());
// State for current user input
let input_text = RwSignal::new(String::new());
// State for loading indicator
let is_loading = RwSignal::new(false);
// State for error messages
let error_message = RwSignal::new(Option::<String>::None);
// State for available models and selected model
let available_models = RwSignal::new(Vec::<ModelInfo>::new());
let selected_model = RwSignal::new(String::from("gemma-3-1b-it")); // Default model
// Client-side only: Fetch models on component mount
#[cfg(target_arch = "wasm32")]
{
use leptos::task::spawn_local;
spawn_local(async move {
match fetch_models().await {
Ok(models) => {
available_models.set(models);
}
Err(error) => {
console::log_1(&format!("Failed to fetch models: {}", error).into());
error_message.set(Some(format!("Failed to load models: {}", error)));
}
}
});
}
// Shared logic for sending a message
let send_message_logic = move || {
let user_input = input_text.get();
if user_input.trim().is_empty() {
return;
}
// Add user message to conversation
let user_message = ChatMessage {
role: "user".to_string(),
content: user_input.clone(),
};
messages.update(|msgs| msgs.push(user_message.clone()));
input_text.set(String::new());
is_loading.set(true);
error_message.set(None);
// Client-side only: Send chat completion request
#[cfg(target_arch = "wasm32")]
{
use leptos::task::spawn_local;
// Prepare messages for API call
let current_messages = messages.get();
let current_model = selected_model.get();
// Spawn async task to call API
spawn_local(async move {
match send_chat_completion(current_messages, current_model).await {
Ok(response_content) => {
let assistant_message = ChatMessage {
role: "assistant".to_string(),
content: response_content,
};
messages.update(|msgs| msgs.push(assistant_message));
is_loading.set(false);
}
Err(error) => {
console::log_1(&format!("API Error: {}", error).into());
error_message.set(Some(error));
is_loading.set(false);
}
}
});
}
};
// Button click handler
let on_button_click = {
let send_logic = send_message_logic.clone();
move |_: web_sys::MouseEvent| {
send_logic();
}
};
// Handle enter key press in input field
let on_key_down = move |ev: web_sys::KeyboardEvent| {
if ev.key() == "Enter" && !ev.shift_key() {
ev.prevent_default();
send_message_logic();
}
};
view! {
<div class="chat-container">
<div class="chat-header">
<h1>"Predict-Otron-9000 Chat"</h1>
<div class="model-selector">
<label for="model-select">"Model:"</label>
<select
id="model-select"
prop:value=move || selected_model.get()
on:change=move |ev| {
let new_model = event_target_value(&ev);
selected_model.set(new_model);
}
>
<For
each=move || available_models.get().into_iter()
key=|model| model.id.clone()
children=move |model| {
view! {
<option value=model.id.clone()>
{format!("{} ({})", model.id, model.owned_by)}
</option>
}
}
/>
</select>
</div>
</div>
<div class="chat-messages">
<For
each=move || messages.get().into_iter().enumerate()
key=|(i, _)| *i
children=move |(_, message)| {
let role_class = if message.role == "user" { "user-message" } else { "assistant-message" };
view! {
<div class=format!("message {}", role_class)>
<div class="message-role">{message.role.clone()}</div>
<div class="message-content">{message.content.clone()}</div>
</div>
}
}
/>
{move || {
if is_loading.get() {
view! {
<div class="message assistant-message loading">
<div class="message-role">"assistant"</div>
<div class="message-content">"Thinking..."</div>
</div>
}.into_any()
} else {
view! {}.into_any()
}
}}
</div>
{move || {
if let Some(error) = error_message.get() {
view! {
<div class="error-message">
"Error: " {error}
</div>
}.into_any()
} else {
view! {}.into_any()
}
}}
<div class="chat-input">
<textarea
placeholder="Type your message here... (Press Enter to send, Shift+Enter for new line)"
prop:value=move || input_text.get()
on:input=move |ev| input_text.set(event_target_value(&ev))
on:keydown=on_key_down
class:disabled=move || is_loading.get()
/>
<button
on:click=on_button_click
class:disabled=move || is_loading.get() || input_text.get().trim().is_empty()
>
"Send"
</button>
</div>
</div>
}
}

View File

@@ -0,0 +1,9 @@
pub mod app;
#[cfg(feature = "hydrate")]
#[wasm_bindgen::prelude::wasm_bindgen]
pub fn hydrate() {
use crate::app::*;
console_error_panic_hook::set_once();
leptos::mount::hydrate_body(App);
}

View File

@@ -0,0 +1,26 @@
#[cfg(feature = "ssr")]
#[tokio::main]
async fn main() {
use axum::Router;
use chat_ui::app::*;
use leptos::logging::log;
use leptos::prelude::*;
use leptos_axum::{generate_route_list, LeptosRoutes};
let conf = get_configuration(None).expect("failed to read config");
let addr = conf.leptos_options.site_addr;
// Build the app router with your extracted function
let app: Router = create_router(conf.leptos_options);
log!("listening on http://{}", &addr);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app.into_make_service())
.await
.unwrap();
}
#[cfg(not(feature = "ssr"))]
pub fn main() {
// no client-side main function
}

View File

@@ -0,0 +1,226 @@
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
background-color: #f5f5f5;
height: 100vh;
overflow: hidden;
}
.chat-container {
display: flex;
flex-direction: column;
height: 100vh;
max-width: 800px;
margin: 0 auto;
background-color: white;
box-shadow: 0 0 20px rgba(0, 0, 0, 0.1);
}
.chat-header {
background-color: #000000;
color: white;
padding: 1rem;
text-align: center;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
display: flex;
flex-direction: column;
gap: 1rem;
h1 {
margin: 0;
font-size: 1.5rem;
font-weight: 600;
}
.model-selector {
display: flex;
align-items: center;
justify-content: center;
gap: 0.5rem;
label {
font-weight: 500;
font-size: 0.9rem;
}
select {
background-color: white;
color: #374151;
border: 1px solid #d1d5db;
border-radius: 6px;
padding: 0.5rem 0.75rem;
font-size: 0.9rem;
font-family: inherit;
cursor: pointer;
min-width: 200px;
&:focus {
outline: none;
border-color: #663c99;
box-shadow: 0 0 0 2px rgba(29, 78, 216, 0.2);
}
option {
padding: 0.5rem;
}
}
}
}
.chat-messages {
flex: 1;
overflow-y: auto;
padding: 1rem;
display: flex;
flex-direction: column;
gap: 1rem;
background-color: #fafafa;
}
.message {
display: flex;
flex-direction: column;
gap: 0.5rem;
padding: 1rem;
border-radius: 12px;
max-width: 80%;
word-wrap: break-word;
&.user-message {
align-self: flex-end;
background-color: #2563eb;
color: white;
.message-role {
font-weight: 600;
font-size: 0.8rem;
opacity: 0.8;
text-transform: uppercase;
}
.message-content {
line-height: 1.5;
}
}
&.assistant-message {
align-self: flex-start;
background-color: #646873;
border: 1px solid #e5e7eb;
color: #f3f3f3;
.message-role {
font-weight: 600;
font-size: 0.8rem;
color: #c4c5cd;
text-transform: uppercase;
}
.message-content {
line-height: 1.5;
}
&.loading {
background-color: #f3f4f6;
border-color: #d1d5db;
.message-content {
font-style: italic;
color: #6b7280;
}
}
}
}
.error-message {
background-color: #fef2f2;
border: 1px solid #fca5a5;
color: #dc2626;
padding: 1rem;
margin: 0 1rem;
border-radius: 8px;
text-align: center;
font-weight: 500;
}
.chat-input {
display: flex;
gap: 0.5rem;
padding: 1rem;
background-color: white;
border-top: 1px solid #e5e7eb;
textarea {
flex: 1;
padding: 0.75rem;
border: 1px solid #d1d5db;
border-radius: 8px;
resize: none;
min-height: 60px;
max-height: 120px;
font-family: inherit;
font-size: 1rem;
line-height: 1.5;
&:focus {
outline: none;
border-color: #663c99;
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
}
&.disabled {
background-color: #f9fafb;
color: #6b7280;
cursor: not-allowed;
}
}
button {
padding: 0.75rem 1.5rem;
background-color: #663c99;
color: white;
border: none;
border-radius: 8px;
font-weight: 600;
cursor: pointer;
transition: background-color 0.2s ease;
align-self: flex-end;
&:hover:not(.disabled) {
background-color: #663c99;
}
&.disabled {
background-color: #9ca3af;
cursor: not-allowed;
}
&:focus {
outline: none;
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.3);
}
}
}
/* Scrollbar styling for webkit browsers */
.chat-messages::-webkit-scrollbar {
width: 6px;
}
.chat-messages::-webkit-scrollbar-track {
background: #f1f1f1;
}
.chat-messages::-webkit-scrollbar-thumb {
background: #c1c1c1;
border-radius: 3px;
}
.chat-messages::-webkit-scrollbar-thumb:hover {
background: #a8a8a8;
}

View File

@@ -3,7 +3,7 @@
A Rust/Typescript Hybrid A Rust/Typescript Hybrid
```console ```console
./cli [options] [prompt] bun run cli.ts [options] [prompt]
Simple CLI tool for testing the local OpenAI-compatible API server. Simple CLI tool for testing the local OpenAI-compatible API server.
@@ -14,10 +14,11 @@ Options:
--help Show this help message --help Show this help message
Examples: Examples:
./cli "What is the capital of France?" cd crates/cli/package
./cli --model gemma-3-1b-it --prompt "Hello, world!" bun run cli.ts "What is the capital of France?"
./cli --prompt "Who was the 16th president of the United States?" bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
./cli --list-models bun run cli.ts --prompt "Who was the 16th president of the United States?"
bun run cli.ts --list-models
The server must be running at http://localhost:8080 The server must be running at http://localhost:8080
``` ```

View File

@@ -1,4 +1,100 @@
# Embeddings Engine # Embeddings Engine
A high-performance text embeddings service that generates vector representations of text using state-of-the-art models. A high-performance text embeddings service that generates vector representations of text using state-of-the-art models. This crate wraps the FastEmbed library to provide embeddings with OpenAI-compatible API endpoints.
This crate wraps the fastembed crate to provide embeddings and partially adapts the openai specification.
## Overview
The embeddings-engine provides a standalone service for generating text embeddings that can be used for semantic search, similarity comparisons, and other NLP tasks. It's designed to be compatible with OpenAI's embeddings API format.
## Features
- **OpenAI-Compatible API**: `/v1/embeddings` endpoint matching OpenAI's specification
- **FastEmbed Integration**: Powered by the FastEmbed library for high-quality embeddings
- **Multiple Model Support**: Support for various embedding models
- **High Performance**: Optimized for fast embedding generation
- **Standalone Service**: Can run independently or as part of the predict-otron-9000 platform
## Building and Running
### Prerequisites
- Rust toolchain
- Internet connection for initial model downloads
### Standalone Server
```bash
cargo run --bin embeddings-engine --release
```
The service will start on port 8080 by default.
## API Usage
### Generate Embeddings
**Endpoint**: `POST /v1/embeddings`
**Request Body**:
```json
{
"input": "Your text to embed",
"model": "nomic-embed-text-v1.5"
}
```
**Response**:
```json
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [0.1, 0.2, 0.3, ...]
}
],
"model": "nomic-embed-text-v1.5",
"usage": {
"prompt_tokens": 0,
"total_tokens": 0
}
}
```
### Example Usage
**Using cURL**:
```bash
curl -s http://localhost:8080/v1/embeddings \
-H "Content-Type: application/json" \
-d '{
"input": "The quick brown fox jumps over the lazy dog",
"model": "nomic-embed-text-v1.5"
}' | jq
```
**Using Python OpenAI Client**:
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8080/v1",
api_key="dummy" # Not validated but required by client
)
response = client.embeddings.create(
input="Your text here",
model="nomic-embed-text-v1.5"
)
print(response.data[0].embedding)
```
## Configuration
The service can be configured through environment variables:
- `SERVER_PORT`: Port to run on (default: 8080)
- `RUST_LOG`: Logging level (default: info)
## Integration
This service is designed to work seamlessly with the predict-otron-9000 main server, but can also be deployed independently for dedicated embeddings workloads.

View File

@@ -10,15 +10,15 @@ edition = "2021"
candle-core = { git = "https://github.com/huggingface/candle.git" } candle-core = { git = "https://github.com/huggingface/candle.git" }
candle-nn = { git = "https://github.com/huggingface/candle.git" } candle-nn = { git = "https://github.com/huggingface/candle.git" }
candle-transformers = { git = "https://github.com/huggingface/candle.git" } candle-transformers = { git = "https://github.com/huggingface/candle.git" }
candle-examples = { git = "https://github.com/huggingface/candle.git" }
hf-hub = "0.4" hf-hub = "0.4"
tokenizers = "0.21" tokenizers = "0.22.0"
anyhow = "1.0" anyhow = "1.0"
clap = { version = "4.0", features = ["derive", "string"] } clap = { version = "4.0", features = ["derive", "string"] }
serde_json = "1.0" serde_json = "1.0"
tracing = "0.1" tracing = "0.1"
tracing-chrome = "0.7" tracing-chrome = "0.7"
tracing-subscriber = "0.3" tracing-subscriber = "0.3"
utils = {path = "../utils"}
[target.'cfg(target_os = "macos")'.dependencies] [target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] } candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }

View File

@@ -10,16 +10,17 @@ use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use clap::ValueEnum; use clap::ValueEnum;
// Removed gemma_cli import as it's not needed for the API // Removed gemma_cli import as it's not needed for the API
use candle_core::{utils, DType, Device, Tensor}; use candle_core::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write; use std::io::Write;
use tokenizers::Tokenizer;
use std::sync::mpsc::{self, Receiver, Sender}; use std::sync::mpsc::{self, Receiver, Sender};
use std::thread; use std::thread;
use tokenizers::Tokenizer;
use utils::hub_load_safetensors;
use utils::token_output_stream::TokenOutputStream;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] #[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
pub enum WhichModel { pub enum WhichModel {
@@ -85,9 +86,9 @@ pub struct TextGeneration {
fn device(cpu: bool) -> Result<Device> { fn device(cpu: bool) -> Result<Device> {
if cpu { if cpu {
Ok(Device::Cpu) Ok(Device::Cpu)
} else if utils::cuda_is_available() { } else if candle_core::utils::cuda_is_available() {
Ok(Device::new_cuda(0)?) Ok(Device::new_cuda(0)?)
} else if utils::metal_is_available() { } else if candle_core::utils::metal_is_available() {
Ok(Device::new_metal(0)?) Ok(Device::new_metal(0)?)
} else { } else {
Ok(Device::Cpu) Ok(Device::Cpu)
@@ -98,7 +99,7 @@ impl TextGeneration {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn new( fn new(
model: Model, model: Model,
tokenizer: Tokenizer, tokenizer: tokenizers::Tokenizer,
seed: u64, seed: u64,
temp: Option<f64>, temp: Option<f64>,
top_p: Option<f64>, top_p: Option<f64>,
@@ -262,10 +263,10 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
println!( println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}", "avx: {}, neon: {}, simd128: {}, f16c: {}",
utils::with_avx(), candle_core::utils::with_avx(),
utils::with_neon(), candle_core::utils::with_neon(),
utils::with_simd128(), candle_core::utils::with_simd128(),
utils::with_f16c() candle_core::utils::with_f16c()
); );
let device = device(cfg.cpu)?; let device = device(cfg.cpu)?;
@@ -318,7 +319,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let config_filename = repo.get("config.json")?; let config_filename = repo.get("config.json")?;
let filenames = match cfg.model { let filenames = match cfg.model {
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?], WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, _ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
}; };
println!("Retrieved files in {:?}", start.elapsed()); println!("Retrieved files in {:?}", start.elapsed());

View File

@@ -137,7 +137,7 @@ Parsing workspace at: ..
Output directory: ../generated-helm-chart Output directory: ../generated-helm-chart
Chart name: predict-otron-9000 Chart name: predict-otron-9000
Found 4 services: Found 4 services:
- leptos-app: ghcr.io/geoffsee/leptos-app:latest (port 8788) - chat-ui: ghcr.io/geoffsee/chat-ui:latest (port 8788)
- inference-engine: ghcr.io/geoffsee/inference-service:latest (port 8080) - inference-engine: ghcr.io/geoffsee/inference-service:latest (port 8080)
- embeddings-engine: ghcr.io/geoffsee/embeddings-service:latest (port 8080) - embeddings-engine: ghcr.io/geoffsee/embeddings-service:latest (port 8080)
- predict-otron-9000: ghcr.io/geoffsee/predict-otron-9000:latest (port 8080) - predict-otron-9000: ghcr.io/geoffsee/predict-otron-9000:latest (port 8080)

View File

@@ -31,8 +31,8 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
uuid = { version = "1.7.0", features = ["v4"] } uuid = { version = "1.7.0", features = ["v4"] }
reborrow = "0.5.5" reborrow = "0.5.5"
futures-util = "0.3.31" futures-util = "0.3.31"
gemma-runner = { path = "../gemma-runner" } gemma-runner = { path = "../gemma-runner", features = ["metal"] }
llama-runner = { path = "../llama-runner" } llama-runner = { path = "../llama-runner", features = ["metal"]}
[target.'cfg(target_os = "macos")'.dependencies] [target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] } candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }

View File

@@ -1,49 +1,9 @@
// use candle_core::Tensor;
use candle_transformers::models::csm::{LlamaConfig, LlamaModel}; use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1}; use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2}; use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3}; use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] #[derive(Clone, Debug)]
pub enum Which {
#[value(name = "2b")]
Base2B,
#[value(name = "7b")]
Base7B,
#[value(name = "2b-it")]
Instruct2B,
#[value(name = "7b-it")]
Instruct7B,
#[value(name = "1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
#[value(name = "code-2b")]
CodeBase2B,
#[value(name = "code-7b")]
CodeBase7B,
#[value(name = "code-2b-it")]
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
#[value(name = "2-2b")]
BaseV2_2B,
#[value(name = "2-2b-it")]
InstructV2_2B,
#[value(name = "2-9b")]
BaseV2_9B,
#[value(name = "2-9b-it")]
InstructV2_9B,
#[value(name = "3-1b")]
BaseV3_1B,
#[value(name = "3-1b-it")]
InstructV3_1B,
#[value(name = "llama-3.2-1b-it")]
LlamaInstruct3_2_1B,
#[value(name = "llama-3.2-3b-it")]
LlamaInstruct3_2_3B,
}
pub enum Model { pub enum Model {
V1(Model1), V1(Model1),
V2(Model2), V2(Model2),
@@ -66,48 +26,127 @@ impl Model {
} }
} }
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Family {
GemmaV1,
GemmaV2,
GemmaV3,
Llama,
}
#[derive(Clone, Copy, Debug)]
pub struct ModelMeta {
pub id: &'static str,
pub family: Family,
pub instruct: bool,
}
const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta {
ModelMeta { id, family, instruct }
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum Which {
// Gemma 1.x
#[value(name = "2b")]
Base2B,
#[value(name = "7b")]
Base7B,
#[value(name = "2b-it")]
Instruct2B,
#[value(name = "7b-it")]
Instruct7B,
#[value(name = "1.1-2b-it")]
InstructV1_1_2B,
#[value(name = "1.1-7b-it")]
InstructV1_1_7B,
// CodeGemma
#[value(name = "code-2b")]
CodeBase2B,
#[value(name = "code-7b")]
CodeBase7B,
#[value(name = "code-2b-it")]
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
// Gemma 2
#[value(name = "2-2b")]
BaseV2_2B,
#[value(name = "2-2b-it")]
InstructV2_2B,
#[value(name = "2-9b")]
BaseV2_9B,
#[value(name = "2-9b-it")]
InstructV2_9B,
// Gemma 3
#[value(name = "3-1b")]
BaseV3_1B,
#[value(name = "3-1b-it")]
InstructV3_1B,
// Llama 3.2 (use aliases instead of duplicate variants)
#[value(name = "llama-3.2-1b")]
Llama32_1B,
#[value(name = "llama-3.2-1b-it", alias = "llama-3.2-1b-instruct")]
Llama32_1BInstruct,
#[value(name = "llama-3.2-3b")]
Llama32_3B,
#[value(name = "llama-3.2-3b-it", alias = "llama-3.2-3b-instruct")]
Llama32_3BInstruct,
}
impl Which { impl Which {
pub fn to_model_id(&self) -> String { pub const fn meta(&self) -> ModelMeta {
use Family::*;
match self { match self {
Self::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(), // Gemma 1.x
Self::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(), Self::Base2B => m("google/gemma-2b", GemmaV1, false),
Self::Base2B => "google/gemma-2b".to_string(), Self::Base7B => m("google/gemma-7b", GemmaV1, false),
Self::Base7B => "google/gemma-7b".to_string(), Self::Instruct2B => m("google/gemma-2b-it", GemmaV1, true),
Self::Instruct2B => "google/gemma-2b-it".to_string(), Self::Instruct7B => m("google/gemma-7b-it", GemmaV1, true),
Self::Instruct7B => "google/gemma-7b-it".to_string(), Self::InstructV1_1_2B => m("google/gemma-1.1-2b-it", GemmaV1, true),
Self::CodeBase2B => "google/codegemma-2b".to_string(), Self::InstructV1_1_7B => m("google/gemma-1.1-7b-it", GemmaV1, true),
Self::CodeBase7B => "google/codegemma-7b".to_string(),
Self::CodeInstruct2B => "google/codegemma-2b-it".to_string(), // CodeGemma
Self::CodeInstruct7B => "google/codegemma-7b-it".to_string(), Self::CodeBase2B => m("google/codegemma-2b", GemmaV1, false),
Self::BaseV2_2B => "google/gemma-2-2b".to_string(), Self::CodeBase7B => m("google/codegemma-7b", GemmaV1, false),
Self::InstructV2_2B => "google/gemma-2-2b-it".to_string(), Self::CodeInstruct2B => m("google/codegemma-2b-it", GemmaV1, true),
Self::BaseV2_9B => "google/gemma-2-9b".to_string(), Self::CodeInstruct7B => m("google/codegemma-7b-it", GemmaV1, true),
Self::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
Self::BaseV3_1B => "google/gemma-3-1b-pt".to_string(), // Gemma 2
Self::InstructV3_1B => "google/gemma-3-1b-it".to_string(), Self::BaseV2_2B => m("google/gemma-2-2b", GemmaV2, false),
Self::LlamaInstruct3_2_1B => "meta-llama/Llama-3.2-1B-Instruct".to_string(), Self::InstructV2_2B => m("google/gemma-2-2b-it", GemmaV2, true),
Self::LlamaInstruct3_2_3B => "meta-llama/Llama-3.2-3B-Instruct".to_string(), Self::BaseV2_9B => m("google/gemma-2-9b", GemmaV2, false),
Self::InstructV2_9B => m("google/gemma-2-9b-it", GemmaV2, true),
// Gemma 3
Self::BaseV3_1B => m("google/gemma-3-1b-pt", GemmaV3, false),
Self::InstructV3_1B => m("google/gemma-3-1b-it", GemmaV3, true),
// Llama 3.2
Self::Llama32_1B => m("meta-llama/Llama-3.2-1B", Llama, false),
Self::Llama32_1BInstruct => m("meta-llama/Llama-3.2-1B-Instruct", Llama, true),
Self::Llama32_3B => m("meta-llama/Llama-3.2-3B", Llama, false),
Self::Llama32_3BInstruct => m("meta-llama/Llama-3.2-3B-Instruct", Llama, true),
} }
} }
pub fn to_model_id(&self) -> String {
self.meta().id.to_string()
}
pub fn is_instruct_model(&self) -> bool { pub fn is_instruct_model(&self) -> bool {
match self { self.meta().instruct
Self::Base2B
| Self::Base7B
| Self::CodeBase2B
| Self::CodeBase7B
| Self::BaseV2_2B
| Self::BaseV2_9B
| Self::BaseV3_1B => false,
_ => true,
}
} }
pub fn is_v3_model(&self) -> bool { pub fn is_v3_model(&self) -> bool {
matches!(self, Self::BaseV3_1B | Self::InstructV3_1B) matches!(self.meta().family, Family::GemmaV3)
} }
pub fn is_llama_model(&self) -> bool { pub fn is_llama_model(&self) -> bool {
matches!(self, Self::LlamaInstruct3_2_1B | Self::LlamaInstruct3_2_3B) matches!(self.meta().family, Family::Llama)
} }
} }

View File

@@ -42,13 +42,18 @@ pub struct AppState {
impl Default for AppState { impl Default for AppState {
fn default() -> Self { fn default() -> Self {
// Configure a default model to prevent 503 errors from the chat-ui
// This can be overridden by environment variables if needed
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
let gemma_config = GemmaInferenceConfig { let gemma_config = GemmaInferenceConfig {
model: gemma_runner::WhichModel::InstructV3_1B, model: gemma_runner::WhichModel::InstructV3_1B,
..Default::default() ..Default::default()
}; };
Self { Self {
model_type: ModelType::Gemma, model_type: ModelType::Gemma,
model_id: "gemma-3-1b-it".to_string(), model_id: default_model_id,
gemma_config: Some(gemma_config), gemma_config: Some(gemma_config),
llama_config: None, llama_config: None,
} }
@@ -59,6 +64,34 @@ impl Default for AppState {
// Helper functions // Helper functions
// ------------------------- // -------------------------
fn model_id_to_which(model_id: &str) -> Option<Which> {
let normalized = normalize_model_id(model_id);
match normalized.as_str() {
"gemma-2b" => Some(Which::Base2B),
"gemma-7b" => Some(Which::Base7B),
"gemma-2b-it" => Some(Which::Instruct2B),
"gemma-7b-it" => Some(Which::Instruct7B),
"gemma-1.1-2b-it" => Some(Which::InstructV1_1_2B),
"gemma-1.1-7b-it" => Some(Which::InstructV1_1_7B),
"codegemma-2b" => Some(Which::CodeBase2B),
"codegemma-7b" => Some(Which::CodeBase7B),
"codegemma-2b-it" => Some(Which::CodeInstruct2B),
"codegemma-7b-it" => Some(Which::CodeInstruct7B),
"gemma-2-2b" => Some(Which::BaseV2_2B),
"gemma-2-2b-it" => Some(Which::InstructV2_2B),
"gemma-2-9b" => Some(Which::BaseV2_9B),
"gemma-2-9b-it" => Some(Which::InstructV2_9B),
"gemma-3-1b" => Some(Which::BaseV3_1B),
"gemma-3-1b-it" => Some(Which::InstructV3_1B),
"llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct),
"llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct),
_ => None,
}
}
fn normalize_model_id(model_id: &str) -> String { fn normalize_model_id(model_id: &str) -> String {
model_id.to_lowercase().replace("_", "-") model_id.to_lowercase().replace("_", "-")
} }
@@ -116,91 +149,77 @@ pub async fn chat_completions_non_streaming_proxy(
state: AppState, state: AppState,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> { ) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
// Enforce model selection behavior: reject if a different model is requested // Use the model specified in the request
let configured_model = state.model_id.clone(); let model_id = request.model.clone();
let requested_model = request.model.clone(); let which_model = model_id_to_which(&model_id);
if requested_model.to_lowercase() != "default" {
let normalized_requested = normalize_model_id(&requested_model); // Validate that the requested model is supported
let normalized_configured = normalize_model_id(&configured_model); let which_model = match which_model {
if normalized_requested != normalized_configured { Some(model) => model,
None => {
return Err(( return Err((
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "error": {
"message": format!( "message": format!("Unsupported model: {}", model_id),
"Requested model '{}' is not available. This server is running '{}' only.", "type": "model_not_supported"
requested_model, configured_model
),
"type": "model_mismatch"
} }
})), })),
)); ));
} }
} };
let model_id = state.model_id.clone();
let max_tokens = request.max_tokens.unwrap_or(1000); let max_tokens = request.max_tokens.unwrap_or(1000);
// Build prompt based on model type // Build prompt based on model type
let prompt = match state.model_type { let prompt = if which_model.is_llama_model() {
ModelType::Gemma => build_gemma_prompt(&request.messages), // For Llama, just use the last user message for now
ModelType::Llama => { request
// For Llama, just use the last user message for now .messages
request .last()
.messages .and_then(|m| m.content.as_ref())
.last() .and_then(|c| match c {
.and_then(|m| m.content.as_ref()) MessageContent(Either::Left(text)) => Some(text.clone()),
.and_then(|c| match c { _ => None,
MessageContent(Either::Left(text)) => Some(text.clone()), })
_ => None, .unwrap_or_default()
}) } else {
.unwrap_or_default() build_gemma_prompt(&request.messages)
}
}; };
// Get streaming receiver based on model type // Get streaming receiver based on model type
let rx = let rx = if which_model.is_llama_model() {
match state.model_type { // Create Llama configuration dynamically
ModelType::Gemma => { let mut config = LlamaInferenceConfig::default();
if let Some(mut config) = state.gemma_config { config.prompt = prompt.clone();
config.prompt = prompt.clone(); config.max_tokens = max_tokens;
config.max_tokens = max_tokens; run_llama_inference(config).map_err(|e| (
run_gemma_api(config).map_err(|e| ( StatusCode::INTERNAL_SERVER_ERROR,
StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({
Json(serde_json::json!({ "error": { "message": format!("Error initializing Llama model: {}", e) }
"error": { "message": format!("Error initializing Gemma model: {}", e) } }))
})) ))?
))? } else {
} else { // Create Gemma configuration dynamically
return Err(( let gemma_model = if which_model.is_v3_model() {
StatusCode::INTERNAL_SERVER_ERROR, gemma_runner::WhichModel::InstructV3_1B
Json(serde_json::json!({ } else {
"error": { "message": "Gemma configuration not available" } gemma_runner::WhichModel::InstructV3_1B // Default fallback
})),
));
}
}
ModelType::Llama => {
if let Some(mut config) = state.llama_config {
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_llama_inference(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) }
}))
))?
} else {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": "Llama configuration not available" }
})),
));
}
}
}; };
let mut config = GemmaInferenceConfig {
model: gemma_model,
..Default::default()
};
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_gemma_api(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) }
}))
))?
};
// Collect all tokens from the stream // Collect all tokens from the stream
let mut completion = String::new(); let mut completion = String::new();
while let Ok(token_result) = rx.recv() { while let Ok(token_result) = rx.recv() {
@@ -258,27 +277,25 @@ async fn handle_streaming_request(
state: AppState, state: AppState,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> { ) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<Value>)> {
// Validate requested model vs configured model // Use the model specified in the request
let configured_model = state.model_id.clone(); let model_id = request.model.clone();
let requested_model = request.model.clone(); let which_model = model_id_to_which(&model_id);
if requested_model.to_lowercase() != "default" {
let normalized_requested = normalize_model_id(&requested_model); // Validate that the requested model is supported
let normalized_configured = normalize_model_id(&configured_model); let which_model = match which_model {
if normalized_requested != normalized_configured { Some(model) => model,
None => {
return Err(( return Err((
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "error": {
"message": format!( "message": format!("Unsupported model: {}", model_id),
"Requested model '{}' is not available. This server is running '{}' only.", "type": "model_not_supported"
requested_model, configured_model
),
"type": "model_mismatch"
} }
})), })),
)); ));
} }
} };
// Generate a unique ID and metadata // Generate a unique ID and metadata
let response_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', "")); let response_id = format!("chatcmpl-{}", Uuid::new_v4().to_string().replace('-', ""));
@@ -286,24 +303,22 @@ async fn handle_streaming_request(
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default() .unwrap_or_default()
.as_secs(); .as_secs();
let model_id = state.model_id.clone();
let max_tokens = request.max_tokens.unwrap_or(1000); let max_tokens = request.max_tokens.unwrap_or(1000);
// Build prompt based on model type // Build prompt based on model type
let prompt = match state.model_type { let prompt = if which_model.is_llama_model() {
ModelType::Gemma => build_gemma_prompt(&request.messages), // For Llama, just use the last user message for now
ModelType::Llama => { request
// For Llama, just use the last user message for now .messages
request .last()
.messages .and_then(|m| m.content.as_ref())
.last() .and_then(|c| match c {
.and_then(|m| m.content.as_ref()) MessageContent(Either::Left(text)) => Some(text.clone()),
.and_then(|c| match c { _ => None,
MessageContent(Either::Left(text)) => Some(text.clone()), })
_ => None, .unwrap_or_default()
}) } else {
.unwrap_or_default() build_gemma_prompt(&request.messages)
}
}; };
tracing::debug!("Formatted prompt: {}", prompt); tracing::debug!("Formatted prompt: {}", prompt);
@@ -330,51 +345,43 @@ async fn handle_streaming_request(
} }
// Get streaming receiver based on model type // Get streaming receiver based on model type
let model_rx = match state.model_type { let model_rx = if which_model.is_llama_model() {
ModelType::Gemma => { // Create Llama configuration dynamically
if let Some(mut config) = state.gemma_config { let mut config = LlamaInferenceConfig::default();
config.prompt = prompt.clone(); config.prompt = prompt.clone();
config.max_tokens = max_tokens; config.max_tokens = max_tokens;
match run_gemma_api(config) { match run_llama_inference(config) {
Ok(rx) => rx, Ok(rx) => rx,
Err(e) => { Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) }
})),
));
}
}
} else {
return Err(( return Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": "Gemma configuration not available" } "error": { "message": format!("Error initializing Llama model: {}", e) }
})), })),
)); ));
} }
} }
ModelType::Llama => { } else {
if let Some(mut config) = state.llama_config { // Create Gemma configuration dynamically
config.prompt = prompt.clone(); let gemma_model = if which_model.is_v3_model() {
config.max_tokens = max_tokens; gemma_runner::WhichModel::InstructV3_1B
match run_llama_inference(config) { } else {
Ok(rx) => rx, gemma_runner::WhichModel::InstructV3_1B // Default fallback
Err(e) => { };
return Err((
StatusCode::INTERNAL_SERVER_ERROR, let mut config = GemmaInferenceConfig {
Json(serde_json::json!({ model: gemma_model,
"error": { "message": format!("Error initializing Llama model: {}", e) } ..Default::default()
})), };
)); config.prompt = prompt.clone();
} config.max_tokens = max_tokens;
} match run_gemma_api(config) {
} else { Ok(rx) => rx,
Err(e) => {
return Err(( return Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": "Llama configuration not available" } "error": { "message": format!("Error initializing Gemma model: {}", e) }
})), })),
)); ));
} }
@@ -500,172 +507,69 @@ pub fn create_router(app_state: AppState) -> Router {
/// Handler for GET /v1/models - returns list of available models /// Handler for GET /v1/models - returns list of available models
pub async fn list_models() -> Json<ModelListResponse> { pub async fn list_models() -> Json<ModelListResponse> {
// Get all available model variants from the Which enum // Get all available model variants from the Which enum
let models = vec![ let which_variants = vec![
// Gemma models Which::Base2B,
Which::Base7B,
Which::Instruct2B,
Which::Instruct7B,
Which::InstructV1_1_2B,
Which::InstructV1_1_7B,
Which::CodeBase2B,
Which::CodeBase7B,
Which::CodeInstruct2B,
Which::CodeInstruct7B,
Which::BaseV2_2B,
Which::InstructV2_2B,
Which::BaseV2_9B,
Which::InstructV2_9B,
Which::BaseV3_1B,
Which::InstructV3_1B,
Which::Llama32_1B,
Which::Llama32_1BInstruct,
Which::Llama32_3B,
Which::Llama32_3BInstruct,
];
let models: Vec<Model> = which_variants.into_iter().map(|which| {
let meta = which.meta();
let model_id = match which {
Which::Base2B => "gemma-2b",
Which::Base7B => "gemma-7b",
Which::Instruct2B => "gemma-2b-it",
Which::Instruct7B => "gemma-7b-it",
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
Which::CodeBase2B => "codegemma-2b",
Which::CodeBase7B => "codegemma-7b",
Which::CodeInstruct2B => "codegemma-2b-it",
Which::CodeInstruct7B => "codegemma-7b-it",
Which::BaseV2_2B => "gemma-2-2b",
Which::InstructV2_2B => "gemma-2-2b-it",
Which::BaseV2_9B => "gemma-2-9b",
Which::InstructV2_9B => "gemma-2-9b-it",
Which::BaseV3_1B => "gemma-3-1b",
Which::InstructV3_1B => "gemma-3-1b-it",
Which::Llama32_1B => "llama-3.2-1b",
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
Which::Llama32_3B => "llama-3.2-3b",
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
};
let owned_by = if meta.id.starts_with("google/") {
"google"
} else if meta.id.starts_with("meta-llama/") {
"meta"
} else {
"unknown"
};
Model { Model {
id: "gemma-2b".to_string(), id: model_id.to_string(),
object: "model".to_string(), object: "model".to_string(),
created: 1686935002, // Using same timestamp as OpenAI example created: 1686935002, // Using same timestamp as OpenAI example
owned_by: "google".to_string(), owned_by: owned_by.to_string(),
}, }
Model { }).collect();
id: "gemma-7b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-7b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-1.1-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-1.1-7b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-2b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-7b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "codegemma-7b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-2b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-2b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-9b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-2-9b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-3-1b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
Model {
id: "gemma-3-1b-it".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "google".to_string(),
},
// Llama models
Model {
id: "llama-3.2-1b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "meta".to_string(),
},
Model {
id: "llama-3.2-1b-instruct".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "meta".to_string(),
},
Model {
id: "llama-3.2-3b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "meta".to_string(),
},
Model {
id: "llama-3.2-3b-instruct".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "meta".to_string(),
},
Model {
id: "smollm2-135m".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "smollm2-135m-instruct".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "smollm2-360m".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "smollm2-360m-instruct".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "smollm2-1.7b".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "smollm2-1.7b-instruct".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "huggingface".to_string(),
},
Model {
id: "tinyllama-1.1b-chat".to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: "tinyllama".to_string(),
},
];
Json(ModelListResponse { Json(ModelListResponse {
object: "list".to_string(), object: "list".to_string(),

View File

@@ -1,3 +0,0 @@
# Ensure getrandom works on wasm32-unknown-unknown without needing manual RUSTFLAGS
[target.wasm32-unknown-unknown]
rustflags = ["--cfg", "getrandom_backend=\"wasm_js\""]

View File

@@ -1,21 +0,0 @@
# Build stage
FROM rust:1-alpine AS builder
# Install build dependencies
RUN apk add --no-cache npm nodejs musl-dev pkgconfig openssl-dev git curl bash
RUN curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash
WORKDIR /app
# Copy manifest first (cache deps)
COPY . .
# Install cargo-leptos
RUN cargo binstall cargo-leptos
# Build release artifacts
RUN cargo leptos build --release
EXPOSE 8788
CMD ["cargo", "leptos", "serve", "--release"]

View File

@@ -1,520 +0,0 @@
use leptos::prelude::*;
use leptos_meta::{provide_meta_context, MetaTags, Stylesheet, Title};
use leptos_router::{
components::{Route, Router, Routes},
StaticSegment,
};
#[cfg(feature = "hydrate")]
use async_openai_wasm::config::OpenAIConfig;
#[cfg(feature = "hydrate")]
use async_openai_wasm::types::{FinishReason, Role};
#[cfg(feature = "hydrate")]
use async_openai_wasm::{
types::{
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs,
Model as OpenAIModel,
},
Client,
};
#[cfg(feature = "hydrate")]
use futures_util::StreamExt;
#[cfg(feature = "hydrate")]
use js_sys::Date;
#[cfg(feature = "hydrate")]
use leptos::task::spawn_local;
#[cfg(feature = "hydrate")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "hydrate")]
use std::collections::VecDeque;
#[cfg(feature = "hydrate")]
use uuid::Uuid;
#[cfg(feature = "hydrate")]
use web_sys::{HtmlInputElement, KeyboardEvent, SubmitEvent};
#[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub id: String,
pub role: String,
pub content: String,
pub timestamp: f64,
}
#[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageContent(
pub either::Either<String, Vec<std::collections::HashMap<String, MessageInnerContent>>>,
);
#[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageInnerContent(
pub either::Either<String, std::collections::HashMap<String, String>>,
);
#[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: Option<MessageContent>,
pub name: Option<String>,
}
#[cfg(feature = "hydrate")]
const DEFAULT_MODEL: &str = "default";
#[cfg(feature = "hydrate")]
async fn fetch_available_models() -> Result<Vec<OpenAIModel>, String> {
leptos::logging::log!(
"[DEBUG_LOG] fetch_available_models: Starting model fetch from http://localhost:8080/v1"
);
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
let client = Client::with_config(config);
match client.models().list().await {
Ok(response) => {
let model_count = response.data.len();
leptos::logging::log!(
"[DEBUG_LOG] fetch_available_models: Successfully fetched {} models",
model_count
);
if model_count > 0 {
let model_names: Vec<String> = response.data.iter().map(|m| m.id.clone()).collect();
leptos::logging::log!(
"[DEBUG_LOG] fetch_available_models: Available models: {:?}",
model_names
);
} else {
leptos::logging::log!(
"[DEBUG_LOG] fetch_available_models: No models returned by server"
);
}
Ok(response.data)
}
Err(e) => {
leptos::logging::log!(
"[DEBUG_LOG] fetch_available_models: Failed to fetch models: {:?}",
e
);
Err(format!("Failed to fetch models: {}", e))
}
}
}
pub fn shell(options: LeptosOptions) -> impl IntoView {
view! {
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<AutoReload options=options.clone() />
<HydrationScripts options/>
<MetaTags/>
</head>
<body>
<App/>
</body>
</html>
}
}
#[component]
pub fn App() -> impl IntoView {
// Provides context that manages stylesheets, titles, meta tags, etc.
provide_meta_context();
view! {
// injects a stylesheet into the document <head>
// id=leptos means cargo-leptos will hot-reload this stylesheet
<Stylesheet id="leptos" href="/pkg/leptos-app.css"/>
// sets the document title
<Title text="Chat Interface"/>
// content for this chat interface
<Router>
<main>
<Routes fallback=|| "Page not found.".into_view()>
<Route path=StaticSegment("") view=ChatInterface/>
</Routes>
</main>
</Router>
}
}
/// Renders the home page of your application.
#[component]
fn HomePage() -> impl IntoView {
// Creates a reactive value to update the button
let count = RwSignal::new(0);
let on_click = move |_| *count.write() += 1;
view! {
<h1>"Welcome to Leptos!"</h1>
<button on:click=on_click>"Click Me: " {count}</button>
}
}
/// Renders the chat interface
#[component]
fn ChatInterface() -> impl IntoView {
#[cfg(feature = "hydrate")]
{
ChatInterfaceImpl()
}
#[cfg(not(feature = "hydrate"))]
{
view! {
<div class="chat-container">
<h1>"Chat Interface"</h1>
<p>"Loading chat interface..."</p>
</div>
}
}
}
#[cfg(feature = "hydrate")]
#[component]
fn ChatInterfaceImpl() -> impl IntoView {
let (messages, set_messages) = RwSignal::new(VecDeque::<Message>::new()).split();
let (input_value, set_input_value) = RwSignal::new(String::new()).split();
let (is_loading, set_is_loading) = RwSignal::new(false).split();
let (available_models, set_available_models) = RwSignal::new(Vec::<OpenAIModel>::new()).split();
let (selected_model, set_selected_model) = RwSignal::new(DEFAULT_MODEL.to_string()).split();
let (models_loading, set_models_loading) = RwSignal::new(false).split();
// Fetch models on component initialization
Effect::new(move |_| {
spawn_local(async move {
set_models_loading.set(true);
match fetch_available_models().await {
Ok(models) => {
set_available_models.set(models);
set_models_loading.set(false);
}
Err(e) => {
leptos::logging::log!("Failed to fetch models: {}", e);
set_available_models.set(vec![]);
set_models_loading.set(false);
}
}
});
});
let send_message = Action::new_unsync(move |content: &String| {
let content = content.clone();
async move {
if content.trim().is_empty() {
leptos::logging::log!("[DEBUG_LOG] send_message: Empty content, skipping");
return;
}
leptos::logging::log!("[DEBUG_LOG] send_message: Starting message send process");
set_is_loading.set(true);
// Add user message to chat
let user_message = Message {
id: Uuid::new_v4().to_string(),
role: "user".to_string(),
content: content.clone(),
timestamp: Date::now(),
};
set_messages.update(|msgs| msgs.push_back(user_message.clone()));
set_input_value.set(String::new());
let mut chat_messages = Vec::new();
// Add system message
let system_message = ChatCompletionRequestSystemMessageArgs::default()
.content("You are a helpful assistant.")
.build()
.expect("failed to build system message");
chat_messages.push(system_message.into());
// Add history messages
let history_count = messages.get_untracked().len();
for msg in messages.get_untracked().iter() {
match msg.role.as_str() {
"user" => {
let message = ChatCompletionRequestUserMessageArgs::default()
.content(msg.content.clone())
.build()
.expect("failed to build user message");
chat_messages.push(message.into());
}
"assistant" => {
let message = ChatCompletionRequestAssistantMessageArgs::default()
.content(msg.content.clone())
.build()
.expect("failed to build assistant message");
chat_messages.push(message.into());
}
_ => {}
}
}
// Add current user message
let message = ChatCompletionRequestUserMessageArgs::default()
.content(user_message.content.clone())
.build()
.expect("failed to build user message");
chat_messages.push(message.into());
let current_model = selected_model.get_untracked();
let total_messages = chat_messages.len();
leptos::logging::log!("[DEBUG_LOG] send_message: Preparing request - model: '{}', history_count: {}, total_messages: {}",
current_model, history_count, total_messages);
let request = CreateChatCompletionRequestArgs::default()
.model(current_model.as_str())
.max_tokens(512u32)
.messages(chat_messages)
.stream(true)
.build()
.expect("failed to build request");
// Send request
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
let client = Client::with_config(config);
leptos::logging::log!("[DEBUG_LOG] send_message: Sending request to http://localhost:8080/v1 with model: '{}'", current_model);
match client.chat().create_stream(request).await {
Ok(mut stream) => {
leptos::logging::log!("[DEBUG_LOG] send_message: Successfully created stream");
let mut assistant_created = false;
let mut content_appended = false;
let mut chunks_received = 0;
while let Some(next) = stream.next().await {
match next {
Ok(chunk) => {
chunks_received += 1;
if let Some(choice) = chunk.choices.get(0) {
if !assistant_created {
if let Some(role) = &choice.delta.role {
if role == &Role::Assistant {
assistant_created = true;
let assistant_id = Uuid::new_v4().to_string();
set_messages.update(|msgs| {
msgs.push_back(Message {
id: assistant_id,
role: "assistant".to_string(),
content: String::new(),
timestamp: Date::now(),
});
});
}
}
}
if let Some(content) = &choice.delta.content {
if !content.is_empty() {
if !assistant_created {
assistant_created = true;
let assistant_id = Uuid::new_v4().to_string();
set_messages.update(|msgs| {
msgs.push_back(Message {
id: assistant_id,
role: "assistant".to_string(),
content: String::new(),
timestamp: Date::now(),
});
});
}
content_appended = true;
set_messages.update(|msgs| {
if let Some(last) = msgs.back_mut() {
if last.role == "assistant" {
last.content.push_str(content);
last.timestamp = Date::now();
}
}
});
}
}
if let Some(reason) = &choice.finish_reason {
if reason == &FinishReason::Stop {
leptos::logging::log!("[DEBUG_LOG] send_message: Received finish_reason=stop after {} chunks", chunks_received);
break;
}
}
}
}
Err(e) => {
leptos::logging::log!(
"[DEBUG_LOG] send_message: Stream error after {} chunks: {:?}",
chunks_received,
e
);
set_messages.update(|msgs| {
msgs.push_back(Message {
id: Uuid::new_v4().to_string(),
role: "system".to_string(),
content: format!("Stream error: {}", e),
timestamp: Date::now(),
});
});
break;
}
}
}
if assistant_created && !content_appended {
set_messages.update(|msgs| {
let should_pop = msgs
.back()
.map(|m| m.role == "assistant" && m.content.is_empty())
.unwrap_or(false);
if should_pop {
msgs.pop_back();
}
});
}
leptos::logging::log!("[DEBUG_LOG] send_message: Stream completed successfully, received {} chunks", chunks_received);
}
Err(e) => {
leptos::logging::log!(
"[DEBUG_LOG] send_message: Request failed with error: {:?}",
e
);
let error_message = Message {
id: Uuid::new_v4().to_string(),
role: "system".to_string(),
content: format!("Error: Request failed - {}", e),
timestamp: Date::now(),
};
set_messages.update(|msgs| msgs.push_back(error_message));
}
}
set_is_loading.set(false);
}
});
let on_input = move |ev| {
let input = event_target::<HtmlInputElement>(&ev);
set_input_value.set(input.value());
};
let on_submit = move |ev: SubmitEvent| {
ev.prevent_default();
let content = input_value.get();
send_message.dispatch(content);
};
let on_keypress = move |ev: KeyboardEvent| {
if ev.key() == "Enter" && !ev.shift_key() {
ev.prevent_default();
let content = input_value.get();
send_message.dispatch(content);
}
};
let on_model_change = move |ev| {
let select = event_target::<web_sys::HtmlSelectElement>(&ev);
set_selected_model.set(select.value());
};
let messages_list = move || {
messages
.get()
.into_iter()
.map(|message| {
let role_class = match message.role.as_str() {
"user" => "user-message",
"assistant" => "assistant-message",
_ => "system-message",
};
view! {
<div class=format!("message {}", role_class)>
<div class="message-role">{message.role}</div>
<div class="message-content">{message.content}</div>
</div>
}
})
.collect::<Vec<_>>()
};
let loading_indicator = move || {
is_loading.get().then(|| {
view! {
<div class="message assistant-message">
<div class="message-role">"assistant"</div>
<div class="message-content">"Thinking..."</div>
</div>
}
})
};
view! {
<div class="chat-container">
<h1>"Chat Interface"</h1>
<div class="model-selector">
<label for="model-select">"Model: "</label>
<select
id="model-select"
on:change=on_model_change
prop:value=selected_model
prop:disabled=models_loading
>
{move || {
if models_loading.get() {
vec![view! {
<option value={String::from("")} selected=false>{String::from("Loading models...")}</option>
}]
} else {
let models = available_models.get();
if models.is_empty() {
vec![view! {
<option value={String::from("default")} selected=true>{String::from("default")}</option>
}]
} else {
models.into_iter().map(|model| {
view! {
<option value=model.id.clone() selected={model.id == DEFAULT_MODEL}>{model.id.clone()}</option>
}
}).collect::<Vec<_>>()
}
}
}}
</select>
</div>
<div class="messages-container">
{messages_list}
{loading_indicator}
</div>
<form class="input-form" on:submit=on_submit>
<input
type="text"
class="message-input"
placeholder="Type your message here..."
prop:value=input_value
on:input=on_input
on:keypress=on_keypress
prop:disabled=is_loading
/>
<button
type="submit"
class="send-button"
prop:disabled=move || is_loading.get() || input_value.get().trim().is_empty()
>
"Send"
</button>
</form>
</div>
}
}

View File

@@ -1,30 +0,0 @@
pub mod app;
#[cfg(feature = "hydrate")]
#[wasm_bindgen::prelude::wasm_bindgen]
pub fn hydrate() {
use crate::app::*;
console_error_panic_hook::set_once();
leptos::mount::hydrate_body(App);
}
#[cfg(feature = "ssr")]
pub fn create_leptos_router() -> axum::Router {
use crate::app::*;
use axum::Router;
use leptos::prelude::*;
use leptos_axum::{generate_route_list, LeptosRoutes};
let conf = get_configuration(None).unwrap();
let leptos_options = conf.leptos_options;
// Generate the list of routes in your Leptos App
let routes = generate_route_list(App);
Router::new()
.leptos_routes(&leptos_options, routes, {
let leptos_options = leptos_options.clone();
move || shell(leptos_options.clone())
})
.fallback(leptos_axum::file_and_error_handler(shell))
.with_state(leptos_options)
}

View File

@@ -1,38 +0,0 @@
#[cfg(feature = "ssr")]
#[tokio::main]
async fn main() {
use axum::Router;
use leptos::logging::log;
use leptos::prelude::*;
use leptos_app::app::*;
use leptos_axum::{generate_route_list, LeptosRoutes};
let conf = get_configuration(None).unwrap();
let addr = conf.leptos_options.site_addr;
let leptos_options = conf.leptos_options;
// Generate the list of routes in your Leptos App
let routes = generate_route_list(App);
let app = Router::new()
.leptos_routes(&leptos_options, routes, {
let leptos_options = leptos_options.clone();
move || shell(leptos_options.clone())
})
.fallback(leptos_axum::file_and_error_handler(shell))
.with_state(leptos_options);
// run our app with hyper
// `axum::Server` is a re-export of `hyper::Server`
log!("listening on http://{}", &addr);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app.into_make_service())
.await
.unwrap();
}
#[cfg(not(feature = "ssr"))]
pub fn main() {
// no client-side main function
// unless we want this to work with e.g., Trunk for pure client-side testing
// see lib.rs for hydration function instead
}

View File

@@ -1,4 +0,0 @@
body {
font-family: sans-serif;
text-align: center;
}

View File

@@ -6,7 +6,7 @@ edition = "2021"
[dependencies] [dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git" } candle-core = { git = "https://github.com/huggingface/candle.git" }
candle-nn = { git = "https://github.com/huggingface/candle.git" } candle-nn = { git = "https://github.com/huggingface/candle.git" }
candle-transformers = { git = "https://github.com/huggingface/candle.git" } candle-transformers = { git = "https://github.com/huggingface/candle.git"}
hf-hub = "0.3" hf-hub = "0.3"
tokenizers = "0.20" tokenizers = "0.20"
anyhow = "1.0" anyhow = "1.0"

View File

@@ -82,7 +82,7 @@ impl Default for LlamaInferenceConfig {
// Performance flags // Performance flags
no_kv_cache: false, // keep cache ON for speed no_kv_cache: false, // keep cache ON for speed
use_flash_attn: true, // great speed boost if supported use_flash_attn: false, // great speed boost if supported
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed. // Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
dtype: Some("bf16".to_string()), dtype: Some("bf16".to_string()),

View File

@@ -19,7 +19,7 @@ tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.7.0", features = ["v4"] } uuid = { version = "1.7.0", features = ["v4"] }
reqwest = { version = "0.12", features = ["json"] } reqwest = { version = "0.12", features = ["json"] }
rust-embed = { version = "8.7.2", features = ["include-exclude"] } rust-embed = { version = "8.7.2", features = ["include-exclude", "axum"] }
# Dependencies for embeddings functionality # Dependencies for embeddings functionality
embeddings-engine = { path = "../embeddings-engine" } embeddings-engine = { path = "../embeddings-engine" }
@@ -28,9 +28,11 @@ embeddings-engine = { path = "../embeddings-engine" }
inference-engine = { path = "../inference-engine" } inference-engine = { path = "../inference-engine" }
# Dependencies for leptos web app # Dependencies for leptos web app
leptos-app = { path = "../leptos-app", features = ["ssr"] } #leptos-app = { path = "../leptos-app", features = ["ssr"] }
chat-ui = { path = "../chat-ui", features = ["ssr", "hydrate"], optional = false }
mime_guess = "2.0.5" mime_guess = "2.0.5"
log = "0.4.27"
[package.metadata.compose] [package.metadata.compose]

View File

@@ -4,22 +4,56 @@ mod middleware;
mod standalone_mode; mod standalone_mode;
use crate::standalone_mode::create_standalone_router; use crate::standalone_mode::create_standalone_router;
use axum::handler::Handler;
use axum::http::StatusCode as AxumStatusCode;
use axum::http::header;
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::get; use axum::routing::get;
use axum::{Router, http::Uri, response::Html, serve}; use axum::{Router, ServiceExt, http::Uri, response::Html, serve};
use config::ServerConfig; use config::ServerConfig;
use ha_mode::create_ha_router; use ha_mode::create_ha_router;
use inference_engine::AppState; use inference_engine::AppState;
use log::info;
use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore}; use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
use mime_guess::from_path;
use rust_embed::Embed; use rust_embed::Embed;
use std::env; use std::env;
use std::path::Component::ParentDir; use std::path::Component::ParentDir;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tower::MakeService;
use tower_http::classify::ServerErrorsFailureClass::StatusCode; use tower_http::classify::ServerErrorsFailureClass::StatusCode;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[derive(Embed)]
#[folder = "../../target/site"]
#[include = "*.js"]
#[include = "*.wasm"]
#[include = "*.css"]
#[include = "*.ico"]
struct Asset;
async fn static_handler(uri: Uri) -> axum::response::Response {
// Strip the leading `/`
let path = uri.path().trim_start_matches('/');
tracing::info!("Static file: {}", &path);
// If root is requested, serve index.html
let path = if path.is_empty() { "index.html" } else { path };
match Asset::get(path) {
Some(content) => {
let body = content.data.into_owned();
let mime = from_path(path).first_or_octet_stream();
([(header::CONTENT_TYPE, mime.as_ref())], body).into_response()
}
None => (AxumStatusCode::NOT_FOUND, "404 Not Found").into_response(),
}
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
// Initialize tracing // Initialize tracing
@@ -77,14 +111,17 @@ async fn main() {
// Create metrics layer // Create metrics layer
let metrics_layer = MetricsLayer::new(metrics_store); let metrics_layer = MetricsLayer::new(metrics_store);
let leptos_config = chat_ui::app::AppConfig::default();
// Create the leptos router for the web frontend // Create the leptos router for the web frontend
let leptos_router = leptos_app::create_leptos_router(); let leptos_router = chat_ui::app::create_router(leptos_config.config.leptos_options);
// Merge the service router with base routes and add middleware layers // Merge the service router with base routes and add middleware layers
let app = Router::new() let app = Router::new()
.route("/pkg/{*path}", get(static_handler))
.route("/health", get(|| async { "ok" })) .route("/health", get(|| async { "ok" }))
.merge(service_router) .merge(service_router)
.merge(leptos_router) // Add leptos web frontend routes .merge(leptos_router)
.layer(metrics_layer) // Add metrics tracking .layer(metrics_layer) // Add metrics tracking
.layer(cors) .layer(cors)
.layer(TraceLayer::new_for_http()); .layer(TraceLayer::new_for_http());
@@ -110,7 +147,7 @@ async fn main() {
tracing::info!(" POST /v1/embeddings - Text embeddings API"); tracing::info!(" POST /v1/embeddings - Text embeddings API");
tracing::info!(" POST /v1/chat/completions - Chat completions API"); tracing::info!(" POST /v1/chat/completions - Chat completions API");
serve(listener, app).await.unwrap(); serve(listener, app.into_make_service()).await.unwrap();
} }
fn log_config(config: ServerConfig) { fn log_config(config: ServerConfig) {

View File

@@ -6,7 +6,8 @@ pub fn create_standalone_router(server_config: ServerConfig) -> Router {
// Create unified router by merging embeddings and inference routers (existing behavior) // Create unified router by merging embeddings and inference routers (existing behavior)
let embeddings_router = embeddings_engine::create_embeddings_router(); let embeddings_router = embeddings_engine::create_embeddings_router();
// Create AppState with correct model configuration // Create AppState - no default model, must be configured explicitly
// This removes the hardcoded gemma-3-1b-it default behavior
let app_state = AppState::default(); let app_state = AppState::default();
// Get the inference router directly from the inference engine // Get the inference router directly from the inference engine

88
crates/utils/Cargo.toml Normal file
View File

@@ -0,0 +1,88 @@
[package]
name = "utils"
[lib]
path = "src/lib.rs"
[dependencies]
accelerate-src = {version = "0.3.2", optional = true }
candle-nn = {version = "0.9.1" }
candle-transformers = {version = "0.9.1" }
candle-flash-attn = {version = "0.9.1", optional = true }
candle-onnx = {version = "0.9.1", optional = true }
candle-core="0.9.1"
csv = "1.3.0"
anyhow = "1.0.99"
cudarc = {version = "0.17.3", optional = true }
half = {version = "2.6.0", optional = true }
hf-hub = {version = "0.4.3", features = ["tokio"] }
image = {version = "0.25.6" }
intel-mkl-src = {version = "0.8.1", optional = true }
num-traits = {version = "0.2.19" }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true }
pyo3 = { version = "0.22.0", features = [
"auto-initialize",
"abi3-py311",
], optional = true }
rayon = {version = "1.11.0" }
rubato = { version = "0.15.0", optional = true }
safetensors = {version = "0.6.2" }
serde = {version = "1.0.219" }
serde_json = {version = "1.0.143" }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = {version = "0.22.0", features = ["onig"] }
cpal = { version = "0.15.2", optional = true }
pdf2image = { version = "0.1.2", optional = true }
tekken-rs = { version = "0.1.1", optional = true }
[dev-dependencies]
anyhow = {version = "1.0.99" }
byteorder = {version = "1.5.0" }
clap = {version = "4.5.46" }
imageproc = {version = "0.25.0" }
memmap2 = {version = "0.9.8" }
rand = {version = "0.9.2" }
ab_glyph = {version = "0.2.31" }
tracing = {version = "0.1.41" }
tracing-chrome = {version = "0.7.2" }
tracing-subscriber = {version = "0.3.20" }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.43.0"
[build-dependencies]
anyhow = {version = "1.0.99" }
bindgen_cuda = { version = "0.1.1", optional = true }
#
[features]
default = []
accelerate = [
"dep:accelerate-src",
"candle-core/accelerate",
"candle-nn/accelerate",
"candle-transformers/accelerate",
]
cuda = [
"candle-core/cuda",
"candle-nn/cuda",
"candle-transformers/cuda",
"dep:bindgen_cuda",
]
cudnn = ["candle-core/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = [
"dep:intel-mkl-src",
"candle-core/mkl",
"candle-nn/mkl",
"candle-transformers/mkl",
]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle-core/metal", "candle-nn/metal"]
microphone = ["cpal", "rubato"]
encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
snac = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
tekken = ["tekken-rs"]

138
crates/utils/src/audio.rs Normal file
View File

@@ -0,0 +1,138 @@
use candle_core::{Result, Tensor};
// https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/data/audio_utils.py#L57
pub fn normalize_loudness(
wav: &Tensor,
sample_rate: u32,
loudness_compressor: bool,
) -> Result<Tensor> {
let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::<f32>()?;
if energy < 2e-3 {
return Ok(wav.clone());
}
let wav_array = wav.to_vec1::<f32>()?;
let mut meter = crate::bs1770::ChannelLoudnessMeter::new(sample_rate);
meter.push(wav_array.into_iter());
let power = meter.as_100ms_windows();
let loudness = match crate::bs1770::gated_mean(power) {
None => return Ok(wav.clone()),
Some(gp) => gp.loudness_lkfs() as f64,
};
let delta_loudness = -14. - loudness;
let gain = 10f64.powf(delta_loudness / 20.);
let wav = (wav * gain)?;
if loudness_compressor {
wav.tanh()
} else {
Ok(wav)
}
}
#[cfg(feature = "symphonia")]
pub fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
use symphonia::core::conv::FromSample;
fn conv<T>(
samples: &mut Vec<f32>,
data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>,
) where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
// Open the media source.
let src = std::fs::File::open(path).map_err(candle::Error::wrap)?;
// Create the media source stream.
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
// Create a probe hint using the file's extension. [Optional]
let hint = symphonia::core::probe::Hint::new();
// Use the default options for metadata and format readers.
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
// Probe the media source.
let probed = symphonia::default::get_probe()
.format(&hint, mss, &fmt_opts, &meta_opts)
.map_err(candle::Error::wrap)?;
// Get the instantiated format reader.
let mut format = probed.format;
// Find the first audio track with a known (decodeable) codec.
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
.ok_or_else(|| candle::Error::Msg("no supported audio tracks".to_string()))?;
// Use the default options for the decoder.
let dec_opts: DecoderOptions = Default::default();
// Create a decoder for the track.
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &dec_opts)
.map_err(|_| candle::Error::Msg("unsupported codec".to_string()))?;
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
// The decode loop.
while let Ok(packet) = format.next_packet() {
// Consume any new metadata that has been read since the last packet.
while !format.metadata().is_latest() {
format.metadata().pop();
}
// If the packet does not belong to the selected track, skip over it.
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet).map_err(candle::Error::wrap)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
#[cfg(feature = "rubato")]
pub fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
use rubato::Resampler;
let mut pcm_out =
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)
.map_err(candle::Error::wrap)?;
let mut output_buffer = resampler.output_buffer_allocate(true);
let mut pos_in = 0;
while pos_in + resampler.input_frames_next() < pcm_in.len() {
let (in_len, out_len) = resampler
.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)
.map_err(candle::Error::wrap)?;
pos_in += in_len;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
if pos_in < pcm_in.len() {
let (_in_len, out_len) = resampler
.process_partial_into_buffer(Some(&[&pcm_in[pos_in..]]), &mut output_buffer, None)
.map_err(candle::Error::wrap)?;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
Ok(pcm_out)
}

506
crates/utils/src/bs1770.rs Normal file
View File

@@ -0,0 +1,506 @@
// Copied from https://github.com/ruuda/bs1770/blob/master/src/lib.rs
// BS1770 -- Loudness analysis library conforming to ITU-R BS.1770
// Copyright 2020 Ruud van Asseldonk
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// A copy of the License has been included in the root of the repository.
//! Loudness analysis conforming to [ITU-R BS.1770-4][bs17704].
//!
//! This library offers the building blocks to perform BS.1770 loudness
//! measurements, but you need to put the pieces together yourself.
//!
//! [bs17704]: https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
//!
//! # Stereo integrated loudness example
//!
//! ```ignore
//! # fn load_stereo_audio() -> [Vec<i16>; 2] {
//! # [vec![0; 48_000], vec![0; 48_000]]
//! # }
//! #
//! let sample_rate_hz = 44_100;
//! let bits_per_sample = 16;
//! let channel_samples: [Vec<i16>; 2] = load_stereo_audio();
//!
//! // When converting integer samples to float, note that the maximum amplitude
//! // is `1 << (bits_per_sample - 1)`, one bit is the sign bit.
//! let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
//!
//! let channel_power: Vec<_> = channel_samples.iter().map(|samples| {
//! let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate_hz);
//! meter.push(samples.iter().map(|&s| s as f32 * normalizer));
//! meter.into_100ms_windows()
//! }).collect();
//!
//! let stereo_power = bs1770::reduce_stereo(
//! channel_power[0].as_ref(),
//! channel_power[1].as_ref(),
//! );
//!
//! let gated_power = bs1770::gated_mean(
//! stereo_power.as_ref()
//! ).unwrap_or(bs1770::Power(0.0));
//! println!("Integrated loudness: {:.1} LUFS", gated_power.loudness_lkfs());
//! ```
use std::f32;
/// Coefficients for a 2nd-degree infinite impulse response filter.
///
/// Coefficient a0 is implicitly 1.0.
#[derive(Clone)]
struct Filter {
a1: f32,
a2: f32,
b0: f32,
b1: f32,
b2: f32,
// The past two input and output samples.
x1: f32,
x2: f32,
y1: f32,
y2: f32,
}
impl Filter {
/// Stage 1 of th BS.1770-4 pre-filter.
pub fn high_shelf(sample_rate_hz: f32) -> Filter {
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
let gain_db = 3.999_843_8;
let q = 0.707_175_25;
let center_hz = 1_681.974_5;
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L134-L143.
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
let vh = 10.0_f32.powf(gain_db / 20.0);
let vb = vh.powf(0.499_666_78);
let a0 = 1.0 + k / q + k * k;
Filter {
b0: (vh + vb * k / q + k * k) / a0,
b1: 2.0 * (k * k - vh) / a0,
b2: (vh - vb * k / q + k * k) / a0,
a1: 2.0 * (k * k - 1.0) / a0,
a2: (1.0 - k / q + k * k) / a0,
x1: 0.0,
x2: 0.0,
y1: 0.0,
y2: 0.0,
}
}
/// Stage 2 of th BS.1770-4 pre-filter.
pub fn high_pass(sample_rate_hz: f32) -> Filter {
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
let q = 0.500_327_05;
let center_hz = 38.135_47;
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L145-L151
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
Filter {
a1: 2.0 * (k * k - 1.0) / (1.0 + k / q + k * k),
a2: (1.0 - k / q + k * k) / (1.0 + k / q + k * k),
b0: 1.0,
b1: -2.0,
b2: 1.0,
x1: 0.0,
x2: 0.0,
y1: 0.0,
y2: 0.0,
}
}
/// Feed the next input sample, get the next output sample.
#[inline(always)]
pub fn apply(&mut self, x0: f32) -> f32 {
let y0 = 0.0 + self.b0 * x0 + self.b1 * self.x1 + self.b2 * self.x2
- self.a1 * self.y1
- self.a2 * self.y2;
self.x2 = self.x1;
self.x1 = x0;
self.y2 = self.y1;
self.y1 = y0;
y0
}
}
/// Compensated sum, for summing many values of different orders of magnitude
/// accurately.
#[derive(Copy, Clone, PartialEq)]
struct Sum {
sum: f32,
residue: f32,
}
impl Sum {
#[inline(always)]
fn zero() -> Sum {
Sum {
sum: 0.0,
residue: 0.0,
}
}
#[inline(always)]
fn add(&mut self, x: f32) {
let sum = self.sum + (self.residue + x);
self.residue = (self.residue + x) - (sum - self.sum);
self.sum = sum;
}
}
/// The mean of the squares of the K-weighted samples in a window of time.
///
/// K-weighted power is equivalent to K-weighted loudness, the only difference
/// is one of scale: power is quadratic in sample amplitudes, whereas loudness
/// units are logarithmic. `loudness_lkfs` and `from_lkfs` convert between power,
/// and K-weighted Loudness Units relative to nominal Full Scale (LKFS).
///
/// The term “LKFS” (Loudness Units, K-Weighted, relative to nominal Full Scale)
/// is used in BS.1770-4 to emphasize K-weighting, but the term is otherwise
/// interchangeable with the more widespread term “LUFS” (Loudness Units,
/// relative to Full Scale). Loudness units are related to decibels in the
/// following sense: boosting a signal that has a loudness of
/// -<var>L<sub>K</sub></var> LUFS by <var>L<sub>K</sub></var> dB (by
/// multiplying the amplitude by 10<sup><var>L<sub>K</sub></var>/20</sup>) will
/// bring the loudness to 0 LUFS.
///
/// K-weighting refers to a high-shelf and high-pass filter that model the
/// effect that humans perceive a certain amount of power in low frequencies to
/// be less loud than the same amount of power in higher frequencies. In this
/// library the `Power` type is used exclusively to refer to power after applying K-weighting.
///
/// The nominal “full scale” is the range [-1.0, 1.0]. Because the power is the
/// mean square of the samples, if no input samples exceeded the full scale, the
/// power will be in the range [0.0, 1.0]. However, the power delivered by
/// multiple channels, which is a weighted sum over individual channel powers,
/// can exceed this range, because the weighted sum is not normalized.
#[derive(Copy, Clone, PartialEq, PartialOrd)]
pub struct Power(pub f32);
impl Power {
/// Convert Loudness Units relative to Full Scale into a squared sample amplitude.
///
/// This is the inverse of `loudness_lkfs`.
pub fn from_lkfs(lkfs: f32) -> Power {
// The inverse of the formula below.
Power(10.0_f32.powf((lkfs + 0.691) * 0.1))
}
/// Return the loudness of this window in Loudness Units, K-weighted, relative to Full Scale.
///
/// This is the inverse of `from_lkfs`.
pub fn loudness_lkfs(&self) -> f32 {
// Equation 2 (p.5) of BS.1770-4.
-0.691 + 10.0 * self.0.log10()
}
}
/// A `T` value for non-overlapping windows of audio, 100ms in length.
///
/// The `ChannelLoudnessMeter` applies K-weighting and then produces the power
/// for non-overlapping windows of 100ms duration.
///
/// These non-overlapping 100ms windows can later be combined into overlapping
/// windows of 400ms, spaced 100ms apart, to compute instantaneous loudness or
/// to perform a gated measurement, or they can be combined into even larger
/// windows for a momentary loudness measurement.
#[derive(Copy, Clone, Debug)]
pub struct Windows100ms<T> {
pub inner: T,
}
impl<T> Windows100ms<T> {
/// Wrap a new empty vector.
pub fn new() -> Windows100ms<Vec<T>> {
Windows100ms { inner: Vec::new() }
}
/// Apply `as_ref` to the inner value.
pub fn as_ref(&self) -> Windows100ms<&[Power]>
where
T: AsRef<[Power]>,
{
Windows100ms {
inner: self.inner.as_ref(),
}
}
/// Apply `as_mut` to the inner value.
pub fn as_mut(&mut self) -> Windows100ms<&mut [Power]>
where
T: AsMut<[Power]>,
{
Windows100ms {
inner: self.inner.as_mut(),
}
}
#[allow(clippy::len_without_is_empty)]
/// Apply `len` to the inner value.
pub fn len(&self) -> usize
where
T: AsRef<[Power]>,
{
self.inner.as_ref().len()
}
}
/// Measures K-weighted power of non-overlapping 100ms windows of a single channel of audio.
///
/// # Output
///
/// The output of the meter is an intermediate result in the form of power for
/// 100ms non-overlapping windows. The windows need to be processed further to
/// get one of the instantaneous, momentary, and integrated loudness
/// measurements defined in BS.1770.
///
/// The windows can also be inspected directly; the data is meaningful
/// on its own (the K-weighted power delivered in that window of time), but it
/// is not something that BS.1770 defines a term for.
///
/// # Multichannel audio
///
/// To perform a loudness measurement of multichannel audio, construct a
/// `ChannelLoudnessMeter` per channel, and later combine the measured power
/// with e.g. `reduce_stereo`.
///
/// # Instantaneous loudness
///
/// The instantaneous loudness is the power over a 400ms window, so you can
/// average four 100ms windows. No special functionality is implemented to help
/// with that at this time. ([Pull requests would be accepted.][contribute])
///
/// # Momentary loudness
///
/// The momentary loudness is the power over a 3-second window, so you can
/// average thirty 100ms windows. No special functionality is implemented to
/// help with that at this time. ([Pull requests would be accepted.][contribute])
///
/// # Integrated loudness
///
/// Use `gated_mean` to perform an integrated loudness measurement:
///
/// ```ignore
/// # use std::iter;
/// # use bs1770::{ChannelLoudnessMeter, gated_mean};
/// # let sample_rate_hz = 44_100;
/// # let samples_per_100ms = sample_rate_hz / 10;
/// # let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
/// # meter.push((0..44_100).map(|i| (i as f32 * 0.01).sin()));
/// let integrated_loudness_lkfs = gated_mean(meter.as_100ms_windows())
/// .unwrap_or(bs1770::Power(0.0))
/// .loudness_lkfs();
/// ```
///
/// [contribute]: https://github.com/ruuda/bs1770/blob/master/CONTRIBUTING.md
#[derive(Clone)]
pub struct ChannelLoudnessMeter {
/// The number of samples that fit in 100ms of audio.
samples_per_100ms: u32,
/// Stage 1 filter (head effects, high shelf).
filter_stage1: Filter,
/// Stage 2 filter (high-pass).
filter_stage2: Filter,
/// Sum of the squares over non-overlapping windows of 100ms.
windows: Windows100ms<Vec<Power>>,
/// The number of samples in the current unfinished window.
count: u32,
/// The sum of the squares of the samples in the current unfinished window.
square_sum: Sum,
}
impl ChannelLoudnessMeter {
/// Construct a new loudness meter for the given sample rate.
pub fn new(sample_rate_hz: u32) -> ChannelLoudnessMeter {
ChannelLoudnessMeter {
samples_per_100ms: sample_rate_hz / 10,
filter_stage1: Filter::high_shelf(sample_rate_hz as f32),
filter_stage2: Filter::high_pass(sample_rate_hz as f32),
windows: Windows100ms::new(),
count: 0,
square_sum: Sum::zero(),
}
}
/// Feed input samples for loudness analysis.
///
/// # Full scale
///
/// Full scale for the input samples is the interval [-1.0, 1.0]. If your
/// input consists of signed integer samples, you can convert as follows:
///
/// ```ignore
/// # let mut meter = bs1770::ChannelLoudnessMeter::new(44_100);
/// # let bits_per_sample = 16_usize;
/// # let samples = &[0_i16];
/// // Note that the maximum amplitude is `1 << (bits_per_sample - 1)`,
/// // one bit is the sign bit.
/// let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
/// meter.push(samples.iter().map(|&s| s as f32 * normalizer));
/// ```
///
/// # Repeated calls
///
/// You can call `push` multiple times to feed multiple batches of samples.
/// This is equivalent to feeding a single chained iterator. The leftover of
/// samples that did not fill a full 100ms window is not discarded:
///
/// ```ignore
/// # use std::iter;
/// # use bs1770::ChannelLoudnessMeter;
/// let sample_rate_hz = 44_100;
/// let samples_per_100ms = sample_rate_hz / 10;
/// let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
///
/// meter.push(iter::repeat(0.0).take(samples_per_100ms as usize - 1));
/// assert_eq!(meter.as_100ms_windows().len(), 0);
///
/// meter.push(iter::once(0.0));
/// assert_eq!(meter.as_100ms_windows().len(), 1);
/// ```
pub fn push<I: Iterator<Item = f32>>(&mut self, samples: I) {
let normalizer = 1.0 / self.samples_per_100ms as f32;
// LLVM, if you could go ahead and inline those apply calls, and then
// unroll and vectorize the loop, that'd be terrific.
for x in samples {
let y = self.filter_stage1.apply(x);
let z = self.filter_stage2.apply(y);
self.square_sum.add(z * z);
self.count += 1;
// TODO: Should this branch be marked cold?
if self.count == self.samples_per_100ms {
let mean_squares = Power(self.square_sum.sum * normalizer);
self.windows.inner.push(mean_squares);
// We intentionally do not reset the residue. That way, leftover
// energy from this window is not lost, so for the file overall,
// the sum remains more accurate.
self.square_sum.sum = 0.0;
self.count = 0;
}
}
}
/// Return a reference to the 100ms windows analyzed so far.
pub fn as_100ms_windows(&self) -> Windows100ms<&[Power]> {
self.windows.as_ref()
}
/// Return all 100ms windows analyzed so far.
pub fn into_100ms_windows(self) -> Windows100ms<Vec<Power>> {
self.windows
}
}
/// Combine power for multiple channels by taking a weighted sum.
///
/// Note that BS.1770-4 defines power for a multi-channel signal as a weighted
/// sum over channels which is not normalized. This means that a stereo signal
/// is inherently louder than a mono signal. For a mono signal played back on
/// stereo speakers, you should therefore still apply `reduce_stereo`, passing
/// in the same signal for both channels.
pub fn reduce_stereo(
left: Windows100ms<&[Power]>,
right: Windows100ms<&[Power]>,
) -> Windows100ms<Vec<Power>> {
assert_eq!(
left.len(),
right.len(),
"Channels must have the same length."
);
let mut result = Vec::with_capacity(left.len());
for (l, r) in left.inner.iter().zip(right.inner) {
result.push(Power(l.0 + r.0));
}
Windows100ms { inner: result }
}
/// In-place version of `reduce_stereo` that stores the result in the former left channel.
pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows100ms<&[Power]>) {
assert_eq!(
left.len(),
right.len(),
"Channels must have the same length."
);
for (l, r) in left.inner.iter_mut().zip(right.inner) {
l.0 += r.0;
}
}
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
///
/// The integrated loudness measurement is not just the average power over the
/// entire signal. BS.1770-4 defines two stages of gating that exclude
/// parts of the signal, to ensure that silent parts do not contribute to the
/// loudness measurement. This function performs that gating, and returns the
/// average power over the windows that were not excluded.
///
/// The result of this function is the integrated loudness measurement.
///
/// When no signal remains after applying the gate, this function returns
/// `None`. In particular, this happens when all of the signal is softer than
/// -70 LKFS, including a signal that consists of pure silence.
pub fn gated_mean(windows_100ms: Windows100ms<&[Power]>) -> Option<Power> {
let mut gating_blocks = Vec::with_capacity(windows_100ms.len());
// Stage 1: an absolute threshold of -70 LKFS. (Equation 6, p.6.)
let absolute_threshold = Power::from_lkfs(-70.0);
// Iterate over all 400ms windows.
for window in windows_100ms.inner.windows(4) {
// Note that the sum over channels has already been performed at this point.
let gating_block_power = Power(0.25 * window.iter().map(|mean| mean.0).sum::<f32>());
if gating_block_power > absolute_threshold {
gating_blocks.push(gating_block_power);
}
}
if gating_blocks.is_empty() {
return None;
}
// Compute the loudness after applying the absolute gate, in order to
// determine the threshold for the relative gate.
let mut sum_power = Sum::zero();
for &gating_block_power in &gating_blocks {
sum_power.add(gating_block_power.0);
}
let absolute_gated_power = Power(sum_power.sum / (gating_blocks.len() as f32));
// Stage 2: Apply the relative gate.
let relative_threshold = Power::from_lkfs(absolute_gated_power.loudness_lkfs() - 10.0);
let mut sum_power = Sum::zero();
let mut n_blocks = 0_usize;
for &gating_block_power in &gating_blocks {
if gating_block_power > relative_threshold {
sum_power.add(gating_block_power.0);
n_blocks += 1;
}
}
if n_blocks == 0 {
return None;
}
let relative_gated_power = Power(sum_power.sum / n_blocks as f32);
Some(relative_gated_power)
}

View File

@@ -0,0 +1,82 @@
pub const NAMES: [&str; 80] = [
"person",
"bicycle",
"car",
"motorbike",
"aeroplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"sofa",
"pottedplant",
"bed",
"diningtable",
"toilet",
"tvmonitor",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
];

1056
crates/utils/src/imagenet.rs Normal file

File diff suppressed because it is too large Load Diff

156
crates/utils/src/lib.rs Normal file
View File

@@ -0,0 +1,156 @@
extern crate candle_core;
extern crate candle_transformers;
extern crate tokenizers;
pub mod audio;
pub mod bs1770;
pub mod coco_classes;
pub mod imagenet;
pub mod token_output_stream;
pub mod wav;
use candle_core::{Device, Tensor, utils::{cuda_is_available, metal_is_available}};
pub fn device(cpu: bool) -> Result<Device, anyhow::Error> {
if cpu {
Ok(Device::Cpu)
} else if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!(
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
Ok(Device::Cpu)
}
}
pub fn load_image<P: AsRef<std::path::Path>>(
p: P,
resize_longest: Option<usize>,
) -> Result<(Tensor, usize, usize), anyhow::Error> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle_core::Error::wrap)?;
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
let img = match resize_longest {
None => img,
Some(resize_longest) => {
let (height, width) = (img.height(), img.width());
let resize_longest = resize_longest as u32;
let (height, width) = if height < width {
let h = (resize_longest * height) / width;
(h, resize_longest)
} else {
let w = (resize_longest * width) / height;
(resize_longest, w)
};
img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
}
};
let (height, width) = (img.height() as usize, img.width() as usize);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
Ok((data, initial_h, initial_w))
}
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
p: P,
width: usize,
height: usize,
) -> candle_core::Result<Tensor> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle_core::Error::wrap)?
.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let data = img.into_raw();
Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1))
}
/// Saves an image to disk using the image crate, this expects an input with shape
/// (c, height, width).
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<(), anyhow::Error> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
anyhow::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => anyhow::bail!("error saving image {p:?}"),
};
image.save(p).map_err(candle_core::Error::wrap)?;
Ok(())
}
/// Loads the safetensors files for a model from the hub based on a json index file.
pub fn hub_load_safetensors(
repo: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => anyhow::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| {
repo.get(v)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
})
.collect::<Result<Vec<_>, std::io::Error, >>()?;
Ok(safetensors_files)
}
pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
path: P,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
let path = path.as_ref();
let jsfile = std::fs::File::open(path.join(json_file))?;
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => anyhow::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => anyhow::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file);
}
}
let safetensors_files: Vec<_> = safetensors_files
.into_iter()
.map(|v| path.join(v))
.collect();
Ok(safetensors_files)
}

3
crates/utils/src/main.rs Normal file
View File

@@ -0,0 +1,3 @@
fn main() {
println!("Hello, world!");
}

View File

@@ -0,0 +1,85 @@
use candle_core::Result;
use tokenizers::Tokenizer;
pub struct TokenOutputStream {
tokenizer: tokenizers::Tokenizer,
tokens: Vec<u32>,
prev_index: usize,
current_index: usize,
}
impl TokenOutputStream {
pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
Self {
tokenizer,
tokens: Vec::new(),
prev_index: 0,
current_index: 0,
}
}
pub fn into_inner(self) -> tokenizers::Tokenizer {
self.tokenizer
}
fn decode(&self, tokens: &[u32]) -> Result<String> {
match self.tokenizer.decode(tokens, true) {
Ok(str) => Ok(str),
Err(err) => candle_core::bail!("cannot decode: {err}"),
}
}
// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
let prev_text = if self.tokens.is_empty() {
String::new()
} else {
let tokens = &self.tokens[self.prev_index..self.current_index];
self.decode(tokens)?
};
self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
let text = text.split_at(prev_text.len());
self.prev_index = self.current_index;
self.current_index = self.tokens.len();
Ok(Some(text.1.to_string()))
} else {
Ok(None)
}
}
pub fn decode_rest(&self) -> Result<Option<String>> {
let prev_text = if self.tokens.is_empty() {
String::new()
} else {
let tokens = &self.tokens[self.prev_index..self.current_index];
self.decode(tokens)?
};
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() {
let text = text.split_at(prev_text.len());
Ok(Some(text.1.to_string()))
} else {
Ok(None)
}
}
pub fn decode_all(&self) -> Result<String> {
self.decode(&self.tokens)
}
pub fn get_token(&self, token_s: &str) -> Option<u32> {
self.tokenizer.get_vocab(true).get(token_s).copied()
}
pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
&self.tokenizer
}
pub fn clear(&mut self) {
self.tokens.clear();
self.prev_index = 0;
self.current_index = 0;
}
}

56
crates/utils/src/wav.rs Normal file
View File

@@ -0,0 +1,56 @@
use std::io::prelude::*;
pub trait Sample {
fn to_i16(&self) -> i16;
}
impl Sample for f32 {
fn to_i16(&self) -> i16 {
(self.clamp(-1.0, 1.0) * 32767.0) as i16
}
}
impl Sample for f64 {
fn to_i16(&self) -> i16 {
(self.clamp(-1.0, 1.0) * 32767.0) as i16
}
}
impl Sample for i16 {
fn to_i16(&self) -> i16 {
*self
}
}
pub fn write_pcm_as_wav<W: Write, S: Sample>(
w: &mut W,
samples: &[S],
sample_rate: u32,
) -> std::io::Result<()> {
let len = 12u32; // header
let len = len + 24u32; // fmt
let len = len + samples.len() as u32 * 2 + 8; // data
let n_channels = 1u16;
let bytes_per_second = sample_rate * 2 * n_channels as u32;
w.write_all(b"RIFF")?;
w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
w.write_all(b"WAVE")?;
// Format block
w.write_all(b"fmt ")?;
w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes
w.write_all(&1u16.to_le_bytes())?; // PCM
w.write_all(&n_channels.to_le_bytes())?; // one channel
w.write_all(&sample_rate.to_le_bytes())?;
w.write_all(&bytes_per_second.to_le_bytes())?;
w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample
w.write_all(&16u16.to_le_bytes())?; // bits per sample
// Data block
w.write_all(b"data")?;
w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
for sample in samples.iter() {
w.write_all(&sample.to_i16().to_le_bytes())?
}
Ok(())
}

View File

@@ -52,7 +52,7 @@ graph TB
## Workspace Structure ## Workspace Structure
The project uses a 7-crate Rust workspace with TypeScript tooling, designed for maximum flexibility in deployment configurations. The project uses a 9-crate Rust workspace with TypeScript tooling, designed for maximum flexibility in deployment configurations.
```mermaid ```mermaid
graph TD graph TD
@@ -69,18 +69,15 @@ graph TD
end end
subgraph "Frontend" subgraph "Frontend"
D[leptos-app<br/>Edition: 2021<br/>Port: 3000/8788<br/>WASM/SSR] D[chat-ui<br/>Edition: 2021<br/>Port: 8788<br/>WASM UI]
end end
subgraph "Tooling" subgraph "Tooling"
L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment] L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment]
E[cli<br/>Edition: 2024<br/>TypeScript/Bun CLI]
end end
end end
subgraph "External Tooling"
E[scripts/cli.ts<br/>TypeScript/Bun<br/>OpenAI SDK]
end
subgraph "Dependencies" subgraph "Dependencies"
A --> B A --> B
A --> C A --> C
@@ -193,7 +190,7 @@ graph TB
end end
subgraph "Frontend" subgraph "Frontend"
D[leptos-app Pod<br/>:8788<br/>ClusterIP Service] D[chat-ui Pod<br/>:8788<br/>ClusterIP Service]
end end
subgraph "Ingress" subgraph "Ingress"

BIN
predict-otron-9000.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 248 KiB

14
scripts/build_ui.sh Executable file
View File

@@ -0,0 +1,14 @@
#!/usr/bin/env sh
# Resolve the project root (script_dir/..)
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
# Move into the chat-ui crate
cd "$PROJECT_ROOT/crates/chat-ui" || exit 1
# Build with cargo leptos
cargo leptos build --release
# Move the wasm file, keeping paths relative to the project root
mv "$PROJECT_ROOT/target/site/pkg/chat-ui.wasm" \
"$PROJECT_ROOT/target/site/pkg/chat-ui_bg.wasm"

17
scripts/run.sh Executable file
View File

@@ -0,0 +1,17 @@
#!/bin/bash
set -e
# Resolve the project root (script_dir/..)
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
# todo, conditionally run this only when those files change
"$PROJECT_ROOT/scripts/build_ui.sh"
# build the frontend first
# Start the unified predict-otron-9000 server on port 8080
export SERVER_PORT=${SERVER_PORT:-8080}
export RUST_LOG=${RUST_LOG:-info}
cd "$PROJECT_ROOT" || exit 1
cargo run --bin predict-otron-9000 --release

View File

@@ -1,7 +0,0 @@
#!/bin/bash
# Start the unified predict-otron-9000 server on port 8080
export SERVER_PORT=${SERVER_PORT:-8080}
export RUST_LOG=${RUST_LOG:-info}
cargo run --bin predict-otron-9000 --release