mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
Compare commits
9 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4380ac69d3 | ||
![]() |
e6f3351ebb | ||
![]() |
3992532f15 | ||
![]() |
3ecdd9ffa0 | ||
![]() |
296d4dbe7e | ||
![]() |
fb5098eba6 | ||
![]() |
c1c583faab | ||
![]() |
1e02b12cda | ||
![]() |
ff55d882c7 |
35
.dockerignore
Normal file
35
.dockerignore
Normal file
@@ -0,0 +1,35 @@
|
||||
# Git
|
||||
.git
|
||||
.gitignore
|
||||
|
||||
# Rust
|
||||
target/
|
||||
|
||||
# Documentation
|
||||
README.md
|
||||
*.md
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# Environment
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Dependencies
|
||||
node_modules/
|
||||
|
||||
# Build artifacts
|
||||
dist/
|
||||
build/
|
||||
.fastembed_cache
|
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
|
||||
- name: Clippy
|
||||
shell: bash
|
||||
run: cargo clippy --all-targets
|
||||
run: cargo clippy --all
|
||||
|
||||
- name: Tests
|
||||
shell: bash
|
||||
|
46
.github/workflows/docker.yml
vendored
Normal file
46
.github/workflows/docker.yml
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
name: Build and Push Docker Image
|
||||
|
||||
on:
|
||||
tags:
|
||||
- 'v*'
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Log in to Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=sha
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -45,7 +45,7 @@ jobs:
|
||||
|
||||
- name: Clippy
|
||||
shell: bash
|
||||
run: cargo clippy --all-targets
|
||||
run: cargo clippy --all
|
||||
|
||||
- name: Tests
|
||||
shell: bash
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@@ -77,3 +77,4 @@ venv/
|
||||
!/scripts/cli.ts
|
||||
/**/.*.bun-build
|
||||
/AGENTS.md
|
||||
.claude
|
||||
|
3
Cargo.lock
generated
3
Cargo.lock
generated
@@ -2905,6 +2905,7 @@ dependencies = [
|
||||
"clap",
|
||||
"cpal",
|
||||
"either",
|
||||
"embeddings-engine",
|
||||
"futures-util",
|
||||
"gemma-runner",
|
||||
"imageproc 0.24.0",
|
||||
@@ -7040,7 +7041,7 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "utils"
|
||||
version = "0.0.0"
|
||||
version = "0.1.4"
|
||||
dependencies = [
|
||||
"ab_glyph",
|
||||
"accelerate-src",
|
||||
|
12
Cargo.toml
12
Cargo.toml
@@ -3,17 +3,17 @@ members = [
|
||||
"crates/predict-otron-9000",
|
||||
"crates/inference-engine",
|
||||
"crates/embeddings-engine",
|
||||
"crates/helm-chart-tool",
|
||||
"crates/llama-runner",
|
||||
"crates/gemma-runner",
|
||||
"crates/cli",
|
||||
"integration/helm-chart-tool",
|
||||
"integration/llama-runner",
|
||||
"integration/gemma-runner",
|
||||
"integration/cli",
|
||||
"crates/chat-ui"
|
||||
, "crates/utils"]
|
||||
, "integration/utils"]
|
||||
default-members = ["crates/predict-otron-9000"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.4"
|
||||
version = "0.1.6"
|
||||
|
||||
# Compiler optimization profiles for the workspace
|
||||
[profile.release]
|
||||
|
50
Dockerfile
Normal file
50
Dockerfile
Normal file
@@ -0,0 +1,50 @@
|
||||
# Multi-stage build for predict-otron-9000 workspace
|
||||
FROM rust:1 AS builder
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create app directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy workspace files
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
COPY crates/ ./crates/
|
||||
COPY integration/ ./integration/
|
||||
|
||||
# Build all 3 main server binaries in release mode
|
||||
RUN cargo build --release -p predict-otron-9000 --bin predict-otron-9000 --no-default-features -p embeddings-engine --bin embeddings-engine -p inference-engine --bin inference-engine
|
||||
|
||||
# Runtime stage
|
||||
FROM debian:bookworm-slim AS runtime
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
ca-certificates \
|
||||
libssl3 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create app user
|
||||
RUN useradd -r -s /bin/false -m -d /app appuser
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy binaries from builder stage
|
||||
COPY --from=builder /app/target/release/predict-otron-9000 ./bin/
|
||||
COPY --from=builder /app/target/release/embeddings-engine ./bin/
|
||||
COPY --from=builder /app/target/release/inference-engine ./bin/
|
||||
# Make binaries executable and change ownership
|
||||
RUN chmod +x ./bin/* && chown -R appuser:appuser /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Expose ports (adjust as needed based on your services)
|
||||
EXPOSE 8080 8081 8082
|
||||
|
||||
# Default command (can be overridden)
|
||||
CMD ["./bin/predict-otron-9000"]
|
25
README.md
25
README.md
@@ -12,7 +12,7 @@ AI inference Server with OpenAI-compatible API (Limited Features)
|
||||
> This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks.
|
||||
|
||||
> By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems.
|
||||
Stability is currently best effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
|
||||
Stability is currently best-effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
|
||||
|
||||
A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces.
|
||||
|
||||
@@ -53,14 +53,17 @@ The project uses a 9-crate Rust workspace plus TypeScript components:
|
||||
crates/
|
||||
├── predict-otron-9000/ # Main orchestration server (Rust 2024)
|
||||
├── inference-engine/ # Multi-model inference orchestrator (Rust 2021)
|
||||
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
|
||||
└── chat-ui/ # WASM web frontend (Rust 2021)
|
||||
|
||||
integration/
|
||||
├── cli/ # CLI client crate (Rust 2024)
|
||||
│ └── package/
|
||||
│ └── cli.ts # TypeScript/Bun CLI client
|
||||
├── gemma-runner/ # Gemma model inference via Candle (Rust 2021)
|
||||
├── llama-runner/ # Llama model inference via Candle (Rust 2021)
|
||||
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
|
||||
├── chat-ui/ # WASM web frontend (Rust 2021)
|
||||
├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024)
|
||||
└── cli/ # CLI client crate (Rust 2024)
|
||||
└── package/
|
||||
└── cli.ts # TypeScript/Bun CLI client
|
||||
└── utils/ # Shared utilities (Rust 2021)
|
||||
```
|
||||
|
||||
### Service Architecture
|
||||
@@ -160,16 +163,16 @@ cd crates/chat-ui
|
||||
#### TypeScript CLI Client
|
||||
```bash
|
||||
# List available models
|
||||
cd crates/cli/package && bun run cli.ts --list-models
|
||||
cd integration/cli/package && bun run cli.ts --list-models
|
||||
|
||||
# Chat completion
|
||||
cd crates/cli/package && bun run cli.ts "What is the capital of France?"
|
||||
cd integration/cli/package && bun run cli.ts "What is the capital of France?"
|
||||
|
||||
# With specific model
|
||||
cd crates/cli/package && bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||
cd integration/cli/package && bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||
|
||||
# Show help
|
||||
cd crates/cli/package && bun run cli.ts --help
|
||||
cd integration/cli/package && bun run cli.ts --help
|
||||
```
|
||||
|
||||
## API Usage
|
||||
@@ -464,7 +467,7 @@ curl -s http://localhost:8080/v1/models | jq
|
||||
|
||||
**CLI client test:**
|
||||
```bash
|
||||
cd crates/cli/package && bun run cli.ts "What is 2+2?"
|
||||
cd integration/cli/package && bun run cli.ts "What is 2+2?"
|
||||
```
|
||||
|
||||
**Web frontend:**
|
||||
|
4
bun.lock
4
bun.lock
@@ -4,7 +4,7 @@
|
||||
"": {
|
||||
"name": "predict-otron-9000",
|
||||
},
|
||||
"crates/cli/package": {
|
||||
"integration/cli/package": {
|
||||
"name": "cli",
|
||||
"dependencies": {
|
||||
"install": "^0.13.0",
|
||||
@@ -13,7 +13,7 @@
|
||||
},
|
||||
},
|
||||
"packages": {
|
||||
"cli": ["cli@workspace:crates/cli/package"],
|
||||
"cli": ["cli@workspace:integration/cli/package"],
|
||||
|
||||
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
|
||||
|
||||
|
@@ -3,6 +3,7 @@ name = "chat-ui"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
@@ -122,3 +123,7 @@ lib-default-features = false
|
||||
#
|
||||
# Optional. Defaults to "release".
|
||||
lib-profile-release = "release"
|
||||
|
||||
[[bin]]
|
||||
name = "chat-ui"
|
||||
path = "src/main.rs"
|
@@ -257,7 +257,8 @@ pub fn send_chat_completion_stream(
|
||||
break;
|
||||
}
|
||||
|
||||
let value = js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap();
|
||||
let value =
|
||||
js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap();
|
||||
let array = js_sys::Uint8Array::new(&value);
|
||||
let mut bytes = vec![0; array.length() as usize];
|
||||
array.copy_to(&mut bytes);
|
||||
@@ -279,7 +280,9 @@ pub fn send_chat_completion_stream(
|
||||
}
|
||||
|
||||
// Parse JSON chunk
|
||||
if let Ok(chunk) = serde_json::from_str::<StreamChatResponse>(data) {
|
||||
if let Ok(chunk) =
|
||||
serde_json::from_str::<StreamChatResponse>(data)
|
||||
{
|
||||
if let Some(choice) = chunk.choices.first() {
|
||||
if let Some(content) = &choice.delta.content {
|
||||
on_chunk(content.clone());
|
||||
@@ -365,7 +368,7 @@ fn ChatPage() -> impl IntoView {
|
||||
|
||||
// State for available models and selected model
|
||||
let available_models = RwSignal::new(Vec::<ModelInfo>::new());
|
||||
let selected_model = RwSignal::new(String::from("gemma-3-1b-it")); // Default model
|
||||
let selected_model = RwSignal::new(String::from("")); // Default model
|
||||
|
||||
// State for streaming response
|
||||
let streaming_content = RwSignal::new(String::new());
|
||||
@@ -382,6 +385,7 @@ fn ChatPage() -> impl IntoView {
|
||||
match fetch_models().await {
|
||||
Ok(models) => {
|
||||
available_models.set(models);
|
||||
selected_model.set(String::from("gemma-3-1b-it"));
|
||||
}
|
||||
Err(error) => {
|
||||
console::log_1(&format!("Failed to fetch models: {}", error).into());
|
||||
|
@@ -25,15 +25,9 @@ rand = "0.8.5"
|
||||
async-openai = "0.28.3"
|
||||
once_cell = "1.19.0"
|
||||
|
||||
|
||||
|
||||
[package.metadata.compose]
|
||||
image = "ghcr.io/geoffsee/embeddings-service:latest"
|
||||
port = 8080
|
||||
|
||||
|
||||
# generates kubernetes manifests
|
||||
[package.metadata.kube]
|
||||
image = "ghcr.io/geoffsee/embeddings-service:latest"
|
||||
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||
cmd = ["./bin/embeddings-engine"]
|
||||
replicas = 1
|
||||
port = 8080
|
@@ -1,42 +0,0 @@
|
||||
# ---- Build stage ----
|
||||
FROM rust:1-slim-bullseye AS builder
|
||||
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Cache deps first
|
||||
COPY . ./
|
||||
RUN rm -rf src
|
||||
RUN mkdir src && echo "fn main() {}" > src/main.rs && echo "// lib" > src/lib.rs && cargo build --release
|
||||
RUN rm -rf src
|
||||
|
||||
# Copy real sources and build
|
||||
COPY . .
|
||||
RUN cargo build --release
|
||||
|
||||
# ---- Runtime stage ----
|
||||
FROM debian:bullseye-slim
|
||||
|
||||
# Install only what the compiled binary needs
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libssl1.1 \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy binary from builder
|
||||
COPY --from=builder /usr/src/app/target/release/embeddings-engine /usr/local/bin/
|
||||
|
||||
# Run as non-root user for safety
|
||||
RUN useradd -m appuser
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8080
|
||||
CMD ["embeddings-engine"]
|
@@ -1,43 +1,225 @@
|
||||
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
||||
use axum::{Json, Router, response::Json as ResponseJson, routing::post};
|
||||
use axum::{Json, Router, http::StatusCode, response::Json as ResponseJson, routing::post};
|
||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing;
|
||||
|
||||
// Persistent model instance (singleton pattern)
|
||||
static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
|
||||
tracing::info!("Initializing persistent embedding model (singleton)");
|
||||
// Cache for multiple embedding models
|
||||
static MODEL_CACHE: Lazy<RwLock<HashMap<EmbeddingModel, Arc<TextEmbedding>>>> =
|
||||
Lazy::new(|| RwLock::new(HashMap::new()));
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ModelInfo {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub owned_by: String,
|
||||
pub description: String,
|
||||
pub dimensions: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ModelsResponse {
|
||||
pub object: String,
|
||||
pub data: Vec<ModelInfo>,
|
||||
}
|
||||
|
||||
// Function to convert model name strings to EmbeddingModel enum variants
|
||||
fn parse_embedding_model(model_name: &str) -> Result<EmbeddingModel, String> {
|
||||
match model_name {
|
||||
// Sentence Transformers models
|
||||
"sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => {
|
||||
Ok(EmbeddingModel::AllMiniLML6V2)
|
||||
}
|
||||
"sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => {
|
||||
Ok(EmbeddingModel::AllMiniLML6V2Q)
|
||||
}
|
||||
"sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => {
|
||||
Ok(EmbeddingModel::AllMiniLML12V2)
|
||||
}
|
||||
"sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => {
|
||||
Ok(EmbeddingModel::AllMiniLML12V2Q)
|
||||
}
|
||||
|
||||
// BGE models
|
||||
"BAAI/bge-base-en-v1.5" | "bge-base-en-v1.5" => Ok(EmbeddingModel::BGEBaseENV15),
|
||||
"BAAI/bge-base-en-v1.5-q" | "bge-base-en-v1.5-q" => Ok(EmbeddingModel::BGEBaseENV15Q),
|
||||
"BAAI/bge-large-en-v1.5" | "bge-large-en-v1.5" => Ok(EmbeddingModel::BGELargeENV15),
|
||||
"BAAI/bge-large-en-v1.5-q" | "bge-large-en-v1.5-q" => Ok(EmbeddingModel::BGELargeENV15Q),
|
||||
"BAAI/bge-small-en-v1.5" | "bge-small-en-v1.5" => Ok(EmbeddingModel::BGESmallENV15),
|
||||
"BAAI/bge-small-en-v1.5-q" | "bge-small-en-v1.5-q" => Ok(EmbeddingModel::BGESmallENV15Q),
|
||||
"BAAI/bge-small-zh-v1.5" | "bge-small-zh-v1.5" => Ok(EmbeddingModel::BGESmallZHV15),
|
||||
"BAAI/bge-large-zh-v1.5" | "bge-large-zh-v1.5" => Ok(EmbeddingModel::BGELargeZHV15),
|
||||
|
||||
// Nomic models
|
||||
"nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => {
|
||||
Ok(EmbeddingModel::NomicEmbedTextV1)
|
||||
}
|
||||
"nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => {
|
||||
Ok(EmbeddingModel::NomicEmbedTextV15)
|
||||
}
|
||||
"nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => {
|
||||
Ok(EmbeddingModel::NomicEmbedTextV15Q)
|
||||
}
|
||||
|
||||
// Paraphrase models
|
||||
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
| "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2),
|
||||
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q"
|
||||
| "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q),
|
||||
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
| "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2),
|
||||
|
||||
// ModernBert
|
||||
"lightonai/modernbert-embed-large" | "modernbert-embed-large" => {
|
||||
Ok(EmbeddingModel::ModernBertEmbedLarge)
|
||||
}
|
||||
|
||||
// Multilingual E5 models
|
||||
"intfloat/multilingual-e5-small" | "multilingual-e5-small" => {
|
||||
Ok(EmbeddingModel::MultilingualE5Small)
|
||||
}
|
||||
"intfloat/multilingual-e5-base" | "multilingual-e5-base" => {
|
||||
Ok(EmbeddingModel::MultilingualE5Base)
|
||||
}
|
||||
"intfloat/multilingual-e5-large" | "multilingual-e5-large" => {
|
||||
Ok(EmbeddingModel::MultilingualE5Large)
|
||||
}
|
||||
|
||||
// Mixedbread models
|
||||
"mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => {
|
||||
Ok(EmbeddingModel::MxbaiEmbedLargeV1)
|
||||
}
|
||||
"mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => {
|
||||
Ok(EmbeddingModel::MxbaiEmbedLargeV1Q)
|
||||
}
|
||||
|
||||
// GTE models
|
||||
"Alibaba-NLP/gte-base-en-v1.5" | "gte-base-en-v1.5" => Ok(EmbeddingModel::GTEBaseENV15),
|
||||
"Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => {
|
||||
Ok(EmbeddingModel::GTEBaseENV15Q)
|
||||
}
|
||||
"Alibaba-NLP/gte-large-en-v1.5" | "gte-large-en-v1.5" => Ok(EmbeddingModel::GTELargeENV15),
|
||||
"Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => {
|
||||
Ok(EmbeddingModel::GTELargeENV15Q)
|
||||
}
|
||||
|
||||
// CLIP model
|
||||
"Qdrant/clip-ViT-B-32-text" | "clip-vit-b-32" => Ok(EmbeddingModel::ClipVitB32),
|
||||
|
||||
// Jina model
|
||||
"jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => {
|
||||
Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode)
|
||||
}
|
||||
|
||||
_ => Err(format!("Unsupported embedding model: {}", model_name)),
|
||||
}
|
||||
}
|
||||
|
||||
// Function to get model dimensions
|
||||
fn get_model_dimensions(model: &EmbeddingModel) -> usize {
|
||||
match model {
|
||||
EmbeddingModel::AllMiniLML6V2 | EmbeddingModel::AllMiniLML6V2Q => 384,
|
||||
EmbeddingModel::AllMiniLML12V2 | EmbeddingModel::AllMiniLML12V2Q => 384,
|
||||
EmbeddingModel::BGEBaseENV15 | EmbeddingModel::BGEBaseENV15Q => 768,
|
||||
EmbeddingModel::BGELargeENV15 | EmbeddingModel::BGELargeENV15Q => 1024,
|
||||
EmbeddingModel::BGESmallENV15 | EmbeddingModel::BGESmallENV15Q => 384,
|
||||
EmbeddingModel::BGESmallZHV15 => 512,
|
||||
EmbeddingModel::BGELargeZHV15 => 1024,
|
||||
EmbeddingModel::NomicEmbedTextV1
|
||||
| EmbeddingModel::NomicEmbedTextV15
|
||||
| EmbeddingModel::NomicEmbedTextV15Q => 768,
|
||||
EmbeddingModel::ParaphraseMLMiniLML12V2 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => 384,
|
||||
EmbeddingModel::ParaphraseMLMpnetBaseV2 => 768,
|
||||
EmbeddingModel::ModernBertEmbedLarge => 1024,
|
||||
EmbeddingModel::MultilingualE5Small => 384,
|
||||
EmbeddingModel::MultilingualE5Base => 768,
|
||||
EmbeddingModel::MultilingualE5Large => 1024,
|
||||
EmbeddingModel::MxbaiEmbedLargeV1 | EmbeddingModel::MxbaiEmbedLargeV1Q => 1024,
|
||||
EmbeddingModel::GTEBaseENV15 | EmbeddingModel::GTEBaseENV15Q => 768,
|
||||
EmbeddingModel::GTELargeENV15 | EmbeddingModel::GTELargeENV15Q => 1024,
|
||||
EmbeddingModel::ClipVitB32 => 512,
|
||||
EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
|
||||
}
|
||||
}
|
||||
|
||||
// Function to get or create a model from cache
|
||||
fn get_or_create_model(embedding_model: EmbeddingModel) -> Result<Arc<TextEmbedding>, String> {
|
||||
// First try to get from cache (read lock)
|
||||
{
|
||||
let cache = MODEL_CACHE
|
||||
.read()
|
||||
.map_err(|e| format!("Failed to acquire read lock: {}", e))?;
|
||||
if let Some(model) = cache.get(&embedding_model) {
|
||||
tracing::debug!("Using cached model: {:?}", embedding_model);
|
||||
return Ok(Arc::clone(model));
|
||||
}
|
||||
}
|
||||
|
||||
// Model not in cache, create it (write lock)
|
||||
let mut cache = MODEL_CACHE
|
||||
.write()
|
||||
.map_err(|e| format!("Failed to acquire write lock: {}", e))?;
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if let Some(model) = cache.get(&embedding_model) {
|
||||
tracing::debug!("Using cached model (double-check): {:?}", embedding_model);
|
||||
return Ok(Arc::clone(model));
|
||||
}
|
||||
|
||||
tracing::info!("Initializing new embedding model: {:?}", embedding_model);
|
||||
let model_start_time = std::time::Instant::now();
|
||||
|
||||
let model = TextEmbedding::try_new(
|
||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
|
||||
InitOptions::new(embedding_model.clone()).with_show_download_progress(true),
|
||||
)
|
||||
.expect("Failed to initialize persistent embedding model");
|
||||
.map_err(|e| format!("Failed to initialize model {:?}: {}", embedding_model, e))?;
|
||||
|
||||
let model_init_time = model_start_time.elapsed();
|
||||
tracing::info!(
|
||||
"Persistent embedding model initialized in {:.2?}",
|
||||
"Embedding model {:?} initialized in {:.2?}",
|
||||
embedding_model,
|
||||
model_init_time
|
||||
);
|
||||
|
||||
model
|
||||
});
|
||||
let model_arc = Arc::new(model);
|
||||
cache.insert(embedding_model.clone(), Arc::clone(&model_arc));
|
||||
Ok(model_arc)
|
||||
}
|
||||
|
||||
pub async fn embeddings_create(
|
||||
Json(payload): Json<CreateEmbeddingRequest>,
|
||||
) -> ResponseJson<serde_json::Value> {
|
||||
) -> Result<ResponseJson<serde_json::Value>, (StatusCode, String)> {
|
||||
// Start timing the entire process
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
// Phase 1: Access persistent model instance
|
||||
// Phase 1: Parse and get the embedding model
|
||||
let model_start_time = std::time::Instant::now();
|
||||
|
||||
// Access the lazy-initialized persistent model instance
|
||||
// This will only initialize the model on the first request
|
||||
let embedding_model = match parse_embedding_model(&payload.model) {
|
||||
Ok(model) => model,
|
||||
Err(e) => {
|
||||
tracing::error!("Invalid model requested: {}", e);
|
||||
return Err((StatusCode::BAD_REQUEST, format!("Invalid model: {}", e)));
|
||||
}
|
||||
};
|
||||
|
||||
let model = match get_or_create_model(embedding_model.clone()) {
|
||||
Ok(model) => model,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to get/create model: {}", e);
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Model initialization failed: {}", e),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let model_access_time = model_start_time.elapsed();
|
||||
tracing::debug!(
|
||||
"Persistent model access completed in {:.2?}",
|
||||
"Model access/creation completed in {:.2?}",
|
||||
model_access_time
|
||||
);
|
||||
|
||||
@@ -65,9 +247,13 @@ pub async fn embeddings_create(
|
||||
// Phase 3: Generate embeddings
|
||||
let embedding_start_time = std::time::Instant::now();
|
||||
|
||||
let embeddings = EMBEDDING_MODEL
|
||||
.embed(texts_from_embedding_input, None)
|
||||
.expect("failed to embed document");
|
||||
let embeddings = model.embed(texts_from_embedding_input, None).map_err(|e| {
|
||||
tracing::error!("Failed to generate embeddings: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Embedding generation failed: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
let embedding_generation_time = embedding_start_time.elapsed();
|
||||
tracing::info!(
|
||||
@@ -117,8 +303,9 @@ pub async fn embeddings_create(
|
||||
// Generate a random non-zero embedding
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut random_embedding = Vec::with_capacity(768);
|
||||
for _ in 0..768 {
|
||||
let expected_dimensions = get_model_dimensions(&embedding_model);
|
||||
let mut random_embedding = Vec::with_capacity(expected_dimensions);
|
||||
for _ in 0..expected_dimensions {
|
||||
// Generate random values between -1.0 and 1.0, excluding 0
|
||||
let mut val = 0.0;
|
||||
while val == 0.0 {
|
||||
@@ -138,18 +325,19 @@ pub async fn embeddings_create(
|
||||
random_embedding
|
||||
} else {
|
||||
// Check if dimensions parameter is provided and pad the embeddings if necessary
|
||||
let mut padded_embedding = embeddings[0].clone();
|
||||
let padded_embedding = embeddings[0].clone();
|
||||
|
||||
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
|
||||
let target_dimension = 768;
|
||||
if padded_embedding.len() < target_dimension {
|
||||
let padding_needed = target_dimension - padded_embedding.len();
|
||||
tracing::trace!(
|
||||
"Padding embedding with {} zeros to reach {} dimensions",
|
||||
padding_needed,
|
||||
target_dimension
|
||||
// Use the actual model dimensions instead of hardcoded 768
|
||||
let actual_dimensions = padded_embedding.len();
|
||||
let expected_dimensions = get_model_dimensions(&embedding_model);
|
||||
|
||||
if actual_dimensions != expected_dimensions {
|
||||
tracing::warn!(
|
||||
"Model {:?} produced {} dimensions but expected {}",
|
||||
embedding_model,
|
||||
actual_dimensions,
|
||||
expected_dimensions
|
||||
);
|
||||
padded_embedding.extend(vec![0.0; padding_needed]);
|
||||
}
|
||||
|
||||
padded_embedding
|
||||
@@ -203,11 +391,234 @@ pub async fn embeddings_create(
|
||||
postprocessing_time
|
||||
);
|
||||
|
||||
ResponseJson(response)
|
||||
Ok(ResponseJson(response))
|
||||
}
|
||||
|
||||
pub async fn models_list() -> ResponseJson<ModelsResponse> {
|
||||
let models = vec![
|
||||
ModelInfo {
|
||||
id: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "sentence-transformers".to_string(),
|
||||
description: "Sentence Transformer model, MiniLM-L6-v2".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "sentence-transformers/all-MiniLM-L6-v2-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "sentence-transformers".to_string(),
|
||||
description: "Quantized Sentence Transformer model, MiniLM-L6-v2".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "sentence-transformers/all-MiniLM-L12-v2".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "sentence-transformers".to_string(),
|
||||
description: "Sentence Transformer model, MiniLM-L12-v2".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "sentence-transformers/all-MiniLM-L12-v2-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "sentence-transformers".to_string(),
|
||||
description: "Quantized Sentence Transformer model, MiniLM-L12-v2".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-base-en-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "v1.5 release of the base English model".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-base-en-v1.5-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "Quantized v1.5 release of the base English model".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-large-en-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "v1.5 release of the large English model".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-large-en-v1.5-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "Quantized v1.5 release of the large English model".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-small-en-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "v1.5 release of the fast and default English model".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-small-en-v1.5-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "Quantized v1.5 release of the fast and default English model".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-small-zh-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "v1.5 release of the small Chinese model".to_string(),
|
||||
dimensions: 512,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "BAAI/bge-large-zh-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "BAAI".to_string(),
|
||||
description: "v1.5 release of the large Chinese model".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "nomic-ai/nomic-embed-text-v1".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "nomic-ai".to_string(),
|
||||
description: "8192 context length english model".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "nomic-ai/nomic-embed-text-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "nomic-ai".to_string(),
|
||||
description: "v1.5 release of the 8192 context length english model".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "nomic-ai/nomic-embed-text-v1.5-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "nomic-ai".to_string(),
|
||||
description: "Quantized v1.5 release of the 8192 context length english model"
|
||||
.to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "sentence-transformers".to_string(),
|
||||
description: "Multi-lingual model".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "sentence-transformers".to_string(),
|
||||
description: "Quantized Multi-lingual model".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "sentence-transformers".to_string(),
|
||||
description: "Sentence-transformers model for tasks like clustering or semantic search"
|
||||
.to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "lightonai/modernbert-embed-large".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "lightonai".to_string(),
|
||||
description: "Large model of ModernBert Text Embeddings".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "intfloat/multilingual-e5-small".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "intfloat".to_string(),
|
||||
description: "Small model of multilingual E5 Text Embeddings".to_string(),
|
||||
dimensions: 384,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "intfloat/multilingual-e5-base".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "intfloat".to_string(),
|
||||
description: "Base model of multilingual E5 Text Embeddings".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "intfloat/multilingual-e5-large".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "intfloat".to_string(),
|
||||
description: "Large model of multilingual E5 Text Embeddings".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "mixedbread-ai/mxbai-embed-large-v1".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "mixedbread-ai".to_string(),
|
||||
description: "Large English embedding model from MixedBreed.ai".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "mixedbread-ai/mxbai-embed-large-v1-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "mixedbread-ai".to_string(),
|
||||
description: "Quantized Large English embedding model from MixedBreed.ai".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "Alibaba-NLP/gte-base-en-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "Alibaba-NLP".to_string(),
|
||||
description: "Base multilingual embedding model from Alibaba".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "Alibaba-NLP/gte-base-en-v1.5-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "Alibaba-NLP".to_string(),
|
||||
description: "Quantized Base multilingual embedding model from Alibaba".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "Alibaba-NLP/gte-large-en-v1.5".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "Alibaba-NLP".to_string(),
|
||||
description: "Large multilingual embedding model from Alibaba".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "Alibaba-NLP/gte-large-en-v1.5-q".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "Alibaba-NLP".to_string(),
|
||||
description: "Quantized Large multilingual embedding model from Alibaba".to_string(),
|
||||
dimensions: 1024,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "Qdrant/clip-ViT-B-32-text".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "Qdrant".to_string(),
|
||||
description: "CLIP text encoder based on ViT-B/32".to_string(),
|
||||
dimensions: 512,
|
||||
},
|
||||
ModelInfo {
|
||||
id: "jinaai/jina-embeddings-v2-base-code".to_string(),
|
||||
object: "model".to_string(),
|
||||
owned_by: "jinaai".to_string(),
|
||||
description: "Jina embeddings v2 base code".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
];
|
||||
|
||||
ResponseJson(ModelsResponse {
|
||||
object: "list".to_string(),
|
||||
data: models,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn create_embeddings_router() -> Router {
|
||||
Router::new()
|
||||
.route("/v1/embeddings", post(embeddings_create))
|
||||
// .route("/v1/models", get(models_list))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
}
|
||||
|
@@ -4,8 +4,6 @@ use axum::{
|
||||
response::Json as ResponseJson,
|
||||
routing::{get, post},
|
||||
};
|
||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing;
|
||||
@@ -13,127 +11,28 @@ use tracing;
|
||||
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
|
||||
const DEFAULT_SERVER_PORT: &str = "8080";
|
||||
|
||||
use embeddings_engine;
|
||||
|
||||
async fn embeddings_create(
|
||||
Json(payload): Json<CreateEmbeddingRequest>,
|
||||
) -> ResponseJson<serde_json::Value> {
|
||||
let model = TextEmbedding::try_new(
|
||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
|
||||
)
|
||||
.expect("Failed to initialize model");
|
||||
|
||||
let embedding_input = payload.input;
|
||||
|
||||
let texts_from_embedding_input = match embedding_input {
|
||||
EmbeddingInput::String(text) => vec![text],
|
||||
EmbeddingInput::StringArray(texts) => texts,
|
||||
EmbeddingInput::IntegerArray(_) => {
|
||||
panic!("Integer array input not supported for text embeddings");
|
||||
) -> Result<ResponseJson<serde_json::Value>, axum::response::Response> {
|
||||
match embeddings_engine::embeddings_create(Json(payload)).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err((status_code, message)) => Err(axum::response::Response::builder()
|
||||
.status(status_code)
|
||||
.body(axum::body::Body::from(message))
|
||||
.unwrap()),
|
||||
}
|
||||
EmbeddingInput::ArrayOfIntegerArray(_) => {
|
||||
panic!("Array of integer arrays not supported for text embeddings");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let embeddings = model
|
||||
.embed(texts_from_embedding_input, None)
|
||||
.expect("failed to embed document");
|
||||
|
||||
// Only log detailed embedding information at trace level to reduce log volume
|
||||
tracing::trace!("Embeddings length: {}", embeddings.len());
|
||||
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
||||
|
||||
// Log the first 10 values of the original embedding at trace level
|
||||
tracing::trace!(
|
||||
"Original embedding preview: {:?}",
|
||||
&embeddings[0][..10.min(embeddings[0].len())]
|
||||
);
|
||||
|
||||
// Check if there are any NaN or zero values in the original embedding
|
||||
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
||||
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
||||
tracing::trace!(
|
||||
"Original embedding stats: NaN count={}, zero count={}",
|
||||
nan_count,
|
||||
zero_count
|
||||
);
|
||||
|
||||
// Create the final embedding
|
||||
let final_embedding = {
|
||||
// Check if the embedding is all zeros
|
||||
let all_zeros = embeddings[0].iter().all(|&x| x == 0.0);
|
||||
if all_zeros {
|
||||
tracing::warn!("Embedding is all zeros. Generating random non-zero embedding.");
|
||||
|
||||
// Generate a random non-zero embedding
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut random_embedding = Vec::with_capacity(768);
|
||||
for _ in 0..768 {
|
||||
// Generate random values between -1.0 and 1.0, excluding 0
|
||||
let mut val = 0.0;
|
||||
while val == 0.0 {
|
||||
val = rng.gen_range(-1.0..1.0);
|
||||
}
|
||||
random_embedding.push(val);
|
||||
}
|
||||
|
||||
// Normalize the random embedding
|
||||
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
for i in 0..random_embedding.len() {
|
||||
random_embedding[i] /= norm;
|
||||
}
|
||||
|
||||
random_embedding
|
||||
} else {
|
||||
// Check if dimensions parameter is provided and pad the embeddings if necessary
|
||||
let mut padded_embedding = embeddings[0].clone();
|
||||
|
||||
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
|
||||
let target_dimension = 768;
|
||||
if padded_embedding.len() < target_dimension {
|
||||
let padding_needed = target_dimension - padded_embedding.len();
|
||||
tracing::trace!(
|
||||
"Padding embedding with {} zeros to reach {} dimensions",
|
||||
padding_needed,
|
||||
target_dimension
|
||||
);
|
||||
padded_embedding.extend(vec![0.0; padding_needed]);
|
||||
}
|
||||
|
||||
padded_embedding
|
||||
}
|
||||
};
|
||||
|
||||
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
||||
|
||||
// Log the first 10 values of the final embedding at trace level
|
||||
tracing::trace!(
|
||||
"Final embedding preview: {:?}",
|
||||
&final_embedding[..10.min(final_embedding.len())]
|
||||
);
|
||||
|
||||
// Return a response that matches the OpenAI API format
|
||||
let response = serde_json::json!({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": final_embedding
|
||||
}
|
||||
],
|
||||
"model": payload.model,
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
});
|
||||
ResponseJson(response)
|
||||
async fn models_list() -> ResponseJson<embeddings_engine::ModelsResponse> {
|
||||
embeddings_engine::models_list().await
|
||||
}
|
||||
|
||||
fn create_app() -> Router {
|
||||
Router::new()
|
||||
.route("/v1/embeddings", post(embeddings_create))
|
||||
.route("/v1/models", get(models_list))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
}
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "inference-engine"
|
||||
version.workspace = true
|
||||
edition = "2021"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||
@@ -31,13 +31,21 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
reborrow = "0.5.5"
|
||||
futures-util = "0.3.31"
|
||||
gemma-runner = { path = "../gemma-runner", features = ["metal"] }
|
||||
llama-runner = { path = "../llama-runner", features = ["metal"]}
|
||||
gemma-runner = { path = "../../integration/gemma-runner" }
|
||||
llama-runner = { path = "../../integration/llama-runner" }
|
||||
embeddings-engine = { path = "../embeddings-engine" }
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", default-features = false }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", default-features = false }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", default-features = false }
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||
gemma-runner = { path = "../../integration/gemma-runner", features = ["metal"] }
|
||||
llama-runner = { path = "../../integration/llama-runner", features = ["metal"] }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
@@ -61,15 +69,13 @@ bindgen_cuda = { version = "0.1.1", optional = true }
|
||||
[features]
|
||||
bin = []
|
||||
|
||||
|
||||
|
||||
[package.metadata.compose]
|
||||
image = "ghcr.io/geoffsee/inference-engine:latest"
|
||||
port = 8080
|
||||
|
||||
[[bin]]
|
||||
name = "inference-engine"
|
||||
path = "src/main.rs"
|
||||
|
||||
# generates kubernetes manifests
|
||||
[package.metadata.kube]
|
||||
image = "ghcr.io/geoffsee/inference-service:latest"
|
||||
replicas = 1
|
||||
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||
cmd = ["./bin/inference-engine"]
|
||||
port = 8080
|
||||
replicas = 1
|
||||
|
@@ -1,86 +0,0 @@
|
||||
# ---- Build stage ----
|
||||
FROM rust:1-slim-bullseye AS builder
|
||||
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
# Install build dependencies including CUDA toolkit for GPU support
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
build-essential \
|
||||
wget \
|
||||
gnupg2 \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CUDA toolkit (optional, for GPU support)
|
||||
# This is a minimal CUDA installation for building
|
||||
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb && \
|
||||
dpkg -i cuda-keyring_1.0-1_all.deb && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-minimal-build-11-8 \
|
||||
libcublas-dev-11-8 \
|
||||
libcurand-dev-11-8 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& rm cuda-keyring_1.0-1_all.deb
|
||||
|
||||
# Set CUDA environment variables
|
||||
ENV CUDA_HOME=/usr/local/cuda
|
||||
ENV PATH=${CUDA_HOME}/bin:${PATH}
|
||||
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
||||
|
||||
# Copy the entire workspace to get access to all crates
|
||||
COPY . ./
|
||||
|
||||
# Cache dependencies first - create dummy source files
|
||||
RUN rm -rf crates/inference-engine/src
|
||||
RUN mkdir -p crates/inference-engine/src && \
|
||||
echo "fn main() {}" > crates/inference-engine/src/main.rs && \
|
||||
echo "fn main() {}" > crates/inference-engine/src/cli_main.rs && \
|
||||
echo "// lib" > crates/inference-engine/src/lib.rs && \
|
||||
cargo build --release --bin cli --package inference-engine
|
||||
|
||||
# Remove dummy source and copy real sources
|
||||
RUN rm -rf crates/inference-engine/src
|
||||
COPY . .
|
||||
|
||||
# Build the actual CLI binary
|
||||
RUN cargo build --release --bin cli --package inference-engine
|
||||
|
||||
# ---- Runtime stage ----
|
||||
FROM debian:bullseye-slim
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libssl1.1 \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CUDA runtime libraries (optional, for GPU support at runtime)
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
wget \
|
||||
gnupg2 \
|
||||
&& wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb \
|
||||
&& dpkg -i cuda-keyring_1.0-1_all.deb \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
cuda-cudart-11-8 \
|
||||
libcublas11 \
|
||||
libcurand10 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& rm cuda-keyring_1.0-1_all.deb \
|
||||
&& apt-get purge -y wget gnupg2
|
||||
|
||||
# Copy binary from builder
|
||||
COPY --from=builder /usr/src/app/target/release/cli /usr/local/bin/inference-cli
|
||||
|
||||
# Run as non-root user for safety
|
||||
RUN useradd -m appuser
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8080
|
||||
CMD ["inference-cli"]
|
@@ -8,7 +8,7 @@ pub mod server;
|
||||
// Re-export key components for easier access
|
||||
pub use inference::ModelInference;
|
||||
pub use model::{Model, Which};
|
||||
pub use server::{create_router, AppState};
|
||||
pub use server::{AppState, create_router};
|
||||
|
||||
use std::env;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
26
crates/inference-engine/src/main.rs
Normal file
26
crates/inference-engine/src/main.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use inference_engine::{AppState, create_router, get_server_config, init_tracing};
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::info;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
init_tracing();
|
||||
|
||||
let app_state = AppState::default();
|
||||
let app = create_router(app_state);
|
||||
|
||||
let (server_host, server_port, server_address) = get_server_config();
|
||||
let listener = TcpListener::bind(&server_address).await?;
|
||||
|
||||
info!(
|
||||
"Inference Engine server starting on http://{}",
|
||||
server_address
|
||||
);
|
||||
info!("Available endpoints:");
|
||||
info!(" POST /v1/chat/completions - OpenAI-compatible chat completions");
|
||||
info!(" GET /v1/models - List available models");
|
||||
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
@@ -42,7 +42,11 @@ pub struct ModelMeta {
|
||||
}
|
||||
|
||||
const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta {
|
||||
ModelMeta { id, family, instruct }
|
||||
ModelMeta {
|
||||
id,
|
||||
family,
|
||||
instruct,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
|
@@ -1,26 +1,28 @@
|
||||
use axum::{
|
||||
Json, Router,
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{sse::Event, sse::Sse, IntoResponse},
|
||||
response::{IntoResponse, sse::Event, sse::Sse},
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use futures_util::stream::{self, Stream};
|
||||
use std::convert::Infallible;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::Which;
|
||||
use crate::openai_types::{
|
||||
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest,
|
||||
ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage,
|
||||
};
|
||||
use crate::Which;
|
||||
use either::Either;
|
||||
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
|
||||
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
|
||||
use embeddings_engine::models_list;
|
||||
use gemma_runner::{GemmaInferenceConfig, WhichModel, run_gemma_api};
|
||||
use llama_runner::{LlamaInferenceConfig, run_llama_inference};
|
||||
use serde_json::Value;
|
||||
// -------------------------
|
||||
// Shared app state
|
||||
@@ -34,7 +36,7 @@ pub enum ModelType {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub model_type: ModelType,
|
||||
pub model_type: Option<ModelType>,
|
||||
pub model_id: String,
|
||||
pub gemma_config: Option<GemmaInferenceConfig>,
|
||||
pub llama_config: Option<LlamaInferenceConfig>,
|
||||
@@ -44,15 +46,16 @@ impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
// Configure a default model to prevent 503 errors from the chat-ui
|
||||
// This can be overridden by environment variables if needed
|
||||
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
|
||||
let default_model_id =
|
||||
std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
|
||||
|
||||
let gemma_config = GemmaInferenceConfig {
|
||||
model: gemma_runner::WhichModel::InstructV3_1B,
|
||||
model: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
Self {
|
||||
model_type: ModelType::Gemma,
|
||||
model_type: None,
|
||||
model_id: default_model_id,
|
||||
gemma_config: Some(gemma_config),
|
||||
llama_config: None,
|
||||
@@ -83,15 +86,14 @@ fn model_id_to_which(model_id: &str) -> Option<Which> {
|
||||
"gemma-2-9b-it" => Some(Which::InstructV2_9B),
|
||||
"gemma-3-1b" => Some(Which::BaseV3_1B),
|
||||
"gemma-3-1b-it" => Some(Which::InstructV3_1B),
|
||||
"llama-3.2-1b" => Some(Which::Llama32_1B),
|
||||
"llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct),
|
||||
"llama-3.2-3b" => Some(Which::Llama32_3B),
|
||||
"llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
fn normalize_model_id(model_id: &str) -> String {
|
||||
model_id.to_lowercase().replace("_", "-")
|
||||
}
|
||||
@@ -189,35 +191,74 @@ pub async fn chat_completions_non_streaming_proxy(
|
||||
// Get streaming receiver based on model type
|
||||
let rx = if which_model.is_llama_model() {
|
||||
// Create Llama configuration dynamically
|
||||
let mut config = LlamaInferenceConfig::default();
|
||||
let llama_model = match which_model {
|
||||
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
|
||||
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
|
||||
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
|
||||
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Llama model", model_id) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
let mut config = LlamaInferenceConfig::new(llama_model);
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_llama_inference(config).map_err(|e| (
|
||||
run_llama_inference(config).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
})),
|
||||
)
|
||||
})?
|
||||
} else {
|
||||
// Create Gemma configuration dynamically
|
||||
let gemma_model = if which_model.is_v3_model() {
|
||||
gemma_runner::WhichModel::InstructV3_1B
|
||||
} else {
|
||||
gemma_runner::WhichModel::InstructV3_1B // Default fallback
|
||||
let gemma_model = match which_model {
|
||||
Which::Base2B => gemma_runner::WhichModel::Base2B,
|
||||
Which::Base7B => gemma_runner::WhichModel::Base7B,
|
||||
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
|
||||
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
|
||||
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
|
||||
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
|
||||
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
|
||||
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
|
||||
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
|
||||
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
|
||||
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
|
||||
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
|
||||
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
|
||||
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
|
||||
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
|
||||
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut config = GemmaInferenceConfig {
|
||||
model: gemma_model,
|
||||
model: Some(gemma_model),
|
||||
..Default::default()
|
||||
};
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
run_gemma_api(config).map_err(|e| (
|
||||
run_gemma_api(config).map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||
}))
|
||||
))?
|
||||
})),
|
||||
)
|
||||
})?
|
||||
};
|
||||
|
||||
// Collect all tokens from the stream
|
||||
@@ -347,7 +388,21 @@ async fn handle_streaming_request(
|
||||
// Get streaming receiver based on model type
|
||||
let model_rx = if which_model.is_llama_model() {
|
||||
// Create Llama configuration dynamically
|
||||
let mut config = LlamaInferenceConfig::default();
|
||||
let llama_model = match which_model {
|
||||
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
|
||||
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
|
||||
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
|
||||
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Llama model", model_id) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
let mut config = LlamaInferenceConfig::new(llama_model);
|
||||
config.prompt = prompt.clone();
|
||||
config.max_tokens = max_tokens;
|
||||
match run_llama_inference(config) {
|
||||
@@ -363,14 +418,35 @@ async fn handle_streaming_request(
|
||||
}
|
||||
} else {
|
||||
// Create Gemma configuration dynamically
|
||||
let gemma_model = if which_model.is_v3_model() {
|
||||
gemma_runner::WhichModel::InstructV3_1B
|
||||
} else {
|
||||
gemma_runner::WhichModel::InstructV3_1B // Default fallback
|
||||
let gemma_model = match which_model {
|
||||
Which::Base2B => gemma_runner::WhichModel::Base2B,
|
||||
Which::Base7B => gemma_runner::WhichModel::Base7B,
|
||||
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
|
||||
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
|
||||
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
|
||||
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
|
||||
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
|
||||
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
|
||||
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
|
||||
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
|
||||
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
|
||||
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
|
||||
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
|
||||
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
|
||||
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
|
||||
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
|
||||
})),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut config = GemmaInferenceConfig {
|
||||
model: gemma_model,
|
||||
model: Some(gemma_model),
|
||||
..Default::default()
|
||||
};
|
||||
config.prompt = prompt.clone();
|
||||
@@ -530,7 +606,9 @@ pub async fn list_models() -> Json<ModelListResponse> {
|
||||
Which::Llama32_3BInstruct,
|
||||
];
|
||||
|
||||
let models: Vec<Model> = which_variants.into_iter().map(|which| {
|
||||
let mut models: Vec<Model> = which_variants
|
||||
.into_iter()
|
||||
.map(|which| {
|
||||
let meta = which.meta();
|
||||
let model_id = match which {
|
||||
Which::Base2B => "gemma-2b",
|
||||
@@ -566,10 +644,31 @@ pub async fn list_models() -> Json<ModelListResponse> {
|
||||
Model {
|
||||
id: model_id.to_string(),
|
||||
object: "model".to_string(),
|
||||
created: 1686935002, // Using same timestamp as OpenAI example
|
||||
created: 1686935002,
|
||||
owned_by: owned_by.to_string(),
|
||||
}
|
||||
}).collect();
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Get embeddings models and convert them to inference Model format
|
||||
let embeddings_response = models_list().await;
|
||||
let embeddings_models: Vec<Model> = embeddings_response
|
||||
.0
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|embedding_model| Model {
|
||||
id: embedding_model.id,
|
||||
object: embedding_model.object,
|
||||
created: 1686935002,
|
||||
owned_by: format!(
|
||||
"{} - {}",
|
||||
embedding_model.owned_by, embedding_model.description
|
||||
),
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Add embeddings models to the main models list
|
||||
models.extend(embeddings_models);
|
||||
|
||||
Json(ModelListResponse {
|
||||
object: "list".to_string(),
|
||||
|
@@ -29,25 +29,23 @@ inference-engine = { path = "../inference-engine" }
|
||||
|
||||
# Dependencies for leptos web app
|
||||
#leptos-app = { path = "../leptos-app", features = ["ssr"] }
|
||||
chat-ui = { path = "../chat-ui", features = ["ssr", "hydrate"], optional = false }
|
||||
chat-ui = { path = "../chat-ui", features = ["ssr", "hydrate"], optional = true }
|
||||
|
||||
mime_guess = "2.0.5"
|
||||
log = "0.4.27"
|
||||
|
||||
|
||||
[package.metadata.compose]
|
||||
name = "predict-otron-9000"
|
||||
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||
port = 8080
|
||||
|
||||
|
||||
# generates kubernetes manifests
|
||||
[package.metadata.kube]
|
||||
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||
replicas = 1
|
||||
port = 8080
|
||||
cmd = ["./bin/predict-otron-9000"]
|
||||
# SERVER_CONFIG Example: {\"serverMode\":\"HighAvailability\",\"services\":{\"inference_url\":\"http://custom-inference:9000\",\"embeddings_url\":\"http://custom-embeddings:9001\"}}
|
||||
# you can generate this via node to avoid toil
|
||||
# const server_config = {serverMode: "HighAvailability", services: {inference_url: "http://custom-inference:9000", embeddings_url: "http://custom-embeddings:9001"} };
|
||||
# console.log(JSON.stringify(server_config).replace(/"/g, '\\"'));
|
||||
env = { SERVER_CONFIG = "<your-json-value-here>" }
|
||||
|
||||
[features]
|
||||
default = ["ui"]
|
||||
ui = ["dep:chat-ui"]
|
||||
|
@@ -1,89 +0,0 @@
|
||||
# ---- Build stage ----
|
||||
FROM rust:1-slim-bullseye AS builder
|
||||
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
# Install build dependencies including CUDA toolkit for GPU support (needed for inference-engine dependency)
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
build-essential \
|
||||
wget \
|
||||
gnupg2 \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CUDA toolkit (required for inference-engine dependency)
|
||||
# This is a minimal CUDA installation for building
|
||||
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb && \
|
||||
dpkg -i cuda-keyring_1.0-1_all.deb && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cuda-minimal-build-11-8 \
|
||||
libcublas-dev-11-8 \
|
||||
libcurand-dev-11-8 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& rm cuda-keyring_1.0-1_all.deb
|
||||
|
||||
# Set CUDA environment variables
|
||||
ENV CUDA_HOME=/usr/local/cuda
|
||||
ENV PATH=${CUDA_HOME}/bin:${PATH}
|
||||
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
||||
|
||||
# Copy the entire workspace to get access to all crates (needed for local dependencies)
|
||||
COPY . ./
|
||||
|
||||
# Cache dependencies first - create dummy source files for all crates
|
||||
RUN rm -rf crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src
|
||||
RUN mkdir -p crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src && \
|
||||
echo "fn main() {}" > crates/predict-otron-9000/src/main.rs && \
|
||||
echo "fn main() {}" > crates/inference-engine/src/main.rs && \
|
||||
echo "fn main() {}" > crates/inference-engine/src/cli_main.rs && \
|
||||
echo "// lib" > crates/inference-engine/src/lib.rs && \
|
||||
echo "fn main() {}" > crates/embeddings-engine/src/main.rs && \
|
||||
echo "// lib" > crates/embeddings-engine/src/lib.rs && \
|
||||
cargo build --release --bin predict-otron-9000 --package predict-otron-9000
|
||||
|
||||
# Remove dummy sources and copy real sources
|
||||
RUN rm -rf crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src
|
||||
COPY . .
|
||||
|
||||
# Build the actual binary
|
||||
RUN cargo build --release --bin predict-otron-9000 --package predict-otron-9000
|
||||
|
||||
# ---- Runtime stage ----
|
||||
FROM debian:bullseye-slim
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libssl1.1 \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CUDA runtime libraries (required for inference-engine dependency)
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
wget \
|
||||
gnupg2 \
|
||||
&& wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb \
|
||||
&& dpkg -i cuda-keyring_1.0-1_all.deb \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
cuda-cudart-11-8 \
|
||||
libcublas11 \
|
||||
libcurand10 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& rm cuda-keyring_1.0-1_all.deb \
|
||||
&& apt-get purge -y wget gnupg2
|
||||
|
||||
# Copy binary from builder
|
||||
COPY --from=builder /usr/src/app/target/release/predict-otron-9000 /usr/local/bin/
|
||||
|
||||
# Run as non-root user for safety
|
||||
RUN useradd -m appuser
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8080
|
||||
CMD ["predict-otron-9000"]
|
@@ -39,29 +39,12 @@ impl Default for ServerMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
||||
pub struct Services {
|
||||
pub inference_url: Option<String>,
|
||||
pub embeddings_url: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for Services {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
inference_url: None,
|
||||
embeddings_url: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn inference_service_url() -> String {
|
||||
"http://inference-service:8080".to_string()
|
||||
}
|
||||
|
||||
fn embeddings_service_url() -> String {
|
||||
"http://embeddings-service:8080".to_string()
|
||||
}
|
||||
|
||||
impl Default for ServerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -118,8 +101,7 @@ impl ServerConfig {
|
||||
"HighAvailability mode configured but services not well defined! \n## Config Used:\n {}",
|
||||
config_string
|
||||
);
|
||||
let err = std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
let err = std::io::Error::other(
|
||||
"HighAvailability mode configured but services not well defined!",
|
||||
);
|
||||
return Err(err);
|
||||
|
@@ -4,28 +4,31 @@ mod middleware;
|
||||
mod standalone_mode;
|
||||
|
||||
use crate::standalone_mode::create_standalone_router;
|
||||
use axum::handler::Handler;
|
||||
use axum::http::StatusCode as AxumStatusCode;
|
||||
use axum::http::header;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::get;
|
||||
use axum::{Router, ServiceExt, http::Uri, response::Html, serve};
|
||||
use axum::{Router, serve};
|
||||
use config::ServerConfig;
|
||||
use ha_mode::create_ha_router;
|
||||
use inference_engine::AppState;
|
||||
use log::info;
|
||||
use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
|
||||
use mime_guess::from_path;
|
||||
use rust_embed::Embed;
|
||||
use std::env;
|
||||
use std::path::Component::ParentDir;
|
||||
|
||||
#[cfg(feature = "ui")]
|
||||
use axum::http::StatusCode as AxumStatusCode;
|
||||
#[cfg(feature = "ui")]
|
||||
use axum::http::Uri;
|
||||
#[cfg(feature = "ui")]
|
||||
use axum::http::header;
|
||||
#[cfg(feature = "ui")]
|
||||
use axum::response::IntoResponse;
|
||||
#[cfg(feature = "ui")]
|
||||
use mime_guess::from_path;
|
||||
#[cfg(feature = "ui")]
|
||||
use rust_embed::Embed;
|
||||
use tokio::net::TcpListener;
|
||||
use tower::MakeService;
|
||||
use tower_http::classify::ServerErrorsFailureClass::StatusCode;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
#[cfg(feature = "ui")]
|
||||
#[derive(Embed)]
|
||||
#[folder = "../../target/site"]
|
||||
#[include = "*.js"]
|
||||
@@ -34,6 +37,7 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
#[include = "*.ico"]
|
||||
struct Asset;
|
||||
|
||||
#[cfg(feature = "ui")]
|
||||
async fn static_handler(uri: Uri) -> axum::response::Response {
|
||||
// Strip the leading `/`
|
||||
let path = uri.path().trim_start_matches('/');
|
||||
@@ -111,23 +115,28 @@ async fn main() {
|
||||
// Create metrics layer
|
||||
let metrics_layer = MetricsLayer::new(metrics_store);
|
||||
|
||||
let leptos_config = chat_ui::app::AppConfig::default();
|
||||
|
||||
// Create the leptos router for the web frontend
|
||||
let leptos_router = chat_ui::app::create_router(leptos_config.config.leptos_options);
|
||||
|
||||
// Merge the service router with base routes and add middleware layers
|
||||
let app = Router::new()
|
||||
.route("/pkg/{*path}", get(static_handler))
|
||||
let mut app = Router::new()
|
||||
.route("/health", get(|| async { "ok" }))
|
||||
.merge(service_router)
|
||||
.merge(leptos_router)
|
||||
.merge(service_router);
|
||||
|
||||
// Add UI routes if the UI feature is enabled
|
||||
#[cfg(feature = "ui")]
|
||||
{
|
||||
let leptos_config = chat_ui::app::AppConfig::default();
|
||||
let leptos_router = chat_ui::app::create_router(leptos_config.config.leptos_options);
|
||||
app = app
|
||||
.route("/pkg/{*path}", get(static_handler))
|
||||
.merge(leptos_router);
|
||||
}
|
||||
|
||||
let app = app
|
||||
.layer(metrics_layer) // Add metrics tracking
|
||||
.layer(cors)
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
// Server configuration
|
||||
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| String::from(default_host));
|
||||
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| default_host.to_string());
|
||||
|
||||
let server_port = env::var("SERVER_PORT")
|
||||
.map(|v| v.parse::<u16>().unwrap_or(default_port))
|
||||
@@ -142,8 +151,10 @@ async fn main() {
|
||||
);
|
||||
tracing::info!("Performance metrics tracking enabled - summary logs every 60 seconds");
|
||||
tracing::info!("Available endpoints:");
|
||||
#[cfg(feature = "ui")]
|
||||
tracing::info!(" GET / - Leptos chat web application");
|
||||
tracing::info!(" GET /health - Health check");
|
||||
tracing::info!(" POST /v1/models - List Models");
|
||||
tracing::info!(" POST /v1/embeddings - Text embeddings API");
|
||||
tracing::info!(" POST /v1/chat/completions - Chat completions API");
|
||||
|
||||
|
@@ -2,7 +2,7 @@ use crate::config::ServerConfig;
|
||||
use axum::Router;
|
||||
use inference_engine::AppState;
|
||||
|
||||
pub fn create_standalone_router(server_config: ServerConfig) -> Router {
|
||||
pub fn create_standalone_router(_server_config: ServerConfig) -> Router {
|
||||
// Create unified router by merging embeddings and inference routers (existing behavior)
|
||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||
|
||||
|
@@ -61,20 +61,22 @@ graph TD
|
||||
A[predict-otron-9000<br/>Edition: 2024<br/>Port: 8080]
|
||||
end
|
||||
|
||||
subgraph "AI Services"
|
||||
subgraph "AI Services (crates/)"
|
||||
B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Multi-model orchestrator]
|
||||
J[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
|
||||
K[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
|
||||
C[embeddings-engine<br/>Edition: 2024<br/>Port: 8080<br/>FastEmbed]
|
||||
end
|
||||
|
||||
subgraph "Frontend"
|
||||
subgraph "Frontend (crates/)"
|
||||
D[chat-ui<br/>Edition: 2021<br/>Port: 8788<br/>WASM UI]
|
||||
end
|
||||
|
||||
subgraph "Tooling"
|
||||
|
||||
subgraph "Integration Tools (integration/)"
|
||||
L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment]
|
||||
E[cli<br/>Edition: 2024<br/>TypeScript/Bun CLI]
|
||||
M[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
|
||||
N[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
|
||||
O[utils<br/>Edition: 2021<br/>Shared utilities]
|
||||
end
|
||||
end
|
||||
|
||||
@@ -82,10 +84,10 @@ graph TD
|
||||
A --> B
|
||||
A --> C
|
||||
A --> D
|
||||
B --> J
|
||||
B --> K
|
||||
J -.-> F[Candle 0.9.1]
|
||||
K -.-> F
|
||||
B --> M
|
||||
B --> N
|
||||
M -.-> F[Candle 0.9.1]
|
||||
N -.-> F
|
||||
C -.-> G[FastEmbed 4.x]
|
||||
D -.-> H[Leptos 0.8.0]
|
||||
E -.-> I[OpenAI SDK 5.16+]
|
||||
@@ -93,12 +95,13 @@ graph TD
|
||||
|
||||
style A fill:#e1f5fe
|
||||
style B fill:#f3e5f5
|
||||
style J fill:#f3e5f5
|
||||
style K fill:#f3e5f5
|
||||
style C fill:#e8f5e8
|
||||
style D fill:#fff3e0
|
||||
style E fill:#fce4ec
|
||||
style L fill:#fff9c4
|
||||
style M fill:#f3e5f5
|
||||
style N fill:#f3e5f5
|
||||
style O fill:#fff9c4
|
||||
```
|
||||
|
||||
## Deployment Configurations
|
||||
|
@@ -14,7 +14,7 @@ Options:
|
||||
--help Show this help message
|
||||
|
||||
Examples:
|
||||
cd crates/cli/package
|
||||
cd integration/cli/package
|
||||
bun run cli.ts "What is the capital of France?"
|
||||
bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||
bun run cli.ts --prompt "Who was the 16th president of the United States?"
|
@@ -24,8 +24,7 @@ fn run_build() -> io::Result<()> {
|
||||
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set by Cargo"));
|
||||
let output_path = out_dir.join("client-cli");
|
||||
|
||||
let bun_tgt = BunTarget::from_cargo_env()
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
|
||||
let bun_tgt = BunTarget::from_cargo_env().map_err(|e| io::Error::other(e.to_string()))?;
|
||||
|
||||
// Optional: warn if using a Bun target that’s marked unsupported in your chart
|
||||
if matches!(bun_tgt, BunTarget::WindowsArm64) {
|
||||
@@ -54,13 +53,12 @@ fn run_build() -> io::Result<()> {
|
||||
|
||||
if !install_status.success() {
|
||||
let code = install_status.code().unwrap_or(1);
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("bun install failed with status {code}"),
|
||||
));
|
||||
return Err(io::Error::other(format!(
|
||||
"bun install failed with status {code}"
|
||||
)));
|
||||
}
|
||||
|
||||
let target = env::var("TARGET").unwrap();
|
||||
let _target = env::var("TARGET").unwrap();
|
||||
|
||||
// --- bun build (in ./package), emit to OUT_DIR, keep temps inside OUT_DIR ---
|
||||
let mut build = Command::new("bun")
|
||||
@@ -87,7 +85,7 @@ fn run_build() -> io::Result<()> {
|
||||
} else {
|
||||
let code = status.code().unwrap_or(1);
|
||||
warn(&format!("bun build failed with status: {code}"));
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "bun build failed"));
|
||||
return Err(io::Error::other("bun build failed"));
|
||||
}
|
||||
|
||||
// Ensure the output is executable (after it exists)
|
17
integration/cli/package/bun.lock
Normal file
17
integration/cli/package/bun.lock
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"lockfileVersion": 1,
|
||||
"workspaces": {
|
||||
"": {
|
||||
"name": "cli",
|
||||
"dependencies": {
|
||||
"install": "^0.13.0",
|
||||
"openai": "^5.16.0",
|
||||
},
|
||||
},
|
||||
},
|
||||
"packages": {
|
||||
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
|
||||
|
||||
"openai": ["openai@5.19.1", "", { "peerDependencies": { "ws": "^8.18.0", "zod": "^3.23.8" }, "optionalPeers": ["ws", "zod"], "bin": { "openai": "bin/cli" } }, "sha512-zSqnUF7oR9ksmpusKkpUgkNrj8Sl57U+OyzO8jzc7LUjTMg4DRfR3uCm+EIMA6iw06sRPNp4t7ojp3sCpEUZRQ=="],
|
||||
}
|
||||
}
|
@@ -25,7 +25,7 @@ fn main() -> io::Result<()> {
|
||||
// Run it
|
||||
let status = Command::new(&tmp).arg("--version").status()?;
|
||||
if !status.success() {
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "client-cli failed"));
|
||||
return Err(io::Error::other("client-cli failed"));
|
||||
}
|
||||
|
||||
Ok(())
|
@@ -18,7 +18,7 @@ serde_json = "1.0"
|
||||
tracing = "0.1"
|
||||
tracing-chrome = "0.7"
|
||||
tracing-subscriber = "0.3"
|
||||
utils = {path = "../utils"}
|
||||
utils = {path = "../utils" }
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
@@ -1,13 +1,7 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
use clap::ValueEnum;
|
||||
|
||||
// Removed gemma_cli import as it's not needed for the API
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
@@ -16,13 +10,15 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
use std::sync::mpsc::{self, Receiver, Sender};
|
||||
use std::thread;
|
||||
use tokenizers::Tokenizer;
|
||||
use utils::hub_load_safetensors;
|
||||
use utils::token_output_stream::TokenOutputStream;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
pub enum WhichModel {
|
||||
#[value(name = "gemma-2b")]
|
||||
Base2B,
|
||||
@@ -58,6 +54,56 @@ pub enum WhichModel {
|
||||
InstructV3_1B,
|
||||
}
|
||||
|
||||
impl FromStr for WhichModel {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"gemma-2b" => Ok(Self::Base2B),
|
||||
"gemma-7b" => Ok(Self::Base7B),
|
||||
"gemma-2b-it" => Ok(Self::Instruct2B),
|
||||
"gemma-7b-it" => Ok(Self::Instruct7B),
|
||||
"gemma-1.1-2b-it" => Ok(Self::InstructV1_1_2B),
|
||||
"gemma-1.1-7b-it" => Ok(Self::InstructV1_1_7B),
|
||||
"codegemma-2b" => Ok(Self::CodeBase2B),
|
||||
"codegemma-7b" => Ok(Self::CodeBase7B),
|
||||
"codegemma-2b-it" => Ok(Self::CodeInstruct2B),
|
||||
"codegemma-7b-it" => Ok(Self::CodeInstruct7B),
|
||||
"gemma-2-2b" => Ok(Self::BaseV2_2B),
|
||||
"gemma-2-2b-it" => Ok(Self::InstructV2_2B),
|
||||
"gemma-2-9b" => Ok(Self::BaseV2_9B),
|
||||
"gemma-2-9b-it" => Ok(Self::InstructV2_9B),
|
||||
"gemma-3-1b" => Ok(Self::BaseV3_1B),
|
||||
"gemma-3-1b-it" => Ok(Self::InstructV3_1B),
|
||||
_ => Err(format!("Unknown model: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for WhichModel {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let name = match self {
|
||||
Self::Base2B => "gemma-2b",
|
||||
Self::Base7B => "gemma-7b",
|
||||
Self::Instruct2B => "gemma-2b-it",
|
||||
Self::Instruct7B => "gemma-7b-it",
|
||||
Self::InstructV1_1_2B => "gemma-1.1-2b-it",
|
||||
Self::InstructV1_1_7B => "gemma-1.1-7b-it",
|
||||
Self::CodeBase2B => "codegemma-2b",
|
||||
Self::CodeBase7B => "codegemma-7b",
|
||||
Self::CodeInstruct2B => "codegemma-2b-it",
|
||||
Self::CodeInstruct7B => "codegemma-7b-it",
|
||||
Self::BaseV2_2B => "gemma-2-2b",
|
||||
Self::InstructV2_2B => "gemma-2-2b-it",
|
||||
Self::BaseV2_9B => "gemma-2-9b",
|
||||
Self::InstructV2_9B => "gemma-2-9b-it",
|
||||
Self::BaseV3_1B => "gemma-3-1b",
|
||||
Self::InstructV3_1B => "gemma-3-1b-it",
|
||||
};
|
||||
write!(f, "{}", name)
|
||||
}
|
||||
}
|
||||
|
||||
enum Model {
|
||||
V1(Model1),
|
||||
V2(Model2),
|
||||
@@ -145,8 +191,6 @@ impl TextGeneration {
|
||||
// Make sure stdout isn't holding anything (if caller also prints).
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
@@ -183,7 +227,6 @@ impl TextGeneration {
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
@@ -210,7 +253,7 @@ impl TextGeneration {
|
||||
pub struct GemmaInferenceConfig {
|
||||
pub tracing: bool,
|
||||
pub prompt: String,
|
||||
pub model: WhichModel,
|
||||
pub model: Option<WhichModel>,
|
||||
pub cpu: bool,
|
||||
pub dtype: Option<String>,
|
||||
pub model_id: Option<String>,
|
||||
@@ -229,7 +272,7 @@ impl Default for GemmaInferenceConfig {
|
||||
Self {
|
||||
tracing: false,
|
||||
prompt: "Hello".to_string(),
|
||||
model: WhichModel::InstructV2_2B,
|
||||
model: Some(WhichModel::InstructV2_2B),
|
||||
cpu: false,
|
||||
dtype: None,
|
||||
model_id: None,
|
||||
@@ -286,28 +329,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
}
|
||||
};
|
||||
println!("Using dtype: {:?}", dtype);
|
||||
println!("Raw model string: {:?}", cfg.model_id);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
|
||||
let model_id = cfg.model_id.unwrap_or_else(|| {
|
||||
match cfg.model {
|
||||
WhichModel::Base2B => "google/gemma-2b",
|
||||
WhichModel::Base7B => "google/gemma-7b",
|
||||
WhichModel::Instruct2B => "google/gemma-2b-it",
|
||||
WhichModel::Instruct7B => "google/gemma-7b-it",
|
||||
WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it",
|
||||
WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it",
|
||||
WhichModel::CodeBase2B => "google/codegemma-2b",
|
||||
WhichModel::CodeBase7B => "google/codegemma-7b",
|
||||
WhichModel::CodeInstruct2B => "google/codegemma-2b-it",
|
||||
WhichModel::CodeInstruct7B => "google/codegemma-7b-it",
|
||||
WhichModel::BaseV2_2B => "google/gemma-2-2b",
|
||||
WhichModel::InstructV2_2B => "google/gemma-2-2b-it",
|
||||
WhichModel::BaseV2_9B => "google/gemma-2-9b",
|
||||
WhichModel::InstructV2_9B => "google/gemma-2-9b-it",
|
||||
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
|
||||
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
|
||||
Some(WhichModel::Base2B) => "google/gemma-2b",
|
||||
Some(WhichModel::Base7B) => "google/gemma-7b",
|
||||
Some(WhichModel::Instruct2B) => "google/gemma-2b-it",
|
||||
Some(WhichModel::Instruct7B) => "google/gemma-7b-it",
|
||||
Some(WhichModel::InstructV1_1_2B) => "google/gemma-1.1-2b-it",
|
||||
Some(WhichModel::InstructV1_1_7B) => "google/gemma-1.1-7b-it",
|
||||
Some(WhichModel::CodeBase2B) => "google/codegemma-2b",
|
||||
Some(WhichModel::CodeBase7B) => "google/codegemma-7b",
|
||||
Some(WhichModel::CodeInstruct2B) => "google/codegemma-2b-it",
|
||||
Some(WhichModel::CodeInstruct7B) => "google/codegemma-7b-it",
|
||||
Some(WhichModel::BaseV2_2B) => "google/gemma-2-2b",
|
||||
Some(WhichModel::InstructV2_2B) => "google/gemma-2-2b-it",
|
||||
Some(WhichModel::BaseV2_9B) => "google/gemma-2-9b",
|
||||
Some(WhichModel::InstructV2_9B) => "google/gemma-2-9b-it",
|
||||
Some(WhichModel::BaseV3_1B) => "google/gemma-3-1b-pt",
|
||||
Some(WhichModel::InstructV3_1B) => "google/gemma-3-1b-it",
|
||||
None => "google/gemma-2-2b-it", // default fallback
|
||||
}
|
||||
.to_string()
|
||||
});
|
||||
@@ -318,7 +363,9 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let filenames = match cfg.model {
|
||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
||||
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("Retrieved files in {:?}", start.elapsed());
|
||||
@@ -329,29 +376,31 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
|
||||
let model: Model = match cfg.model {
|
||||
WhichModel::Base2B
|
||||
| WhichModel::Base7B
|
||||
| WhichModel::Instruct2B
|
||||
| WhichModel::Instruct7B
|
||||
| WhichModel::InstructV1_1_2B
|
||||
| WhichModel::InstructV1_1_7B
|
||||
| WhichModel::CodeBase2B
|
||||
| WhichModel::CodeBase7B
|
||||
| WhichModel::CodeInstruct2B
|
||||
| WhichModel::CodeInstruct7B => {
|
||||
Some(WhichModel::Base2B)
|
||||
| Some(WhichModel::Base7B)
|
||||
| Some(WhichModel::Instruct2B)
|
||||
| Some(WhichModel::Instruct7B)
|
||||
| Some(WhichModel::InstructV1_1_2B)
|
||||
| Some(WhichModel::InstructV1_1_7B)
|
||||
| Some(WhichModel::CodeBase2B)
|
||||
| Some(WhichModel::CodeBase7B)
|
||||
| Some(WhichModel::CodeInstruct2B)
|
||||
| Some(WhichModel::CodeInstruct7B) => {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V1(model)
|
||||
}
|
||||
WhichModel::BaseV2_2B
|
||||
| WhichModel::InstructV2_2B
|
||||
| WhichModel::BaseV2_9B
|
||||
| WhichModel::InstructV2_9B => {
|
||||
Some(WhichModel::BaseV2_2B)
|
||||
| Some(WhichModel::InstructV2_2B)
|
||||
| Some(WhichModel::BaseV2_9B)
|
||||
| Some(WhichModel::InstructV2_9B)
|
||||
| None => {
|
||||
// default to V2 model
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V2(model)
|
||||
}
|
||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => {
|
||||
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
|
||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
|
||||
Model::V3(model)
|
||||
@@ -371,7 +420,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
||||
);
|
||||
|
||||
let prompt = match cfg.model {
|
||||
WhichModel::InstructV3_1B => {
|
||||
Some(WhichModel::InstructV3_1B) => {
|
||||
format!(
|
||||
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
|
||||
cfg.prompt
|
@@ -67,7 +67,7 @@ pub fn run_cli() -> anyhow::Result<()> {
|
||||
let cfg = GemmaInferenceConfig {
|
||||
tracing: args.tracing,
|
||||
prompt: args.prompt,
|
||||
model: args.model,
|
||||
model: Some(args.model),
|
||||
cpu: args.cpu,
|
||||
dtype: args.dtype,
|
||||
model_id: args.model_id,
|
@@ -6,10 +6,8 @@ mod gemma_api;
|
||||
mod gemma_cli;
|
||||
|
||||
use anyhow::Error;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use crate::gemma_cli::run_cli;
|
||||
use std::io::Write;
|
||||
|
||||
/// just a placeholder, not used for anything
|
||||
fn main() -> std::result::Result<(), Error> {
|
@@ -64,14 +64,9 @@ version = "0.1.0"
|
||||
|
||||
# Required: Kubernetes metadata
|
||||
[package.metadata.kube]
|
||||
image = "ghcr.io/myorg/my-service:latest"
|
||||
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||
replicas = 1
|
||||
port = 8080
|
||||
|
||||
# Optional: Docker Compose metadata (currently not used but parsed)
|
||||
[package.metadata.compose]
|
||||
image = "ghcr.io/myorg/my-service:latest"
|
||||
port = 8080
|
||||
```
|
||||
|
||||
### Required Fields
|
@@ -1,9 +1,8 @@
|
||||
use anyhow::{Context, Result};
|
||||
use clap::{Arg, Command};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use serde::Deserialize;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::path::Path;
|
||||
use walkdir::WalkDir;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -20,7 +19,6 @@ struct Package {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Metadata {
|
||||
kube: Option<KubeMetadata>,
|
||||
compose: Option<ComposeMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -30,12 +28,6 @@ struct KubeMetadata {
|
||||
port: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ComposeMetadata {
|
||||
image: Option<String>,
|
||||
port: Option<u16>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ServiceInfo {
|
||||
name: String,
|
||||
@@ -105,7 +97,9 @@ fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
|
||||
.into_iter()
|
||||
.filter_map(|e| e.ok())
|
||||
{
|
||||
if entry.file_name() == "Cargo.toml" && entry.path() != workspace_root.join("Cargo.toml") {
|
||||
if entry.file_name() == "Cargo.toml"
|
||||
&& entry.path() != workspace_root.join("../../../Cargo.toml")
|
||||
{
|
||||
if let Ok(service_info) = parse_cargo_toml(entry.path()) {
|
||||
services.push(service_info);
|
||||
}
|
||||
@@ -375,7 +369,7 @@ spec:
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_ingress_template(templates_dir: &Path, services: &[ServiceInfo]) -> Result<()> {
|
||||
fn generate_ingress_template(templates_dir: &Path, _services: &[ServiceInfo]) -> Result<()> {
|
||||
let ingress_template = r#"{{- if .Values.ingress.enabled -}}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
@@ -1,6 +1,5 @@
|
||||
pub mod llama_api;
|
||||
|
||||
use clap::ValueEnum;
|
||||
pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
||||
|
||||
// Re-export constants and types that might be needed
|
@@ -57,6 +57,27 @@ pub struct LlamaInferenceConfig {
|
||||
pub repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl LlamaInferenceConfig {
|
||||
pub fn new(model: WhichModel) -> Self {
|
||||
Self {
|
||||
prompt: String::new(),
|
||||
model,
|
||||
cpu: false,
|
||||
temperature: 1.0,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
seed: 42,
|
||||
max_tokens: 512,
|
||||
no_kv_cache: false,
|
||||
dtype: None,
|
||||
model_id: None,
|
||||
revision: None,
|
||||
use_flash_attn: true,
|
||||
repeat_penalty: 1.1,
|
||||
repeat_last_n: 64,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Default for LlamaInferenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
@@ -6,9 +6,6 @@ mod llama_api;
|
||||
mod llama_cli;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use std::io::Write;
|
||||
|
||||
use crate::llama_cli::run_cli;
|
||||
|
@@ -1,17 +1,14 @@
|
||||
[package]
|
||||
name = "utils"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = {version = "0.3.2", optional = true }
|
||||
candle-nn = {version = "0.9.1" }
|
||||
candle-transformers = {version = "0.9.1" }
|
||||
|
||||
candle-flash-attn = {version = "0.9.1", optional = true }
|
||||
candle-onnx = {version = "0.9.1", optional = true }
|
||||
candle-core="0.9.1"
|
||||
csv = "1.3.0"
|
||||
anyhow = "1.0.99"
|
||||
cudarc = {version = "0.17.3", optional = true }
|
||||
@@ -86,3 +83,14 @@ mimi = ["cpal", "symphonia", "rubato"]
|
||||
snac = ["cpal", "symphonia", "rubato"]
|
||||
depth_anything_v2 = ["palette", "enterpolation"]
|
||||
tekken = ["tekken-rs"]
|
||||
|
||||
# Platform-specific candle dependencies
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
candle-nn = {version = "0.9.1", default-features = false }
|
||||
candle-transformers = {version = "0.9.1", default-features = false }
|
||||
candle-core = {version = "0.9.1", default-features = false }
|
||||
|
||||
[target.'cfg(not(target_os = "linux"))'.dependencies]
|
||||
candle-nn = {version = "0.9.1" }
|
||||
candle-transformers = {version = "0.9.1" }
|
||||
candle-core = {version = "0.9.1" }
|
@@ -1,5 +1,5 @@
|
||||
use candle_transformers::models::mimi::candle;
|
||||
use candle_core::{Device, Result, Tensor};
|
||||
use candle_transformers::models::mimi::candle;
|
||||
|
||||
pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
|
||||
pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
|
@@ -8,8 +8,10 @@ pub mod coco_classes;
|
||||
pub mod imagenet;
|
||||
pub mod token_output_stream;
|
||||
pub mod wav;
|
||||
use candle_core::{Device, Tensor, utils::{cuda_is_available, metal_is_available}};
|
||||
|
||||
use candle_core::{
|
||||
utils::{cuda_is_available, metal_is_available},
|
||||
Device, Tensor,
|
||||
};
|
||||
|
||||
pub fn device(cpu: bool) -> Result<Device, anyhow::Error> {
|
||||
if cpu {
|
||||
@@ -122,11 +124,8 @@ pub fn hub_load_safetensors(
|
||||
}
|
||||
let safetensors_files = safetensors_files
|
||||
.iter()
|
||||
.map(|v| {
|
||||
repo.get(v)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
|
||||
})
|
||||
.collect::<Result<Vec<_>, std::io::Error, >>()?;
|
||||
.map(|v| repo.get(v).map_err(std::io::Error::other))
|
||||
.collect::<Result<Vec<_>, std::io::Error>>()?;
|
||||
Ok(safetensors_files)
|
||||
}
|
||||
|
||||
@@ -136,7 +135,8 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
||||
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
|
||||
let path = path.as_ref();
|
||||
let jsfile = std::fs::File::open(path.join(json_file))?;
|
||||
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
None => anyhow::bail!("no weight map in {json_file:?}"),
|
||||
Some(serde_json::Value::Object(map)) => map,
|
@@ -1,8 +1,8 @@
|
||||
{
|
||||
"name": "predict-otron-9000",
|
||||
"workspaces": ["crates/cli/package"],
|
||||
"workspaces": ["integration/cli/package"],
|
||||
"scripts": {
|
||||
"# WORKSPACE ALIASES": "#",
|
||||
"cli": "bun --filter crates/cli/package"
|
||||
"cli": "bun --filter integration/cli/package"
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user