9 Commits

Author SHA1 Message Date
geoffsee
4380ac69d3 v0.1.5 already exists 2025-09-04 15:09:30 -04:00
geoffsee
e6f3351ebb minor version 2025-09-04 15:08:43 -04:00
geoffsee
3992532f15 fmt and clippy 2025-09-04 15:07:49 -04:00
geoffsee
3ecdd9ffa0 update deployment tooling to remove dependencies on unused metadata 2025-09-04 15:03:17 -04:00
geoffsee
296d4dbe7e add root dockerfile that contains binaries for all services 2025-09-04 14:54:20 -04:00
geoffsee
fb5098eba6 fix clippy errors 2025-09-04 13:53:00 -04:00
geoffsee
c1c583faab run cargo fmt 2025-09-04 13:45:25 -04:00
geoffsee
1e02b12cda fixes issue with model selection 2025-09-04 13:42:30 -04:00
geoffsee
ff55d882c7 reorg + update docs with new paths 2025-09-04 12:40:59 -04:00
62 changed files with 1109 additions and 672 deletions

35
.dockerignore Normal file
View File

@@ -0,0 +1,35 @@
# Git
.git
.gitignore
# Rust
target/
# Documentation
README.md
*.md
# IDE
.vscode/
.idea/
*.swp
*.swo
# OS
.DS_Store
Thumbs.db
# Logs
*.log
# Environment
.env
.env.local
# Dependencies
node_modules/
# Build artifacts
dist/
build/
.fastembed_cache

View File

@@ -44,7 +44,7 @@ jobs:
- name: Clippy - name: Clippy
shell: bash shell: bash
run: cargo clippy --all-targets run: cargo clippy --all
- name: Tests - name: Tests
shell: bash shell: bash

46
.github/workflows/docker.yml vendored Normal file
View File

@@ -0,0 +1,46 @@
name: Build and Push Docker Image
on:
tags:
- 'v*'
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
jobs:
build-and-push:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Log in to Container Registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=sha
- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -45,7 +45,7 @@ jobs:
- name: Clippy - name: Clippy
shell: bash shell: bash
run: cargo clippy --all-targets run: cargo clippy --all
- name: Tests - name: Tests
shell: bash shell: bash

1
.gitignore vendored
View File

@@ -77,3 +77,4 @@ venv/
!/scripts/cli.ts !/scripts/cli.ts
/**/.*.bun-build /**/.*.bun-build
/AGENTS.md /AGENTS.md
.claude

3
Cargo.lock generated
View File

@@ -2905,6 +2905,7 @@ dependencies = [
"clap", "clap",
"cpal", "cpal",
"either", "either",
"embeddings-engine",
"futures-util", "futures-util",
"gemma-runner", "gemma-runner",
"imageproc 0.24.0", "imageproc 0.24.0",
@@ -7040,7 +7041,7 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]] [[package]]
name = "utils" name = "utils"
version = "0.0.0" version = "0.1.4"
dependencies = [ dependencies = [
"ab_glyph", "ab_glyph",
"accelerate-src", "accelerate-src",

View File

@@ -3,17 +3,17 @@ members = [
"crates/predict-otron-9000", "crates/predict-otron-9000",
"crates/inference-engine", "crates/inference-engine",
"crates/embeddings-engine", "crates/embeddings-engine",
"crates/helm-chart-tool", "integration/helm-chart-tool",
"crates/llama-runner", "integration/llama-runner",
"crates/gemma-runner", "integration/gemma-runner",
"crates/cli", "integration/cli",
"crates/chat-ui" "crates/chat-ui"
, "crates/utils"] , "integration/utils"]
default-members = ["crates/predict-otron-9000"] default-members = ["crates/predict-otron-9000"]
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "0.1.4" version = "0.1.6"
# Compiler optimization profiles for the workspace # Compiler optimization profiles for the workspace
[profile.release] [profile.release]

50
Dockerfile Normal file
View File

@@ -0,0 +1,50 @@
# Multi-stage build for predict-otron-9000 workspace
FROM rust:1 AS builder
# Install system dependencies
RUN apt-get update && apt-get install -y \
pkg-config \
libssl-dev \
&& rm -rf /var/lib/apt/lists/*
# Create app directory
WORKDIR /app
# Copy workspace files
COPY Cargo.toml Cargo.lock ./
COPY crates/ ./crates/
COPY integration/ ./integration/
# Build all 3 main server binaries in release mode
RUN cargo build --release -p predict-otron-9000 --bin predict-otron-9000 --no-default-features -p embeddings-engine --bin embeddings-engine -p inference-engine --bin inference-engine
# Runtime stage
FROM debian:bookworm-slim AS runtime
# Install runtime dependencies
RUN apt-get update && apt-get install -y \
ca-certificates \
libssl3 \
&& rm -rf /var/lib/apt/lists/*
# Create app user
RUN useradd -r -s /bin/false -m -d /app appuser
# Set working directory
WORKDIR /app
# Copy binaries from builder stage
COPY --from=builder /app/target/release/predict-otron-9000 ./bin/
COPY --from=builder /app/target/release/embeddings-engine ./bin/
COPY --from=builder /app/target/release/inference-engine ./bin/
# Make binaries executable and change ownership
RUN chmod +x ./bin/* && chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Expose ports (adjust as needed based on your services)
EXPOSE 8080 8081 8082
# Default command (can be overridden)
CMD ["./bin/predict-otron-9000"]

View File

@@ -12,7 +12,7 @@ AI inference Server with OpenAI-compatible API (Limited Features)
> This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks. > This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks.
> By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems. > By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems.
Stability is currently best effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name. Stability is currently best-effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces. A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces.
@@ -53,14 +53,17 @@ The project uses a 9-crate Rust workspace plus TypeScript components:
crates/ crates/
├── predict-otron-9000/ # Main orchestration server (Rust 2024) ├── predict-otron-9000/ # Main orchestration server (Rust 2024)
├── inference-engine/ # Multi-model inference orchestrator (Rust 2021) ├── inference-engine/ # Multi-model inference orchestrator (Rust 2021)
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
└── chat-ui/ # WASM web frontend (Rust 2021)
integration/
├── cli/ # CLI client crate (Rust 2024)
│ └── package/
│ └── cli.ts # TypeScript/Bun CLI client
├── gemma-runner/ # Gemma model inference via Candle (Rust 2021) ├── gemma-runner/ # Gemma model inference via Candle (Rust 2021)
├── llama-runner/ # Llama model inference via Candle (Rust 2021) ├── llama-runner/ # Llama model inference via Candle (Rust 2021)
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
├── chat-ui/ # WASM web frontend (Rust 2021)
├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024) ├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024)
└── cli/ # CLI client crate (Rust 2024) └── utils/ # Shared utilities (Rust 2021)
└── package/
└── cli.ts # TypeScript/Bun CLI client
``` ```
### Service Architecture ### Service Architecture
@@ -160,16 +163,16 @@ cd crates/chat-ui
#### TypeScript CLI Client #### TypeScript CLI Client
```bash ```bash
# List available models # List available models
cd crates/cli/package && bun run cli.ts --list-models cd integration/cli/package && bun run cli.ts --list-models
# Chat completion # Chat completion
cd crates/cli/package && bun run cli.ts "What is the capital of France?" cd integration/cli/package && bun run cli.ts "What is the capital of France?"
# With specific model # With specific model
cd crates/cli/package && bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!" cd integration/cli/package && bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
# Show help # Show help
cd crates/cli/package && bun run cli.ts --help cd integration/cli/package && bun run cli.ts --help
``` ```
## API Usage ## API Usage
@@ -464,7 +467,7 @@ curl -s http://localhost:8080/v1/models | jq
**CLI client test:** **CLI client test:**
```bash ```bash
cd crates/cli/package && bun run cli.ts "What is 2+2?" cd integration/cli/package && bun run cli.ts "What is 2+2?"
``` ```
**Web frontend:** **Web frontend:**

View File

@@ -4,7 +4,7 @@
"": { "": {
"name": "predict-otron-9000", "name": "predict-otron-9000",
}, },
"crates/cli/package": { "integration/cli/package": {
"name": "cli", "name": "cli",
"dependencies": { "dependencies": {
"install": "^0.13.0", "install": "^0.13.0",
@@ -13,7 +13,7 @@
}, },
}, },
"packages": { "packages": {
"cli": ["cli@workspace:crates/cli/package"], "cli": ["cli@workspace:integration/cli/package"],
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="], "install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],

View File

@@ -3,6 +3,7 @@ name = "chat-ui"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
[lib] [lib]
crate-type = ["cdylib", "rlib"] crate-type = ["cdylib", "rlib"]
@@ -122,3 +123,7 @@ lib-default-features = false
# #
# Optional. Defaults to "release". # Optional. Defaults to "release".
lib-profile-release = "release" lib-profile-release = "release"
[[bin]]
name = "chat-ui"
path = "src/main.rs"

View File

@@ -257,7 +257,8 @@ pub fn send_chat_completion_stream(
break; break;
} }
let value = js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap(); let value =
js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap();
let array = js_sys::Uint8Array::new(&value); let array = js_sys::Uint8Array::new(&value);
let mut bytes = vec![0; array.length() as usize]; let mut bytes = vec![0; array.length() as usize];
array.copy_to(&mut bytes); array.copy_to(&mut bytes);
@@ -279,7 +280,9 @@ pub fn send_chat_completion_stream(
} }
// Parse JSON chunk // Parse JSON chunk
if let Ok(chunk) = serde_json::from_str::<StreamChatResponse>(data) { if let Ok(chunk) =
serde_json::from_str::<StreamChatResponse>(data)
{
if let Some(choice) = chunk.choices.first() { if let Some(choice) = chunk.choices.first() {
if let Some(content) = &choice.delta.content { if let Some(content) = &choice.delta.content {
on_chunk(content.clone()); on_chunk(content.clone());
@@ -365,7 +368,7 @@ fn ChatPage() -> impl IntoView {
// State for available models and selected model // State for available models and selected model
let available_models = RwSignal::new(Vec::<ModelInfo>::new()); let available_models = RwSignal::new(Vec::<ModelInfo>::new());
let selected_model = RwSignal::new(String::from("gemma-3-1b-it")); // Default model let selected_model = RwSignal::new(String::from("")); // Default model
// State for streaming response // State for streaming response
let streaming_content = RwSignal::new(String::new()); let streaming_content = RwSignal::new(String::new());
@@ -382,6 +385,7 @@ fn ChatPage() -> impl IntoView {
match fetch_models().await { match fetch_models().await {
Ok(models) => { Ok(models) => {
available_models.set(models); available_models.set(models);
selected_model.set(String::from("gemma-3-1b-it"));
} }
Err(error) => { Err(error) => {
console::log_1(&format!("Failed to fetch models: {}", error).into()); console::log_1(&format!("Failed to fetch models: {}", error).into());

View File

@@ -25,15 +25,9 @@ rand = "0.8.5"
async-openai = "0.28.3" async-openai = "0.28.3"
once_cell = "1.19.0" once_cell = "1.19.0"
[package.metadata.compose]
image = "ghcr.io/geoffsee/embeddings-service:latest"
port = 8080
# generates kubernetes manifests # generates kubernetes manifests
[package.metadata.kube] [package.metadata.kube]
image = "ghcr.io/geoffsee/embeddings-service:latest" image = "ghcr.io/geoffsee/predict-otron-9000:latest"
cmd = ["./bin/embeddings-engine"]
replicas = 1 replicas = 1
port = 8080 port = 8080

View File

@@ -1,42 +0,0 @@
# ---- Build stage ----
FROM rust:1-slim-bullseye AS builder
WORKDIR /usr/src/app
# Install build dependencies
RUN apt-get update && \
apt-get install -y --no-install-recommends \
pkg-config \
libssl-dev \
build-essential \
&& rm -rf /var/lib/apt/lists/*
# Cache deps first
COPY . ./
RUN rm -rf src
RUN mkdir src && echo "fn main() {}" > src/main.rs && echo "// lib" > src/lib.rs && cargo build --release
RUN rm -rf src
# Copy real sources and build
COPY . .
RUN cargo build --release
# ---- Runtime stage ----
FROM debian:bullseye-slim
# Install only what the compiled binary needs
RUN apt-get update && \
apt-get install -y --no-install-recommends \
libssl1.1 \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Copy binary from builder
COPY --from=builder /usr/src/app/target/release/embeddings-engine /usr/local/bin/
# Run as non-root user for safety
RUN useradd -m appuser
USER appuser
EXPOSE 8080
CMD ["embeddings-engine"]

View File

@@ -1,43 +1,225 @@
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput}; use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
use axum::{Json, Router, response::Json as ResponseJson, routing::post}; use axum::{Json, Router, http::StatusCode, response::Json as ResponseJson, routing::post};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use serde::Serialize;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
use tracing;
// Persistent model instance (singleton pattern) // Cache for multiple embedding models
static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| { static MODEL_CACHE: Lazy<RwLock<HashMap<EmbeddingModel, Arc<TextEmbedding>>>> =
tracing::info!("Initializing persistent embedding model (singleton)"); Lazy::new(|| RwLock::new(HashMap::new()));
#[derive(Serialize)]
pub struct ModelInfo {
pub id: String,
pub object: String,
pub owned_by: String,
pub description: String,
pub dimensions: usize,
}
#[derive(Serialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelInfo>,
}
// Function to convert model name strings to EmbeddingModel enum variants
fn parse_embedding_model(model_name: &str) -> Result<EmbeddingModel, String> {
match model_name {
// Sentence Transformers models
"sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => {
Ok(EmbeddingModel::AllMiniLML6V2)
}
"sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => {
Ok(EmbeddingModel::AllMiniLML6V2Q)
}
"sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => {
Ok(EmbeddingModel::AllMiniLML12V2)
}
"sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => {
Ok(EmbeddingModel::AllMiniLML12V2Q)
}
// BGE models
"BAAI/bge-base-en-v1.5" | "bge-base-en-v1.5" => Ok(EmbeddingModel::BGEBaseENV15),
"BAAI/bge-base-en-v1.5-q" | "bge-base-en-v1.5-q" => Ok(EmbeddingModel::BGEBaseENV15Q),
"BAAI/bge-large-en-v1.5" | "bge-large-en-v1.5" => Ok(EmbeddingModel::BGELargeENV15),
"BAAI/bge-large-en-v1.5-q" | "bge-large-en-v1.5-q" => Ok(EmbeddingModel::BGELargeENV15Q),
"BAAI/bge-small-en-v1.5" | "bge-small-en-v1.5" => Ok(EmbeddingModel::BGESmallENV15),
"BAAI/bge-small-en-v1.5-q" | "bge-small-en-v1.5-q" => Ok(EmbeddingModel::BGESmallENV15Q),
"BAAI/bge-small-zh-v1.5" | "bge-small-zh-v1.5" => Ok(EmbeddingModel::BGESmallZHV15),
"BAAI/bge-large-zh-v1.5" | "bge-large-zh-v1.5" => Ok(EmbeddingModel::BGELargeZHV15),
// Nomic models
"nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => {
Ok(EmbeddingModel::NomicEmbedTextV1)
}
"nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => {
Ok(EmbeddingModel::NomicEmbedTextV15)
}
"nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => {
Ok(EmbeddingModel::NomicEmbedTextV15Q)
}
// Paraphrase models
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
| "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2),
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q"
| "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q),
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
| "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2),
// ModernBert
"lightonai/modernbert-embed-large" | "modernbert-embed-large" => {
Ok(EmbeddingModel::ModernBertEmbedLarge)
}
// Multilingual E5 models
"intfloat/multilingual-e5-small" | "multilingual-e5-small" => {
Ok(EmbeddingModel::MultilingualE5Small)
}
"intfloat/multilingual-e5-base" | "multilingual-e5-base" => {
Ok(EmbeddingModel::MultilingualE5Base)
}
"intfloat/multilingual-e5-large" | "multilingual-e5-large" => {
Ok(EmbeddingModel::MultilingualE5Large)
}
// Mixedbread models
"mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => {
Ok(EmbeddingModel::MxbaiEmbedLargeV1)
}
"mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => {
Ok(EmbeddingModel::MxbaiEmbedLargeV1Q)
}
// GTE models
"Alibaba-NLP/gte-base-en-v1.5" | "gte-base-en-v1.5" => Ok(EmbeddingModel::GTEBaseENV15),
"Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => {
Ok(EmbeddingModel::GTEBaseENV15Q)
}
"Alibaba-NLP/gte-large-en-v1.5" | "gte-large-en-v1.5" => Ok(EmbeddingModel::GTELargeENV15),
"Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => {
Ok(EmbeddingModel::GTELargeENV15Q)
}
// CLIP model
"Qdrant/clip-ViT-B-32-text" | "clip-vit-b-32" => Ok(EmbeddingModel::ClipVitB32),
// Jina model
"jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => {
Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode)
}
_ => Err(format!("Unsupported embedding model: {}", model_name)),
}
}
// Function to get model dimensions
fn get_model_dimensions(model: &EmbeddingModel) -> usize {
match model {
EmbeddingModel::AllMiniLML6V2 | EmbeddingModel::AllMiniLML6V2Q => 384,
EmbeddingModel::AllMiniLML12V2 | EmbeddingModel::AllMiniLML12V2Q => 384,
EmbeddingModel::BGEBaseENV15 | EmbeddingModel::BGEBaseENV15Q => 768,
EmbeddingModel::BGELargeENV15 | EmbeddingModel::BGELargeENV15Q => 1024,
EmbeddingModel::BGESmallENV15 | EmbeddingModel::BGESmallENV15Q => 384,
EmbeddingModel::BGESmallZHV15 => 512,
EmbeddingModel::BGELargeZHV15 => 1024,
EmbeddingModel::NomicEmbedTextV1
| EmbeddingModel::NomicEmbedTextV15
| EmbeddingModel::NomicEmbedTextV15Q => 768,
EmbeddingModel::ParaphraseMLMiniLML12V2 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => 384,
EmbeddingModel::ParaphraseMLMpnetBaseV2 => 768,
EmbeddingModel::ModernBertEmbedLarge => 1024,
EmbeddingModel::MultilingualE5Small => 384,
EmbeddingModel::MultilingualE5Base => 768,
EmbeddingModel::MultilingualE5Large => 1024,
EmbeddingModel::MxbaiEmbedLargeV1 | EmbeddingModel::MxbaiEmbedLargeV1Q => 1024,
EmbeddingModel::GTEBaseENV15 | EmbeddingModel::GTEBaseENV15Q => 768,
EmbeddingModel::GTELargeENV15 | EmbeddingModel::GTELargeENV15Q => 1024,
EmbeddingModel::ClipVitB32 => 512,
EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
}
}
// Function to get or create a model from cache
fn get_or_create_model(embedding_model: EmbeddingModel) -> Result<Arc<TextEmbedding>, String> {
// First try to get from cache (read lock)
{
let cache = MODEL_CACHE
.read()
.map_err(|e| format!("Failed to acquire read lock: {}", e))?;
if let Some(model) = cache.get(&embedding_model) {
tracing::debug!("Using cached model: {:?}", embedding_model);
return Ok(Arc::clone(model));
}
}
// Model not in cache, create it (write lock)
let mut cache = MODEL_CACHE
.write()
.map_err(|e| format!("Failed to acquire write lock: {}", e))?;
// Double-check after acquiring write lock
if let Some(model) = cache.get(&embedding_model) {
tracing::debug!("Using cached model (double-check): {:?}", embedding_model);
return Ok(Arc::clone(model));
}
tracing::info!("Initializing new embedding model: {:?}", embedding_model);
let model_start_time = std::time::Instant::now(); let model_start_time = std::time::Instant::now();
let model = TextEmbedding::try_new( let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true), InitOptions::new(embedding_model.clone()).with_show_download_progress(true),
) )
.expect("Failed to initialize persistent embedding model"); .map_err(|e| format!("Failed to initialize model {:?}: {}", embedding_model, e))?;
let model_init_time = model_start_time.elapsed(); let model_init_time = model_start_time.elapsed();
tracing::info!( tracing::info!(
"Persistent embedding model initialized in {:.2?}", "Embedding model {:?} initialized in {:.2?}",
embedding_model,
model_init_time model_init_time
); );
model let model_arc = Arc::new(model);
}); cache.insert(embedding_model.clone(), Arc::clone(&model_arc));
Ok(model_arc)
}
pub async fn embeddings_create( pub async fn embeddings_create(
Json(payload): Json<CreateEmbeddingRequest>, Json(payload): Json<CreateEmbeddingRequest>,
) -> ResponseJson<serde_json::Value> { ) -> Result<ResponseJson<serde_json::Value>, (StatusCode, String)> {
// Start timing the entire process // Start timing the entire process
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
// Phase 1: Access persistent model instance // Phase 1: Parse and get the embedding model
let model_start_time = std::time::Instant::now(); let model_start_time = std::time::Instant::now();
// Access the lazy-initialized persistent model instance let embedding_model = match parse_embedding_model(&payload.model) {
// This will only initialize the model on the first request Ok(model) => model,
Err(e) => {
tracing::error!("Invalid model requested: {}", e);
return Err((StatusCode::BAD_REQUEST, format!("Invalid model: {}", e)));
}
};
let model = match get_or_create_model(embedding_model.clone()) {
Ok(model) => model,
Err(e) => {
tracing::error!("Failed to get/create model: {}", e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Model initialization failed: {}", e),
));
}
};
let model_access_time = model_start_time.elapsed(); let model_access_time = model_start_time.elapsed();
tracing::debug!( tracing::debug!(
"Persistent model access completed in {:.2?}", "Model access/creation completed in {:.2?}",
model_access_time model_access_time
); );
@@ -65,9 +247,13 @@ pub async fn embeddings_create(
// Phase 3: Generate embeddings // Phase 3: Generate embeddings
let embedding_start_time = std::time::Instant::now(); let embedding_start_time = std::time::Instant::now();
let embeddings = EMBEDDING_MODEL let embeddings = model.embed(texts_from_embedding_input, None).map_err(|e| {
.embed(texts_from_embedding_input, None) tracing::error!("Failed to generate embeddings: {}", e);
.expect("failed to embed document"); (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Embedding generation failed: {}", e),
)
})?;
let embedding_generation_time = embedding_start_time.elapsed(); let embedding_generation_time = embedding_start_time.elapsed();
tracing::info!( tracing::info!(
@@ -117,8 +303,9 @@ pub async fn embeddings_create(
// Generate a random non-zero embedding // Generate a random non-zero embedding
use rand::Rng; use rand::Rng;
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let mut random_embedding = Vec::with_capacity(768); let expected_dimensions = get_model_dimensions(&embedding_model);
for _ in 0..768 { let mut random_embedding = Vec::with_capacity(expected_dimensions);
for _ in 0..expected_dimensions {
// Generate random values between -1.0 and 1.0, excluding 0 // Generate random values between -1.0 and 1.0, excluding 0
let mut val = 0.0; let mut val = 0.0;
while val == 0.0 { while val == 0.0 {
@@ -138,18 +325,19 @@ pub async fn embeddings_create(
random_embedding random_embedding
} else { } else {
// Check if dimensions parameter is provided and pad the embeddings if necessary // Check if dimensions parameter is provided and pad the embeddings if necessary
let mut padded_embedding = embeddings[0].clone(); let padded_embedding = embeddings[0].clone();
// If the client expects 768 dimensions but our model produces fewer, pad with zeros // Use the actual model dimensions instead of hardcoded 768
let target_dimension = 768; let actual_dimensions = padded_embedding.len();
if padded_embedding.len() < target_dimension { let expected_dimensions = get_model_dimensions(&embedding_model);
let padding_needed = target_dimension - padded_embedding.len();
tracing::trace!( if actual_dimensions != expected_dimensions {
"Padding embedding with {} zeros to reach {} dimensions", tracing::warn!(
padding_needed, "Model {:?} produced {} dimensions but expected {}",
target_dimension embedding_model,
actual_dimensions,
expected_dimensions
); );
padded_embedding.extend(vec![0.0; padding_needed]);
} }
padded_embedding padded_embedding
@@ -203,11 +391,234 @@ pub async fn embeddings_create(
postprocessing_time postprocessing_time
); );
ResponseJson(response) Ok(ResponseJson(response))
}
pub async fn models_list() -> ResponseJson<ModelsResponse> {
let models = vec![
ModelInfo {
id: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence Transformer model, MiniLM-L6-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/all-MiniLM-L6-v2-q".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Quantized Sentence Transformer model, MiniLM-L6-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/all-MiniLM-L12-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence Transformer model, MiniLM-L12-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/all-MiniLM-L12-v2-q".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Quantized Sentence Transformer model, MiniLM-L12-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "BAAI/bge-base-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the base English model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "BAAI/bge-base-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "Quantized v1.5 release of the base English model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "BAAI/bge-large-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the large English model".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "BAAI/bge-large-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "Quantized v1.5 release of the large English model".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "BAAI/bge-small-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the fast and default English model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "BAAI/bge-small-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "Quantized v1.5 release of the fast and default English model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "BAAI/bge-small-zh-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the small Chinese model".to_string(),
dimensions: 512,
},
ModelInfo {
id: "BAAI/bge-large-zh-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the large Chinese model".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "nomic-ai/nomic-embed-text-v1".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "8192 context length english model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "nomic-ai/nomic-embed-text-v1.5".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "v1.5 release of the 8192 context length english model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "nomic-ai/nomic-embed-text-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "Quantized v1.5 release of the 8192 context length english model"
.to_string(),
dimensions: 768,
},
ModelInfo {
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Multi-lingual model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Quantized Multi-lingual model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence-transformers model for tasks like clustering or semantic search"
.to_string(),
dimensions: 768,
},
ModelInfo {
id: "lightonai/modernbert-embed-large".to_string(),
object: "model".to_string(),
owned_by: "lightonai".to_string(),
description: "Large model of ModernBert Text Embeddings".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "intfloat/multilingual-e5-small".to_string(),
object: "model".to_string(),
owned_by: "intfloat".to_string(),
description: "Small model of multilingual E5 Text Embeddings".to_string(),
dimensions: 384,
},
ModelInfo {
id: "intfloat/multilingual-e5-base".to_string(),
object: "model".to_string(),
owned_by: "intfloat".to_string(),
description: "Base model of multilingual E5 Text Embeddings".to_string(),
dimensions: 768,
},
ModelInfo {
id: "intfloat/multilingual-e5-large".to_string(),
object: "model".to_string(),
owned_by: "intfloat".to_string(),
description: "Large model of multilingual E5 Text Embeddings".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "mixedbread-ai/mxbai-embed-large-v1".to_string(),
object: "model".to_string(),
owned_by: "mixedbread-ai".to_string(),
description: "Large English embedding model from MixedBreed.ai".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "mixedbread-ai/mxbai-embed-large-v1-q".to_string(),
object: "model".to_string(),
owned_by: "mixedbread-ai".to_string(),
description: "Quantized Large English embedding model from MixedBreed.ai".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "Alibaba-NLP/gte-base-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Base multilingual embedding model from Alibaba".to_string(),
dimensions: 768,
},
ModelInfo {
id: "Alibaba-NLP/gte-base-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Quantized Base multilingual embedding model from Alibaba".to_string(),
dimensions: 768,
},
ModelInfo {
id: "Alibaba-NLP/gte-large-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Large multilingual embedding model from Alibaba".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "Alibaba-NLP/gte-large-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Quantized Large multilingual embedding model from Alibaba".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "Qdrant/clip-ViT-B-32-text".to_string(),
object: "model".to_string(),
owned_by: "Qdrant".to_string(),
description: "CLIP text encoder based on ViT-B/32".to_string(),
dimensions: 512,
},
ModelInfo {
id: "jinaai/jina-embeddings-v2-base-code".to_string(),
object: "model".to_string(),
owned_by: "jinaai".to_string(),
description: "Jina embeddings v2 base code".to_string(),
dimensions: 768,
},
];
ResponseJson(ModelsResponse {
object: "list".to_string(),
data: models,
})
} }
pub fn create_embeddings_router() -> Router { pub fn create_embeddings_router() -> Router {
Router::new() Router::new()
.route("/v1/embeddings", post(embeddings_create)) .route("/v1/embeddings", post(embeddings_create))
// .route("/v1/models", get(models_list))
.layer(TraceLayer::new_for_http()) .layer(TraceLayer::new_for_http())
} }

View File

@@ -4,8 +4,6 @@ use axum::{
response::Json as ResponseJson, response::Json as ResponseJson,
routing::{get, post}, routing::{get, post},
}; };
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use serde::{Deserialize, Serialize};
use std::env; use std::env;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
use tracing; use tracing;
@@ -13,127 +11,28 @@ use tracing;
const DEFAULT_SERVER_HOST: &str = "127.0.0.1"; const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
const DEFAULT_SERVER_PORT: &str = "8080"; const DEFAULT_SERVER_PORT: &str = "8080";
use embeddings_engine;
async fn embeddings_create( async fn embeddings_create(
Json(payload): Json<CreateEmbeddingRequest>, Json(payload): Json<CreateEmbeddingRequest>,
) -> ResponseJson<serde_json::Value> { ) -> Result<ResponseJson<serde_json::Value>, axum::response::Response> {
let model = TextEmbedding::try_new( match embeddings_engine::embeddings_create(Json(payload)).await {
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true), Ok(response) => Ok(response),
) Err((status_code, message)) => Err(axum::response::Response::builder()
.expect("Failed to initialize model"); .status(status_code)
.body(axum::body::Body::from(message))
.unwrap()),
}
}
let embedding_input = payload.input; async fn models_list() -> ResponseJson<embeddings_engine::ModelsResponse> {
embeddings_engine::models_list().await
let texts_from_embedding_input = match embedding_input {
EmbeddingInput::String(text) => vec![text],
EmbeddingInput::StringArray(texts) => texts,
EmbeddingInput::IntegerArray(_) => {
panic!("Integer array input not supported for text embeddings");
}
EmbeddingInput::ArrayOfIntegerArray(_) => {
panic!("Array of integer arrays not supported for text embeddings");
}
};
let embeddings = model
.embed(texts_from_embedding_input, None)
.expect("failed to embed document");
// Only log detailed embedding information at trace level to reduce log volume
tracing::trace!("Embeddings length: {}", embeddings.len());
tracing::info!("Embedding dimension: {}", embeddings[0].len());
// Log the first 10 values of the original embedding at trace level
tracing::trace!(
"Original embedding preview: {:?}",
&embeddings[0][..10.min(embeddings[0].len())]
);
// Check if there are any NaN or zero values in the original embedding
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
tracing::trace!(
"Original embedding stats: NaN count={}, zero count={}",
nan_count,
zero_count
);
// Create the final embedding
let final_embedding = {
// Check if the embedding is all zeros
let all_zeros = embeddings[0].iter().all(|&x| x == 0.0);
if all_zeros {
tracing::warn!("Embedding is all zeros. Generating random non-zero embedding.");
// Generate a random non-zero embedding
use rand::Rng;
let mut rng = rand::thread_rng();
let mut random_embedding = Vec::with_capacity(768);
for _ in 0..768 {
// Generate random values between -1.0 and 1.0, excluding 0
let mut val = 0.0;
while val == 0.0 {
val = rng.gen_range(-1.0..1.0);
}
random_embedding.push(val);
}
// Normalize the random embedding
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
for i in 0..random_embedding.len() {
random_embedding[i] /= norm;
}
random_embedding
} else {
// Check if dimensions parameter is provided and pad the embeddings if necessary
let mut padded_embedding = embeddings[0].clone();
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
let target_dimension = 768;
if padded_embedding.len() < target_dimension {
let padding_needed = target_dimension - padded_embedding.len();
tracing::trace!(
"Padding embedding with {} zeros to reach {} dimensions",
padding_needed,
target_dimension
);
padded_embedding.extend(vec![0.0; padding_needed]);
}
padded_embedding
}
};
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
// Log the first 10 values of the final embedding at trace level
tracing::trace!(
"Final embedding preview: {:?}",
&final_embedding[..10.min(final_embedding.len())]
);
// Return a response that matches the OpenAI API format
let response = serde_json::json!({
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": final_embedding
}
],
"model": payload.model,
"usage": {
"prompt_tokens": 0,
"total_tokens": 0
}
});
ResponseJson(response)
} }
fn create_app() -> Router { fn create_app() -> Router {
Router::new() Router::new()
.route("/v1/embeddings", post(embeddings_create)) .route("/v1/embeddings", post(embeddings_create))
.route("/v1/models", get(models_list))
.layer(TraceLayer::new_for_http()) .layer(TraceLayer::new_for_http())
} }
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

View File

@@ -1,7 +1,7 @@
[package] [package]
name = "inference-engine" name = "inference-engine"
version.workspace = true version.workspace = true
edition = "2021" edition = "2024"
[dependencies] [dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git" } candle-core = { git = "https://github.com/huggingface/candle.git" }
@@ -31,13 +31,21 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
uuid = { version = "1.7.0", features = ["v4"] } uuid = { version = "1.7.0", features = ["v4"] }
reborrow = "0.5.5" reborrow = "0.5.5"
futures-util = "0.3.31" futures-util = "0.3.31"
gemma-runner = { path = "../gemma-runner", features = ["metal"] } gemma-runner = { path = "../../integration/gemma-runner" }
llama-runner = { path = "../llama-runner", features = ["metal"]} llama-runner = { path = "../../integration/llama-runner" }
embeddings-engine = { path = "../embeddings-engine" }
[target.'cfg(target_os = "linux")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", default-features = false }
candle-nn = { git = "https://github.com/huggingface/candle.git", default-features = false }
candle-transformers = { git = "https://github.com/huggingface/candle.git", default-features = false }
[target.'cfg(target_os = "macos")'.dependencies] [target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] } candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] } candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] } candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
gemma-runner = { path = "../../integration/gemma-runner", features = ["metal"] }
llama-runner = { path = "../../integration/llama-runner", features = ["metal"] }
[dev-dependencies] [dev-dependencies]
@@ -61,15 +69,13 @@ bindgen_cuda = { version = "0.1.1", optional = true }
[features] [features]
bin = [] bin = []
[[bin]]
name = "inference-engine"
[package.metadata.compose] path = "src/main.rs"
image = "ghcr.io/geoffsee/inference-engine:latest"
port = 8080
# generates kubernetes manifests # generates kubernetes manifests
[package.metadata.kube] [package.metadata.kube]
image = "ghcr.io/geoffsee/inference-service:latest" image = "ghcr.io/geoffsee/predict-otron-9000:latest"
replicas = 1 cmd = ["./bin/inference-engine"]
port = 8080 port = 8080
replicas = 1

View File

@@ -1,86 +0,0 @@
# ---- Build stage ----
FROM rust:1-slim-bullseye AS builder
WORKDIR /usr/src/app
# Install build dependencies including CUDA toolkit for GPU support
RUN apt-get update && \
apt-get install -y --no-install-recommends \
pkg-config \
libssl-dev \
build-essential \
wget \
gnupg2 \
curl \
&& rm -rf /var/lib/apt/lists/*
# Install CUDA toolkit (optional, for GPU support)
# This is a minimal CUDA installation for building
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb && \
dpkg -i cuda-keyring_1.0-1_all.deb && \
apt-get update && \
apt-get install -y --no-install-recommends \
cuda-minimal-build-11-8 \
libcublas-dev-11-8 \
libcurand-dev-11-8 \
&& rm -rf /var/lib/apt/lists/* \
&& rm cuda-keyring_1.0-1_all.deb
# Set CUDA environment variables
ENV CUDA_HOME=/usr/local/cuda
ENV PATH=${CUDA_HOME}/bin:${PATH}
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
# Copy the entire workspace to get access to all crates
COPY . ./
# Cache dependencies first - create dummy source files
RUN rm -rf crates/inference-engine/src
RUN mkdir -p crates/inference-engine/src && \
echo "fn main() {}" > crates/inference-engine/src/main.rs && \
echo "fn main() {}" > crates/inference-engine/src/cli_main.rs && \
echo "// lib" > crates/inference-engine/src/lib.rs && \
cargo build --release --bin cli --package inference-engine
# Remove dummy source and copy real sources
RUN rm -rf crates/inference-engine/src
COPY . .
# Build the actual CLI binary
RUN cargo build --release --bin cli --package inference-engine
# ---- Runtime stage ----
FROM debian:bullseye-slim
# Install runtime dependencies
RUN apt-get update && \
apt-get install -y --no-install-recommends \
libssl1.1 \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Install CUDA runtime libraries (optional, for GPU support at runtime)
RUN apt-get update && \
apt-get install -y --no-install-recommends \
wget \
gnupg2 \
&& wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb \
&& dpkg -i cuda-keyring_1.0-1_all.deb \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
cuda-cudart-11-8 \
libcublas11 \
libcurand10 \
&& rm -rf /var/lib/apt/lists/* \
&& rm cuda-keyring_1.0-1_all.deb \
&& apt-get purge -y wget gnupg2
# Copy binary from builder
COPY --from=builder /usr/src/app/target/release/cli /usr/local/bin/inference-cli
# Run as non-root user for safety
RUN useradd -m appuser
USER appuser
EXPOSE 8080
CMD ["inference-cli"]

View File

@@ -8,7 +8,7 @@ pub mod server;
// Re-export key components for easier access // Re-export key components for easier access
pub use inference::ModelInference; pub use inference::ModelInference;
pub use model::{Model, Which}; pub use model::{Model, Which};
pub use server::{create_router, AppState}; pub use server::{AppState, create_router};
use std::env; use std::env;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

View File

@@ -0,0 +1,26 @@
use inference_engine::{AppState, create_router, get_server_config, init_tracing};
use tokio::net::TcpListener;
use tracing::info;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
init_tracing();
let app_state = AppState::default();
let app = create_router(app_state);
let (server_host, server_port, server_address) = get_server_config();
let listener = TcpListener::bind(&server_address).await?;
info!(
"Inference Engine server starting on http://{}",
server_address
);
info!("Available endpoints:");
info!(" POST /v1/chat/completions - OpenAI-compatible chat completions");
info!(" GET /v1/models - List available models");
axum::serve(listener, app).await?;
Ok(())
}

View File

@@ -42,7 +42,11 @@ pub struct ModelMeta {
} }
const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta { const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta {
ModelMeta { id, family, instruct } ModelMeta {
id,
family,
instruct,
}
} }
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] #[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]

View File

@@ -1,26 +1,28 @@
use axum::{ use axum::{
Json, Router,
extract::State, extract::State,
http::StatusCode, http::StatusCode,
response::{sse::Event, sse::Sse, IntoResponse}, response::{IntoResponse, sse::Event, sse::Sse},
routing::{get, post}, routing::{get, post},
Json, Router,
}; };
use futures_util::stream::{self, Stream}; use futures_util::stream::{self, Stream};
use std::convert::Infallible; use std::convert::Infallible;
use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{Mutex, mpsc};
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid; use uuid::Uuid;
use crate::Which;
use crate::openai_types::{ use crate::openai_types::{
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest,
ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage,
}; };
use crate::Which;
use either::Either; use either::Either;
use gemma_runner::{run_gemma_api, GemmaInferenceConfig}; use embeddings_engine::models_list;
use llama_runner::{run_llama_inference, LlamaInferenceConfig}; use gemma_runner::{GemmaInferenceConfig, WhichModel, run_gemma_api};
use llama_runner::{LlamaInferenceConfig, run_llama_inference};
use serde_json::Value; use serde_json::Value;
// ------------------------- // -------------------------
// Shared app state // Shared app state
@@ -34,7 +36,7 @@ pub enum ModelType {
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState {
pub model_type: ModelType, pub model_type: Option<ModelType>,
pub model_id: String, pub model_id: String,
pub gemma_config: Option<GemmaInferenceConfig>, pub gemma_config: Option<GemmaInferenceConfig>,
pub llama_config: Option<LlamaInferenceConfig>, pub llama_config: Option<LlamaInferenceConfig>,
@@ -44,15 +46,16 @@ impl Default for AppState {
fn default() -> Self { fn default() -> Self {
// Configure a default model to prevent 503 errors from the chat-ui // Configure a default model to prevent 503 errors from the chat-ui
// This can be overridden by environment variables if needed // This can be overridden by environment variables if needed
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string()); let default_model_id =
std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
let gemma_config = GemmaInferenceConfig { let gemma_config = GemmaInferenceConfig {
model: gemma_runner::WhichModel::InstructV3_1B, model: None,
..Default::default() ..Default::default()
}; };
Self { Self {
model_type: ModelType::Gemma, model_type: None,
model_id: default_model_id, model_id: default_model_id,
gemma_config: Some(gemma_config), gemma_config: Some(gemma_config),
llama_config: None, llama_config: None,
@@ -83,15 +86,14 @@ fn model_id_to_which(model_id: &str) -> Option<Which> {
"gemma-2-9b-it" => Some(Which::InstructV2_9B), "gemma-2-9b-it" => Some(Which::InstructV2_9B),
"gemma-3-1b" => Some(Which::BaseV3_1B), "gemma-3-1b" => Some(Which::BaseV3_1B),
"gemma-3-1b-it" => Some(Which::InstructV3_1B), "gemma-3-1b-it" => Some(Which::InstructV3_1B),
"llama-3.2-1b" => Some(Which::Llama32_1B),
"llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct), "llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct),
"llama-3.2-3b" => Some(Which::Llama32_3B),
"llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct), "llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct),
_ => None, _ => None,
} }
} }
fn normalize_model_id(model_id: &str) -> String { fn normalize_model_id(model_id: &str) -> String {
model_id.to_lowercase().replace("_", "-") model_id.to_lowercase().replace("_", "-")
} }
@@ -189,35 +191,74 @@ pub async fn chat_completions_non_streaming_proxy(
// Get streaming receiver based on model type // Get streaming receiver based on model type
let rx = if which_model.is_llama_model() { let rx = if which_model.is_llama_model() {
// Create Llama configuration dynamically // Create Llama configuration dynamically
let mut config = LlamaInferenceConfig::default(); let llama_model = match which_model {
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
_ => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Llama model", model_id) }
})),
));
}
};
let mut config = LlamaInferenceConfig::new(llama_model);
config.prompt = prompt.clone(); config.prompt = prompt.clone();
config.max_tokens = max_tokens; config.max_tokens = max_tokens;
run_llama_inference(config).map_err(|e| ( run_llama_inference(config).map_err(|e| {
StatusCode::INTERNAL_SERVER_ERROR, (
Json(serde_json::json!({ StatusCode::INTERNAL_SERVER_ERROR,
"error": { "message": format!("Error initializing Llama model: {}", e) } Json(serde_json::json!({
})) "error": { "message": format!("Error initializing Llama model: {}", e) }
))? })),
)
})?
} else { } else {
// Create Gemma configuration dynamically // Create Gemma configuration dynamically
let gemma_model = if which_model.is_v3_model() { let gemma_model = match which_model {
gemma_runner::WhichModel::InstructV3_1B Which::Base2B => gemma_runner::WhichModel::Base2B,
} else { Which::Base7B => gemma_runner::WhichModel::Base7B,
gemma_runner::WhichModel::InstructV3_1B // Default fallback Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
_ => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
})),
));
}
}; };
let mut config = GemmaInferenceConfig { let mut config = GemmaInferenceConfig {
model: gemma_model, model: Some(gemma_model),
..Default::default() ..Default::default()
}; };
config.prompt = prompt.clone(); config.prompt = prompt.clone();
config.max_tokens = max_tokens; config.max_tokens = max_tokens;
run_gemma_api(config).map_err(|e| ( run_gemma_api(config).map_err(|e| {
StatusCode::INTERNAL_SERVER_ERROR, (
Json(serde_json::json!({ StatusCode::INTERNAL_SERVER_ERROR,
"error": { "message": format!("Error initializing Gemma model: {}", e) } Json(serde_json::json!({
})) "error": { "message": format!("Error initializing Gemma model: {}", e) }
))? })),
)
})?
}; };
// Collect all tokens from the stream // Collect all tokens from the stream
@@ -347,7 +388,21 @@ async fn handle_streaming_request(
// Get streaming receiver based on model type // Get streaming receiver based on model type
let model_rx = if which_model.is_llama_model() { let model_rx = if which_model.is_llama_model() {
// Create Llama configuration dynamically // Create Llama configuration dynamically
let mut config = LlamaInferenceConfig::default(); let llama_model = match which_model {
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
_ => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Llama model", model_id) }
})),
));
}
};
let mut config = LlamaInferenceConfig::new(llama_model);
config.prompt = prompt.clone(); config.prompt = prompt.clone();
config.max_tokens = max_tokens; config.max_tokens = max_tokens;
match run_llama_inference(config) { match run_llama_inference(config) {
@@ -363,14 +418,35 @@ async fn handle_streaming_request(
} }
} else { } else {
// Create Gemma configuration dynamically // Create Gemma configuration dynamically
let gemma_model = if which_model.is_v3_model() { let gemma_model = match which_model {
gemma_runner::WhichModel::InstructV3_1B Which::Base2B => gemma_runner::WhichModel::Base2B,
} else { Which::Base7B => gemma_runner::WhichModel::Base7B,
gemma_runner::WhichModel::InstructV3_1B // Default fallback Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
_ => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
})),
));
}
}; };
let mut config = GemmaInferenceConfig { let mut config = GemmaInferenceConfig {
model: gemma_model, model: Some(gemma_model),
..Default::default() ..Default::default()
}; };
config.prompt = prompt.clone(); config.prompt = prompt.clone();
@@ -530,46 +606,69 @@ pub async fn list_models() -> Json<ModelListResponse> {
Which::Llama32_3BInstruct, Which::Llama32_3BInstruct,
]; ];
let models: Vec<Model> = which_variants.into_iter().map(|which| { let mut models: Vec<Model> = which_variants
let meta = which.meta(); .into_iter()
let model_id = match which { .map(|which| {
Which::Base2B => "gemma-2b", let meta = which.meta();
Which::Base7B => "gemma-7b", let model_id = match which {
Which::Instruct2B => "gemma-2b-it", Which::Base2B => "gemma-2b",
Which::Instruct7B => "gemma-7b-it", Which::Base7B => "gemma-7b",
Which::InstructV1_1_2B => "gemma-1.1-2b-it", Which::Instruct2B => "gemma-2b-it",
Which::InstructV1_1_7B => "gemma-1.1-7b-it", Which::Instruct7B => "gemma-7b-it",
Which::CodeBase2B => "codegemma-2b", Which::InstructV1_1_2B => "gemma-1.1-2b-it",
Which::CodeBase7B => "codegemma-7b", Which::InstructV1_1_7B => "gemma-1.1-7b-it",
Which::CodeInstruct2B => "codegemma-2b-it", Which::CodeBase2B => "codegemma-2b",
Which::CodeInstruct7B => "codegemma-7b-it", Which::CodeBase7B => "codegemma-7b",
Which::BaseV2_2B => "gemma-2-2b", Which::CodeInstruct2B => "codegemma-2b-it",
Which::InstructV2_2B => "gemma-2-2b-it", Which::CodeInstruct7B => "codegemma-7b-it",
Which::BaseV2_9B => "gemma-2-9b", Which::BaseV2_2B => "gemma-2-2b",
Which::InstructV2_9B => "gemma-2-9b-it", Which::InstructV2_2B => "gemma-2-2b-it",
Which::BaseV3_1B => "gemma-3-1b", Which::BaseV2_9B => "gemma-2-9b",
Which::InstructV3_1B => "gemma-3-1b-it", Which::InstructV2_9B => "gemma-2-9b-it",
Which::Llama32_1B => "llama-3.2-1b", Which::BaseV3_1B => "gemma-3-1b",
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct", Which::InstructV3_1B => "gemma-3-1b-it",
Which::Llama32_3B => "llama-3.2-3b", Which::Llama32_1B => "llama-3.2-1b",
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct", Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
}; Which::Llama32_3B => "llama-3.2-3b",
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
};
let owned_by = if meta.id.starts_with("google/") { let owned_by = if meta.id.starts_with("google/") {
"google" "google"
} else if meta.id.starts_with("meta-llama/") { } else if meta.id.starts_with("meta-llama/") {
"meta" "meta"
} else { } else {
"unknown" "unknown"
}; };
Model { Model {
id: model_id.to_string(), id: model_id.to_string(),
object: "model".to_string(), object: "model".to_string(),
created: 1686935002, // Using same timestamp as OpenAI example created: 1686935002,
owned_by: owned_by.to_string(), owned_by: owned_by.to_string(),
} }
}).collect(); })
.collect();
// Get embeddings models and convert them to inference Model format
let embeddings_response = models_list().await;
let embeddings_models: Vec<Model> = embeddings_response
.0
.data
.into_iter()
.map(|embedding_model| Model {
id: embedding_model.id,
object: embedding_model.object,
created: 1686935002,
owned_by: format!(
"{} - {}",
embedding_model.owned_by, embedding_model.description
),
})
.collect();
// Add embeddings models to the main models list
models.extend(embeddings_models);
Json(ModelListResponse { Json(ModelListResponse {
object: "list".to_string(), object: "list".to_string(),

View File

@@ -29,25 +29,23 @@ inference-engine = { path = "../inference-engine" }
# Dependencies for leptos web app # Dependencies for leptos web app
#leptos-app = { path = "../leptos-app", features = ["ssr"] } #leptos-app = { path = "../leptos-app", features = ["ssr"] }
chat-ui = { path = "../chat-ui", features = ["ssr", "hydrate"], optional = false } chat-ui = { path = "../chat-ui", features = ["ssr", "hydrate"], optional = true }
mime_guess = "2.0.5" mime_guess = "2.0.5"
log = "0.4.27" log = "0.4.27"
[package.metadata.compose]
name = "predict-otron-9000"
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
port = 8080
# generates kubernetes manifests # generates kubernetes manifests
[package.metadata.kube] [package.metadata.kube]
image = "ghcr.io/geoffsee/predict-otron-9000:latest" image = "ghcr.io/geoffsee/predict-otron-9000:latest"
replicas = 1 replicas = 1
port = 8080 port = 8080
cmd = ["./bin/predict-otron-9000"]
# SERVER_CONFIG Example: {\"serverMode\":\"HighAvailability\",\"services\":{\"inference_url\":\"http://custom-inference:9000\",\"embeddings_url\":\"http://custom-embeddings:9001\"}} # SERVER_CONFIG Example: {\"serverMode\":\"HighAvailability\",\"services\":{\"inference_url\":\"http://custom-inference:9000\",\"embeddings_url\":\"http://custom-embeddings:9001\"}}
# you can generate this via node to avoid toil # you can generate this via node to avoid toil
# const server_config = {serverMode: "HighAvailability", services: {inference_url: "http://custom-inference:9000", embeddings_url: "http://custom-embeddings:9001"} }; # const server_config = {serverMode: "HighAvailability", services: {inference_url: "http://custom-inference:9000", embeddings_url: "http://custom-embeddings:9001"} };
# console.log(JSON.stringify(server_config).replace(/"/g, '\\"')); # console.log(JSON.stringify(server_config).replace(/"/g, '\\"'));
env = { SERVER_CONFIG = "<your-json-value-here>" } env = { SERVER_CONFIG = "<your-json-value-here>" }
[features]
default = ["ui"]
ui = ["dep:chat-ui"]

View File

@@ -1,89 +0,0 @@
# ---- Build stage ----
FROM rust:1-slim-bullseye AS builder
WORKDIR /usr/src/app
# Install build dependencies including CUDA toolkit for GPU support (needed for inference-engine dependency)
RUN apt-get update && \
apt-get install -y --no-install-recommends \
pkg-config \
libssl-dev \
build-essential \
wget \
gnupg2 \
curl \
&& rm -rf /var/lib/apt/lists/*
# Install CUDA toolkit (required for inference-engine dependency)
# This is a minimal CUDA installation for building
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb && \
dpkg -i cuda-keyring_1.0-1_all.deb && \
apt-get update && \
apt-get install -y --no-install-recommends \
cuda-minimal-build-11-8 \
libcublas-dev-11-8 \
libcurand-dev-11-8 \
&& rm -rf /var/lib/apt/lists/* \
&& rm cuda-keyring_1.0-1_all.deb
# Set CUDA environment variables
ENV CUDA_HOME=/usr/local/cuda
ENV PATH=${CUDA_HOME}/bin:${PATH}
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
# Copy the entire workspace to get access to all crates (needed for local dependencies)
COPY . ./
# Cache dependencies first - create dummy source files for all crates
RUN rm -rf crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src
RUN mkdir -p crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src && \
echo "fn main() {}" > crates/predict-otron-9000/src/main.rs && \
echo "fn main() {}" > crates/inference-engine/src/main.rs && \
echo "fn main() {}" > crates/inference-engine/src/cli_main.rs && \
echo "// lib" > crates/inference-engine/src/lib.rs && \
echo "fn main() {}" > crates/embeddings-engine/src/main.rs && \
echo "// lib" > crates/embeddings-engine/src/lib.rs && \
cargo build --release --bin predict-otron-9000 --package predict-otron-9000
# Remove dummy sources and copy real sources
RUN rm -rf crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src
COPY . .
# Build the actual binary
RUN cargo build --release --bin predict-otron-9000 --package predict-otron-9000
# ---- Runtime stage ----
FROM debian:bullseye-slim
# Install runtime dependencies
RUN apt-get update && \
apt-get install -y --no-install-recommends \
libssl1.1 \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Install CUDA runtime libraries (required for inference-engine dependency)
RUN apt-get update && \
apt-get install -y --no-install-recommends \
wget \
gnupg2 \
&& wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb \
&& dpkg -i cuda-keyring_1.0-1_all.deb \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
cuda-cudart-11-8 \
libcublas11 \
libcurand10 \
&& rm -rf /var/lib/apt/lists/* \
&& rm cuda-keyring_1.0-1_all.deb \
&& apt-get purge -y wget gnupg2
# Copy binary from builder
COPY --from=builder /usr/src/app/target/release/predict-otron-9000 /usr/local/bin/
# Run as non-root user for safety
RUN useradd -m appuser
USER appuser
EXPOSE 8080
CMD ["predict-otron-9000"]

View File

@@ -39,29 +39,12 @@ impl Default for ServerMode {
} }
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct Services { pub struct Services {
pub inference_url: Option<String>, pub inference_url: Option<String>,
pub embeddings_url: Option<String>, pub embeddings_url: Option<String>,
} }
impl Default for Services {
fn default() -> Self {
Self {
inference_url: None,
embeddings_url: None,
}
}
}
fn inference_service_url() -> String {
"http://inference-service:8080".to_string()
}
fn embeddings_service_url() -> String {
"http://embeddings-service:8080".to_string()
}
impl Default for ServerConfig { impl Default for ServerConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
@@ -118,8 +101,7 @@ impl ServerConfig {
"HighAvailability mode configured but services not well defined! \n## Config Used:\n {}", "HighAvailability mode configured but services not well defined! \n## Config Used:\n {}",
config_string config_string
); );
let err = std::io::Error::new( let err = std::io::Error::other(
std::io::ErrorKind::Other,
"HighAvailability mode configured but services not well defined!", "HighAvailability mode configured but services not well defined!",
); );
return Err(err); return Err(err);

View File

@@ -126,7 +126,7 @@ use crate::config::ServerConfig;
/// - Pretty JSON is fine in TOML using `''' ... '''`, but remember the newlines are part of the string. /// - Pretty JSON is fine in TOML using `''' ... '''`, but remember the newlines are part of the string.
/// - If you control the consumer, TOML tables (the alternative above) are more ergonomic than embedding JSON. /// - If you control the consumer, TOML tables (the alternative above) are more ergonomic than embedding JSON.
/// HTTP client configured for proxying requests /// HTTP client configured for proxying requests
#[derive(Clone)] #[derive(Clone)]
pub struct ProxyClient { pub struct ProxyClient {
client: Client, client: Client,

View File

@@ -4,28 +4,31 @@ mod middleware;
mod standalone_mode; mod standalone_mode;
use crate::standalone_mode::create_standalone_router; use crate::standalone_mode::create_standalone_router;
use axum::handler::Handler;
use axum::http::StatusCode as AxumStatusCode;
use axum::http::header;
use axum::response::IntoResponse;
use axum::routing::get; use axum::routing::get;
use axum::{Router, ServiceExt, http::Uri, response::Html, serve}; use axum::{Router, serve};
use config::ServerConfig; use config::ServerConfig;
use ha_mode::create_ha_router; use ha_mode::create_ha_router;
use inference_engine::AppState;
use log::info;
use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore}; use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
use mime_guess::from_path;
use rust_embed::Embed;
use std::env; use std::env;
use std::path::Component::ParentDir;
#[cfg(feature = "ui")]
use axum::http::StatusCode as AxumStatusCode;
#[cfg(feature = "ui")]
use axum::http::Uri;
#[cfg(feature = "ui")]
use axum::http::header;
#[cfg(feature = "ui")]
use axum::response::IntoResponse;
#[cfg(feature = "ui")]
use mime_guess::from_path;
#[cfg(feature = "ui")]
use rust_embed::Embed;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tower::MakeService;
use tower_http::classify::ServerErrorsFailureClass::StatusCode;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[cfg(feature = "ui")]
#[derive(Embed)] #[derive(Embed)]
#[folder = "../../target/site"] #[folder = "../../target/site"]
#[include = "*.js"] #[include = "*.js"]
@@ -34,6 +37,7 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[include = "*.ico"] #[include = "*.ico"]
struct Asset; struct Asset;
#[cfg(feature = "ui")]
async fn static_handler(uri: Uri) -> axum::response::Response { async fn static_handler(uri: Uri) -> axum::response::Response {
// Strip the leading `/` // Strip the leading `/`
let path = uri.path().trim_start_matches('/'); let path = uri.path().trim_start_matches('/');
@@ -111,23 +115,28 @@ async fn main() {
// Create metrics layer // Create metrics layer
let metrics_layer = MetricsLayer::new(metrics_store); let metrics_layer = MetricsLayer::new(metrics_store);
let leptos_config = chat_ui::app::AppConfig::default();
// Create the leptos router for the web frontend
let leptos_router = chat_ui::app::create_router(leptos_config.config.leptos_options);
// Merge the service router with base routes and add middleware layers // Merge the service router with base routes and add middleware layers
let app = Router::new() let mut app = Router::new()
.route("/pkg/{*path}", get(static_handler))
.route("/health", get(|| async { "ok" })) .route("/health", get(|| async { "ok" }))
.merge(service_router) .merge(service_router);
.merge(leptos_router)
// Add UI routes if the UI feature is enabled
#[cfg(feature = "ui")]
{
let leptos_config = chat_ui::app::AppConfig::default();
let leptos_router = chat_ui::app::create_router(leptos_config.config.leptos_options);
app = app
.route("/pkg/{*path}", get(static_handler))
.merge(leptos_router);
}
let app = app
.layer(metrics_layer) // Add metrics tracking .layer(metrics_layer) // Add metrics tracking
.layer(cors) .layer(cors)
.layer(TraceLayer::new_for_http()); .layer(TraceLayer::new_for_http());
// Server configuration // Server configuration
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| String::from(default_host)); let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| default_host.to_string());
let server_port = env::var("SERVER_PORT") let server_port = env::var("SERVER_PORT")
.map(|v| v.parse::<u16>().unwrap_or(default_port)) .map(|v| v.parse::<u16>().unwrap_or(default_port))
@@ -142,8 +151,10 @@ async fn main() {
); );
tracing::info!("Performance metrics tracking enabled - summary logs every 60 seconds"); tracing::info!("Performance metrics tracking enabled - summary logs every 60 seconds");
tracing::info!("Available endpoints:"); tracing::info!("Available endpoints:");
#[cfg(feature = "ui")]
tracing::info!(" GET / - Leptos chat web application"); tracing::info!(" GET / - Leptos chat web application");
tracing::info!(" GET /health - Health check"); tracing::info!(" GET /health - Health check");
tracing::info!(" POST /v1/models - List Models");
tracing::info!(" POST /v1/embeddings - Text embeddings API"); tracing::info!(" POST /v1/embeddings - Text embeddings API");
tracing::info!(" POST /v1/chat/completions - Chat completions API"); tracing::info!(" POST /v1/chat/completions - Chat completions API");

View File

@@ -2,7 +2,7 @@ use crate::config::ServerConfig;
use axum::Router; use axum::Router;
use inference_engine::AppState; use inference_engine::AppState;
pub fn create_standalone_router(server_config: ServerConfig) -> Router { pub fn create_standalone_router(_server_config: ServerConfig) -> Router {
// Create unified router by merging embeddings and inference routers (existing behavior) // Create unified router by merging embeddings and inference routers (existing behavior)
let embeddings_router = embeddings_engine::create_embeddings_router(); let embeddings_router = embeddings_engine::create_embeddings_router();

View File

@@ -61,20 +61,22 @@ graph TD
A[predict-otron-9000<br/>Edition: 2024<br/>Port: 8080] A[predict-otron-9000<br/>Edition: 2024<br/>Port: 8080]
end end
subgraph "AI Services" subgraph "AI Services (crates/)"
B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Multi-model orchestrator] B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Multi-model orchestrator]
J[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
K[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
C[embeddings-engine<br/>Edition: 2024<br/>Port: 8080<br/>FastEmbed] C[embeddings-engine<br/>Edition: 2024<br/>Port: 8080<br/>FastEmbed]
end end
subgraph "Frontend" subgraph "Frontend (crates/)"
D[chat-ui<br/>Edition: 2021<br/>Port: 8788<br/>WASM UI] D[chat-ui<br/>Edition: 2021<br/>Port: 8788<br/>WASM UI]
end end
subgraph "Tooling"
subgraph "Integration Tools (integration/)"
L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment] L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment]
E[cli<br/>Edition: 2024<br/>TypeScript/Bun CLI] E[cli<br/>Edition: 2024<br/>TypeScript/Bun CLI]
M[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
N[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
O[utils<br/>Edition: 2021<br/>Shared utilities]
end end
end end
@@ -82,10 +84,10 @@ graph TD
A --> B A --> B
A --> C A --> C
A --> D A --> D
B --> J B --> M
B --> K B --> N
J -.-> F[Candle 0.9.1] M -.-> F[Candle 0.9.1]
K -.-> F N -.-> F
C -.-> G[FastEmbed 4.x] C -.-> G[FastEmbed 4.x]
D -.-> H[Leptos 0.8.0] D -.-> H[Leptos 0.8.0]
E -.-> I[OpenAI SDK 5.16+] E -.-> I[OpenAI SDK 5.16+]
@@ -93,12 +95,13 @@ graph TD
style A fill:#e1f5fe style A fill:#e1f5fe
style B fill:#f3e5f5 style B fill:#f3e5f5
style J fill:#f3e5f5
style K fill:#f3e5f5
style C fill:#e8f5e8 style C fill:#e8f5e8
style D fill:#fff3e0 style D fill:#fff3e0
style E fill:#fce4ec style E fill:#fce4ec
style L fill:#fff9c4 style L fill:#fff9c4
style M fill:#f3e5f5
style N fill:#f3e5f5
style O fill:#fff9c4
``` ```
## Deployment Configurations ## Deployment Configurations

View File

@@ -14,7 +14,7 @@ Options:
--help Show this help message --help Show this help message
Examples: Examples:
cd crates/cli/package cd integration/cli/package
bun run cli.ts "What is the capital of France?" bun run cli.ts "What is the capital of France?"
bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!" bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
bun run cli.ts --prompt "Who was the 16th president of the United States?" bun run cli.ts --prompt "Who was the 16th president of the United States?"

View File

@@ -24,8 +24,7 @@ fn run_build() -> io::Result<()> {
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set by Cargo")); let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set by Cargo"));
let output_path = out_dir.join("client-cli"); let output_path = out_dir.join("client-cli");
let bun_tgt = BunTarget::from_cargo_env() let bun_tgt = BunTarget::from_cargo_env().map_err(|e| io::Error::other(e.to_string()))?;
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
// Optional: warn if using a Bun target thats marked unsupported in your chart // Optional: warn if using a Bun target thats marked unsupported in your chart
if matches!(bun_tgt, BunTarget::WindowsArm64) { if matches!(bun_tgt, BunTarget::WindowsArm64) {
@@ -54,13 +53,12 @@ fn run_build() -> io::Result<()> {
if !install_status.success() { if !install_status.success() {
let code = install_status.code().unwrap_or(1); let code = install_status.code().unwrap_or(1);
return Err(io::Error::new( return Err(io::Error::other(format!(
io::ErrorKind::Other, "bun install failed with status {code}"
format!("bun install failed with status {code}"), )));
));
} }
let target = env::var("TARGET").unwrap(); let _target = env::var("TARGET").unwrap();
// --- bun build (in ./package), emit to OUT_DIR, keep temps inside OUT_DIR --- // --- bun build (in ./package), emit to OUT_DIR, keep temps inside OUT_DIR ---
let mut build = Command::new("bun") let mut build = Command::new("bun")
@@ -87,7 +85,7 @@ fn run_build() -> io::Result<()> {
} else { } else {
let code = status.code().unwrap_or(1); let code = status.code().unwrap_or(1);
warn(&format!("bun build failed with status: {code}")); warn(&format!("bun build failed with status: {code}"));
return Err(io::Error::new(io::ErrorKind::Other, "bun build failed")); return Err(io::Error::other("bun build failed"));
} }
// Ensure the output is executable (after it exists) // Ensure the output is executable (after it exists)

View File

@@ -0,0 +1,17 @@
{
"lockfileVersion": 1,
"workspaces": {
"": {
"name": "cli",
"dependencies": {
"install": "^0.13.0",
"openai": "^5.16.0",
},
},
},
"packages": {
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
"openai": ["openai@5.19.1", "", { "peerDependencies": { "ws": "^8.18.0", "zod": "^3.23.8" }, "optionalPeers": ["ws", "zod"], "bin": { "openai": "bin/cli" } }, "sha512-zSqnUF7oR9ksmpusKkpUgkNrj8Sl57U+OyzO8jzc7LUjTMg4DRfR3uCm+EIMA6iw06sRPNp4t7ojp3sCpEUZRQ=="],
}
}

View File

@@ -25,7 +25,7 @@ fn main() -> io::Result<()> {
// Run it // Run it
let status = Command::new(&tmp).arg("--version").status()?; let status = Command::new(&tmp).arg("--version").status()?;
if !status.success() { if !status.success() {
return Err(io::Error::new(io::ErrorKind::Other, "client-cli failed")); return Err(io::Error::other("client-cli failed"));
} }
Ok(()) Ok(())

View File

@@ -18,7 +18,7 @@ serde_json = "1.0"
tracing = "0.1" tracing = "0.1"
tracing-chrome = "0.7" tracing-chrome = "0.7"
tracing-subscriber = "0.3" tracing-subscriber = "0.3"
utils = {path = "../utils"} utils = {path = "../utils" }
[target.'cfg(target_os = "macos")'.dependencies] [target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] } candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }

View File

@@ -1,13 +1,7 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1}; use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2}; use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3}; use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use clap::ValueEnum;
// Removed gemma_cli import as it's not needed for the API // Removed gemma_cli import as it's not needed for the API
use candle_core::{DType, Device, Tensor}; use candle_core::{DType, Device, Tensor};
@@ -16,13 +10,15 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write; use std::io::Write;
use std::fmt;
use std::str::FromStr;
use std::sync::mpsc::{self, Receiver, Sender}; use std::sync::mpsc::{self, Receiver, Sender};
use std::thread; use std::thread;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use utils::hub_load_safetensors; use utils::hub_load_safetensors;
use utils::token_output_stream::TokenOutputStream; use utils::token_output_stream::TokenOutputStream;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] #[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum WhichModel { pub enum WhichModel {
#[value(name = "gemma-2b")] #[value(name = "gemma-2b")]
Base2B, Base2B,
@@ -58,6 +54,56 @@ pub enum WhichModel {
InstructV3_1B, InstructV3_1B,
} }
impl FromStr for WhichModel {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gemma-2b" => Ok(Self::Base2B),
"gemma-7b" => Ok(Self::Base7B),
"gemma-2b-it" => Ok(Self::Instruct2B),
"gemma-7b-it" => Ok(Self::Instruct7B),
"gemma-1.1-2b-it" => Ok(Self::InstructV1_1_2B),
"gemma-1.1-7b-it" => Ok(Self::InstructV1_1_7B),
"codegemma-2b" => Ok(Self::CodeBase2B),
"codegemma-7b" => Ok(Self::CodeBase7B),
"codegemma-2b-it" => Ok(Self::CodeInstruct2B),
"codegemma-7b-it" => Ok(Self::CodeInstruct7B),
"gemma-2-2b" => Ok(Self::BaseV2_2B),
"gemma-2-2b-it" => Ok(Self::InstructV2_2B),
"gemma-2-9b" => Ok(Self::BaseV2_9B),
"gemma-2-9b-it" => Ok(Self::InstructV2_9B),
"gemma-3-1b" => Ok(Self::BaseV3_1B),
"gemma-3-1b-it" => Ok(Self::InstructV3_1B),
_ => Err(format!("Unknown model: {}", s)),
}
}
}
impl fmt::Display for WhichModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
Self::Base2B => "gemma-2b",
Self::Base7B => "gemma-7b",
Self::Instruct2B => "gemma-2b-it",
Self::Instruct7B => "gemma-7b-it",
Self::InstructV1_1_2B => "gemma-1.1-2b-it",
Self::InstructV1_1_7B => "gemma-1.1-7b-it",
Self::CodeBase2B => "codegemma-2b",
Self::CodeBase7B => "codegemma-7b",
Self::CodeInstruct2B => "codegemma-2b-it",
Self::CodeInstruct7B => "codegemma-7b-it",
Self::BaseV2_2B => "gemma-2-2b",
Self::InstructV2_2B => "gemma-2-2b-it",
Self::BaseV2_9B => "gemma-2-9b",
Self::InstructV2_9B => "gemma-2-9b-it",
Self::BaseV3_1B => "gemma-3-1b",
Self::InstructV3_1B => "gemma-3-1b-it",
};
write!(f, "{}", name)
}
}
enum Model { enum Model {
V1(Model1), V1(Model1),
V2(Model2), V2(Model2),
@@ -145,8 +191,6 @@ impl TextGeneration {
// Make sure stdout isn't holding anything (if caller also prints). // Make sure stdout isn't holding anything (if caller also prints).
std::io::stdout().flush()?; std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") { let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token, Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"), None => anyhow::bail!("cannot find the <eos> token"),
@@ -183,7 +227,6 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?; let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token); tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token { if next_token == eos_token || next_token == eot_token {
break; break;
@@ -210,7 +253,7 @@ impl TextGeneration {
pub struct GemmaInferenceConfig { pub struct GemmaInferenceConfig {
pub tracing: bool, pub tracing: bool,
pub prompt: String, pub prompt: String,
pub model: WhichModel, pub model: Option<WhichModel>,
pub cpu: bool, pub cpu: bool,
pub dtype: Option<String>, pub dtype: Option<String>,
pub model_id: Option<String>, pub model_id: Option<String>,
@@ -229,7 +272,7 @@ impl Default for GemmaInferenceConfig {
Self { Self {
tracing: false, tracing: false,
prompt: "Hello".to_string(), prompt: "Hello".to_string(),
model: WhichModel::InstructV2_2B, model: Some(WhichModel::InstructV2_2B),
cpu: false, cpu: false,
dtype: None, dtype: None,
model_id: None, model_id: None,
@@ -286,28 +329,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
} }
}; };
println!("Using dtype: {:?}", dtype); println!("Using dtype: {:?}", dtype);
println!("Raw model string: {:?}", cfg.model_id);
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let api = Api::new()?; let api = Api::new()?;
let model_id = cfg.model_id.unwrap_or_else(|| { let model_id = cfg.model_id.unwrap_or_else(|| {
match cfg.model { match cfg.model {
WhichModel::Base2B => "google/gemma-2b", Some(WhichModel::Base2B) => "google/gemma-2b",
WhichModel::Base7B => "google/gemma-7b", Some(WhichModel::Base7B) => "google/gemma-7b",
WhichModel::Instruct2B => "google/gemma-2b-it", Some(WhichModel::Instruct2B) => "google/gemma-2b-it",
WhichModel::Instruct7B => "google/gemma-7b-it", Some(WhichModel::Instruct7B) => "google/gemma-7b-it",
WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it", Some(WhichModel::InstructV1_1_2B) => "google/gemma-1.1-2b-it",
WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it", Some(WhichModel::InstructV1_1_7B) => "google/gemma-1.1-7b-it",
WhichModel::CodeBase2B => "google/codegemma-2b", Some(WhichModel::CodeBase2B) => "google/codegemma-2b",
WhichModel::CodeBase7B => "google/codegemma-7b", Some(WhichModel::CodeBase7B) => "google/codegemma-7b",
WhichModel::CodeInstruct2B => "google/codegemma-2b-it", Some(WhichModel::CodeInstruct2B) => "google/codegemma-2b-it",
WhichModel::CodeInstruct7B => "google/codegemma-7b-it", Some(WhichModel::CodeInstruct7B) => "google/codegemma-7b-it",
WhichModel::BaseV2_2B => "google/gemma-2-2b", Some(WhichModel::BaseV2_2B) => "google/gemma-2-2b",
WhichModel::InstructV2_2B => "google/gemma-2-2b-it", Some(WhichModel::InstructV2_2B) => "google/gemma-2-2b-it",
WhichModel::BaseV2_9B => "google/gemma-2-9b", Some(WhichModel::BaseV2_9B) => "google/gemma-2-9b",
WhichModel::InstructV2_9B => "google/gemma-2-9b-it", Some(WhichModel::InstructV2_9B) => "google/gemma-2-9b-it",
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt", Some(WhichModel::BaseV3_1B) => "google/gemma-3-1b-pt",
WhichModel::InstructV3_1B => "google/gemma-3-1b-it", Some(WhichModel::InstructV3_1B) => "google/gemma-3-1b-it",
None => "google/gemma-2-2b-it", // default fallback
} }
.to_string() .to_string()
}); });
@@ -318,7 +363,9 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let tokenizer_filename = repo.get("tokenizer.json")?; let tokenizer_filename = repo.get("tokenizer.json")?;
let config_filename = repo.get("config.json")?; let config_filename = repo.get("config.json")?;
let filenames = match cfg.model { let filenames = match cfg.model {
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?], Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
vec![repo.get("model.safetensors")?]
}
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?, _ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
}; };
println!("Retrieved files in {:?}", start.elapsed()); println!("Retrieved files in {:?}", start.elapsed());
@@ -329,29 +376,31 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model: Model = match cfg.model { let model: Model = match cfg.model {
WhichModel::Base2B Some(WhichModel::Base2B)
| WhichModel::Base7B | Some(WhichModel::Base7B)
| WhichModel::Instruct2B | Some(WhichModel::Instruct2B)
| WhichModel::Instruct7B | Some(WhichModel::Instruct7B)
| WhichModel::InstructV1_1_2B | Some(WhichModel::InstructV1_1_2B)
| WhichModel::InstructV1_1_7B | Some(WhichModel::InstructV1_1_7B)
| WhichModel::CodeBase2B | Some(WhichModel::CodeBase2B)
| WhichModel::CodeBase7B | Some(WhichModel::CodeBase7B)
| WhichModel::CodeInstruct2B | Some(WhichModel::CodeInstruct2B)
| WhichModel::CodeInstruct7B => { | Some(WhichModel::CodeInstruct7B) => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(cfg.use_flash_attn, &config, vb)?; let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
Model::V1(model) Model::V1(model)
} }
WhichModel::BaseV2_2B Some(WhichModel::BaseV2_2B)
| WhichModel::InstructV2_2B | Some(WhichModel::InstructV2_2B)
| WhichModel::BaseV2_9B | Some(WhichModel::BaseV2_9B)
| WhichModel::InstructV2_9B => { | Some(WhichModel::InstructV2_9B)
| None => {
// default to V2 model
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(cfg.use_flash_attn, &config, vb)?; let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
Model::V2(model) Model::V2(model)
} }
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => { Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model3::new(cfg.use_flash_attn, &config, vb)?; let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
Model::V3(model) Model::V3(model)
@@ -371,7 +420,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
); );
let prompt = match cfg.model { let prompt = match cfg.model {
WhichModel::InstructV3_1B => { Some(WhichModel::InstructV3_1B) => {
format!( format!(
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n", "<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
cfg.prompt cfg.prompt

View File

@@ -67,7 +67,7 @@ pub fn run_cli() -> anyhow::Result<()> {
let cfg = GemmaInferenceConfig { let cfg = GemmaInferenceConfig {
tracing: args.tracing, tracing: args.tracing,
prompt: args.prompt, prompt: args.prompt,
model: args.model, model: Some(args.model),
cpu: args.cpu, cpu: args.cpu,
dtype: args.dtype, dtype: args.dtype,
model_id: args.model_id, model_id: args.model_id,

View File

@@ -6,10 +6,8 @@ mod gemma_api;
mod gemma_cli; mod gemma_cli;
use anyhow::Error; use anyhow::Error;
use clap::{Parser, ValueEnum};
use crate::gemma_cli::run_cli; use crate::gemma_cli::run_cli;
use std::io::Write;
/// just a placeholder, not used for anything /// just a placeholder, not used for anything
fn main() -> std::result::Result<(), Error> { fn main() -> std::result::Result<(), Error> {

View File

@@ -64,14 +64,9 @@ version = "0.1.0"
# Required: Kubernetes metadata # Required: Kubernetes metadata
[package.metadata.kube] [package.metadata.kube]
image = "ghcr.io/myorg/my-service:latest" image = "ghcr.io/geoffsee/predict-otron-9000:latest"
replicas = 1 replicas = 1
port = 8080 port = 8080
# Optional: Docker Compose metadata (currently not used but parsed)
[package.metadata.compose]
image = "ghcr.io/myorg/my-service:latest"
port = 8080
``` ```
### Required Fields ### Required Fields

View File

@@ -1,9 +1,8 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use clap::{Arg, Command}; use clap::{Arg, Command};
use serde::{Deserialize, Serialize}; use serde::Deserialize;
use std::collections::HashMap;
use std::fs; use std::fs;
use std::path::{Path, PathBuf}; use std::path::Path;
use walkdir::WalkDir; use walkdir::WalkDir;
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@@ -20,7 +19,6 @@ struct Package {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct Metadata { struct Metadata {
kube: Option<KubeMetadata>, kube: Option<KubeMetadata>,
compose: Option<ComposeMetadata>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@@ -30,12 +28,6 @@ struct KubeMetadata {
port: u16, port: u16,
} }
#[derive(Debug, Deserialize)]
struct ComposeMetadata {
image: Option<String>,
port: Option<u16>,
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct ServiceInfo { struct ServiceInfo {
name: String, name: String,
@@ -105,7 +97,9 @@ fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
.into_iter() .into_iter()
.filter_map(|e| e.ok()) .filter_map(|e| e.ok())
{ {
if entry.file_name() == "Cargo.toml" && entry.path() != workspace_root.join("Cargo.toml") { if entry.file_name() == "Cargo.toml"
&& entry.path() != workspace_root.join("../../../Cargo.toml")
{
if let Ok(service_info) = parse_cargo_toml(entry.path()) { if let Ok(service_info) = parse_cargo_toml(entry.path()) {
services.push(service_info); services.push(service_info);
} }
@@ -375,7 +369,7 @@ spec:
Ok(()) Ok(())
} }
fn generate_ingress_template(templates_dir: &Path, services: &[ServiceInfo]) -> Result<()> { fn generate_ingress_template(templates_dir: &Path, _services: &[ServiceInfo]) -> Result<()> {
let ingress_template = r#"{{- if .Values.ingress.enabled -}} let ingress_template = r#"{{- if .Values.ingress.enabled -}}
apiVersion: networking.k8s.io/v1 apiVersion: networking.k8s.io/v1
kind: Ingress kind: Ingress

View File

@@ -1,6 +1,5 @@
pub mod llama_api; pub mod llama_api;
use clap::ValueEnum;
pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel}; pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
// Re-export constants and types that might be needed // Re-export constants and types that might be needed

View File

@@ -57,6 +57,27 @@ pub struct LlamaInferenceConfig {
pub repeat_last_n: usize, pub repeat_last_n: usize,
} }
impl LlamaInferenceConfig {
pub fn new(model: WhichModel) -> Self {
Self {
prompt: String::new(),
model,
cpu: false,
temperature: 1.0,
top_p: None,
top_k: None,
seed: 42,
max_tokens: 512,
no_kv_cache: false,
dtype: None,
model_id: None,
revision: None,
use_flash_attn: true,
repeat_penalty: 1.1,
repeat_last_n: 64,
}
}
}
impl Default for LlamaInferenceConfig { impl Default for LlamaInferenceConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
@@ -81,7 +102,7 @@ impl Default for LlamaInferenceConfig {
max_tokens: 512, max_tokens: 512,
// Performance flags // Performance flags
no_kv_cache: false, // keep cache ON for speed no_kv_cache: false, // keep cache ON for speed
use_flash_attn: false, // great speed boost if supported use_flash_attn: false, // great speed boost if supported
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed. // Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.

View File

@@ -6,9 +6,6 @@ mod llama_api;
mod llama_cli; mod llama_cli;
use anyhow::Result; use anyhow::Result;
use clap::{Parser, ValueEnum};
use std::io::Write;
use crate::llama_cli::run_cli; use crate::llama_cli::run_cli;

View File

@@ -1,17 +1,14 @@
[package] [package]
name = "utils" name = "utils"
edition = "2021"
[lib] [lib]
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
accelerate-src = {version = "0.3.2", optional = true } accelerate-src = {version = "0.3.2", optional = true }
candle-nn = {version = "0.9.1" }
candle-transformers = {version = "0.9.1" }
candle-flash-attn = {version = "0.9.1", optional = true } candle-flash-attn = {version = "0.9.1", optional = true }
candle-onnx = {version = "0.9.1", optional = true } candle-onnx = {version = "0.9.1", optional = true }
candle-core="0.9.1"
csv = "1.3.0" csv = "1.3.0"
anyhow = "1.0.99" anyhow = "1.0.99"
cudarc = {version = "0.17.3", optional = true } cudarc = {version = "0.17.3", optional = true }
@@ -86,3 +83,14 @@ mimi = ["cpal", "symphonia", "rubato"]
snac = ["cpal", "symphonia", "rubato"] snac = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"] depth_anything_v2 = ["palette", "enterpolation"]
tekken = ["tekken-rs"] tekken = ["tekken-rs"]
# Platform-specific candle dependencies
[target.'cfg(target_os = "linux")'.dependencies]
candle-nn = {version = "0.9.1", default-features = false }
candle-transformers = {version = "0.9.1", default-features = false }
candle-core = {version = "0.9.1", default-features = false }
[target.'cfg(not(target_os = "linux"))'.dependencies]
candle-nn = {version = "0.9.1" }
candle-transformers = {version = "0.9.1" }
candle-core = {version = "0.9.1" }

View File

@@ -1,5 +1,5 @@
use candle_transformers::models::mimi::candle;
use candle_core::{Device, Result, Tensor}; use candle_core::{Device, Result, Tensor};
use candle_transformers::models::mimi::candle;
pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406]; pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225]; pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];

View File

@@ -8,8 +8,10 @@ pub mod coco_classes;
pub mod imagenet; pub mod imagenet;
pub mod token_output_stream; pub mod token_output_stream;
pub mod wav; pub mod wav;
use candle_core::{Device, Tensor, utils::{cuda_is_available, metal_is_available}}; use candle_core::{
utils::{cuda_is_available, metal_is_available},
Device, Tensor,
};
pub fn device(cpu: bool) -> Result<Device, anyhow::Error> { pub fn device(cpu: bool) -> Result<Device, anyhow::Error> {
if cpu { if cpu {
@@ -122,11 +124,8 @@ pub fn hub_load_safetensors(
} }
let safetensors_files = safetensors_files let safetensors_files = safetensors_files
.iter() .iter()
.map(|v| { .map(|v| repo.get(v).map_err(std::io::Error::other))
repo.get(v) .collect::<Result<Vec<_>, std::io::Error>>()?;
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
})
.collect::<Result<Vec<_>, std::io::Error, >>()?;
Ok(safetensors_files) Ok(safetensors_files)
} }
@@ -136,7 +135,8 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> { ) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
let path = path.as_ref(); let path = path.as_ref();
let jsfile = std::fs::File::open(path.join(json_file))?; let jsfile = std::fs::File::open(path.join(json_file))?;
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?; let json: serde_json::Value =
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") { let weight_map = match json.get("weight_map") {
None => anyhow::bail!("no weight map in {json_file:?}"), None => anyhow::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map, Some(serde_json::Value::Object(map)) => map,

View File

@@ -1,8 +1,8 @@
{ {
"name": "predict-otron-9000", "name": "predict-otron-9000",
"workspaces": ["crates/cli/package"], "workspaces": ["integration/cli/package"],
"scripts": { "scripts": {
"# WORKSPACE ALIASES": "#", "# WORKSPACE ALIASES": "#",
"cli": "bun --filter crates/cli/package" "cli": "bun --filter integration/cli/package"
} }
} }