9 Commits

Author SHA1 Message Date
geoffsee
4380ac69d3 v0.1.5 already exists 2025-09-04 15:09:30 -04:00
geoffsee
e6f3351ebb minor version 2025-09-04 15:08:43 -04:00
geoffsee
3992532f15 fmt and clippy 2025-09-04 15:07:49 -04:00
geoffsee
3ecdd9ffa0 update deployment tooling to remove dependencies on unused metadata 2025-09-04 15:03:17 -04:00
geoffsee
296d4dbe7e add root dockerfile that contains binaries for all services 2025-09-04 14:54:20 -04:00
geoffsee
fb5098eba6 fix clippy errors 2025-09-04 13:53:00 -04:00
geoffsee
c1c583faab run cargo fmt 2025-09-04 13:45:25 -04:00
geoffsee
1e02b12cda fixes issue with model selection 2025-09-04 13:42:30 -04:00
geoffsee
ff55d882c7 reorg + update docs with new paths 2025-09-04 12:40:59 -04:00
62 changed files with 1109 additions and 672 deletions

35
.dockerignore Normal file
View File

@@ -0,0 +1,35 @@
# Git
.git
.gitignore
# Rust
target/
# Documentation
README.md
*.md
# IDE
.vscode/
.idea/
*.swp
*.swo
# OS
.DS_Store
Thumbs.db
# Logs
*.log
# Environment
.env
.env.local
# Dependencies
node_modules/
# Build artifacts
dist/
build/
.fastembed_cache

View File

@@ -44,7 +44,7 @@ jobs:
- name: Clippy
shell: bash
run: cargo clippy --all-targets
run: cargo clippy --all
- name: Tests
shell: bash

46
.github/workflows/docker.yml vendored Normal file
View File

@@ -0,0 +1,46 @@
name: Build and Push Docker Image
on:
tags:
- 'v*'
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
jobs:
build-and-push:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Log in to Container Registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=sha
- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -45,7 +45,7 @@ jobs:
- name: Clippy
shell: bash
run: cargo clippy --all-targets
run: cargo clippy --all
- name: Tests
shell: bash

1
.gitignore vendored
View File

@@ -77,3 +77,4 @@ venv/
!/scripts/cli.ts
/**/.*.bun-build
/AGENTS.md
.claude

3
Cargo.lock generated
View File

@@ -2905,6 +2905,7 @@ dependencies = [
"clap",
"cpal",
"either",
"embeddings-engine",
"futures-util",
"gemma-runner",
"imageproc 0.24.0",
@@ -7040,7 +7041,7 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "utils"
version = "0.0.0"
version = "0.1.4"
dependencies = [
"ab_glyph",
"accelerate-src",

View File

@@ -3,17 +3,17 @@ members = [
"crates/predict-otron-9000",
"crates/inference-engine",
"crates/embeddings-engine",
"crates/helm-chart-tool",
"crates/llama-runner",
"crates/gemma-runner",
"crates/cli",
"integration/helm-chart-tool",
"integration/llama-runner",
"integration/gemma-runner",
"integration/cli",
"crates/chat-ui"
, "crates/utils"]
, "integration/utils"]
default-members = ["crates/predict-otron-9000"]
resolver = "2"
[workspace.package]
version = "0.1.4"
version = "0.1.6"
# Compiler optimization profiles for the workspace
[profile.release]

50
Dockerfile Normal file
View File

@@ -0,0 +1,50 @@
# Multi-stage build for predict-otron-9000 workspace
FROM rust:1 AS builder
# Install system dependencies
RUN apt-get update && apt-get install -y \
pkg-config \
libssl-dev \
&& rm -rf /var/lib/apt/lists/*
# Create app directory
WORKDIR /app
# Copy workspace files
COPY Cargo.toml Cargo.lock ./
COPY crates/ ./crates/
COPY integration/ ./integration/
# Build all 3 main server binaries in release mode
RUN cargo build --release -p predict-otron-9000 --bin predict-otron-9000 --no-default-features -p embeddings-engine --bin embeddings-engine -p inference-engine --bin inference-engine
# Runtime stage
FROM debian:bookworm-slim AS runtime
# Install runtime dependencies
RUN apt-get update && apt-get install -y \
ca-certificates \
libssl3 \
&& rm -rf /var/lib/apt/lists/*
# Create app user
RUN useradd -r -s /bin/false -m -d /app appuser
# Set working directory
WORKDIR /app
# Copy binaries from builder stage
COPY --from=builder /app/target/release/predict-otron-9000 ./bin/
COPY --from=builder /app/target/release/embeddings-engine ./bin/
COPY --from=builder /app/target/release/inference-engine ./bin/
# Make binaries executable and change ownership
RUN chmod +x ./bin/* && chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Expose ports (adjust as needed based on your services)
EXPOSE 8080 8081 8082
# Default command (can be overridden)
CMD ["./bin/predict-otron-9000"]

View File

@@ -12,7 +12,7 @@ AI inference Server with OpenAI-compatible API (Limited Features)
> This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks.
> By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems.
Stability is currently best effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
Stability is currently best-effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces.
@@ -53,14 +53,17 @@ The project uses a 9-crate Rust workspace plus TypeScript components:
crates/
├── predict-otron-9000/ # Main orchestration server (Rust 2024)
├── inference-engine/ # Multi-model inference orchestrator (Rust 2021)
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
└── chat-ui/ # WASM web frontend (Rust 2021)
integration/
├── cli/ # CLI client crate (Rust 2024)
│ └── package/
│ └── cli.ts # TypeScript/Bun CLI client
├── gemma-runner/ # Gemma model inference via Candle (Rust 2021)
├── llama-runner/ # Llama model inference via Candle (Rust 2021)
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
├── chat-ui/ # WASM web frontend (Rust 2021)
├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024)
└── cli/ # CLI client crate (Rust 2024)
└── package/
└── cli.ts # TypeScript/Bun CLI client
└── utils/ # Shared utilities (Rust 2021)
```
### Service Architecture
@@ -160,16 +163,16 @@ cd crates/chat-ui
#### TypeScript CLI Client
```bash
# List available models
cd crates/cli/package && bun run cli.ts --list-models
cd integration/cli/package && bun run cli.ts --list-models
# Chat completion
cd crates/cli/package && bun run cli.ts "What is the capital of France?"
cd integration/cli/package && bun run cli.ts "What is the capital of France?"
# With specific model
cd crates/cli/package && bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
cd integration/cli/package && bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
# Show help
cd crates/cli/package && bun run cli.ts --help
cd integration/cli/package && bun run cli.ts --help
```
## API Usage
@@ -464,7 +467,7 @@ curl -s http://localhost:8080/v1/models | jq
**CLI client test:**
```bash
cd crates/cli/package && bun run cli.ts "What is 2+2?"
cd integration/cli/package && bun run cli.ts "What is 2+2?"
```
**Web frontend:**

View File

@@ -4,7 +4,7 @@
"": {
"name": "predict-otron-9000",
},
"crates/cli/package": {
"integration/cli/package": {
"name": "cli",
"dependencies": {
"install": "^0.13.0",
@@ -13,7 +13,7 @@
},
},
"packages": {
"cli": ["cli@workspace:crates/cli/package"],
"cli": ["cli@workspace:integration/cli/package"],
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],

View File

@@ -3,6 +3,7 @@ name = "chat-ui"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["cdylib", "rlib"]
@@ -122,3 +123,7 @@ lib-default-features = false
#
# Optional. Defaults to "release".
lib-profile-release = "release"
[[bin]]
name = "chat-ui"
path = "src/main.rs"

View File

@@ -257,7 +257,8 @@ pub fn send_chat_completion_stream(
break;
}
let value = js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap();
let value =
js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap();
let array = js_sys::Uint8Array::new(&value);
let mut bytes = vec![0; array.length() as usize];
array.copy_to(&mut bytes);
@@ -279,7 +280,9 @@ pub fn send_chat_completion_stream(
}
// Parse JSON chunk
if let Ok(chunk) = serde_json::from_str::<StreamChatResponse>(data) {
if let Ok(chunk) =
serde_json::from_str::<StreamChatResponse>(data)
{
if let Some(choice) = chunk.choices.first() {
if let Some(content) = &choice.delta.content {
on_chunk(content.clone());
@@ -365,7 +368,7 @@ fn ChatPage() -> impl IntoView {
// State for available models and selected model
let available_models = RwSignal::new(Vec::<ModelInfo>::new());
let selected_model = RwSignal::new(String::from("gemma-3-1b-it")); // Default model
let selected_model = RwSignal::new(String::from("")); // Default model
// State for streaming response
let streaming_content = RwSignal::new(String::new());
@@ -382,6 +385,7 @@ fn ChatPage() -> impl IntoView {
match fetch_models().await {
Ok(models) => {
available_models.set(models);
selected_model.set(String::from("gemma-3-1b-it"));
}
Err(error) => {
console::log_1(&format!("Failed to fetch models: {}", error).into());

View File

@@ -25,15 +25,9 @@ rand = "0.8.5"
async-openai = "0.28.3"
once_cell = "1.19.0"
[package.metadata.compose]
image = "ghcr.io/geoffsee/embeddings-service:latest"
port = 8080
# generates kubernetes manifests
[package.metadata.kube]
image = "ghcr.io/geoffsee/embeddings-service:latest"
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
cmd = ["./bin/embeddings-engine"]
replicas = 1
port = 8080

View File

@@ -1,42 +0,0 @@
# ---- Build stage ----
FROM rust:1-slim-bullseye AS builder
WORKDIR /usr/src/app
# Install build dependencies
RUN apt-get update && \
apt-get install -y --no-install-recommends \
pkg-config \
libssl-dev \
build-essential \
&& rm -rf /var/lib/apt/lists/*
# Cache deps first
COPY . ./
RUN rm -rf src
RUN mkdir src && echo "fn main() {}" > src/main.rs && echo "// lib" > src/lib.rs && cargo build --release
RUN rm -rf src
# Copy real sources and build
COPY . .
RUN cargo build --release
# ---- Runtime stage ----
FROM debian:bullseye-slim
# Install only what the compiled binary needs
RUN apt-get update && \
apt-get install -y --no-install-recommends \
libssl1.1 \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Copy binary from builder
COPY --from=builder /usr/src/app/target/release/embeddings-engine /usr/local/bin/
# Run as non-root user for safety
RUN useradd -m appuser
USER appuser
EXPOSE 8080
CMD ["embeddings-engine"]

View File

@@ -1,43 +1,225 @@
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
use axum::{Json, Router, response::Json as ResponseJson, routing::post};
use axum::{Json, Router, http::StatusCode, response::Json as ResponseJson, routing::post};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use once_cell::sync::Lazy;
use serde::Serialize;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tower_http::trace::TraceLayer;
use tracing;
// Persistent model instance (singleton pattern)
static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
tracing::info!("Initializing persistent embedding model (singleton)");
// Cache for multiple embedding models
static MODEL_CACHE: Lazy<RwLock<HashMap<EmbeddingModel, Arc<TextEmbedding>>>> =
Lazy::new(|| RwLock::new(HashMap::new()));
#[derive(Serialize)]
pub struct ModelInfo {
pub id: String,
pub object: String,
pub owned_by: String,
pub description: String,
pub dimensions: usize,
}
#[derive(Serialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelInfo>,
}
// Function to convert model name strings to EmbeddingModel enum variants
fn parse_embedding_model(model_name: &str) -> Result<EmbeddingModel, String> {
match model_name {
// Sentence Transformers models
"sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => {
Ok(EmbeddingModel::AllMiniLML6V2)
}
"sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => {
Ok(EmbeddingModel::AllMiniLML6V2Q)
}
"sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => {
Ok(EmbeddingModel::AllMiniLML12V2)
}
"sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => {
Ok(EmbeddingModel::AllMiniLML12V2Q)
}
// BGE models
"BAAI/bge-base-en-v1.5" | "bge-base-en-v1.5" => Ok(EmbeddingModel::BGEBaseENV15),
"BAAI/bge-base-en-v1.5-q" | "bge-base-en-v1.5-q" => Ok(EmbeddingModel::BGEBaseENV15Q),
"BAAI/bge-large-en-v1.5" | "bge-large-en-v1.5" => Ok(EmbeddingModel::BGELargeENV15),
"BAAI/bge-large-en-v1.5-q" | "bge-large-en-v1.5-q" => Ok(EmbeddingModel::BGELargeENV15Q),
"BAAI/bge-small-en-v1.5" | "bge-small-en-v1.5" => Ok(EmbeddingModel::BGESmallENV15),
"BAAI/bge-small-en-v1.5-q" | "bge-small-en-v1.5-q" => Ok(EmbeddingModel::BGESmallENV15Q),
"BAAI/bge-small-zh-v1.5" | "bge-small-zh-v1.5" => Ok(EmbeddingModel::BGESmallZHV15),
"BAAI/bge-large-zh-v1.5" | "bge-large-zh-v1.5" => Ok(EmbeddingModel::BGELargeZHV15),
// Nomic models
"nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => {
Ok(EmbeddingModel::NomicEmbedTextV1)
}
"nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => {
Ok(EmbeddingModel::NomicEmbedTextV15)
}
"nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => {
Ok(EmbeddingModel::NomicEmbedTextV15Q)
}
// Paraphrase models
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
| "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2),
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q"
| "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q),
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
| "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2),
// ModernBert
"lightonai/modernbert-embed-large" | "modernbert-embed-large" => {
Ok(EmbeddingModel::ModernBertEmbedLarge)
}
// Multilingual E5 models
"intfloat/multilingual-e5-small" | "multilingual-e5-small" => {
Ok(EmbeddingModel::MultilingualE5Small)
}
"intfloat/multilingual-e5-base" | "multilingual-e5-base" => {
Ok(EmbeddingModel::MultilingualE5Base)
}
"intfloat/multilingual-e5-large" | "multilingual-e5-large" => {
Ok(EmbeddingModel::MultilingualE5Large)
}
// Mixedbread models
"mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => {
Ok(EmbeddingModel::MxbaiEmbedLargeV1)
}
"mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => {
Ok(EmbeddingModel::MxbaiEmbedLargeV1Q)
}
// GTE models
"Alibaba-NLP/gte-base-en-v1.5" | "gte-base-en-v1.5" => Ok(EmbeddingModel::GTEBaseENV15),
"Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => {
Ok(EmbeddingModel::GTEBaseENV15Q)
}
"Alibaba-NLP/gte-large-en-v1.5" | "gte-large-en-v1.5" => Ok(EmbeddingModel::GTELargeENV15),
"Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => {
Ok(EmbeddingModel::GTELargeENV15Q)
}
// CLIP model
"Qdrant/clip-ViT-B-32-text" | "clip-vit-b-32" => Ok(EmbeddingModel::ClipVitB32),
// Jina model
"jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => {
Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode)
}
_ => Err(format!("Unsupported embedding model: {}", model_name)),
}
}
// Function to get model dimensions
fn get_model_dimensions(model: &EmbeddingModel) -> usize {
match model {
EmbeddingModel::AllMiniLML6V2 | EmbeddingModel::AllMiniLML6V2Q => 384,
EmbeddingModel::AllMiniLML12V2 | EmbeddingModel::AllMiniLML12V2Q => 384,
EmbeddingModel::BGEBaseENV15 | EmbeddingModel::BGEBaseENV15Q => 768,
EmbeddingModel::BGELargeENV15 | EmbeddingModel::BGELargeENV15Q => 1024,
EmbeddingModel::BGESmallENV15 | EmbeddingModel::BGESmallENV15Q => 384,
EmbeddingModel::BGESmallZHV15 => 512,
EmbeddingModel::BGELargeZHV15 => 1024,
EmbeddingModel::NomicEmbedTextV1
| EmbeddingModel::NomicEmbedTextV15
| EmbeddingModel::NomicEmbedTextV15Q => 768,
EmbeddingModel::ParaphraseMLMiniLML12V2 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => 384,
EmbeddingModel::ParaphraseMLMpnetBaseV2 => 768,
EmbeddingModel::ModernBertEmbedLarge => 1024,
EmbeddingModel::MultilingualE5Small => 384,
EmbeddingModel::MultilingualE5Base => 768,
EmbeddingModel::MultilingualE5Large => 1024,
EmbeddingModel::MxbaiEmbedLargeV1 | EmbeddingModel::MxbaiEmbedLargeV1Q => 1024,
EmbeddingModel::GTEBaseENV15 | EmbeddingModel::GTEBaseENV15Q => 768,
EmbeddingModel::GTELargeENV15 | EmbeddingModel::GTELargeENV15Q => 1024,
EmbeddingModel::ClipVitB32 => 512,
EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
}
}
// Function to get or create a model from cache
fn get_or_create_model(embedding_model: EmbeddingModel) -> Result<Arc<TextEmbedding>, String> {
// First try to get from cache (read lock)
{
let cache = MODEL_CACHE
.read()
.map_err(|e| format!("Failed to acquire read lock: {}", e))?;
if let Some(model) = cache.get(&embedding_model) {
tracing::debug!("Using cached model: {:?}", embedding_model);
return Ok(Arc::clone(model));
}
}
// Model not in cache, create it (write lock)
let mut cache = MODEL_CACHE
.write()
.map_err(|e| format!("Failed to acquire write lock: {}", e))?;
// Double-check after acquiring write lock
if let Some(model) = cache.get(&embedding_model) {
tracing::debug!("Using cached model (double-check): {:?}", embedding_model);
return Ok(Arc::clone(model));
}
tracing::info!("Initializing new embedding model: {:?}", embedding_model);
let model_start_time = std::time::Instant::now();
let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
InitOptions::new(embedding_model.clone()).with_show_download_progress(true),
)
.expect("Failed to initialize persistent embedding model");
.map_err(|e| format!("Failed to initialize model {:?}: {}", embedding_model, e))?;
let model_init_time = model_start_time.elapsed();
tracing::info!(
"Persistent embedding model initialized in {:.2?}",
"Embedding model {:?} initialized in {:.2?}",
embedding_model,
model_init_time
);
model
});
let model_arc = Arc::new(model);
cache.insert(embedding_model.clone(), Arc::clone(&model_arc));
Ok(model_arc)
}
pub async fn embeddings_create(
Json(payload): Json<CreateEmbeddingRequest>,
) -> ResponseJson<serde_json::Value> {
) -> Result<ResponseJson<serde_json::Value>, (StatusCode, String)> {
// Start timing the entire process
let start_time = std::time::Instant::now();
// Phase 1: Access persistent model instance
// Phase 1: Parse and get the embedding model
let model_start_time = std::time::Instant::now();
// Access the lazy-initialized persistent model instance
// This will only initialize the model on the first request
let embedding_model = match parse_embedding_model(&payload.model) {
Ok(model) => model,
Err(e) => {
tracing::error!("Invalid model requested: {}", e);
return Err((StatusCode::BAD_REQUEST, format!("Invalid model: {}", e)));
}
};
let model = match get_or_create_model(embedding_model.clone()) {
Ok(model) => model,
Err(e) => {
tracing::error!("Failed to get/create model: {}", e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Model initialization failed: {}", e),
));
}
};
let model_access_time = model_start_time.elapsed();
tracing::debug!(
"Persistent model access completed in {:.2?}",
"Model access/creation completed in {:.2?}",
model_access_time
);
@@ -65,9 +247,13 @@ pub async fn embeddings_create(
// Phase 3: Generate embeddings
let embedding_start_time = std::time::Instant::now();
let embeddings = EMBEDDING_MODEL
.embed(texts_from_embedding_input, None)
.expect("failed to embed document");
let embeddings = model.embed(texts_from_embedding_input, None).map_err(|e| {
tracing::error!("Failed to generate embeddings: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Embedding generation failed: {}", e),
)
})?;
let embedding_generation_time = embedding_start_time.elapsed();
tracing::info!(
@@ -117,8 +303,9 @@ pub async fn embeddings_create(
// Generate a random non-zero embedding
use rand::Rng;
let mut rng = rand::thread_rng();
let mut random_embedding = Vec::with_capacity(768);
for _ in 0..768 {
let expected_dimensions = get_model_dimensions(&embedding_model);
let mut random_embedding = Vec::with_capacity(expected_dimensions);
for _ in 0..expected_dimensions {
// Generate random values between -1.0 and 1.0, excluding 0
let mut val = 0.0;
while val == 0.0 {
@@ -138,18 +325,19 @@ pub async fn embeddings_create(
random_embedding
} else {
// Check if dimensions parameter is provided and pad the embeddings if necessary
let mut padded_embedding = embeddings[0].clone();
let padded_embedding = embeddings[0].clone();
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
let target_dimension = 768;
if padded_embedding.len() < target_dimension {
let padding_needed = target_dimension - padded_embedding.len();
tracing::trace!(
"Padding embedding with {} zeros to reach {} dimensions",
padding_needed,
target_dimension
// Use the actual model dimensions instead of hardcoded 768
let actual_dimensions = padded_embedding.len();
let expected_dimensions = get_model_dimensions(&embedding_model);
if actual_dimensions != expected_dimensions {
tracing::warn!(
"Model {:?} produced {} dimensions but expected {}",
embedding_model,
actual_dimensions,
expected_dimensions
);
padded_embedding.extend(vec![0.0; padding_needed]);
}
padded_embedding
@@ -203,11 +391,234 @@ pub async fn embeddings_create(
postprocessing_time
);
ResponseJson(response)
Ok(ResponseJson(response))
}
pub async fn models_list() -> ResponseJson<ModelsResponse> {
let models = vec![
ModelInfo {
id: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence Transformer model, MiniLM-L6-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/all-MiniLM-L6-v2-q".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Quantized Sentence Transformer model, MiniLM-L6-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/all-MiniLM-L12-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence Transformer model, MiniLM-L12-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/all-MiniLM-L12-v2-q".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Quantized Sentence Transformer model, MiniLM-L12-v2".to_string(),
dimensions: 384,
},
ModelInfo {
id: "BAAI/bge-base-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the base English model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "BAAI/bge-base-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "Quantized v1.5 release of the base English model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "BAAI/bge-large-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the large English model".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "BAAI/bge-large-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "Quantized v1.5 release of the large English model".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "BAAI/bge-small-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the fast and default English model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "BAAI/bge-small-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "Quantized v1.5 release of the fast and default English model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "BAAI/bge-small-zh-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the small Chinese model".to_string(),
dimensions: 512,
},
ModelInfo {
id: "BAAI/bge-large-zh-v1.5".to_string(),
object: "model".to_string(),
owned_by: "BAAI".to_string(),
description: "v1.5 release of the large Chinese model".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "nomic-ai/nomic-embed-text-v1".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "8192 context length english model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "nomic-ai/nomic-embed-text-v1.5".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "v1.5 release of the 8192 context length english model".to_string(),
dimensions: 768,
},
ModelInfo {
id: "nomic-ai/nomic-embed-text-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "nomic-ai".to_string(),
description: "Quantized v1.5 release of the 8192 context length english model"
.to_string(),
dimensions: 768,
},
ModelInfo {
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Multi-lingual model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Quantized Multi-lingual model".to_string(),
dimensions: 384,
},
ModelInfo {
id: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2".to_string(),
object: "model".to_string(),
owned_by: "sentence-transformers".to_string(),
description: "Sentence-transformers model for tasks like clustering or semantic search"
.to_string(),
dimensions: 768,
},
ModelInfo {
id: "lightonai/modernbert-embed-large".to_string(),
object: "model".to_string(),
owned_by: "lightonai".to_string(),
description: "Large model of ModernBert Text Embeddings".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "intfloat/multilingual-e5-small".to_string(),
object: "model".to_string(),
owned_by: "intfloat".to_string(),
description: "Small model of multilingual E5 Text Embeddings".to_string(),
dimensions: 384,
},
ModelInfo {
id: "intfloat/multilingual-e5-base".to_string(),
object: "model".to_string(),
owned_by: "intfloat".to_string(),
description: "Base model of multilingual E5 Text Embeddings".to_string(),
dimensions: 768,
},
ModelInfo {
id: "intfloat/multilingual-e5-large".to_string(),
object: "model".to_string(),
owned_by: "intfloat".to_string(),
description: "Large model of multilingual E5 Text Embeddings".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "mixedbread-ai/mxbai-embed-large-v1".to_string(),
object: "model".to_string(),
owned_by: "mixedbread-ai".to_string(),
description: "Large English embedding model from MixedBreed.ai".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "mixedbread-ai/mxbai-embed-large-v1-q".to_string(),
object: "model".to_string(),
owned_by: "mixedbread-ai".to_string(),
description: "Quantized Large English embedding model from MixedBreed.ai".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "Alibaba-NLP/gte-base-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Base multilingual embedding model from Alibaba".to_string(),
dimensions: 768,
},
ModelInfo {
id: "Alibaba-NLP/gte-base-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Quantized Base multilingual embedding model from Alibaba".to_string(),
dimensions: 768,
},
ModelInfo {
id: "Alibaba-NLP/gte-large-en-v1.5".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Large multilingual embedding model from Alibaba".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "Alibaba-NLP/gte-large-en-v1.5-q".to_string(),
object: "model".to_string(),
owned_by: "Alibaba-NLP".to_string(),
description: "Quantized Large multilingual embedding model from Alibaba".to_string(),
dimensions: 1024,
},
ModelInfo {
id: "Qdrant/clip-ViT-B-32-text".to_string(),
object: "model".to_string(),
owned_by: "Qdrant".to_string(),
description: "CLIP text encoder based on ViT-B/32".to_string(),
dimensions: 512,
},
ModelInfo {
id: "jinaai/jina-embeddings-v2-base-code".to_string(),
object: "model".to_string(),
owned_by: "jinaai".to_string(),
description: "Jina embeddings v2 base code".to_string(),
dimensions: 768,
},
];
ResponseJson(ModelsResponse {
object: "list".to_string(),
data: models,
})
}
pub fn create_embeddings_router() -> Router {
Router::new()
.route("/v1/embeddings", post(embeddings_create))
// .route("/v1/models", get(models_list))
.layer(TraceLayer::new_for_http())
}

View File

@@ -4,8 +4,6 @@ use axum::{
response::Json as ResponseJson,
routing::{get, post},
};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use serde::{Deserialize, Serialize};
use std::env;
use tower_http::trace::TraceLayer;
use tracing;
@@ -13,127 +11,28 @@ use tracing;
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
const DEFAULT_SERVER_PORT: &str = "8080";
use embeddings_engine;
async fn embeddings_create(
Json(payload): Json<CreateEmbeddingRequest>,
) -> ResponseJson<serde_json::Value> {
let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
)
.expect("Failed to initialize model");
) -> Result<ResponseJson<serde_json::Value>, axum::response::Response> {
match embeddings_engine::embeddings_create(Json(payload)).await {
Ok(response) => Ok(response),
Err((status_code, message)) => Err(axum::response::Response::builder()
.status(status_code)
.body(axum::body::Body::from(message))
.unwrap()),
}
}
let embedding_input = payload.input;
let texts_from_embedding_input = match embedding_input {
EmbeddingInput::String(text) => vec![text],
EmbeddingInput::StringArray(texts) => texts,
EmbeddingInput::IntegerArray(_) => {
panic!("Integer array input not supported for text embeddings");
}
EmbeddingInput::ArrayOfIntegerArray(_) => {
panic!("Array of integer arrays not supported for text embeddings");
}
};
let embeddings = model
.embed(texts_from_embedding_input, None)
.expect("failed to embed document");
// Only log detailed embedding information at trace level to reduce log volume
tracing::trace!("Embeddings length: {}", embeddings.len());
tracing::info!("Embedding dimension: {}", embeddings[0].len());
// Log the first 10 values of the original embedding at trace level
tracing::trace!(
"Original embedding preview: {:?}",
&embeddings[0][..10.min(embeddings[0].len())]
);
// Check if there are any NaN or zero values in the original embedding
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
tracing::trace!(
"Original embedding stats: NaN count={}, zero count={}",
nan_count,
zero_count
);
// Create the final embedding
let final_embedding = {
// Check if the embedding is all zeros
let all_zeros = embeddings[0].iter().all(|&x| x == 0.0);
if all_zeros {
tracing::warn!("Embedding is all zeros. Generating random non-zero embedding.");
// Generate a random non-zero embedding
use rand::Rng;
let mut rng = rand::thread_rng();
let mut random_embedding = Vec::with_capacity(768);
for _ in 0..768 {
// Generate random values between -1.0 and 1.0, excluding 0
let mut val = 0.0;
while val == 0.0 {
val = rng.gen_range(-1.0..1.0);
}
random_embedding.push(val);
}
// Normalize the random embedding
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
for i in 0..random_embedding.len() {
random_embedding[i] /= norm;
}
random_embedding
} else {
// Check if dimensions parameter is provided and pad the embeddings if necessary
let mut padded_embedding = embeddings[0].clone();
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
let target_dimension = 768;
if padded_embedding.len() < target_dimension {
let padding_needed = target_dimension - padded_embedding.len();
tracing::trace!(
"Padding embedding with {} zeros to reach {} dimensions",
padding_needed,
target_dimension
);
padded_embedding.extend(vec![0.0; padding_needed]);
}
padded_embedding
}
};
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
// Log the first 10 values of the final embedding at trace level
tracing::trace!(
"Final embedding preview: {:?}",
&final_embedding[..10.min(final_embedding.len())]
);
// Return a response that matches the OpenAI API format
let response = serde_json::json!({
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": final_embedding
}
],
"model": payload.model,
"usage": {
"prompt_tokens": 0,
"total_tokens": 0
}
});
ResponseJson(response)
async fn models_list() -> ResponseJson<embeddings_engine::ModelsResponse> {
embeddings_engine::models_list().await
}
fn create_app() -> Router {
Router::new()
.route("/v1/embeddings", post(embeddings_create))
.route("/v1/models", get(models_list))
.layer(TraceLayer::new_for_http())
}
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

View File

@@ -1,7 +1,7 @@
[package]
name = "inference-engine"
version.workspace = true
edition = "2021"
edition = "2024"
[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git" }
@@ -31,13 +31,21 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
uuid = { version = "1.7.0", features = ["v4"] }
reborrow = "0.5.5"
futures-util = "0.3.31"
gemma-runner = { path = "../gemma-runner", features = ["metal"] }
llama-runner = { path = "../llama-runner", features = ["metal"]}
gemma-runner = { path = "../../integration/gemma-runner" }
llama-runner = { path = "../../integration/llama-runner" }
embeddings-engine = { path = "../embeddings-engine" }
[target.'cfg(target_os = "linux")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", default-features = false }
candle-nn = { git = "https://github.com/huggingface/candle.git", default-features = false }
candle-transformers = { git = "https://github.com/huggingface/candle.git", default-features = false }
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
gemma-runner = { path = "../../integration/gemma-runner", features = ["metal"] }
llama-runner = { path = "../../integration/llama-runner", features = ["metal"] }
[dev-dependencies]
@@ -61,15 +69,13 @@ bindgen_cuda = { version = "0.1.1", optional = true }
[features]
bin = []
[package.metadata.compose]
image = "ghcr.io/geoffsee/inference-engine:latest"
port = 8080
[[bin]]
name = "inference-engine"
path = "src/main.rs"
# generates kubernetes manifests
[package.metadata.kube]
image = "ghcr.io/geoffsee/inference-service:latest"
replicas = 1
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
cmd = ["./bin/inference-engine"]
port = 8080
replicas = 1

View File

@@ -1,86 +0,0 @@
# ---- Build stage ----
FROM rust:1-slim-bullseye AS builder
WORKDIR /usr/src/app
# Install build dependencies including CUDA toolkit for GPU support
RUN apt-get update && \
apt-get install -y --no-install-recommends \
pkg-config \
libssl-dev \
build-essential \
wget \
gnupg2 \
curl \
&& rm -rf /var/lib/apt/lists/*
# Install CUDA toolkit (optional, for GPU support)
# This is a minimal CUDA installation for building
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb && \
dpkg -i cuda-keyring_1.0-1_all.deb && \
apt-get update && \
apt-get install -y --no-install-recommends \
cuda-minimal-build-11-8 \
libcublas-dev-11-8 \
libcurand-dev-11-8 \
&& rm -rf /var/lib/apt/lists/* \
&& rm cuda-keyring_1.0-1_all.deb
# Set CUDA environment variables
ENV CUDA_HOME=/usr/local/cuda
ENV PATH=${CUDA_HOME}/bin:${PATH}
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
# Copy the entire workspace to get access to all crates
COPY . ./
# Cache dependencies first - create dummy source files
RUN rm -rf crates/inference-engine/src
RUN mkdir -p crates/inference-engine/src && \
echo "fn main() {}" > crates/inference-engine/src/main.rs && \
echo "fn main() {}" > crates/inference-engine/src/cli_main.rs && \
echo "// lib" > crates/inference-engine/src/lib.rs && \
cargo build --release --bin cli --package inference-engine
# Remove dummy source and copy real sources
RUN rm -rf crates/inference-engine/src
COPY . .
# Build the actual CLI binary
RUN cargo build --release --bin cli --package inference-engine
# ---- Runtime stage ----
FROM debian:bullseye-slim
# Install runtime dependencies
RUN apt-get update && \
apt-get install -y --no-install-recommends \
libssl1.1 \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Install CUDA runtime libraries (optional, for GPU support at runtime)
RUN apt-get update && \
apt-get install -y --no-install-recommends \
wget \
gnupg2 \
&& wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb \
&& dpkg -i cuda-keyring_1.0-1_all.deb \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
cuda-cudart-11-8 \
libcublas11 \
libcurand10 \
&& rm -rf /var/lib/apt/lists/* \
&& rm cuda-keyring_1.0-1_all.deb \
&& apt-get purge -y wget gnupg2
# Copy binary from builder
COPY --from=builder /usr/src/app/target/release/cli /usr/local/bin/inference-cli
# Run as non-root user for safety
RUN useradd -m appuser
USER appuser
EXPOSE 8080
CMD ["inference-cli"]

View File

@@ -8,7 +8,7 @@ pub mod server;
// Re-export key components for easier access
pub use inference::ModelInference;
pub use model::{Model, Which};
pub use server::{create_router, AppState};
pub use server::{AppState, create_router};
use std::env;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

View File

@@ -0,0 +1,26 @@
use inference_engine::{AppState, create_router, get_server_config, init_tracing};
use tokio::net::TcpListener;
use tracing::info;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
init_tracing();
let app_state = AppState::default();
let app = create_router(app_state);
let (server_host, server_port, server_address) = get_server_config();
let listener = TcpListener::bind(&server_address).await?;
info!(
"Inference Engine server starting on http://{}",
server_address
);
info!("Available endpoints:");
info!(" POST /v1/chat/completions - OpenAI-compatible chat completions");
info!(" GET /v1/models - List available models");
axum::serve(listener, app).await?;
Ok(())
}

View File

@@ -42,7 +42,11 @@ pub struct ModelMeta {
}
const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta {
ModelMeta { id, family, instruct }
ModelMeta {
id,
family,
instruct,
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]

View File

@@ -1,26 +1,28 @@
use axum::{
Json, Router,
extract::State,
http::StatusCode,
response::{sse::Event, sse::Sse, IntoResponse},
response::{IntoResponse, sse::Event, sse::Sse},
routing::{get, post},
Json, Router,
};
use futures_util::stream::{self, Stream};
use std::convert::Infallible;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tokio::sync::{Mutex, mpsc};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid;
use crate::Which;
use crate::openai_types::{
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest,
ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage,
};
use crate::Which;
use either::Either;
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
use embeddings_engine::models_list;
use gemma_runner::{GemmaInferenceConfig, WhichModel, run_gemma_api};
use llama_runner::{LlamaInferenceConfig, run_llama_inference};
use serde_json::Value;
// -------------------------
// Shared app state
@@ -34,7 +36,7 @@ pub enum ModelType {
#[derive(Clone)]
pub struct AppState {
pub model_type: ModelType,
pub model_type: Option<ModelType>,
pub model_id: String,
pub gemma_config: Option<GemmaInferenceConfig>,
pub llama_config: Option<LlamaInferenceConfig>,
@@ -44,15 +46,16 @@ impl Default for AppState {
fn default() -> Self {
// Configure a default model to prevent 503 errors from the chat-ui
// This can be overridden by environment variables if needed
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
let default_model_id =
std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
let gemma_config = GemmaInferenceConfig {
model: gemma_runner::WhichModel::InstructV3_1B,
model: None,
..Default::default()
};
Self {
model_type: ModelType::Gemma,
model_type: None,
model_id: default_model_id,
gemma_config: Some(gemma_config),
llama_config: None,
@@ -83,15 +86,14 @@ fn model_id_to_which(model_id: &str) -> Option<Which> {
"gemma-2-9b-it" => Some(Which::InstructV2_9B),
"gemma-3-1b" => Some(Which::BaseV3_1B),
"gemma-3-1b-it" => Some(Which::InstructV3_1B),
"llama-3.2-1b" => Some(Which::Llama32_1B),
"llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct),
"llama-3.2-3b" => Some(Which::Llama32_3B),
"llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct),
_ => None,
}
}
fn normalize_model_id(model_id: &str) -> String {
model_id.to_lowercase().replace("_", "-")
}
@@ -189,35 +191,74 @@ pub async fn chat_completions_non_streaming_proxy(
// Get streaming receiver based on model type
let rx = if which_model.is_llama_model() {
// Create Llama configuration dynamically
let mut config = LlamaInferenceConfig::default();
let llama_model = match which_model {
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
_ => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Llama model", model_id) }
})),
));
}
};
let mut config = LlamaInferenceConfig::new(llama_model);
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_llama_inference(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) }
}))
))?
run_llama_inference(config).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) }
})),
)
})?
} else {
// Create Gemma configuration dynamically
let gemma_model = if which_model.is_v3_model() {
gemma_runner::WhichModel::InstructV3_1B
} else {
gemma_runner::WhichModel::InstructV3_1B // Default fallback
let gemma_model = match which_model {
Which::Base2B => gemma_runner::WhichModel::Base2B,
Which::Base7B => gemma_runner::WhichModel::Base7B,
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
_ => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
})),
));
}
};
let mut config = GemmaInferenceConfig {
model: gemma_model,
model: Some(gemma_model),
..Default::default()
};
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
run_gemma_api(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) }
}))
))?
run_gemma_api(config).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) }
})),
)
})?
};
// Collect all tokens from the stream
@@ -347,7 +388,21 @@ async fn handle_streaming_request(
// Get streaming receiver based on model type
let model_rx = if which_model.is_llama_model() {
// Create Llama configuration dynamically
let mut config = LlamaInferenceConfig::default();
let llama_model = match which_model {
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
_ => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Llama model", model_id) }
})),
));
}
};
let mut config = LlamaInferenceConfig::new(llama_model);
config.prompt = prompt.clone();
config.max_tokens = max_tokens;
match run_llama_inference(config) {
@@ -363,14 +418,35 @@ async fn handle_streaming_request(
}
} else {
// Create Gemma configuration dynamically
let gemma_model = if which_model.is_v3_model() {
gemma_runner::WhichModel::InstructV3_1B
} else {
gemma_runner::WhichModel::InstructV3_1B // Default fallback
let gemma_model = match which_model {
Which::Base2B => gemma_runner::WhichModel::Base2B,
Which::Base7B => gemma_runner::WhichModel::Base7B,
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
_ => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
})),
));
}
};
let mut config = GemmaInferenceConfig {
model: gemma_model,
model: Some(gemma_model),
..Default::default()
};
config.prompt = prompt.clone();
@@ -530,46 +606,69 @@ pub async fn list_models() -> Json<ModelListResponse> {
Which::Llama32_3BInstruct,
];
let models: Vec<Model> = which_variants.into_iter().map(|which| {
let meta = which.meta();
let model_id = match which {
Which::Base2B => "gemma-2b",
Which::Base7B => "gemma-7b",
Which::Instruct2B => "gemma-2b-it",
Which::Instruct7B => "gemma-7b-it",
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
Which::CodeBase2B => "codegemma-2b",
Which::CodeBase7B => "codegemma-7b",
Which::CodeInstruct2B => "codegemma-2b-it",
Which::CodeInstruct7B => "codegemma-7b-it",
Which::BaseV2_2B => "gemma-2-2b",
Which::InstructV2_2B => "gemma-2-2b-it",
Which::BaseV2_9B => "gemma-2-9b",
Which::InstructV2_9B => "gemma-2-9b-it",
Which::BaseV3_1B => "gemma-3-1b",
Which::InstructV3_1B => "gemma-3-1b-it",
Which::Llama32_1B => "llama-3.2-1b",
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
Which::Llama32_3B => "llama-3.2-3b",
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
};
let mut models: Vec<Model> = which_variants
.into_iter()
.map(|which| {
let meta = which.meta();
let model_id = match which {
Which::Base2B => "gemma-2b",
Which::Base7B => "gemma-7b",
Which::Instruct2B => "gemma-2b-it",
Which::Instruct7B => "gemma-7b-it",
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
Which::CodeBase2B => "codegemma-2b",
Which::CodeBase7B => "codegemma-7b",
Which::CodeInstruct2B => "codegemma-2b-it",
Which::CodeInstruct7B => "codegemma-7b-it",
Which::BaseV2_2B => "gemma-2-2b",
Which::InstructV2_2B => "gemma-2-2b-it",
Which::BaseV2_9B => "gemma-2-9b",
Which::InstructV2_9B => "gemma-2-9b-it",
Which::BaseV3_1B => "gemma-3-1b",
Which::InstructV3_1B => "gemma-3-1b-it",
Which::Llama32_1B => "llama-3.2-1b",
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
Which::Llama32_3B => "llama-3.2-3b",
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
};
let owned_by = if meta.id.starts_with("google/") {
"google"
} else if meta.id.starts_with("meta-llama/") {
"meta"
} else {
"unknown"
};
let owned_by = if meta.id.starts_with("google/") {
"google"
} else if meta.id.starts_with("meta-llama/") {
"meta"
} else {
"unknown"
};
Model {
id: model_id.to_string(),
object: "model".to_string(),
created: 1686935002, // Using same timestamp as OpenAI example
owned_by: owned_by.to_string(),
}
}).collect();
Model {
id: model_id.to_string(),
object: "model".to_string(),
created: 1686935002,
owned_by: owned_by.to_string(),
}
})
.collect();
// Get embeddings models and convert them to inference Model format
let embeddings_response = models_list().await;
let embeddings_models: Vec<Model> = embeddings_response
.0
.data
.into_iter()
.map(|embedding_model| Model {
id: embedding_model.id,
object: embedding_model.object,
created: 1686935002,
owned_by: format!(
"{} - {}",
embedding_model.owned_by, embedding_model.description
),
})
.collect();
// Add embeddings models to the main models list
models.extend(embeddings_models);
Json(ModelListResponse {
object: "list".to_string(),

View File

@@ -29,25 +29,23 @@ inference-engine = { path = "../inference-engine" }
# Dependencies for leptos web app
#leptos-app = { path = "../leptos-app", features = ["ssr"] }
chat-ui = { path = "../chat-ui", features = ["ssr", "hydrate"], optional = false }
chat-ui = { path = "../chat-ui", features = ["ssr", "hydrate"], optional = true }
mime_guess = "2.0.5"
log = "0.4.27"
[package.metadata.compose]
name = "predict-otron-9000"
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
port = 8080
# generates kubernetes manifests
[package.metadata.kube]
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
replicas = 1
port = 8080
cmd = ["./bin/predict-otron-9000"]
# SERVER_CONFIG Example: {\"serverMode\":\"HighAvailability\",\"services\":{\"inference_url\":\"http://custom-inference:9000\",\"embeddings_url\":\"http://custom-embeddings:9001\"}}
# you can generate this via node to avoid toil
# const server_config = {serverMode: "HighAvailability", services: {inference_url: "http://custom-inference:9000", embeddings_url: "http://custom-embeddings:9001"} };
# console.log(JSON.stringify(server_config).replace(/"/g, '\\"'));
env = { SERVER_CONFIG = "<your-json-value-here>" }
[features]
default = ["ui"]
ui = ["dep:chat-ui"]

View File

@@ -1,89 +0,0 @@
# ---- Build stage ----
FROM rust:1-slim-bullseye AS builder
WORKDIR /usr/src/app
# Install build dependencies including CUDA toolkit for GPU support (needed for inference-engine dependency)
RUN apt-get update && \
apt-get install -y --no-install-recommends \
pkg-config \
libssl-dev \
build-essential \
wget \
gnupg2 \
curl \
&& rm -rf /var/lib/apt/lists/*
# Install CUDA toolkit (required for inference-engine dependency)
# This is a minimal CUDA installation for building
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb && \
dpkg -i cuda-keyring_1.0-1_all.deb && \
apt-get update && \
apt-get install -y --no-install-recommends \
cuda-minimal-build-11-8 \
libcublas-dev-11-8 \
libcurand-dev-11-8 \
&& rm -rf /var/lib/apt/lists/* \
&& rm cuda-keyring_1.0-1_all.deb
# Set CUDA environment variables
ENV CUDA_HOME=/usr/local/cuda
ENV PATH=${CUDA_HOME}/bin:${PATH}
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
# Copy the entire workspace to get access to all crates (needed for local dependencies)
COPY . ./
# Cache dependencies first - create dummy source files for all crates
RUN rm -rf crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src
RUN mkdir -p crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src && \
echo "fn main() {}" > crates/predict-otron-9000/src/main.rs && \
echo "fn main() {}" > crates/inference-engine/src/main.rs && \
echo "fn main() {}" > crates/inference-engine/src/cli_main.rs && \
echo "// lib" > crates/inference-engine/src/lib.rs && \
echo "fn main() {}" > crates/embeddings-engine/src/main.rs && \
echo "// lib" > crates/embeddings-engine/src/lib.rs && \
cargo build --release --bin predict-otron-9000 --package predict-otron-9000
# Remove dummy sources and copy real sources
RUN rm -rf crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src
COPY . .
# Build the actual binary
RUN cargo build --release --bin predict-otron-9000 --package predict-otron-9000
# ---- Runtime stage ----
FROM debian:bullseye-slim
# Install runtime dependencies
RUN apt-get update && \
apt-get install -y --no-install-recommends \
libssl1.1 \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Install CUDA runtime libraries (required for inference-engine dependency)
RUN apt-get update && \
apt-get install -y --no-install-recommends \
wget \
gnupg2 \
&& wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb \
&& dpkg -i cuda-keyring_1.0-1_all.deb \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
cuda-cudart-11-8 \
libcublas11 \
libcurand10 \
&& rm -rf /var/lib/apt/lists/* \
&& rm cuda-keyring_1.0-1_all.deb \
&& apt-get purge -y wget gnupg2
# Copy binary from builder
COPY --from=builder /usr/src/app/target/release/predict-otron-9000 /usr/local/bin/
# Run as non-root user for safety
RUN useradd -m appuser
USER appuser
EXPOSE 8080
CMD ["predict-otron-9000"]

View File

@@ -39,29 +39,12 @@ impl Default for ServerMode {
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct Services {
pub inference_url: Option<String>,
pub embeddings_url: Option<String>,
}
impl Default for Services {
fn default() -> Self {
Self {
inference_url: None,
embeddings_url: None,
}
}
}
fn inference_service_url() -> String {
"http://inference-service:8080".to_string()
}
fn embeddings_service_url() -> String {
"http://embeddings-service:8080".to_string()
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
@@ -118,8 +101,7 @@ impl ServerConfig {
"HighAvailability mode configured but services not well defined! \n## Config Used:\n {}",
config_string
);
let err = std::io::Error::new(
std::io::ErrorKind::Other,
let err = std::io::Error::other(
"HighAvailability mode configured but services not well defined!",
);
return Err(err);

View File

@@ -126,7 +126,7 @@ use crate::config::ServerConfig;
/// - Pretty JSON is fine in TOML using `''' ... '''`, but remember the newlines are part of the string.
/// - If you control the consumer, TOML tables (the alternative above) are more ergonomic than embedding JSON.
/// HTTP client configured for proxying requests
/// HTTP client configured for proxying requests
#[derive(Clone)]
pub struct ProxyClient {
client: Client,

View File

@@ -4,28 +4,31 @@ mod middleware;
mod standalone_mode;
use crate::standalone_mode::create_standalone_router;
use axum::handler::Handler;
use axum::http::StatusCode as AxumStatusCode;
use axum::http::header;
use axum::response::IntoResponse;
use axum::routing::get;
use axum::{Router, ServiceExt, http::Uri, response::Html, serve};
use axum::{Router, serve};
use config::ServerConfig;
use ha_mode::create_ha_router;
use inference_engine::AppState;
use log::info;
use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
use mime_guess::from_path;
use rust_embed::Embed;
use std::env;
use std::path::Component::ParentDir;
#[cfg(feature = "ui")]
use axum::http::StatusCode as AxumStatusCode;
#[cfg(feature = "ui")]
use axum::http::Uri;
#[cfg(feature = "ui")]
use axum::http::header;
#[cfg(feature = "ui")]
use axum::response::IntoResponse;
#[cfg(feature = "ui")]
use mime_guess::from_path;
#[cfg(feature = "ui")]
use rust_embed::Embed;
use tokio::net::TcpListener;
use tower::MakeService;
use tower_http::classify::ServerErrorsFailureClass::StatusCode;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[cfg(feature = "ui")]
#[derive(Embed)]
#[folder = "../../target/site"]
#[include = "*.js"]
@@ -34,6 +37,7 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[include = "*.ico"]
struct Asset;
#[cfg(feature = "ui")]
async fn static_handler(uri: Uri) -> axum::response::Response {
// Strip the leading `/`
let path = uri.path().trim_start_matches('/');
@@ -111,23 +115,28 @@ async fn main() {
// Create metrics layer
let metrics_layer = MetricsLayer::new(metrics_store);
let leptos_config = chat_ui::app::AppConfig::default();
// Create the leptos router for the web frontend
let leptos_router = chat_ui::app::create_router(leptos_config.config.leptos_options);
// Merge the service router with base routes and add middleware layers
let app = Router::new()
.route("/pkg/{*path}", get(static_handler))
let mut app = Router::new()
.route("/health", get(|| async { "ok" }))
.merge(service_router)
.merge(leptos_router)
.merge(service_router);
// Add UI routes if the UI feature is enabled
#[cfg(feature = "ui")]
{
let leptos_config = chat_ui::app::AppConfig::default();
let leptos_router = chat_ui::app::create_router(leptos_config.config.leptos_options);
app = app
.route("/pkg/{*path}", get(static_handler))
.merge(leptos_router);
}
let app = app
.layer(metrics_layer) // Add metrics tracking
.layer(cors)
.layer(TraceLayer::new_for_http());
// Server configuration
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| String::from(default_host));
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| default_host.to_string());
let server_port = env::var("SERVER_PORT")
.map(|v| v.parse::<u16>().unwrap_or(default_port))
@@ -142,8 +151,10 @@ async fn main() {
);
tracing::info!("Performance metrics tracking enabled - summary logs every 60 seconds");
tracing::info!("Available endpoints:");
#[cfg(feature = "ui")]
tracing::info!(" GET / - Leptos chat web application");
tracing::info!(" GET /health - Health check");
tracing::info!(" POST /v1/models - List Models");
tracing::info!(" POST /v1/embeddings - Text embeddings API");
tracing::info!(" POST /v1/chat/completions - Chat completions API");

View File

@@ -2,7 +2,7 @@ use crate::config::ServerConfig;
use axum::Router;
use inference_engine::AppState;
pub fn create_standalone_router(server_config: ServerConfig) -> Router {
pub fn create_standalone_router(_server_config: ServerConfig) -> Router {
// Create unified router by merging embeddings and inference routers (existing behavior)
let embeddings_router = embeddings_engine::create_embeddings_router();

View File

@@ -61,20 +61,22 @@ graph TD
A[predict-otron-9000<br/>Edition: 2024<br/>Port: 8080]
end
subgraph "AI Services"
subgraph "AI Services (crates/)"
B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Multi-model orchestrator]
J[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
K[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
C[embeddings-engine<br/>Edition: 2024<br/>Port: 8080<br/>FastEmbed]
end
subgraph "Frontend"
subgraph "Frontend (crates/)"
D[chat-ui<br/>Edition: 2021<br/>Port: 8788<br/>WASM UI]
end
subgraph "Tooling"
subgraph "Integration Tools (integration/)"
L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment]
E[cli<br/>Edition: 2024<br/>TypeScript/Bun CLI]
M[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
N[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
O[utils<br/>Edition: 2021<br/>Shared utilities]
end
end
@@ -82,10 +84,10 @@ graph TD
A --> B
A --> C
A --> D
B --> J
B --> K
J -.-> F[Candle 0.9.1]
K -.-> F
B --> M
B --> N
M -.-> F[Candle 0.9.1]
N -.-> F
C -.-> G[FastEmbed 4.x]
D -.-> H[Leptos 0.8.0]
E -.-> I[OpenAI SDK 5.16+]
@@ -93,12 +95,13 @@ graph TD
style A fill:#e1f5fe
style B fill:#f3e5f5
style J fill:#f3e5f5
style K fill:#f3e5f5
style C fill:#e8f5e8
style D fill:#fff3e0
style E fill:#fce4ec
style L fill:#fff9c4
style M fill:#f3e5f5
style N fill:#f3e5f5
style O fill:#fff9c4
```
## Deployment Configurations

View File

@@ -14,7 +14,7 @@ Options:
--help Show this help message
Examples:
cd crates/cli/package
cd integration/cli/package
bun run cli.ts "What is the capital of France?"
bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
bun run cli.ts --prompt "Who was the 16th president of the United States?"

View File

@@ -24,8 +24,7 @@ fn run_build() -> io::Result<()> {
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set by Cargo"));
let output_path = out_dir.join("client-cli");
let bun_tgt = BunTarget::from_cargo_env()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
let bun_tgt = BunTarget::from_cargo_env().map_err(|e| io::Error::other(e.to_string()))?;
// Optional: warn if using a Bun target thats marked unsupported in your chart
if matches!(bun_tgt, BunTarget::WindowsArm64) {
@@ -54,13 +53,12 @@ fn run_build() -> io::Result<()> {
if !install_status.success() {
let code = install_status.code().unwrap_or(1);
return Err(io::Error::new(
io::ErrorKind::Other,
format!("bun install failed with status {code}"),
));
return Err(io::Error::other(format!(
"bun install failed with status {code}"
)));
}
let target = env::var("TARGET").unwrap();
let _target = env::var("TARGET").unwrap();
// --- bun build (in ./package), emit to OUT_DIR, keep temps inside OUT_DIR ---
let mut build = Command::new("bun")
@@ -87,7 +85,7 @@ fn run_build() -> io::Result<()> {
} else {
let code = status.code().unwrap_or(1);
warn(&format!("bun build failed with status: {code}"));
return Err(io::Error::new(io::ErrorKind::Other, "bun build failed"));
return Err(io::Error::other("bun build failed"));
}
// Ensure the output is executable (after it exists)

View File

@@ -0,0 +1,17 @@
{
"lockfileVersion": 1,
"workspaces": {
"": {
"name": "cli",
"dependencies": {
"install": "^0.13.0",
"openai": "^5.16.0",
},
},
},
"packages": {
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
"openai": ["openai@5.19.1", "", { "peerDependencies": { "ws": "^8.18.0", "zod": "^3.23.8" }, "optionalPeers": ["ws", "zod"], "bin": { "openai": "bin/cli" } }, "sha512-zSqnUF7oR9ksmpusKkpUgkNrj8Sl57U+OyzO8jzc7LUjTMg4DRfR3uCm+EIMA6iw06sRPNp4t7ojp3sCpEUZRQ=="],
}
}

View File

@@ -25,7 +25,7 @@ fn main() -> io::Result<()> {
// Run it
let status = Command::new(&tmp).arg("--version").status()?;
if !status.success() {
return Err(io::Error::new(io::ErrorKind::Other, "client-cli failed"));
return Err(io::Error::other("client-cli failed"));
}
Ok(())

View File

@@ -18,7 +18,7 @@ serde_json = "1.0"
tracing = "0.1"
tracing-chrome = "0.7"
tracing-subscriber = "0.3"
utils = {path = "../utils"}
utils = {path = "../utils" }
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }

View File

@@ -1,13 +1,7 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use clap::ValueEnum;
// Removed gemma_cli import as it's not needed for the API
use candle_core::{DType, Device, Tensor};
@@ -16,13 +10,15 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
use std::fmt;
use std::str::FromStr;
use std::sync::mpsc::{self, Receiver, Sender};
use std::thread;
use tokenizers::Tokenizer;
use utils::hub_load_safetensors;
use utils::token_output_stream::TokenOutputStream;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum WhichModel {
#[value(name = "gemma-2b")]
Base2B,
@@ -58,6 +54,56 @@ pub enum WhichModel {
InstructV3_1B,
}
impl FromStr for WhichModel {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gemma-2b" => Ok(Self::Base2B),
"gemma-7b" => Ok(Self::Base7B),
"gemma-2b-it" => Ok(Self::Instruct2B),
"gemma-7b-it" => Ok(Self::Instruct7B),
"gemma-1.1-2b-it" => Ok(Self::InstructV1_1_2B),
"gemma-1.1-7b-it" => Ok(Self::InstructV1_1_7B),
"codegemma-2b" => Ok(Self::CodeBase2B),
"codegemma-7b" => Ok(Self::CodeBase7B),
"codegemma-2b-it" => Ok(Self::CodeInstruct2B),
"codegemma-7b-it" => Ok(Self::CodeInstruct7B),
"gemma-2-2b" => Ok(Self::BaseV2_2B),
"gemma-2-2b-it" => Ok(Self::InstructV2_2B),
"gemma-2-9b" => Ok(Self::BaseV2_9B),
"gemma-2-9b-it" => Ok(Self::InstructV2_9B),
"gemma-3-1b" => Ok(Self::BaseV3_1B),
"gemma-3-1b-it" => Ok(Self::InstructV3_1B),
_ => Err(format!("Unknown model: {}", s)),
}
}
}
impl fmt::Display for WhichModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
Self::Base2B => "gemma-2b",
Self::Base7B => "gemma-7b",
Self::Instruct2B => "gemma-2b-it",
Self::Instruct7B => "gemma-7b-it",
Self::InstructV1_1_2B => "gemma-1.1-2b-it",
Self::InstructV1_1_7B => "gemma-1.1-7b-it",
Self::CodeBase2B => "codegemma-2b",
Self::CodeBase7B => "codegemma-7b",
Self::CodeInstruct2B => "codegemma-2b-it",
Self::CodeInstruct7B => "codegemma-7b-it",
Self::BaseV2_2B => "gemma-2-2b",
Self::InstructV2_2B => "gemma-2-2b-it",
Self::BaseV2_9B => "gemma-2-9b",
Self::InstructV2_9B => "gemma-2-9b-it",
Self::BaseV3_1B => "gemma-3-1b",
Self::InstructV3_1B => "gemma-3-1b-it",
};
write!(f, "{}", name)
}
}
enum Model {
V1(Model1),
V2(Model2),
@@ -145,8 +191,6 @@ impl TextGeneration {
// Make sure stdout isn't holding anything (if caller also prints).
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
@@ -183,7 +227,6 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token || next_token == eot_token {
break;
@@ -210,7 +253,7 @@ impl TextGeneration {
pub struct GemmaInferenceConfig {
pub tracing: bool,
pub prompt: String,
pub model: WhichModel,
pub model: Option<WhichModel>,
pub cpu: bool,
pub dtype: Option<String>,
pub model_id: Option<String>,
@@ -229,7 +272,7 @@ impl Default for GemmaInferenceConfig {
Self {
tracing: false,
prompt: "Hello".to_string(),
model: WhichModel::InstructV2_2B,
model: Some(WhichModel::InstructV2_2B),
cpu: false,
dtype: None,
model_id: None,
@@ -286,28 +329,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
}
};
println!("Using dtype: {:?}", dtype);
println!("Raw model string: {:?}", cfg.model_id);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = cfg.model_id.unwrap_or_else(|| {
match cfg.model {
WhichModel::Base2B => "google/gemma-2b",
WhichModel::Base7B => "google/gemma-7b",
WhichModel::Instruct2B => "google/gemma-2b-it",
WhichModel::Instruct7B => "google/gemma-7b-it",
WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it",
WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it",
WhichModel::CodeBase2B => "google/codegemma-2b",
WhichModel::CodeBase7B => "google/codegemma-7b",
WhichModel::CodeInstruct2B => "google/codegemma-2b-it",
WhichModel::CodeInstruct7B => "google/codegemma-7b-it",
WhichModel::BaseV2_2B => "google/gemma-2-2b",
WhichModel::InstructV2_2B => "google/gemma-2-2b-it",
WhichModel::BaseV2_9B => "google/gemma-2-9b",
WhichModel::InstructV2_9B => "google/gemma-2-9b-it",
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
Some(WhichModel::Base2B) => "google/gemma-2b",
Some(WhichModel::Base7B) => "google/gemma-7b",
Some(WhichModel::Instruct2B) => "google/gemma-2b-it",
Some(WhichModel::Instruct7B) => "google/gemma-7b-it",
Some(WhichModel::InstructV1_1_2B) => "google/gemma-1.1-2b-it",
Some(WhichModel::InstructV1_1_7B) => "google/gemma-1.1-7b-it",
Some(WhichModel::CodeBase2B) => "google/codegemma-2b",
Some(WhichModel::CodeBase7B) => "google/codegemma-7b",
Some(WhichModel::CodeInstruct2B) => "google/codegemma-2b-it",
Some(WhichModel::CodeInstruct7B) => "google/codegemma-7b-it",
Some(WhichModel::BaseV2_2B) => "google/gemma-2-2b",
Some(WhichModel::InstructV2_2B) => "google/gemma-2-2b-it",
Some(WhichModel::BaseV2_9B) => "google/gemma-2-9b",
Some(WhichModel::InstructV2_9B) => "google/gemma-2-9b-it",
Some(WhichModel::BaseV3_1B) => "google/gemma-3-1b-pt",
Some(WhichModel::InstructV3_1B) => "google/gemma-3-1b-it",
None => "google/gemma-2-2b-it", // default fallback
}
.to_string()
});
@@ -318,7 +363,9 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let tokenizer_filename = repo.get("tokenizer.json")?;
let config_filename = repo.get("config.json")?;
let filenames = match cfg.model {
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
vec![repo.get("model.safetensors")?]
}
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("Retrieved files in {:?}", start.elapsed());
@@ -329,29 +376,31 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model: Model = match cfg.model {
WhichModel::Base2B
| WhichModel::Base7B
| WhichModel::Instruct2B
| WhichModel::Instruct7B
| WhichModel::InstructV1_1_2B
| WhichModel::InstructV1_1_7B
| WhichModel::CodeBase2B
| WhichModel::CodeBase7B
| WhichModel::CodeInstruct2B
| WhichModel::CodeInstruct7B => {
Some(WhichModel::Base2B)
| Some(WhichModel::Base7B)
| Some(WhichModel::Instruct2B)
| Some(WhichModel::Instruct7B)
| Some(WhichModel::InstructV1_1_2B)
| Some(WhichModel::InstructV1_1_7B)
| Some(WhichModel::CodeBase2B)
| Some(WhichModel::CodeBase7B)
| Some(WhichModel::CodeInstruct2B)
| Some(WhichModel::CodeInstruct7B) => {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
Model::V1(model)
}
WhichModel::BaseV2_2B
| WhichModel::InstructV2_2B
| WhichModel::BaseV2_9B
| WhichModel::InstructV2_9B => {
Some(WhichModel::BaseV2_2B)
| Some(WhichModel::InstructV2_2B)
| Some(WhichModel::BaseV2_9B)
| Some(WhichModel::InstructV2_9B)
| None => {
// default to V2 model
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
Model::V2(model)
}
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => {
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
Model::V3(model)
@@ -371,7 +420,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
);
let prompt = match cfg.model {
WhichModel::InstructV3_1B => {
Some(WhichModel::InstructV3_1B) => {
format!(
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
cfg.prompt

View File

@@ -67,7 +67,7 @@ pub fn run_cli() -> anyhow::Result<()> {
let cfg = GemmaInferenceConfig {
tracing: args.tracing,
prompt: args.prompt,
model: args.model,
model: Some(args.model),
cpu: args.cpu,
dtype: args.dtype,
model_id: args.model_id,

View File

@@ -6,10 +6,8 @@ mod gemma_api;
mod gemma_cli;
use anyhow::Error;
use clap::{Parser, ValueEnum};
use crate::gemma_cli::run_cli;
use std::io::Write;
/// just a placeholder, not used for anything
fn main() -> std::result::Result<(), Error> {

View File

@@ -64,14 +64,9 @@ version = "0.1.0"
# Required: Kubernetes metadata
[package.metadata.kube]
image = "ghcr.io/myorg/my-service:latest"
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
replicas = 1
port = 8080
# Optional: Docker Compose metadata (currently not used but parsed)
[package.metadata.compose]
image = "ghcr.io/myorg/my-service:latest"
port = 8080
```
### Required Fields

View File

@@ -1,9 +1,8 @@
use anyhow::{Context, Result};
use clap::{Arg, Command};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use serde::Deserialize;
use std::fs;
use std::path::{Path, PathBuf};
use std::path::Path;
use walkdir::WalkDir;
#[derive(Debug, Deserialize)]
@@ -20,7 +19,6 @@ struct Package {
#[derive(Debug, Deserialize)]
struct Metadata {
kube: Option<KubeMetadata>,
compose: Option<ComposeMetadata>,
}
#[derive(Debug, Deserialize)]
@@ -30,12 +28,6 @@ struct KubeMetadata {
port: u16,
}
#[derive(Debug, Deserialize)]
struct ComposeMetadata {
image: Option<String>,
port: Option<u16>,
}
#[derive(Debug, Clone)]
struct ServiceInfo {
name: String,
@@ -105,7 +97,9 @@ fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
.into_iter()
.filter_map(|e| e.ok())
{
if entry.file_name() == "Cargo.toml" && entry.path() != workspace_root.join("Cargo.toml") {
if entry.file_name() == "Cargo.toml"
&& entry.path() != workspace_root.join("../../../Cargo.toml")
{
if let Ok(service_info) = parse_cargo_toml(entry.path()) {
services.push(service_info);
}
@@ -375,7 +369,7 @@ spec:
Ok(())
}
fn generate_ingress_template(templates_dir: &Path, services: &[ServiceInfo]) -> Result<()> {
fn generate_ingress_template(templates_dir: &Path, _services: &[ServiceInfo]) -> Result<()> {
let ingress_template = r#"{{- if .Values.ingress.enabled -}}
apiVersion: networking.k8s.io/v1
kind: Ingress

View File

@@ -1,6 +1,5 @@
pub mod llama_api;
use clap::ValueEnum;
pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
// Re-export constants and types that might be needed

View File

@@ -57,6 +57,27 @@ pub struct LlamaInferenceConfig {
pub repeat_last_n: usize,
}
impl LlamaInferenceConfig {
pub fn new(model: WhichModel) -> Self {
Self {
prompt: String::new(),
model,
cpu: false,
temperature: 1.0,
top_p: None,
top_k: None,
seed: 42,
max_tokens: 512,
no_kv_cache: false,
dtype: None,
model_id: None,
revision: None,
use_flash_attn: true,
repeat_penalty: 1.1,
repeat_last_n: 64,
}
}
}
impl Default for LlamaInferenceConfig {
fn default() -> Self {
Self {
@@ -81,7 +102,7 @@ impl Default for LlamaInferenceConfig {
max_tokens: 512,
// Performance flags
no_kv_cache: false, // keep cache ON for speed
no_kv_cache: false, // keep cache ON for speed
use_flash_attn: false, // great speed boost if supported
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.

View File

@@ -6,9 +6,6 @@ mod llama_api;
mod llama_cli;
use anyhow::Result;
use clap::{Parser, ValueEnum};
use std::io::Write;
use crate::llama_cli::run_cli;

View File

@@ -1,17 +1,14 @@
[package]
name = "utils"
edition = "2021"
[lib]
path = "src/lib.rs"
[dependencies]
accelerate-src = {version = "0.3.2", optional = true }
candle-nn = {version = "0.9.1" }
candle-transformers = {version = "0.9.1" }
candle-flash-attn = {version = "0.9.1", optional = true }
candle-onnx = {version = "0.9.1", optional = true }
candle-core="0.9.1"
csv = "1.3.0"
anyhow = "1.0.99"
cudarc = {version = "0.17.3", optional = true }
@@ -86,3 +83,14 @@ mimi = ["cpal", "symphonia", "rubato"]
snac = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
tekken = ["tekken-rs"]
# Platform-specific candle dependencies
[target.'cfg(target_os = "linux")'.dependencies]
candle-nn = {version = "0.9.1", default-features = false }
candle-transformers = {version = "0.9.1", default-features = false }
candle-core = {version = "0.9.1", default-features = false }
[target.'cfg(not(target_os = "linux"))'.dependencies]
candle-nn = {version = "0.9.1" }
candle-transformers = {version = "0.9.1" }
candle-core = {version = "0.9.1" }

View File

@@ -1,5 +1,5 @@
use candle_transformers::models::mimi::candle;
use candle_core::{Device, Result, Tensor};
use candle_transformers::models::mimi::candle;
pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];

View File

@@ -8,8 +8,10 @@ pub mod coco_classes;
pub mod imagenet;
pub mod token_output_stream;
pub mod wav;
use candle_core::{Device, Tensor, utils::{cuda_is_available, metal_is_available}};
use candle_core::{
utils::{cuda_is_available, metal_is_available},
Device, Tensor,
};
pub fn device(cpu: bool) -> Result<Device, anyhow::Error> {
if cpu {
@@ -122,11 +124,8 @@ pub fn hub_load_safetensors(
}
let safetensors_files = safetensors_files
.iter()
.map(|v| {
repo.get(v)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
})
.collect::<Result<Vec<_>, std::io::Error, >>()?;
.map(|v| repo.get(v).map_err(std::io::Error::other))
.collect::<Result<Vec<_>, std::io::Error>>()?;
Ok(safetensors_files)
}
@@ -136,7 +135,8 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
let path = path.as_ref();
let jsfile = std::fs::File::open(path.join(json_file))?;
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let json: serde_json::Value =
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => anyhow::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,

View File

@@ -1,8 +1,8 @@
{
"name": "predict-otron-9000",
"workspaces": ["crates/cli/package"],
"workspaces": ["integration/cli/package"],
"scripts": {
"# WORKSPACE ALIASES": "#",
"cli": "bun --filter crates/cli/package"
"cli": "bun --filter integration/cli/package"
}
}