mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
Compare commits
9 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4380ac69d3 | ||
![]() |
e6f3351ebb | ||
![]() |
3992532f15 | ||
![]() |
3ecdd9ffa0 | ||
![]() |
296d4dbe7e | ||
![]() |
fb5098eba6 | ||
![]() |
c1c583faab | ||
![]() |
1e02b12cda | ||
![]() |
ff55d882c7 |
35
.dockerignore
Normal file
35
.dockerignore
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# Git
|
||||||
|
.git
|
||||||
|
.gitignore
|
||||||
|
|
||||||
|
# Rust
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Documentation
|
||||||
|
README.md
|
||||||
|
*.md
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
|
||||||
|
# Dependencies
|
||||||
|
node_modules/
|
||||||
|
|
||||||
|
# Build artifacts
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
.fastembed_cache
|
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Clippy
|
- name: Clippy
|
||||||
shell: bash
|
shell: bash
|
||||||
run: cargo clippy --all-targets
|
run: cargo clippy --all
|
||||||
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
shell: bash
|
shell: bash
|
||||||
|
46
.github/workflows/docker.yml
vendored
Normal file
46
.github/workflows/docker.yml
vendored
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
name: Build and Push Docker Image
|
||||||
|
|
||||||
|
on:
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
env:
|
||||||
|
REGISTRY: ghcr.io
|
||||||
|
IMAGE_NAME: ${{ github.repository }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-and-push:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Log in to Container Registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ${{ env.REGISTRY }}
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Extract metadata
|
||||||
|
id: meta
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
with:
|
||||||
|
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=ref,event=pr
|
||||||
|
type=semver,pattern={{version}}
|
||||||
|
type=semver,pattern={{major}}.{{minor}}
|
||||||
|
type=sha
|
||||||
|
|
||||||
|
- name: Build and push Docker image
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
push: true
|
||||||
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels }}
|
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -45,7 +45,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Clippy
|
- name: Clippy
|
||||||
shell: bash
|
shell: bash
|
||||||
run: cargo clippy --all-targets
|
run: cargo clippy --all
|
||||||
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
shell: bash
|
shell: bash
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@@ -77,3 +77,4 @@ venv/
|
|||||||
!/scripts/cli.ts
|
!/scripts/cli.ts
|
||||||
/**/.*.bun-build
|
/**/.*.bun-build
|
||||||
/AGENTS.md
|
/AGENTS.md
|
||||||
|
.claude
|
||||||
|
3
Cargo.lock
generated
3
Cargo.lock
generated
@@ -2905,6 +2905,7 @@ dependencies = [
|
|||||||
"clap",
|
"clap",
|
||||||
"cpal",
|
"cpal",
|
||||||
"either",
|
"either",
|
||||||
|
"embeddings-engine",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"gemma-runner",
|
"gemma-runner",
|
||||||
"imageproc 0.24.0",
|
"imageproc 0.24.0",
|
||||||
@@ -7040,7 +7041,7 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "utils"
|
name = "utils"
|
||||||
version = "0.0.0"
|
version = "0.1.4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ab_glyph",
|
"ab_glyph",
|
||||||
"accelerate-src",
|
"accelerate-src",
|
||||||
|
12
Cargo.toml
12
Cargo.toml
@@ -3,17 +3,17 @@ members = [
|
|||||||
"crates/predict-otron-9000",
|
"crates/predict-otron-9000",
|
||||||
"crates/inference-engine",
|
"crates/inference-engine",
|
||||||
"crates/embeddings-engine",
|
"crates/embeddings-engine",
|
||||||
"crates/helm-chart-tool",
|
"integration/helm-chart-tool",
|
||||||
"crates/llama-runner",
|
"integration/llama-runner",
|
||||||
"crates/gemma-runner",
|
"integration/gemma-runner",
|
||||||
"crates/cli",
|
"integration/cli",
|
||||||
"crates/chat-ui"
|
"crates/chat-ui"
|
||||||
, "crates/utils"]
|
, "integration/utils"]
|
||||||
default-members = ["crates/predict-otron-9000"]
|
default-members = ["crates/predict-otron-9000"]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.1.4"
|
version = "0.1.6"
|
||||||
|
|
||||||
# Compiler optimization profiles for the workspace
|
# Compiler optimization profiles for the workspace
|
||||||
[profile.release]
|
[profile.release]
|
||||||
|
50
Dockerfile
Normal file
50
Dockerfile
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Multi-stage build for predict-otron-9000 workspace
|
||||||
|
FROM rust:1 AS builder
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
pkg-config \
|
||||||
|
libssl-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Create app directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy workspace files
|
||||||
|
COPY Cargo.toml Cargo.lock ./
|
||||||
|
COPY crates/ ./crates/
|
||||||
|
COPY integration/ ./integration/
|
||||||
|
|
||||||
|
# Build all 3 main server binaries in release mode
|
||||||
|
RUN cargo build --release -p predict-otron-9000 --bin predict-otron-9000 --no-default-features -p embeddings-engine --bin embeddings-engine -p inference-engine --bin inference-engine
|
||||||
|
|
||||||
|
# Runtime stage
|
||||||
|
FROM debian:bookworm-slim AS runtime
|
||||||
|
|
||||||
|
# Install runtime dependencies
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
ca-certificates \
|
||||||
|
libssl3 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Create app user
|
||||||
|
RUN useradd -r -s /bin/false -m -d /app appuser
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy binaries from builder stage
|
||||||
|
COPY --from=builder /app/target/release/predict-otron-9000 ./bin/
|
||||||
|
COPY --from=builder /app/target/release/embeddings-engine ./bin/
|
||||||
|
COPY --from=builder /app/target/release/inference-engine ./bin/
|
||||||
|
# Make binaries executable and change ownership
|
||||||
|
RUN chmod +x ./bin/* && chown -R appuser:appuser /app
|
||||||
|
|
||||||
|
# Switch to non-root user
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
# Expose ports (adjust as needed based on your services)
|
||||||
|
EXPOSE 8080 8081 8082
|
||||||
|
|
||||||
|
# Default command (can be overridden)
|
||||||
|
CMD ["./bin/predict-otron-9000"]
|
25
README.md
25
README.md
@@ -12,7 +12,7 @@ AI inference Server with OpenAI-compatible API (Limited Features)
|
|||||||
> This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks.
|
> This project is an educational aide for bootstrapping my understanding of language model inferencing at the lowest levels I can, serving as a "rubber-duck" solution for Kubernetes based performance-oriented inference capabilities on air-gapped networks.
|
||||||
|
|
||||||
> By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems.
|
> By isolating application behaviors in components at the crate level, development reduces to a short feedback loop for validation and integration, ultimately smoothing the learning curve for scalable AI systems.
|
||||||
Stability is currently best effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
|
Stability is currently best-effort. Many models require unique configuration. When stability is achieved, this project will be promoted to the seemueller-io GitHub organization under a different name.
|
||||||
|
|
||||||
A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces.
|
A comprehensive multi-service AI platform built around local LLM inference, embeddings, and web interfaces.
|
||||||
|
|
||||||
@@ -53,14 +53,17 @@ The project uses a 9-crate Rust workspace plus TypeScript components:
|
|||||||
crates/
|
crates/
|
||||||
├── predict-otron-9000/ # Main orchestration server (Rust 2024)
|
├── predict-otron-9000/ # Main orchestration server (Rust 2024)
|
||||||
├── inference-engine/ # Multi-model inference orchestrator (Rust 2021)
|
├── inference-engine/ # Multi-model inference orchestrator (Rust 2021)
|
||||||
|
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
|
||||||
|
└── chat-ui/ # WASM web frontend (Rust 2021)
|
||||||
|
|
||||||
|
integration/
|
||||||
|
├── cli/ # CLI client crate (Rust 2024)
|
||||||
|
│ └── package/
|
||||||
|
│ └── cli.ts # TypeScript/Bun CLI client
|
||||||
├── gemma-runner/ # Gemma model inference via Candle (Rust 2021)
|
├── gemma-runner/ # Gemma model inference via Candle (Rust 2021)
|
||||||
├── llama-runner/ # Llama model inference via Candle (Rust 2021)
|
├── llama-runner/ # Llama model inference via Candle (Rust 2021)
|
||||||
├── embeddings-engine/ # FastEmbed embeddings service (Rust 2024)
|
|
||||||
├── chat-ui/ # WASM web frontend (Rust 2021)
|
|
||||||
├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024)
|
├── helm-chart-tool/ # Kubernetes deployment tooling (Rust 2024)
|
||||||
└── cli/ # CLI client crate (Rust 2024)
|
└── utils/ # Shared utilities (Rust 2021)
|
||||||
└── package/
|
|
||||||
└── cli.ts # TypeScript/Bun CLI client
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Service Architecture
|
### Service Architecture
|
||||||
@@ -160,16 +163,16 @@ cd crates/chat-ui
|
|||||||
#### TypeScript CLI Client
|
#### TypeScript CLI Client
|
||||||
```bash
|
```bash
|
||||||
# List available models
|
# List available models
|
||||||
cd crates/cli/package && bun run cli.ts --list-models
|
cd integration/cli/package && bun run cli.ts --list-models
|
||||||
|
|
||||||
# Chat completion
|
# Chat completion
|
||||||
cd crates/cli/package && bun run cli.ts "What is the capital of France?"
|
cd integration/cli/package && bun run cli.ts "What is the capital of France?"
|
||||||
|
|
||||||
# With specific model
|
# With specific model
|
||||||
cd crates/cli/package && bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
cd integration/cli/package && bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||||
|
|
||||||
# Show help
|
# Show help
|
||||||
cd crates/cli/package && bun run cli.ts --help
|
cd integration/cli/package && bun run cli.ts --help
|
||||||
```
|
```
|
||||||
|
|
||||||
## API Usage
|
## API Usage
|
||||||
@@ -464,7 +467,7 @@ curl -s http://localhost:8080/v1/models | jq
|
|||||||
|
|
||||||
**CLI client test:**
|
**CLI client test:**
|
||||||
```bash
|
```bash
|
||||||
cd crates/cli/package && bun run cli.ts "What is 2+2?"
|
cd integration/cli/package && bun run cli.ts "What is 2+2?"
|
||||||
```
|
```
|
||||||
|
|
||||||
**Web frontend:**
|
**Web frontend:**
|
||||||
|
4
bun.lock
4
bun.lock
@@ -4,7 +4,7 @@
|
|||||||
"": {
|
"": {
|
||||||
"name": "predict-otron-9000",
|
"name": "predict-otron-9000",
|
||||||
},
|
},
|
||||||
"crates/cli/package": {
|
"integration/cli/package": {
|
||||||
"name": "cli",
|
"name": "cli",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"install": "^0.13.0",
|
"install": "^0.13.0",
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"packages": {
|
"packages": {
|
||||||
"cli": ["cli@workspace:crates/cli/package"],
|
"cli": ["cli@workspace:integration/cli/package"],
|
||||||
|
|
||||||
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
|
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
|
||||||
|
|
||||||
|
@@ -3,6 +3,7 @@ name = "chat-ui"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
crate-type = ["cdylib", "rlib"]
|
crate-type = ["cdylib", "rlib"]
|
||||||
|
|
||||||
@@ -122,3 +123,7 @@ lib-default-features = false
|
|||||||
#
|
#
|
||||||
# Optional. Defaults to "release".
|
# Optional. Defaults to "release".
|
||||||
lib-profile-release = "release"
|
lib-profile-release = "release"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "chat-ui"
|
||||||
|
path = "src/main.rs"
|
@@ -257,7 +257,8 @@ pub fn send_chat_completion_stream(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
let value = js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap();
|
let value =
|
||||||
|
js_sys::Reflect::get(&result, &JsValue::from_str("value")).unwrap();
|
||||||
let array = js_sys::Uint8Array::new(&value);
|
let array = js_sys::Uint8Array::new(&value);
|
||||||
let mut bytes = vec![0; array.length() as usize];
|
let mut bytes = vec![0; array.length() as usize];
|
||||||
array.copy_to(&mut bytes);
|
array.copy_to(&mut bytes);
|
||||||
@@ -279,7 +280,9 @@ pub fn send_chat_completion_stream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse JSON chunk
|
// Parse JSON chunk
|
||||||
if let Ok(chunk) = serde_json::from_str::<StreamChatResponse>(data) {
|
if let Ok(chunk) =
|
||||||
|
serde_json::from_str::<StreamChatResponse>(data)
|
||||||
|
{
|
||||||
if let Some(choice) = chunk.choices.first() {
|
if let Some(choice) = chunk.choices.first() {
|
||||||
if let Some(content) = &choice.delta.content {
|
if let Some(content) = &choice.delta.content {
|
||||||
on_chunk(content.clone());
|
on_chunk(content.clone());
|
||||||
@@ -365,7 +368,7 @@ fn ChatPage() -> impl IntoView {
|
|||||||
|
|
||||||
// State for available models and selected model
|
// State for available models and selected model
|
||||||
let available_models = RwSignal::new(Vec::<ModelInfo>::new());
|
let available_models = RwSignal::new(Vec::<ModelInfo>::new());
|
||||||
let selected_model = RwSignal::new(String::from("gemma-3-1b-it")); // Default model
|
let selected_model = RwSignal::new(String::from("")); // Default model
|
||||||
|
|
||||||
// State for streaming response
|
// State for streaming response
|
||||||
let streaming_content = RwSignal::new(String::new());
|
let streaming_content = RwSignal::new(String::new());
|
||||||
@@ -382,6 +385,7 @@ fn ChatPage() -> impl IntoView {
|
|||||||
match fetch_models().await {
|
match fetch_models().await {
|
||||||
Ok(models) => {
|
Ok(models) => {
|
||||||
available_models.set(models);
|
available_models.set(models);
|
||||||
|
selected_model.set(String::from("gemma-3-1b-it"));
|
||||||
}
|
}
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
console::log_1(&format!("Failed to fetch models: {}", error).into());
|
console::log_1(&format!("Failed to fetch models: {}", error).into());
|
||||||
|
@@ -25,15 +25,9 @@ rand = "0.8.5"
|
|||||||
async-openai = "0.28.3"
|
async-openai = "0.28.3"
|
||||||
once_cell = "1.19.0"
|
once_cell = "1.19.0"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[package.metadata.compose]
|
|
||||||
image = "ghcr.io/geoffsee/embeddings-service:latest"
|
|
||||||
port = 8080
|
|
||||||
|
|
||||||
|
|
||||||
# generates kubernetes manifests
|
# generates kubernetes manifests
|
||||||
[package.metadata.kube]
|
[package.metadata.kube]
|
||||||
image = "ghcr.io/geoffsee/embeddings-service:latest"
|
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||||
|
cmd = ["./bin/embeddings-engine"]
|
||||||
replicas = 1
|
replicas = 1
|
||||||
port = 8080
|
port = 8080
|
@@ -1,42 +0,0 @@
|
|||||||
# ---- Build stage ----
|
|
||||||
FROM rust:1-slim-bullseye AS builder
|
|
||||||
|
|
||||||
WORKDIR /usr/src/app
|
|
||||||
|
|
||||||
# Install build dependencies
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y --no-install-recommends \
|
|
||||||
pkg-config \
|
|
||||||
libssl-dev \
|
|
||||||
build-essential \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Cache deps first
|
|
||||||
COPY . ./
|
|
||||||
RUN rm -rf src
|
|
||||||
RUN mkdir src && echo "fn main() {}" > src/main.rs && echo "// lib" > src/lib.rs && cargo build --release
|
|
||||||
RUN rm -rf src
|
|
||||||
|
|
||||||
# Copy real sources and build
|
|
||||||
COPY . .
|
|
||||||
RUN cargo build --release
|
|
||||||
|
|
||||||
# ---- Runtime stage ----
|
|
||||||
FROM debian:bullseye-slim
|
|
||||||
|
|
||||||
# Install only what the compiled binary needs
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y --no-install-recommends \
|
|
||||||
libssl1.1 \
|
|
||||||
ca-certificates \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Copy binary from builder
|
|
||||||
COPY --from=builder /usr/src/app/target/release/embeddings-engine /usr/local/bin/
|
|
||||||
|
|
||||||
# Run as non-root user for safety
|
|
||||||
RUN useradd -m appuser
|
|
||||||
USER appuser
|
|
||||||
|
|
||||||
EXPOSE 8080
|
|
||||||
CMD ["embeddings-engine"]
|
|
@@ -1,43 +1,225 @@
|
|||||||
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
||||||
use axum::{Json, Router, response::Json as ResponseJson, routing::post};
|
use axum::{Json, Router, http::StatusCode, response::Json as ResponseJson, routing::post};
|
||||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
|
use serde::Serialize;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::{Arc, RwLock};
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
use tracing;
|
|
||||||
|
|
||||||
// Persistent model instance (singleton pattern)
|
// Cache for multiple embedding models
|
||||||
static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
|
static MODEL_CACHE: Lazy<RwLock<HashMap<EmbeddingModel, Arc<TextEmbedding>>>> =
|
||||||
tracing::info!("Initializing persistent embedding model (singleton)");
|
Lazy::new(|| RwLock::new(HashMap::new()));
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct ModelInfo {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub owned_by: String,
|
||||||
|
pub description: String,
|
||||||
|
pub dimensions: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct ModelsResponse {
|
||||||
|
pub object: String,
|
||||||
|
pub data: Vec<ModelInfo>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to convert model name strings to EmbeddingModel enum variants
|
||||||
|
fn parse_embedding_model(model_name: &str) -> Result<EmbeddingModel, String> {
|
||||||
|
match model_name {
|
||||||
|
// Sentence Transformers models
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" | "all-minilm-l6-v2" => {
|
||||||
|
Ok(EmbeddingModel::AllMiniLML6V2)
|
||||||
|
}
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2-q" | "all-minilm-l6-v2-q" => {
|
||||||
|
Ok(EmbeddingModel::AllMiniLML6V2Q)
|
||||||
|
}
|
||||||
|
"sentence-transformers/all-MiniLM-L12-v2" | "all-minilm-l12-v2" => {
|
||||||
|
Ok(EmbeddingModel::AllMiniLML12V2)
|
||||||
|
}
|
||||||
|
"sentence-transformers/all-MiniLM-L12-v2-q" | "all-minilm-l12-v2-q" => {
|
||||||
|
Ok(EmbeddingModel::AllMiniLML12V2Q)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BGE models
|
||||||
|
"BAAI/bge-base-en-v1.5" | "bge-base-en-v1.5" => Ok(EmbeddingModel::BGEBaseENV15),
|
||||||
|
"BAAI/bge-base-en-v1.5-q" | "bge-base-en-v1.5-q" => Ok(EmbeddingModel::BGEBaseENV15Q),
|
||||||
|
"BAAI/bge-large-en-v1.5" | "bge-large-en-v1.5" => Ok(EmbeddingModel::BGELargeENV15),
|
||||||
|
"BAAI/bge-large-en-v1.5-q" | "bge-large-en-v1.5-q" => Ok(EmbeddingModel::BGELargeENV15Q),
|
||||||
|
"BAAI/bge-small-en-v1.5" | "bge-small-en-v1.5" => Ok(EmbeddingModel::BGESmallENV15),
|
||||||
|
"BAAI/bge-small-en-v1.5-q" | "bge-small-en-v1.5-q" => Ok(EmbeddingModel::BGESmallENV15Q),
|
||||||
|
"BAAI/bge-small-zh-v1.5" | "bge-small-zh-v1.5" => Ok(EmbeddingModel::BGESmallZHV15),
|
||||||
|
"BAAI/bge-large-zh-v1.5" | "bge-large-zh-v1.5" => Ok(EmbeddingModel::BGELargeZHV15),
|
||||||
|
|
||||||
|
// Nomic models
|
||||||
|
"nomic-ai/nomic-embed-text-v1" | "nomic-embed-text-v1" => {
|
||||||
|
Ok(EmbeddingModel::NomicEmbedTextV1)
|
||||||
|
}
|
||||||
|
"nomic-ai/nomic-embed-text-v1.5" | "nomic-embed-text-v1.5" | "nomic-text-embed" => {
|
||||||
|
Ok(EmbeddingModel::NomicEmbedTextV15)
|
||||||
|
}
|
||||||
|
"nomic-ai/nomic-embed-text-v1.5-q" | "nomic-embed-text-v1.5-q" => {
|
||||||
|
Ok(EmbeddingModel::NomicEmbedTextV15Q)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Paraphrase models
|
||||||
|
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||||
|
| "paraphrase-multilingual-minilm-l12-v2" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2),
|
||||||
|
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q"
|
||||||
|
| "paraphrase-multilingual-minilm-l12-v2-q" => Ok(EmbeddingModel::ParaphraseMLMiniLML12V2Q),
|
||||||
|
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||||
|
| "paraphrase-multilingual-mpnet-base-v2" => Ok(EmbeddingModel::ParaphraseMLMpnetBaseV2),
|
||||||
|
|
||||||
|
// ModernBert
|
||||||
|
"lightonai/modernbert-embed-large" | "modernbert-embed-large" => {
|
||||||
|
Ok(EmbeddingModel::ModernBertEmbedLarge)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multilingual E5 models
|
||||||
|
"intfloat/multilingual-e5-small" | "multilingual-e5-small" => {
|
||||||
|
Ok(EmbeddingModel::MultilingualE5Small)
|
||||||
|
}
|
||||||
|
"intfloat/multilingual-e5-base" | "multilingual-e5-base" => {
|
||||||
|
Ok(EmbeddingModel::MultilingualE5Base)
|
||||||
|
}
|
||||||
|
"intfloat/multilingual-e5-large" | "multilingual-e5-large" => {
|
||||||
|
Ok(EmbeddingModel::MultilingualE5Large)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mixedbread models
|
||||||
|
"mixedbread-ai/mxbai-embed-large-v1" | "mxbai-embed-large-v1" => {
|
||||||
|
Ok(EmbeddingModel::MxbaiEmbedLargeV1)
|
||||||
|
}
|
||||||
|
"mixedbread-ai/mxbai-embed-large-v1-q" | "mxbai-embed-large-v1-q" => {
|
||||||
|
Ok(EmbeddingModel::MxbaiEmbedLargeV1Q)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GTE models
|
||||||
|
"Alibaba-NLP/gte-base-en-v1.5" | "gte-base-en-v1.5" => Ok(EmbeddingModel::GTEBaseENV15),
|
||||||
|
"Alibaba-NLP/gte-base-en-v1.5-q" | "gte-base-en-v1.5-q" => {
|
||||||
|
Ok(EmbeddingModel::GTEBaseENV15Q)
|
||||||
|
}
|
||||||
|
"Alibaba-NLP/gte-large-en-v1.5" | "gte-large-en-v1.5" => Ok(EmbeddingModel::GTELargeENV15),
|
||||||
|
"Alibaba-NLP/gte-large-en-v1.5-q" | "gte-large-en-v1.5-q" => {
|
||||||
|
Ok(EmbeddingModel::GTELargeENV15Q)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CLIP model
|
||||||
|
"Qdrant/clip-ViT-B-32-text" | "clip-vit-b-32" => Ok(EmbeddingModel::ClipVitB32),
|
||||||
|
|
||||||
|
// Jina model
|
||||||
|
"jinaai/jina-embeddings-v2-base-code" | "jina-embeddings-v2-base-code" => {
|
||||||
|
Ok(EmbeddingModel::JinaEmbeddingsV2BaseCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => Err(format!("Unsupported embedding model: {}", model_name)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to get model dimensions
|
||||||
|
fn get_model_dimensions(model: &EmbeddingModel) -> usize {
|
||||||
|
match model {
|
||||||
|
EmbeddingModel::AllMiniLML6V2 | EmbeddingModel::AllMiniLML6V2Q => 384,
|
||||||
|
EmbeddingModel::AllMiniLML12V2 | EmbeddingModel::AllMiniLML12V2Q => 384,
|
||||||
|
EmbeddingModel::BGEBaseENV15 | EmbeddingModel::BGEBaseENV15Q => 768,
|
||||||
|
EmbeddingModel::BGELargeENV15 | EmbeddingModel::BGELargeENV15Q => 1024,
|
||||||
|
EmbeddingModel::BGESmallENV15 | EmbeddingModel::BGESmallENV15Q => 384,
|
||||||
|
EmbeddingModel::BGESmallZHV15 => 512,
|
||||||
|
EmbeddingModel::BGELargeZHV15 => 1024,
|
||||||
|
EmbeddingModel::NomicEmbedTextV1
|
||||||
|
| EmbeddingModel::NomicEmbedTextV15
|
||||||
|
| EmbeddingModel::NomicEmbedTextV15Q => 768,
|
||||||
|
EmbeddingModel::ParaphraseMLMiniLML12V2 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => 384,
|
||||||
|
EmbeddingModel::ParaphraseMLMpnetBaseV2 => 768,
|
||||||
|
EmbeddingModel::ModernBertEmbedLarge => 1024,
|
||||||
|
EmbeddingModel::MultilingualE5Small => 384,
|
||||||
|
EmbeddingModel::MultilingualE5Base => 768,
|
||||||
|
EmbeddingModel::MultilingualE5Large => 1024,
|
||||||
|
EmbeddingModel::MxbaiEmbedLargeV1 | EmbeddingModel::MxbaiEmbedLargeV1Q => 1024,
|
||||||
|
EmbeddingModel::GTEBaseENV15 | EmbeddingModel::GTEBaseENV15Q => 768,
|
||||||
|
EmbeddingModel::GTELargeENV15 | EmbeddingModel::GTELargeENV15Q => 1024,
|
||||||
|
EmbeddingModel::ClipVitB32 => 512,
|
||||||
|
EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to get or create a model from cache
|
||||||
|
fn get_or_create_model(embedding_model: EmbeddingModel) -> Result<Arc<TextEmbedding>, String> {
|
||||||
|
// First try to get from cache (read lock)
|
||||||
|
{
|
||||||
|
let cache = MODEL_CACHE
|
||||||
|
.read()
|
||||||
|
.map_err(|e| format!("Failed to acquire read lock: {}", e))?;
|
||||||
|
if let Some(model) = cache.get(&embedding_model) {
|
||||||
|
tracing::debug!("Using cached model: {:?}", embedding_model);
|
||||||
|
return Ok(Arc::clone(model));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model not in cache, create it (write lock)
|
||||||
|
let mut cache = MODEL_CACHE
|
||||||
|
.write()
|
||||||
|
.map_err(|e| format!("Failed to acquire write lock: {}", e))?;
|
||||||
|
|
||||||
|
// Double-check after acquiring write lock
|
||||||
|
if let Some(model) = cache.get(&embedding_model) {
|
||||||
|
tracing::debug!("Using cached model (double-check): {:?}", embedding_model);
|
||||||
|
return Ok(Arc::clone(model));
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("Initializing new embedding model: {:?}", embedding_model);
|
||||||
let model_start_time = std::time::Instant::now();
|
let model_start_time = std::time::Instant::now();
|
||||||
|
|
||||||
let model = TextEmbedding::try_new(
|
let model = TextEmbedding::try_new(
|
||||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
|
InitOptions::new(embedding_model.clone()).with_show_download_progress(true),
|
||||||
)
|
)
|
||||||
.expect("Failed to initialize persistent embedding model");
|
.map_err(|e| format!("Failed to initialize model {:?}: {}", embedding_model, e))?;
|
||||||
|
|
||||||
let model_init_time = model_start_time.elapsed();
|
let model_init_time = model_start_time.elapsed();
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
"Persistent embedding model initialized in {:.2?}",
|
"Embedding model {:?} initialized in {:.2?}",
|
||||||
|
embedding_model,
|
||||||
model_init_time
|
model_init_time
|
||||||
);
|
);
|
||||||
|
|
||||||
model
|
let model_arc = Arc::new(model);
|
||||||
});
|
cache.insert(embedding_model.clone(), Arc::clone(&model_arc));
|
||||||
|
Ok(model_arc)
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn embeddings_create(
|
pub async fn embeddings_create(
|
||||||
Json(payload): Json<CreateEmbeddingRequest>,
|
Json(payload): Json<CreateEmbeddingRequest>,
|
||||||
) -> ResponseJson<serde_json::Value> {
|
) -> Result<ResponseJson<serde_json::Value>, (StatusCode, String)> {
|
||||||
// Start timing the entire process
|
// Start timing the entire process
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
|
|
||||||
// Phase 1: Access persistent model instance
|
// Phase 1: Parse and get the embedding model
|
||||||
let model_start_time = std::time::Instant::now();
|
let model_start_time = std::time::Instant::now();
|
||||||
|
|
||||||
// Access the lazy-initialized persistent model instance
|
let embedding_model = match parse_embedding_model(&payload.model) {
|
||||||
// This will only initialize the model on the first request
|
Ok(model) => model,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Invalid model requested: {}", e);
|
||||||
|
return Err((StatusCode::BAD_REQUEST, format!("Invalid model: {}", e)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = match get_or_create_model(embedding_model.clone()) {
|
||||||
|
Ok(model) => model,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to get/create model: {}", e);
|
||||||
|
return Err((
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Model initialization failed: {}", e),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let model_access_time = model_start_time.elapsed();
|
let model_access_time = model_start_time.elapsed();
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
"Persistent model access completed in {:.2?}",
|
"Model access/creation completed in {:.2?}",
|
||||||
model_access_time
|
model_access_time
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -65,9 +247,13 @@ pub async fn embeddings_create(
|
|||||||
// Phase 3: Generate embeddings
|
// Phase 3: Generate embeddings
|
||||||
let embedding_start_time = std::time::Instant::now();
|
let embedding_start_time = std::time::Instant::now();
|
||||||
|
|
||||||
let embeddings = EMBEDDING_MODEL
|
let embeddings = model.embed(texts_from_embedding_input, None).map_err(|e| {
|
||||||
.embed(texts_from_embedding_input, None)
|
tracing::error!("Failed to generate embeddings: {}", e);
|
||||||
.expect("failed to embed document");
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Embedding generation failed: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
let embedding_generation_time = embedding_start_time.elapsed();
|
let embedding_generation_time = embedding_start_time.elapsed();
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
@@ -117,8 +303,9 @@ pub async fn embeddings_create(
|
|||||||
// Generate a random non-zero embedding
|
// Generate a random non-zero embedding
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::thread_rng();
|
||||||
let mut random_embedding = Vec::with_capacity(768);
|
let expected_dimensions = get_model_dimensions(&embedding_model);
|
||||||
for _ in 0..768 {
|
let mut random_embedding = Vec::with_capacity(expected_dimensions);
|
||||||
|
for _ in 0..expected_dimensions {
|
||||||
// Generate random values between -1.0 and 1.0, excluding 0
|
// Generate random values between -1.0 and 1.0, excluding 0
|
||||||
let mut val = 0.0;
|
let mut val = 0.0;
|
||||||
while val == 0.0 {
|
while val == 0.0 {
|
||||||
@@ -138,18 +325,19 @@ pub async fn embeddings_create(
|
|||||||
random_embedding
|
random_embedding
|
||||||
} else {
|
} else {
|
||||||
// Check if dimensions parameter is provided and pad the embeddings if necessary
|
// Check if dimensions parameter is provided and pad the embeddings if necessary
|
||||||
let mut padded_embedding = embeddings[0].clone();
|
let padded_embedding = embeddings[0].clone();
|
||||||
|
|
||||||
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
|
// Use the actual model dimensions instead of hardcoded 768
|
||||||
let target_dimension = 768;
|
let actual_dimensions = padded_embedding.len();
|
||||||
if padded_embedding.len() < target_dimension {
|
let expected_dimensions = get_model_dimensions(&embedding_model);
|
||||||
let padding_needed = target_dimension - padded_embedding.len();
|
|
||||||
tracing::trace!(
|
if actual_dimensions != expected_dimensions {
|
||||||
"Padding embedding with {} zeros to reach {} dimensions",
|
tracing::warn!(
|
||||||
padding_needed,
|
"Model {:?} produced {} dimensions but expected {}",
|
||||||
target_dimension
|
embedding_model,
|
||||||
|
actual_dimensions,
|
||||||
|
expected_dimensions
|
||||||
);
|
);
|
||||||
padded_embedding.extend(vec![0.0; padding_needed]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
padded_embedding
|
padded_embedding
|
||||||
@@ -203,11 +391,234 @@ pub async fn embeddings_create(
|
|||||||
postprocessing_time
|
postprocessing_time
|
||||||
);
|
);
|
||||||
|
|
||||||
ResponseJson(response)
|
Ok(ResponseJson(response))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn models_list() -> ResponseJson<ModelsResponse> {
|
||||||
|
let models = vec![
|
||||||
|
ModelInfo {
|
||||||
|
id: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "sentence-transformers".to_string(),
|
||||||
|
description: "Sentence Transformer model, MiniLM-L6-v2".to_string(),
|
||||||
|
dimensions: 384,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "sentence-transformers/all-MiniLM-L6-v2-q".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "sentence-transformers".to_string(),
|
||||||
|
description: "Quantized Sentence Transformer model, MiniLM-L6-v2".to_string(),
|
||||||
|
dimensions: 384,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "sentence-transformers/all-MiniLM-L12-v2".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "sentence-transformers".to_string(),
|
||||||
|
description: "Sentence Transformer model, MiniLM-L12-v2".to_string(),
|
||||||
|
dimensions: 384,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "sentence-transformers/all-MiniLM-L12-v2-q".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "sentence-transformers".to_string(),
|
||||||
|
description: "Quantized Sentence Transformer model, MiniLM-L12-v2".to_string(),
|
||||||
|
dimensions: 384,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "BAAI/bge-base-en-v1.5".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "BAAI".to_string(),
|
||||||
|
description: "v1.5 release of the base English model".to_string(),
|
||||||
|
dimensions: 768,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "BAAI/bge-base-en-v1.5-q".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "BAAI".to_string(),
|
||||||
|
description: "Quantized v1.5 release of the base English model".to_string(),
|
||||||
|
dimensions: 768,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "BAAI/bge-large-en-v1.5".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "BAAI".to_string(),
|
||||||
|
description: "v1.5 release of the large English model".to_string(),
|
||||||
|
dimensions: 1024,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "BAAI/bge-large-en-v1.5-q".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "BAAI".to_string(),
|
||||||
|
description: "Quantized v1.5 release of the large English model".to_string(),
|
||||||
|
dimensions: 1024,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "BAAI/bge-small-en-v1.5".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "BAAI".to_string(),
|
||||||
|
description: "v1.5 release of the fast and default English model".to_string(),
|
||||||
|
dimensions: 384,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "BAAI/bge-small-en-v1.5-q".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "BAAI".to_string(),
|
||||||
|
description: "Quantized v1.5 release of the fast and default English model".to_string(),
|
||||||
|
dimensions: 384,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "BAAI/bge-small-zh-v1.5".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "BAAI".to_string(),
|
||||||
|
description: "v1.5 release of the small Chinese model".to_string(),
|
||||||
|
dimensions: 512,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "BAAI/bge-large-zh-v1.5".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "BAAI".to_string(),
|
||||||
|
description: "v1.5 release of the large Chinese model".to_string(),
|
||||||
|
dimensions: 1024,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "nomic-ai/nomic-embed-text-v1".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "nomic-ai".to_string(),
|
||||||
|
description: "8192 context length english model".to_string(),
|
||||||
|
dimensions: 768,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "nomic-ai/nomic-embed-text-v1.5".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "nomic-ai".to_string(),
|
||||||
|
description: "v1.5 release of the 8192 context length english model".to_string(),
|
||||||
|
dimensions: 768,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "nomic-ai/nomic-embed-text-v1.5-q".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "nomic-ai".to_string(),
|
||||||
|
description: "Quantized v1.5 release of the 8192 context length english model"
|
||||||
|
.to_string(),
|
||||||
|
dimensions: 768,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "sentence-transformers".to_string(),
|
||||||
|
description: "Multi-lingual model".to_string(),
|
||||||
|
dimensions: 384,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2-q".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "sentence-transformers".to_string(),
|
||||||
|
description: "Quantized Multi-lingual model".to_string(),
|
||||||
|
dimensions: 384,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "sentence-transformers".to_string(),
|
||||||
|
description: "Sentence-transformers model for tasks like clustering or semantic search"
|
||||||
|
.to_string(),
|
||||||
|
dimensions: 768,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "lightonai/modernbert-embed-large".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "lightonai".to_string(),
|
||||||
|
description: "Large model of ModernBert Text Embeddings".to_string(),
|
||||||
|
dimensions: 1024,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "intfloat/multilingual-e5-small".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "intfloat".to_string(),
|
||||||
|
description: "Small model of multilingual E5 Text Embeddings".to_string(),
|
||||||
|
dimensions: 384,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "intfloat/multilingual-e5-base".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "intfloat".to_string(),
|
||||||
|
description: "Base model of multilingual E5 Text Embeddings".to_string(),
|
||||||
|
dimensions: 768,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "intfloat/multilingual-e5-large".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "intfloat".to_string(),
|
||||||
|
description: "Large model of multilingual E5 Text Embeddings".to_string(),
|
||||||
|
dimensions: 1024,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "mixedbread-ai/mxbai-embed-large-v1".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "mixedbread-ai".to_string(),
|
||||||
|
description: "Large English embedding model from MixedBreed.ai".to_string(),
|
||||||
|
dimensions: 1024,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "mixedbread-ai/mxbai-embed-large-v1-q".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "mixedbread-ai".to_string(),
|
||||||
|
description: "Quantized Large English embedding model from MixedBreed.ai".to_string(),
|
||||||
|
dimensions: 1024,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "Alibaba-NLP/gte-base-en-v1.5".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "Alibaba-NLP".to_string(),
|
||||||
|
description: "Base multilingual embedding model from Alibaba".to_string(),
|
||||||
|
dimensions: 768,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "Alibaba-NLP/gte-base-en-v1.5-q".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "Alibaba-NLP".to_string(),
|
||||||
|
description: "Quantized Base multilingual embedding model from Alibaba".to_string(),
|
||||||
|
dimensions: 768,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "Alibaba-NLP/gte-large-en-v1.5".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "Alibaba-NLP".to_string(),
|
||||||
|
description: "Large multilingual embedding model from Alibaba".to_string(),
|
||||||
|
dimensions: 1024,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "Alibaba-NLP/gte-large-en-v1.5-q".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "Alibaba-NLP".to_string(),
|
||||||
|
description: "Quantized Large multilingual embedding model from Alibaba".to_string(),
|
||||||
|
dimensions: 1024,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "Qdrant/clip-ViT-B-32-text".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "Qdrant".to_string(),
|
||||||
|
description: "CLIP text encoder based on ViT-B/32".to_string(),
|
||||||
|
dimensions: 512,
|
||||||
|
},
|
||||||
|
ModelInfo {
|
||||||
|
id: "jinaai/jina-embeddings-v2-base-code".to_string(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
owned_by: "jinaai".to_string(),
|
||||||
|
description: "Jina embeddings v2 base code".to_string(),
|
||||||
|
dimensions: 768,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
ResponseJson(ModelsResponse {
|
||||||
|
object: "list".to_string(),
|
||||||
|
data: models,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_embeddings_router() -> Router {
|
pub fn create_embeddings_router() -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/v1/embeddings", post(embeddings_create))
|
.route("/v1/embeddings", post(embeddings_create))
|
||||||
|
// .route("/v1/models", get(models_list))
|
||||||
.layer(TraceLayer::new_for_http())
|
.layer(TraceLayer::new_for_http())
|
||||||
}
|
}
|
||||||
|
@@ -4,8 +4,6 @@ use axum::{
|
|||||||
response::Json as ResponseJson,
|
response::Json as ResponseJson,
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
};
|
};
|
||||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::env;
|
use std::env;
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
use tracing;
|
use tracing;
|
||||||
@@ -13,127 +11,28 @@ use tracing;
|
|||||||
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
|
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
|
||||||
const DEFAULT_SERVER_PORT: &str = "8080";
|
const DEFAULT_SERVER_PORT: &str = "8080";
|
||||||
|
|
||||||
|
use embeddings_engine;
|
||||||
|
|
||||||
async fn embeddings_create(
|
async fn embeddings_create(
|
||||||
Json(payload): Json<CreateEmbeddingRequest>,
|
Json(payload): Json<CreateEmbeddingRequest>,
|
||||||
) -> ResponseJson<serde_json::Value> {
|
) -> Result<ResponseJson<serde_json::Value>, axum::response::Response> {
|
||||||
let model = TextEmbedding::try_new(
|
match embeddings_engine::embeddings_create(Json(payload)).await {
|
||||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
|
Ok(response) => Ok(response),
|
||||||
)
|
Err((status_code, message)) => Err(axum::response::Response::builder()
|
||||||
.expect("Failed to initialize model");
|
.status(status_code)
|
||||||
|
.body(axum::body::Body::from(message))
|
||||||
|
.unwrap()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let embedding_input = payload.input;
|
async fn models_list() -> ResponseJson<embeddings_engine::ModelsResponse> {
|
||||||
|
embeddings_engine::models_list().await
|
||||||
let texts_from_embedding_input = match embedding_input {
|
|
||||||
EmbeddingInput::String(text) => vec![text],
|
|
||||||
EmbeddingInput::StringArray(texts) => texts,
|
|
||||||
EmbeddingInput::IntegerArray(_) => {
|
|
||||||
panic!("Integer array input not supported for text embeddings");
|
|
||||||
}
|
|
||||||
EmbeddingInput::ArrayOfIntegerArray(_) => {
|
|
||||||
panic!("Array of integer arrays not supported for text embeddings");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let embeddings = model
|
|
||||||
.embed(texts_from_embedding_input, None)
|
|
||||||
.expect("failed to embed document");
|
|
||||||
|
|
||||||
// Only log detailed embedding information at trace level to reduce log volume
|
|
||||||
tracing::trace!("Embeddings length: {}", embeddings.len());
|
|
||||||
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
|
||||||
|
|
||||||
// Log the first 10 values of the original embedding at trace level
|
|
||||||
tracing::trace!(
|
|
||||||
"Original embedding preview: {:?}",
|
|
||||||
&embeddings[0][..10.min(embeddings[0].len())]
|
|
||||||
);
|
|
||||||
|
|
||||||
// Check if there are any NaN or zero values in the original embedding
|
|
||||||
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
|
||||||
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
|
||||||
tracing::trace!(
|
|
||||||
"Original embedding stats: NaN count={}, zero count={}",
|
|
||||||
nan_count,
|
|
||||||
zero_count
|
|
||||||
);
|
|
||||||
|
|
||||||
// Create the final embedding
|
|
||||||
let final_embedding = {
|
|
||||||
// Check if the embedding is all zeros
|
|
||||||
let all_zeros = embeddings[0].iter().all(|&x| x == 0.0);
|
|
||||||
if all_zeros {
|
|
||||||
tracing::warn!("Embedding is all zeros. Generating random non-zero embedding.");
|
|
||||||
|
|
||||||
// Generate a random non-zero embedding
|
|
||||||
use rand::Rng;
|
|
||||||
let mut rng = rand::thread_rng();
|
|
||||||
let mut random_embedding = Vec::with_capacity(768);
|
|
||||||
for _ in 0..768 {
|
|
||||||
// Generate random values between -1.0 and 1.0, excluding 0
|
|
||||||
let mut val = 0.0;
|
|
||||||
while val == 0.0 {
|
|
||||||
val = rng.gen_range(-1.0..1.0);
|
|
||||||
}
|
|
||||||
random_embedding.push(val);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize the random embedding
|
|
||||||
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
|
||||||
for i in 0..random_embedding.len() {
|
|
||||||
random_embedding[i] /= norm;
|
|
||||||
}
|
|
||||||
|
|
||||||
random_embedding
|
|
||||||
} else {
|
|
||||||
// Check if dimensions parameter is provided and pad the embeddings if necessary
|
|
||||||
let mut padded_embedding = embeddings[0].clone();
|
|
||||||
|
|
||||||
// If the client expects 768 dimensions but our model produces fewer, pad with zeros
|
|
||||||
let target_dimension = 768;
|
|
||||||
if padded_embedding.len() < target_dimension {
|
|
||||||
let padding_needed = target_dimension - padded_embedding.len();
|
|
||||||
tracing::trace!(
|
|
||||||
"Padding embedding with {} zeros to reach {} dimensions",
|
|
||||||
padding_needed,
|
|
||||||
target_dimension
|
|
||||||
);
|
|
||||||
padded_embedding.extend(vec![0.0; padding_needed]);
|
|
||||||
}
|
|
||||||
|
|
||||||
padded_embedding
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
|
||||||
|
|
||||||
// Log the first 10 values of the final embedding at trace level
|
|
||||||
tracing::trace!(
|
|
||||||
"Final embedding preview: {:?}",
|
|
||||||
&final_embedding[..10.min(final_embedding.len())]
|
|
||||||
);
|
|
||||||
|
|
||||||
// Return a response that matches the OpenAI API format
|
|
||||||
let response = serde_json::json!({
|
|
||||||
"object": "list",
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"object": "embedding",
|
|
||||||
"index": 0,
|
|
||||||
"embedding": final_embedding
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"model": payload.model,
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
}
|
|
||||||
});
|
|
||||||
ResponseJson(response)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_app() -> Router {
|
fn create_app() -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/v1/embeddings", post(embeddings_create))
|
.route("/v1/embeddings", post(embeddings_create))
|
||||||
|
.route("/v1/models", get(models_list))
|
||||||
.layer(TraceLayer::new_for_http())
|
.layer(TraceLayer::new_for_http())
|
||||||
}
|
}
|
||||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "inference-engine"
|
name = "inference-engine"
|
||||||
version.workspace = true
|
version.workspace = true
|
||||||
edition = "2021"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||||
@@ -31,13 +31,21 @@ utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
|||||||
uuid = { version = "1.7.0", features = ["v4"] }
|
uuid = { version = "1.7.0", features = ["v4"] }
|
||||||
reborrow = "0.5.5"
|
reborrow = "0.5.5"
|
||||||
futures-util = "0.3.31"
|
futures-util = "0.3.31"
|
||||||
gemma-runner = { path = "../gemma-runner", features = ["metal"] }
|
gemma-runner = { path = "../../integration/gemma-runner" }
|
||||||
llama-runner = { path = "../llama-runner", features = ["metal"]}
|
llama-runner = { path = "../../integration/llama-runner" }
|
||||||
|
embeddings-engine = { path = "../embeddings-engine" }
|
||||||
|
|
||||||
|
[target.'cfg(target_os = "linux")'.dependencies]
|
||||||
|
candle-core = { git = "https://github.com/huggingface/candle.git", default-features = false }
|
||||||
|
candle-nn = { git = "https://github.com/huggingface/candle.git", default-features = false }
|
||||||
|
candle-transformers = { git = "https://github.com/huggingface/candle.git", default-features = false }
|
||||||
|
|
||||||
[target.'cfg(target_os = "macos")'.dependencies]
|
[target.'cfg(target_os = "macos")'.dependencies]
|
||||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
gemma-runner = { path = "../../integration/gemma-runner", features = ["metal"] }
|
||||||
|
llama-runner = { path = "../../integration/llama-runner", features = ["metal"] }
|
||||||
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
@@ -61,15 +69,13 @@ bindgen_cuda = { version = "0.1.1", optional = true }
|
|||||||
[features]
|
[features]
|
||||||
bin = []
|
bin = []
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "inference-engine"
|
||||||
[package.metadata.compose]
|
path = "src/main.rs"
|
||||||
image = "ghcr.io/geoffsee/inference-engine:latest"
|
|
||||||
port = 8080
|
|
||||||
|
|
||||||
|
|
||||||
# generates kubernetes manifests
|
# generates kubernetes manifests
|
||||||
[package.metadata.kube]
|
[package.metadata.kube]
|
||||||
image = "ghcr.io/geoffsee/inference-service:latest"
|
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||||
replicas = 1
|
cmd = ["./bin/inference-engine"]
|
||||||
port = 8080
|
port = 8080
|
||||||
|
replicas = 1
|
||||||
|
@@ -1,86 +0,0 @@
|
|||||||
# ---- Build stage ----
|
|
||||||
FROM rust:1-slim-bullseye AS builder
|
|
||||||
|
|
||||||
WORKDIR /usr/src/app
|
|
||||||
|
|
||||||
# Install build dependencies including CUDA toolkit for GPU support
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y --no-install-recommends \
|
|
||||||
pkg-config \
|
|
||||||
libssl-dev \
|
|
||||||
build-essential \
|
|
||||||
wget \
|
|
||||||
gnupg2 \
|
|
||||||
curl \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Install CUDA toolkit (optional, for GPU support)
|
|
||||||
# This is a minimal CUDA installation for building
|
|
||||||
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb && \
|
|
||||||
dpkg -i cuda-keyring_1.0-1_all.deb && \
|
|
||||||
apt-get update && \
|
|
||||||
apt-get install -y --no-install-recommends \
|
|
||||||
cuda-minimal-build-11-8 \
|
|
||||||
libcublas-dev-11-8 \
|
|
||||||
libcurand-dev-11-8 \
|
|
||||||
&& rm -rf /var/lib/apt/lists/* \
|
|
||||||
&& rm cuda-keyring_1.0-1_all.deb
|
|
||||||
|
|
||||||
# Set CUDA environment variables
|
|
||||||
ENV CUDA_HOME=/usr/local/cuda
|
|
||||||
ENV PATH=${CUDA_HOME}/bin:${PATH}
|
|
||||||
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
|
||||||
|
|
||||||
# Copy the entire workspace to get access to all crates
|
|
||||||
COPY . ./
|
|
||||||
|
|
||||||
# Cache dependencies first - create dummy source files
|
|
||||||
RUN rm -rf crates/inference-engine/src
|
|
||||||
RUN mkdir -p crates/inference-engine/src && \
|
|
||||||
echo "fn main() {}" > crates/inference-engine/src/main.rs && \
|
|
||||||
echo "fn main() {}" > crates/inference-engine/src/cli_main.rs && \
|
|
||||||
echo "// lib" > crates/inference-engine/src/lib.rs && \
|
|
||||||
cargo build --release --bin cli --package inference-engine
|
|
||||||
|
|
||||||
# Remove dummy source and copy real sources
|
|
||||||
RUN rm -rf crates/inference-engine/src
|
|
||||||
COPY . .
|
|
||||||
|
|
||||||
# Build the actual CLI binary
|
|
||||||
RUN cargo build --release --bin cli --package inference-engine
|
|
||||||
|
|
||||||
# ---- Runtime stage ----
|
|
||||||
FROM debian:bullseye-slim
|
|
||||||
|
|
||||||
# Install runtime dependencies
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y --no-install-recommends \
|
|
||||||
libssl1.1 \
|
|
||||||
ca-certificates \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Install CUDA runtime libraries (optional, for GPU support at runtime)
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y --no-install-recommends \
|
|
||||||
wget \
|
|
||||||
gnupg2 \
|
|
||||||
&& wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb \
|
|
||||||
&& dpkg -i cuda-keyring_1.0-1_all.deb \
|
|
||||||
&& apt-get update \
|
|
||||||
&& apt-get install -y --no-install-recommends \
|
|
||||||
cuda-cudart-11-8 \
|
|
||||||
libcublas11 \
|
|
||||||
libcurand10 \
|
|
||||||
&& rm -rf /var/lib/apt/lists/* \
|
|
||||||
&& rm cuda-keyring_1.0-1_all.deb \
|
|
||||||
&& apt-get purge -y wget gnupg2
|
|
||||||
|
|
||||||
# Copy binary from builder
|
|
||||||
COPY --from=builder /usr/src/app/target/release/cli /usr/local/bin/inference-cli
|
|
||||||
|
|
||||||
# Run as non-root user for safety
|
|
||||||
RUN useradd -m appuser
|
|
||||||
USER appuser
|
|
||||||
|
|
||||||
EXPOSE 8080
|
|
||||||
CMD ["inference-cli"]
|
|
@@ -8,7 +8,7 @@ pub mod server;
|
|||||||
// Re-export key components for easier access
|
// Re-export key components for easier access
|
||||||
pub use inference::ModelInference;
|
pub use inference::ModelInference;
|
||||||
pub use model::{Model, Which};
|
pub use model::{Model, Which};
|
||||||
pub use server::{create_router, AppState};
|
pub use server::{AppState, create_router};
|
||||||
|
|
||||||
use std::env;
|
use std::env;
|
||||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||||
|
26
crates/inference-engine/src/main.rs
Normal file
26
crates/inference-engine/src/main.rs
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
use inference_engine::{AppState, create_router, get_server_config, init_tracing};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
init_tracing();
|
||||||
|
|
||||||
|
let app_state = AppState::default();
|
||||||
|
let app = create_router(app_state);
|
||||||
|
|
||||||
|
let (server_host, server_port, server_address) = get_server_config();
|
||||||
|
let listener = TcpListener::bind(&server_address).await?;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Inference Engine server starting on http://{}",
|
||||||
|
server_address
|
||||||
|
);
|
||||||
|
info!("Available endpoints:");
|
||||||
|
info!(" POST /v1/chat/completions - OpenAI-compatible chat completions");
|
||||||
|
info!(" GET /v1/models - List available models");
|
||||||
|
|
||||||
|
axum::serve(listener, app).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@@ -42,7 +42,11 @@ pub struct ModelMeta {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta {
|
const fn m(id: &'static str, family: Family, instruct: bool) -> ModelMeta {
|
||||||
ModelMeta { id, family, instruct }
|
ModelMeta {
|
||||||
|
id,
|
||||||
|
family,
|
||||||
|
instruct,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
@@ -1,26 +1,28 @@
|
|||||||
use axum::{
|
use axum::{
|
||||||
|
Json, Router,
|
||||||
extract::State,
|
extract::State,
|
||||||
http::StatusCode,
|
http::StatusCode,
|
||||||
response::{sse::Event, sse::Sse, IntoResponse},
|
response::{IntoResponse, sse::Event, sse::Sse},
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
Json, Router,
|
|
||||||
};
|
};
|
||||||
use futures_util::stream::{self, Stream};
|
use futures_util::stream::{self, Stream};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
|
use std::str::FromStr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::{Mutex, mpsc};
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tower_http::cors::{Any, CorsLayer};
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::Which;
|
||||||
use crate::openai_types::{
|
use crate::openai_types::{
|
||||||
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest,
|
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest,
|
||||||
ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage,
|
ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage,
|
||||||
};
|
};
|
||||||
use crate::Which;
|
|
||||||
use either::Either;
|
use either::Either;
|
||||||
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
|
use embeddings_engine::models_list;
|
||||||
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
|
use gemma_runner::{GemmaInferenceConfig, WhichModel, run_gemma_api};
|
||||||
|
use llama_runner::{LlamaInferenceConfig, run_llama_inference};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
// -------------------------
|
// -------------------------
|
||||||
// Shared app state
|
// Shared app state
|
||||||
@@ -34,7 +36,7 @@ pub enum ModelType {
|
|||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub model_type: ModelType,
|
pub model_type: Option<ModelType>,
|
||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
pub gemma_config: Option<GemmaInferenceConfig>,
|
pub gemma_config: Option<GemmaInferenceConfig>,
|
||||||
pub llama_config: Option<LlamaInferenceConfig>,
|
pub llama_config: Option<LlamaInferenceConfig>,
|
||||||
@@ -44,15 +46,16 @@ impl Default for AppState {
|
|||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
// Configure a default model to prevent 503 errors from the chat-ui
|
// Configure a default model to prevent 503 errors from the chat-ui
|
||||||
// This can be overridden by environment variables if needed
|
// This can be overridden by environment variables if needed
|
||||||
let default_model_id = std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
|
let default_model_id =
|
||||||
|
std::env::var("DEFAULT_MODEL").unwrap_or_else(|_| "gemma-3-1b-it".to_string());
|
||||||
|
|
||||||
let gemma_config = GemmaInferenceConfig {
|
let gemma_config = GemmaInferenceConfig {
|
||||||
model: gemma_runner::WhichModel::InstructV3_1B,
|
model: None,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
model_type: ModelType::Gemma,
|
model_type: None,
|
||||||
model_id: default_model_id,
|
model_id: default_model_id,
|
||||||
gemma_config: Some(gemma_config),
|
gemma_config: Some(gemma_config),
|
||||||
llama_config: None,
|
llama_config: None,
|
||||||
@@ -83,15 +86,14 @@ fn model_id_to_which(model_id: &str) -> Option<Which> {
|
|||||||
"gemma-2-9b-it" => Some(Which::InstructV2_9B),
|
"gemma-2-9b-it" => Some(Which::InstructV2_9B),
|
||||||
"gemma-3-1b" => Some(Which::BaseV3_1B),
|
"gemma-3-1b" => Some(Which::BaseV3_1B),
|
||||||
"gemma-3-1b-it" => Some(Which::InstructV3_1B),
|
"gemma-3-1b-it" => Some(Which::InstructV3_1B),
|
||||||
|
"llama-3.2-1b" => Some(Which::Llama32_1B),
|
||||||
"llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct),
|
"llama-3.2-1b-instruct" => Some(Which::Llama32_1BInstruct),
|
||||||
|
"llama-3.2-3b" => Some(Which::Llama32_3B),
|
||||||
"llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct),
|
"llama-3.2-3b-instruct" => Some(Which::Llama32_3BInstruct),
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
fn normalize_model_id(model_id: &str) -> String {
|
fn normalize_model_id(model_id: &str) -> String {
|
||||||
model_id.to_lowercase().replace("_", "-")
|
model_id.to_lowercase().replace("_", "-")
|
||||||
}
|
}
|
||||||
@@ -189,35 +191,74 @@ pub async fn chat_completions_non_streaming_proxy(
|
|||||||
// Get streaming receiver based on model type
|
// Get streaming receiver based on model type
|
||||||
let rx = if which_model.is_llama_model() {
|
let rx = if which_model.is_llama_model() {
|
||||||
// Create Llama configuration dynamically
|
// Create Llama configuration dynamically
|
||||||
let mut config = LlamaInferenceConfig::default();
|
let llama_model = match which_model {
|
||||||
|
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
|
||||||
|
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
|
||||||
|
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
|
||||||
|
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
|
||||||
|
_ => {
|
||||||
|
return Err((
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"error": { "message": format!("Model {} is not a Llama model", model_id) }
|
||||||
|
})),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut config = LlamaInferenceConfig::new(llama_model);
|
||||||
config.prompt = prompt.clone();
|
config.prompt = prompt.clone();
|
||||||
config.max_tokens = max_tokens;
|
config.max_tokens = max_tokens;
|
||||||
run_llama_inference(config).map_err(|e| (
|
run_llama_inference(config).map_err(|e| {
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
(
|
||||||
Json(serde_json::json!({
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
Json(serde_json::json!({
|
||||||
}))
|
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||||
))?
|
})),
|
||||||
|
)
|
||||||
|
})?
|
||||||
} else {
|
} else {
|
||||||
// Create Gemma configuration dynamically
|
// Create Gemma configuration dynamically
|
||||||
let gemma_model = if which_model.is_v3_model() {
|
let gemma_model = match which_model {
|
||||||
gemma_runner::WhichModel::InstructV3_1B
|
Which::Base2B => gemma_runner::WhichModel::Base2B,
|
||||||
} else {
|
Which::Base7B => gemma_runner::WhichModel::Base7B,
|
||||||
gemma_runner::WhichModel::InstructV3_1B // Default fallback
|
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
|
||||||
|
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
|
||||||
|
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
|
||||||
|
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
|
||||||
|
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
|
||||||
|
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
|
||||||
|
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
|
||||||
|
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
|
||||||
|
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
|
||||||
|
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
|
||||||
|
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
|
||||||
|
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
|
||||||
|
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
|
||||||
|
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
|
||||||
|
_ => {
|
||||||
|
return Err((
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
|
||||||
|
})),
|
||||||
|
));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut config = GemmaInferenceConfig {
|
let mut config = GemmaInferenceConfig {
|
||||||
model: gemma_model,
|
model: Some(gemma_model),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
config.prompt = prompt.clone();
|
config.prompt = prompt.clone();
|
||||||
config.max_tokens = max_tokens;
|
config.max_tokens = max_tokens;
|
||||||
run_gemma_api(config).map_err(|e| (
|
run_gemma_api(config).map_err(|e| {
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
(
|
||||||
Json(serde_json::json!({
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
Json(serde_json::json!({
|
||||||
}))
|
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||||
))?
|
})),
|
||||||
|
)
|
||||||
|
})?
|
||||||
};
|
};
|
||||||
|
|
||||||
// Collect all tokens from the stream
|
// Collect all tokens from the stream
|
||||||
@@ -347,7 +388,21 @@ async fn handle_streaming_request(
|
|||||||
// Get streaming receiver based on model type
|
// Get streaming receiver based on model type
|
||||||
let model_rx = if which_model.is_llama_model() {
|
let model_rx = if which_model.is_llama_model() {
|
||||||
// Create Llama configuration dynamically
|
// Create Llama configuration dynamically
|
||||||
let mut config = LlamaInferenceConfig::default();
|
let llama_model = match which_model {
|
||||||
|
Which::Llama32_1B => llama_runner::WhichModel::Llama32_1B,
|
||||||
|
Which::Llama32_1BInstruct => llama_runner::WhichModel::Llama32_1BInstruct,
|
||||||
|
Which::Llama32_3B => llama_runner::WhichModel::Llama32_3B,
|
||||||
|
Which::Llama32_3BInstruct => llama_runner::WhichModel::Llama32_3BInstruct,
|
||||||
|
_ => {
|
||||||
|
return Err((
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"error": { "message": format!("Model {} is not a Llama model", model_id) }
|
||||||
|
})),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut config = LlamaInferenceConfig::new(llama_model);
|
||||||
config.prompt = prompt.clone();
|
config.prompt = prompt.clone();
|
||||||
config.max_tokens = max_tokens;
|
config.max_tokens = max_tokens;
|
||||||
match run_llama_inference(config) {
|
match run_llama_inference(config) {
|
||||||
@@ -363,14 +418,35 @@ async fn handle_streaming_request(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Create Gemma configuration dynamically
|
// Create Gemma configuration dynamically
|
||||||
let gemma_model = if which_model.is_v3_model() {
|
let gemma_model = match which_model {
|
||||||
gemma_runner::WhichModel::InstructV3_1B
|
Which::Base2B => gemma_runner::WhichModel::Base2B,
|
||||||
} else {
|
Which::Base7B => gemma_runner::WhichModel::Base7B,
|
||||||
gemma_runner::WhichModel::InstructV3_1B // Default fallback
|
Which::Instruct2B => gemma_runner::WhichModel::Instruct2B,
|
||||||
|
Which::Instruct7B => gemma_runner::WhichModel::Instruct7B,
|
||||||
|
Which::InstructV1_1_2B => gemma_runner::WhichModel::InstructV1_1_2B,
|
||||||
|
Which::InstructV1_1_7B => gemma_runner::WhichModel::InstructV1_1_7B,
|
||||||
|
Which::CodeBase2B => gemma_runner::WhichModel::CodeBase2B,
|
||||||
|
Which::CodeBase7B => gemma_runner::WhichModel::CodeBase7B,
|
||||||
|
Which::CodeInstruct2B => gemma_runner::WhichModel::CodeInstruct2B,
|
||||||
|
Which::CodeInstruct7B => gemma_runner::WhichModel::CodeInstruct7B,
|
||||||
|
Which::BaseV2_2B => gemma_runner::WhichModel::BaseV2_2B,
|
||||||
|
Which::InstructV2_2B => gemma_runner::WhichModel::InstructV2_2B,
|
||||||
|
Which::BaseV2_9B => gemma_runner::WhichModel::BaseV2_9B,
|
||||||
|
Which::InstructV2_9B => gemma_runner::WhichModel::InstructV2_9B,
|
||||||
|
Which::BaseV3_1B => gemma_runner::WhichModel::BaseV3_1B,
|
||||||
|
Which::InstructV3_1B => gemma_runner::WhichModel::InstructV3_1B,
|
||||||
|
_ => {
|
||||||
|
return Err((
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"error": { "message": format!("Model {} is not a Gemma model", model_id) }
|
||||||
|
})),
|
||||||
|
));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut config = GemmaInferenceConfig {
|
let mut config = GemmaInferenceConfig {
|
||||||
model: gemma_model,
|
model: Some(gemma_model),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
config.prompt = prompt.clone();
|
config.prompt = prompt.clone();
|
||||||
@@ -530,46 +606,69 @@ pub async fn list_models() -> Json<ModelListResponse> {
|
|||||||
Which::Llama32_3BInstruct,
|
Which::Llama32_3BInstruct,
|
||||||
];
|
];
|
||||||
|
|
||||||
let models: Vec<Model> = which_variants.into_iter().map(|which| {
|
let mut models: Vec<Model> = which_variants
|
||||||
let meta = which.meta();
|
.into_iter()
|
||||||
let model_id = match which {
|
.map(|which| {
|
||||||
Which::Base2B => "gemma-2b",
|
let meta = which.meta();
|
||||||
Which::Base7B => "gemma-7b",
|
let model_id = match which {
|
||||||
Which::Instruct2B => "gemma-2b-it",
|
Which::Base2B => "gemma-2b",
|
||||||
Which::Instruct7B => "gemma-7b-it",
|
Which::Base7B => "gemma-7b",
|
||||||
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
|
Which::Instruct2B => "gemma-2b-it",
|
||||||
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
|
Which::Instruct7B => "gemma-7b-it",
|
||||||
Which::CodeBase2B => "codegemma-2b",
|
Which::InstructV1_1_2B => "gemma-1.1-2b-it",
|
||||||
Which::CodeBase7B => "codegemma-7b",
|
Which::InstructV1_1_7B => "gemma-1.1-7b-it",
|
||||||
Which::CodeInstruct2B => "codegemma-2b-it",
|
Which::CodeBase2B => "codegemma-2b",
|
||||||
Which::CodeInstruct7B => "codegemma-7b-it",
|
Which::CodeBase7B => "codegemma-7b",
|
||||||
Which::BaseV2_2B => "gemma-2-2b",
|
Which::CodeInstruct2B => "codegemma-2b-it",
|
||||||
Which::InstructV2_2B => "gemma-2-2b-it",
|
Which::CodeInstruct7B => "codegemma-7b-it",
|
||||||
Which::BaseV2_9B => "gemma-2-9b",
|
Which::BaseV2_2B => "gemma-2-2b",
|
||||||
Which::InstructV2_9B => "gemma-2-9b-it",
|
Which::InstructV2_2B => "gemma-2-2b-it",
|
||||||
Which::BaseV3_1B => "gemma-3-1b",
|
Which::BaseV2_9B => "gemma-2-9b",
|
||||||
Which::InstructV3_1B => "gemma-3-1b-it",
|
Which::InstructV2_9B => "gemma-2-9b-it",
|
||||||
Which::Llama32_1B => "llama-3.2-1b",
|
Which::BaseV3_1B => "gemma-3-1b",
|
||||||
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
|
Which::InstructV3_1B => "gemma-3-1b-it",
|
||||||
Which::Llama32_3B => "llama-3.2-3b",
|
Which::Llama32_1B => "llama-3.2-1b",
|
||||||
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
|
Which::Llama32_1BInstruct => "llama-3.2-1b-instruct",
|
||||||
};
|
Which::Llama32_3B => "llama-3.2-3b",
|
||||||
|
Which::Llama32_3BInstruct => "llama-3.2-3b-instruct",
|
||||||
|
};
|
||||||
|
|
||||||
let owned_by = if meta.id.starts_with("google/") {
|
let owned_by = if meta.id.starts_with("google/") {
|
||||||
"google"
|
"google"
|
||||||
} else if meta.id.starts_with("meta-llama/") {
|
} else if meta.id.starts_with("meta-llama/") {
|
||||||
"meta"
|
"meta"
|
||||||
} else {
|
} else {
|
||||||
"unknown"
|
"unknown"
|
||||||
};
|
};
|
||||||
|
|
||||||
Model {
|
Model {
|
||||||
id: model_id.to_string(),
|
id: model_id.to_string(),
|
||||||
object: "model".to_string(),
|
object: "model".to_string(),
|
||||||
created: 1686935002, // Using same timestamp as OpenAI example
|
created: 1686935002,
|
||||||
owned_by: owned_by.to_string(),
|
owned_by: owned_by.to_string(),
|
||||||
}
|
}
|
||||||
}).collect();
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Get embeddings models and convert them to inference Model format
|
||||||
|
let embeddings_response = models_list().await;
|
||||||
|
let embeddings_models: Vec<Model> = embeddings_response
|
||||||
|
.0
|
||||||
|
.data
|
||||||
|
.into_iter()
|
||||||
|
.map(|embedding_model| Model {
|
||||||
|
id: embedding_model.id,
|
||||||
|
object: embedding_model.object,
|
||||||
|
created: 1686935002,
|
||||||
|
owned_by: format!(
|
||||||
|
"{} - {}",
|
||||||
|
embedding_model.owned_by, embedding_model.description
|
||||||
|
),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Add embeddings models to the main models list
|
||||||
|
models.extend(embeddings_models);
|
||||||
|
|
||||||
Json(ModelListResponse {
|
Json(ModelListResponse {
|
||||||
object: "list".to_string(),
|
object: "list".to_string(),
|
||||||
|
@@ -29,25 +29,23 @@ inference-engine = { path = "../inference-engine" }
|
|||||||
|
|
||||||
# Dependencies for leptos web app
|
# Dependencies for leptos web app
|
||||||
#leptos-app = { path = "../leptos-app", features = ["ssr"] }
|
#leptos-app = { path = "../leptos-app", features = ["ssr"] }
|
||||||
chat-ui = { path = "../chat-ui", features = ["ssr", "hydrate"], optional = false }
|
chat-ui = { path = "../chat-ui", features = ["ssr", "hydrate"], optional = true }
|
||||||
|
|
||||||
mime_guess = "2.0.5"
|
mime_guess = "2.0.5"
|
||||||
log = "0.4.27"
|
log = "0.4.27"
|
||||||
|
|
||||||
|
|
||||||
[package.metadata.compose]
|
|
||||||
name = "predict-otron-9000"
|
|
||||||
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
|
||||||
port = 8080
|
|
||||||
|
|
||||||
|
|
||||||
# generates kubernetes manifests
|
# generates kubernetes manifests
|
||||||
[package.metadata.kube]
|
[package.metadata.kube]
|
||||||
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||||
replicas = 1
|
replicas = 1
|
||||||
port = 8080
|
port = 8080
|
||||||
|
cmd = ["./bin/predict-otron-9000"]
|
||||||
# SERVER_CONFIG Example: {\"serverMode\":\"HighAvailability\",\"services\":{\"inference_url\":\"http://custom-inference:9000\",\"embeddings_url\":\"http://custom-embeddings:9001\"}}
|
# SERVER_CONFIG Example: {\"serverMode\":\"HighAvailability\",\"services\":{\"inference_url\":\"http://custom-inference:9000\",\"embeddings_url\":\"http://custom-embeddings:9001\"}}
|
||||||
# you can generate this via node to avoid toil
|
# you can generate this via node to avoid toil
|
||||||
# const server_config = {serverMode: "HighAvailability", services: {inference_url: "http://custom-inference:9000", embeddings_url: "http://custom-embeddings:9001"} };
|
# const server_config = {serverMode: "HighAvailability", services: {inference_url: "http://custom-inference:9000", embeddings_url: "http://custom-embeddings:9001"} };
|
||||||
# console.log(JSON.stringify(server_config).replace(/"/g, '\\"'));
|
# console.log(JSON.stringify(server_config).replace(/"/g, '\\"'));
|
||||||
env = { SERVER_CONFIG = "<your-json-value-here>" }
|
env = { SERVER_CONFIG = "<your-json-value-here>" }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["ui"]
|
||||||
|
ui = ["dep:chat-ui"]
|
||||||
|
@@ -1,89 +0,0 @@
|
|||||||
# ---- Build stage ----
|
|
||||||
FROM rust:1-slim-bullseye AS builder
|
|
||||||
|
|
||||||
WORKDIR /usr/src/app
|
|
||||||
|
|
||||||
# Install build dependencies including CUDA toolkit for GPU support (needed for inference-engine dependency)
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y --no-install-recommends \
|
|
||||||
pkg-config \
|
|
||||||
libssl-dev \
|
|
||||||
build-essential \
|
|
||||||
wget \
|
|
||||||
gnupg2 \
|
|
||||||
curl \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Install CUDA toolkit (required for inference-engine dependency)
|
|
||||||
# This is a minimal CUDA installation for building
|
|
||||||
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb && \
|
|
||||||
dpkg -i cuda-keyring_1.0-1_all.deb && \
|
|
||||||
apt-get update && \
|
|
||||||
apt-get install -y --no-install-recommends \
|
|
||||||
cuda-minimal-build-11-8 \
|
|
||||||
libcublas-dev-11-8 \
|
|
||||||
libcurand-dev-11-8 \
|
|
||||||
&& rm -rf /var/lib/apt/lists/* \
|
|
||||||
&& rm cuda-keyring_1.0-1_all.deb
|
|
||||||
|
|
||||||
# Set CUDA environment variables
|
|
||||||
ENV CUDA_HOME=/usr/local/cuda
|
|
||||||
ENV PATH=${CUDA_HOME}/bin:${PATH}
|
|
||||||
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
|
||||||
|
|
||||||
# Copy the entire workspace to get access to all crates (needed for local dependencies)
|
|
||||||
COPY . ./
|
|
||||||
|
|
||||||
# Cache dependencies first - create dummy source files for all crates
|
|
||||||
RUN rm -rf crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src
|
|
||||||
RUN mkdir -p crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src && \
|
|
||||||
echo "fn main() {}" > crates/predict-otron-9000/src/main.rs && \
|
|
||||||
echo "fn main() {}" > crates/inference-engine/src/main.rs && \
|
|
||||||
echo "fn main() {}" > crates/inference-engine/src/cli_main.rs && \
|
|
||||||
echo "// lib" > crates/inference-engine/src/lib.rs && \
|
|
||||||
echo "fn main() {}" > crates/embeddings-engine/src/main.rs && \
|
|
||||||
echo "// lib" > crates/embeddings-engine/src/lib.rs && \
|
|
||||||
cargo build --release --bin predict-otron-9000 --package predict-otron-9000
|
|
||||||
|
|
||||||
# Remove dummy sources and copy real sources
|
|
||||||
RUN rm -rf crates/predict-otron-9000/src crates/inference-engine/src crates/embeddings-engine/src
|
|
||||||
COPY . .
|
|
||||||
|
|
||||||
# Build the actual binary
|
|
||||||
RUN cargo build --release --bin predict-otron-9000 --package predict-otron-9000
|
|
||||||
|
|
||||||
# ---- Runtime stage ----
|
|
||||||
FROM debian:bullseye-slim
|
|
||||||
|
|
||||||
# Install runtime dependencies
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y --no-install-recommends \
|
|
||||||
libssl1.1 \
|
|
||||||
ca-certificates \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Install CUDA runtime libraries (required for inference-engine dependency)
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y --no-install-recommends \
|
|
||||||
wget \
|
|
||||||
gnupg2 \
|
|
||||||
&& wget https://developer.download.nvidia.com/compute/cuda/repos/debian11/x86_64/cuda-keyring_1.0-1_all.deb \
|
|
||||||
&& dpkg -i cuda-keyring_1.0-1_all.deb \
|
|
||||||
&& apt-get update \
|
|
||||||
&& apt-get install -y --no-install-recommends \
|
|
||||||
cuda-cudart-11-8 \
|
|
||||||
libcublas11 \
|
|
||||||
libcurand10 \
|
|
||||||
&& rm -rf /var/lib/apt/lists/* \
|
|
||||||
&& rm cuda-keyring_1.0-1_all.deb \
|
|
||||||
&& apt-get purge -y wget gnupg2
|
|
||||||
|
|
||||||
# Copy binary from builder
|
|
||||||
COPY --from=builder /usr/src/app/target/release/predict-otron-9000 /usr/local/bin/
|
|
||||||
|
|
||||||
# Run as non-root user for safety
|
|
||||||
RUN useradd -m appuser
|
|
||||||
USER appuser
|
|
||||||
|
|
||||||
EXPOSE 8080
|
|
||||||
CMD ["predict-otron-9000"]
|
|
@@ -39,29 +39,12 @@ impl Default for ServerMode {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
||||||
pub struct Services {
|
pub struct Services {
|
||||||
pub inference_url: Option<String>,
|
pub inference_url: Option<String>,
|
||||||
pub embeddings_url: Option<String>,
|
pub embeddings_url: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Services {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
inference_url: None,
|
|
||||||
embeddings_url: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn inference_service_url() -> String {
|
|
||||||
"http://inference-service:8080".to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn embeddings_service_url() -> String {
|
|
||||||
"http://embeddings-service:8080".to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for ServerConfig {
|
impl Default for ServerConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -118,8 +101,7 @@ impl ServerConfig {
|
|||||||
"HighAvailability mode configured but services not well defined! \n## Config Used:\n {}",
|
"HighAvailability mode configured but services not well defined! \n## Config Used:\n {}",
|
||||||
config_string
|
config_string
|
||||||
);
|
);
|
||||||
let err = std::io::Error::new(
|
let err = std::io::Error::other(
|
||||||
std::io::ErrorKind::Other,
|
|
||||||
"HighAvailability mode configured but services not well defined!",
|
"HighAvailability mode configured but services not well defined!",
|
||||||
);
|
);
|
||||||
return Err(err);
|
return Err(err);
|
||||||
|
@@ -126,7 +126,7 @@ use crate::config::ServerConfig;
|
|||||||
/// - Pretty JSON is fine in TOML using `''' ... '''`, but remember the newlines are part of the string.
|
/// - Pretty JSON is fine in TOML using `''' ... '''`, but remember the newlines are part of the string.
|
||||||
/// - If you control the consumer, TOML tables (the alternative above) are more ergonomic than embedding JSON.
|
/// - If you control the consumer, TOML tables (the alternative above) are more ergonomic than embedding JSON.
|
||||||
|
|
||||||
/// HTTP client configured for proxying requests
|
/// HTTP client configured for proxying requests
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ProxyClient {
|
pub struct ProxyClient {
|
||||||
client: Client,
|
client: Client,
|
||||||
|
@@ -4,28 +4,31 @@ mod middleware;
|
|||||||
mod standalone_mode;
|
mod standalone_mode;
|
||||||
|
|
||||||
use crate::standalone_mode::create_standalone_router;
|
use crate::standalone_mode::create_standalone_router;
|
||||||
use axum::handler::Handler;
|
|
||||||
use axum::http::StatusCode as AxumStatusCode;
|
|
||||||
use axum::http::header;
|
|
||||||
use axum::response::IntoResponse;
|
|
||||||
use axum::routing::get;
|
use axum::routing::get;
|
||||||
use axum::{Router, ServiceExt, http::Uri, response::Html, serve};
|
use axum::{Router, serve};
|
||||||
use config::ServerConfig;
|
use config::ServerConfig;
|
||||||
use ha_mode::create_ha_router;
|
use ha_mode::create_ha_router;
|
||||||
use inference_engine::AppState;
|
|
||||||
use log::info;
|
|
||||||
use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
|
use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
|
||||||
use mime_guess::from_path;
|
|
||||||
use rust_embed::Embed;
|
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::path::Component::ParentDir;
|
|
||||||
|
#[cfg(feature = "ui")]
|
||||||
|
use axum::http::StatusCode as AxumStatusCode;
|
||||||
|
#[cfg(feature = "ui")]
|
||||||
|
use axum::http::Uri;
|
||||||
|
#[cfg(feature = "ui")]
|
||||||
|
use axum::http::header;
|
||||||
|
#[cfg(feature = "ui")]
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
#[cfg(feature = "ui")]
|
||||||
|
use mime_guess::from_path;
|
||||||
|
#[cfg(feature = "ui")]
|
||||||
|
use rust_embed::Embed;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use tower::MakeService;
|
|
||||||
use tower_http::classify::ServerErrorsFailureClass::StatusCode;
|
|
||||||
use tower_http::cors::{Any, CorsLayer};
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||||
|
|
||||||
|
#[cfg(feature = "ui")]
|
||||||
#[derive(Embed)]
|
#[derive(Embed)]
|
||||||
#[folder = "../../target/site"]
|
#[folder = "../../target/site"]
|
||||||
#[include = "*.js"]
|
#[include = "*.js"]
|
||||||
@@ -34,6 +37,7 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
|||||||
#[include = "*.ico"]
|
#[include = "*.ico"]
|
||||||
struct Asset;
|
struct Asset;
|
||||||
|
|
||||||
|
#[cfg(feature = "ui")]
|
||||||
async fn static_handler(uri: Uri) -> axum::response::Response {
|
async fn static_handler(uri: Uri) -> axum::response::Response {
|
||||||
// Strip the leading `/`
|
// Strip the leading `/`
|
||||||
let path = uri.path().trim_start_matches('/');
|
let path = uri.path().trim_start_matches('/');
|
||||||
@@ -111,23 +115,28 @@ async fn main() {
|
|||||||
// Create metrics layer
|
// Create metrics layer
|
||||||
let metrics_layer = MetricsLayer::new(metrics_store);
|
let metrics_layer = MetricsLayer::new(metrics_store);
|
||||||
|
|
||||||
let leptos_config = chat_ui::app::AppConfig::default();
|
|
||||||
|
|
||||||
// Create the leptos router for the web frontend
|
|
||||||
let leptos_router = chat_ui::app::create_router(leptos_config.config.leptos_options);
|
|
||||||
|
|
||||||
// Merge the service router with base routes and add middleware layers
|
// Merge the service router with base routes and add middleware layers
|
||||||
let app = Router::new()
|
let mut app = Router::new()
|
||||||
.route("/pkg/{*path}", get(static_handler))
|
|
||||||
.route("/health", get(|| async { "ok" }))
|
.route("/health", get(|| async { "ok" }))
|
||||||
.merge(service_router)
|
.merge(service_router);
|
||||||
.merge(leptos_router)
|
|
||||||
|
// Add UI routes if the UI feature is enabled
|
||||||
|
#[cfg(feature = "ui")]
|
||||||
|
{
|
||||||
|
let leptos_config = chat_ui::app::AppConfig::default();
|
||||||
|
let leptos_router = chat_ui::app::create_router(leptos_config.config.leptos_options);
|
||||||
|
app = app
|
||||||
|
.route("/pkg/{*path}", get(static_handler))
|
||||||
|
.merge(leptos_router);
|
||||||
|
}
|
||||||
|
|
||||||
|
let app = app
|
||||||
.layer(metrics_layer) // Add metrics tracking
|
.layer(metrics_layer) // Add metrics tracking
|
||||||
.layer(cors)
|
.layer(cors)
|
||||||
.layer(TraceLayer::new_for_http());
|
.layer(TraceLayer::new_for_http());
|
||||||
|
|
||||||
// Server configuration
|
// Server configuration
|
||||||
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| String::from(default_host));
|
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| default_host.to_string());
|
||||||
|
|
||||||
let server_port = env::var("SERVER_PORT")
|
let server_port = env::var("SERVER_PORT")
|
||||||
.map(|v| v.parse::<u16>().unwrap_or(default_port))
|
.map(|v| v.parse::<u16>().unwrap_or(default_port))
|
||||||
@@ -142,8 +151,10 @@ async fn main() {
|
|||||||
);
|
);
|
||||||
tracing::info!("Performance metrics tracking enabled - summary logs every 60 seconds");
|
tracing::info!("Performance metrics tracking enabled - summary logs every 60 seconds");
|
||||||
tracing::info!("Available endpoints:");
|
tracing::info!("Available endpoints:");
|
||||||
|
#[cfg(feature = "ui")]
|
||||||
tracing::info!(" GET / - Leptos chat web application");
|
tracing::info!(" GET / - Leptos chat web application");
|
||||||
tracing::info!(" GET /health - Health check");
|
tracing::info!(" GET /health - Health check");
|
||||||
|
tracing::info!(" POST /v1/models - List Models");
|
||||||
tracing::info!(" POST /v1/embeddings - Text embeddings API");
|
tracing::info!(" POST /v1/embeddings - Text embeddings API");
|
||||||
tracing::info!(" POST /v1/chat/completions - Chat completions API");
|
tracing::info!(" POST /v1/chat/completions - Chat completions API");
|
||||||
|
|
||||||
|
@@ -2,7 +2,7 @@ use crate::config::ServerConfig;
|
|||||||
use axum::Router;
|
use axum::Router;
|
||||||
use inference_engine::AppState;
|
use inference_engine::AppState;
|
||||||
|
|
||||||
pub fn create_standalone_router(server_config: ServerConfig) -> Router {
|
pub fn create_standalone_router(_server_config: ServerConfig) -> Router {
|
||||||
// Create unified router by merging embeddings and inference routers (existing behavior)
|
// Create unified router by merging embeddings and inference routers (existing behavior)
|
||||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||||
|
|
||||||
|
@@ -61,20 +61,22 @@ graph TD
|
|||||||
A[predict-otron-9000<br/>Edition: 2024<br/>Port: 8080]
|
A[predict-otron-9000<br/>Edition: 2024<br/>Port: 8080]
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph "AI Services"
|
subgraph "AI Services (crates/)"
|
||||||
B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Multi-model orchestrator]
|
B[inference-engine<br/>Edition: 2021<br/>Port: 8080<br/>Multi-model orchestrator]
|
||||||
J[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
|
|
||||||
K[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
|
|
||||||
C[embeddings-engine<br/>Edition: 2024<br/>Port: 8080<br/>FastEmbed]
|
C[embeddings-engine<br/>Edition: 2024<br/>Port: 8080<br/>FastEmbed]
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph "Frontend"
|
subgraph "Frontend (crates/)"
|
||||||
D[chat-ui<br/>Edition: 2021<br/>Port: 8788<br/>WASM UI]
|
D[chat-ui<br/>Edition: 2021<br/>Port: 8788<br/>WASM UI]
|
||||||
end
|
end
|
||||||
|
|
||||||
subgraph "Tooling"
|
|
||||||
|
subgraph "Integration Tools (integration/)"
|
||||||
L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment]
|
L[helm-chart-tool<br/>Edition: 2024<br/>K8s deployment]
|
||||||
E[cli<br/>Edition: 2024<br/>TypeScript/Bun CLI]
|
E[cli<br/>Edition: 2024<br/>TypeScript/Bun CLI]
|
||||||
|
M[gemma-runner<br/>Edition: 2021<br/>Gemma via Candle]
|
||||||
|
N[llama-runner<br/>Edition: 2021<br/>Llama via Candle]
|
||||||
|
O[utils<br/>Edition: 2021<br/>Shared utilities]
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -82,10 +84,10 @@ graph TD
|
|||||||
A --> B
|
A --> B
|
||||||
A --> C
|
A --> C
|
||||||
A --> D
|
A --> D
|
||||||
B --> J
|
B --> M
|
||||||
B --> K
|
B --> N
|
||||||
J -.-> F[Candle 0.9.1]
|
M -.-> F[Candle 0.9.1]
|
||||||
K -.-> F
|
N -.-> F
|
||||||
C -.-> G[FastEmbed 4.x]
|
C -.-> G[FastEmbed 4.x]
|
||||||
D -.-> H[Leptos 0.8.0]
|
D -.-> H[Leptos 0.8.0]
|
||||||
E -.-> I[OpenAI SDK 5.16+]
|
E -.-> I[OpenAI SDK 5.16+]
|
||||||
@@ -93,12 +95,13 @@ graph TD
|
|||||||
|
|
||||||
style A fill:#e1f5fe
|
style A fill:#e1f5fe
|
||||||
style B fill:#f3e5f5
|
style B fill:#f3e5f5
|
||||||
style J fill:#f3e5f5
|
|
||||||
style K fill:#f3e5f5
|
|
||||||
style C fill:#e8f5e8
|
style C fill:#e8f5e8
|
||||||
style D fill:#fff3e0
|
style D fill:#fff3e0
|
||||||
style E fill:#fce4ec
|
style E fill:#fce4ec
|
||||||
style L fill:#fff9c4
|
style L fill:#fff9c4
|
||||||
|
style M fill:#f3e5f5
|
||||||
|
style N fill:#f3e5f5
|
||||||
|
style O fill:#fff9c4
|
||||||
```
|
```
|
||||||
|
|
||||||
## Deployment Configurations
|
## Deployment Configurations
|
||||||
|
@@ -14,7 +14,7 @@ Options:
|
|||||||
--help Show this help message
|
--help Show this help message
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
cd crates/cli/package
|
cd integration/cli/package
|
||||||
bun run cli.ts "What is the capital of France?"
|
bun run cli.ts "What is the capital of France?"
|
||||||
bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
bun run cli.ts --model gemma-3-1b-it --prompt "Hello, world!"
|
||||||
bun run cli.ts --prompt "Who was the 16th president of the United States?"
|
bun run cli.ts --prompt "Who was the 16th president of the United States?"
|
@@ -24,8 +24,7 @@ fn run_build() -> io::Result<()> {
|
|||||||
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set by Cargo"));
|
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set by Cargo"));
|
||||||
let output_path = out_dir.join("client-cli");
|
let output_path = out_dir.join("client-cli");
|
||||||
|
|
||||||
let bun_tgt = BunTarget::from_cargo_env()
|
let bun_tgt = BunTarget::from_cargo_env().map_err(|e| io::Error::other(e.to_string()))?;
|
||||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
|
|
||||||
|
|
||||||
// Optional: warn if using a Bun target that’s marked unsupported in your chart
|
// Optional: warn if using a Bun target that’s marked unsupported in your chart
|
||||||
if matches!(bun_tgt, BunTarget::WindowsArm64) {
|
if matches!(bun_tgt, BunTarget::WindowsArm64) {
|
||||||
@@ -54,13 +53,12 @@ fn run_build() -> io::Result<()> {
|
|||||||
|
|
||||||
if !install_status.success() {
|
if !install_status.success() {
|
||||||
let code = install_status.code().unwrap_or(1);
|
let code = install_status.code().unwrap_or(1);
|
||||||
return Err(io::Error::new(
|
return Err(io::Error::other(format!(
|
||||||
io::ErrorKind::Other,
|
"bun install failed with status {code}"
|
||||||
format!("bun install failed with status {code}"),
|
)));
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let target = env::var("TARGET").unwrap();
|
let _target = env::var("TARGET").unwrap();
|
||||||
|
|
||||||
// --- bun build (in ./package), emit to OUT_DIR, keep temps inside OUT_DIR ---
|
// --- bun build (in ./package), emit to OUT_DIR, keep temps inside OUT_DIR ---
|
||||||
let mut build = Command::new("bun")
|
let mut build = Command::new("bun")
|
||||||
@@ -87,7 +85,7 @@ fn run_build() -> io::Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
let code = status.code().unwrap_or(1);
|
let code = status.code().unwrap_or(1);
|
||||||
warn(&format!("bun build failed with status: {code}"));
|
warn(&format!("bun build failed with status: {code}"));
|
||||||
return Err(io::Error::new(io::ErrorKind::Other, "bun build failed"));
|
return Err(io::Error::other("bun build failed"));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the output is executable (after it exists)
|
// Ensure the output is executable (after it exists)
|
17
integration/cli/package/bun.lock
Normal file
17
integration/cli/package/bun.lock
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"lockfileVersion": 1,
|
||||||
|
"workspaces": {
|
||||||
|
"": {
|
||||||
|
"name": "cli",
|
||||||
|
"dependencies": {
|
||||||
|
"install": "^0.13.0",
|
||||||
|
"openai": "^5.16.0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"packages": {
|
||||||
|
"install": ["install@0.13.0", "", {}, "sha512-zDml/jzr2PKU9I8J/xyZBQn8rPCAY//UOYNmR01XwNwyfhEWObo2SWfSl1+0tm1u6PhxLwDnfsT/6jB7OUxqFA=="],
|
||||||
|
|
||||||
|
"openai": ["openai@5.19.1", "", { "peerDependencies": { "ws": "^8.18.0", "zod": "^3.23.8" }, "optionalPeers": ["ws", "zod"], "bin": { "openai": "bin/cli" } }, "sha512-zSqnUF7oR9ksmpusKkpUgkNrj8Sl57U+OyzO8jzc7LUjTMg4DRfR3uCm+EIMA6iw06sRPNp4t7ojp3sCpEUZRQ=="],
|
||||||
|
}
|
||||||
|
}
|
@@ -25,7 +25,7 @@ fn main() -> io::Result<()> {
|
|||||||
// Run it
|
// Run it
|
||||||
let status = Command::new(&tmp).arg("--version").status()?;
|
let status = Command::new(&tmp).arg("--version").status()?;
|
||||||
if !status.success() {
|
if !status.success() {
|
||||||
return Err(io::Error::new(io::ErrorKind::Other, "client-cli failed"));
|
return Err(io::Error::other("client-cli failed"));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
@@ -18,7 +18,7 @@ serde_json = "1.0"
|
|||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-chrome = "0.7"
|
tracing-chrome = "0.7"
|
||||||
tracing-subscriber = "0.3"
|
tracing-subscriber = "0.3"
|
||||||
utils = {path = "../utils"}
|
utils = {path = "../utils" }
|
||||||
|
|
||||||
[target.'cfg(target_os = "macos")'.dependencies]
|
[target.'cfg(target_os = "macos")'.dependencies]
|
||||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
@@ -1,13 +1,7 @@
|
|||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||||
use clap::ValueEnum;
|
|
||||||
|
|
||||||
// Removed gemma_cli import as it's not needed for the API
|
// Removed gemma_cli import as it's not needed for the API
|
||||||
use candle_core::{DType, Device, Tensor};
|
use candle_core::{DType, Device, Tensor};
|
||||||
@@ -16,13 +10,15 @@ use candle_transformers::generation::LogitsProcessor;
|
|||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
|
use std::fmt;
|
||||||
|
use std::str::FromStr;
|
||||||
use std::sync::mpsc::{self, Receiver, Sender};
|
use std::sync::mpsc::{self, Receiver, Sender};
|
||||||
use std::thread;
|
use std::thread;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use utils::hub_load_safetensors;
|
use utils::hub_load_safetensors;
|
||||||
use utils::token_output_stream::TokenOutputStream;
|
use utils::token_output_stream::TokenOutputStream;
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
pub enum WhichModel {
|
pub enum WhichModel {
|
||||||
#[value(name = "gemma-2b")]
|
#[value(name = "gemma-2b")]
|
||||||
Base2B,
|
Base2B,
|
||||||
@@ -58,6 +54,56 @@ pub enum WhichModel {
|
|||||||
InstructV3_1B,
|
InstructV3_1B,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl FromStr for WhichModel {
|
||||||
|
type Err = String;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
|
match s {
|
||||||
|
"gemma-2b" => Ok(Self::Base2B),
|
||||||
|
"gemma-7b" => Ok(Self::Base7B),
|
||||||
|
"gemma-2b-it" => Ok(Self::Instruct2B),
|
||||||
|
"gemma-7b-it" => Ok(Self::Instruct7B),
|
||||||
|
"gemma-1.1-2b-it" => Ok(Self::InstructV1_1_2B),
|
||||||
|
"gemma-1.1-7b-it" => Ok(Self::InstructV1_1_7B),
|
||||||
|
"codegemma-2b" => Ok(Self::CodeBase2B),
|
||||||
|
"codegemma-7b" => Ok(Self::CodeBase7B),
|
||||||
|
"codegemma-2b-it" => Ok(Self::CodeInstruct2B),
|
||||||
|
"codegemma-7b-it" => Ok(Self::CodeInstruct7B),
|
||||||
|
"gemma-2-2b" => Ok(Self::BaseV2_2B),
|
||||||
|
"gemma-2-2b-it" => Ok(Self::InstructV2_2B),
|
||||||
|
"gemma-2-9b" => Ok(Self::BaseV2_9B),
|
||||||
|
"gemma-2-9b-it" => Ok(Self::InstructV2_9B),
|
||||||
|
"gemma-3-1b" => Ok(Self::BaseV3_1B),
|
||||||
|
"gemma-3-1b-it" => Ok(Self::InstructV3_1B),
|
||||||
|
_ => Err(format!("Unknown model: {}", s)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for WhichModel {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
let name = match self {
|
||||||
|
Self::Base2B => "gemma-2b",
|
||||||
|
Self::Base7B => "gemma-7b",
|
||||||
|
Self::Instruct2B => "gemma-2b-it",
|
||||||
|
Self::Instruct7B => "gemma-7b-it",
|
||||||
|
Self::InstructV1_1_2B => "gemma-1.1-2b-it",
|
||||||
|
Self::InstructV1_1_7B => "gemma-1.1-7b-it",
|
||||||
|
Self::CodeBase2B => "codegemma-2b",
|
||||||
|
Self::CodeBase7B => "codegemma-7b",
|
||||||
|
Self::CodeInstruct2B => "codegemma-2b-it",
|
||||||
|
Self::CodeInstruct7B => "codegemma-7b-it",
|
||||||
|
Self::BaseV2_2B => "gemma-2-2b",
|
||||||
|
Self::InstructV2_2B => "gemma-2-2b-it",
|
||||||
|
Self::BaseV2_9B => "gemma-2-9b",
|
||||||
|
Self::InstructV2_9B => "gemma-2-9b-it",
|
||||||
|
Self::BaseV3_1B => "gemma-3-1b",
|
||||||
|
Self::InstructV3_1B => "gemma-3-1b-it",
|
||||||
|
};
|
||||||
|
write!(f, "{}", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
enum Model {
|
enum Model {
|
||||||
V1(Model1),
|
V1(Model1),
|
||||||
V2(Model2),
|
V2(Model2),
|
||||||
@@ -145,8 +191,6 @@ impl TextGeneration {
|
|||||||
// Make sure stdout isn't holding anything (if caller also prints).
|
// Make sure stdout isn't holding anything (if caller also prints).
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
|
|
||||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||||
Some(token) => token,
|
Some(token) => token,
|
||||||
None => anyhow::bail!("cannot find the <eos> token"),
|
None => anyhow::bail!("cannot find the <eos> token"),
|
||||||
@@ -183,7 +227,6 @@ impl TextGeneration {
|
|||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
generated_tokens += 1;
|
|
||||||
|
|
||||||
if next_token == eos_token || next_token == eot_token {
|
if next_token == eos_token || next_token == eot_token {
|
||||||
break;
|
break;
|
||||||
@@ -210,7 +253,7 @@ impl TextGeneration {
|
|||||||
pub struct GemmaInferenceConfig {
|
pub struct GemmaInferenceConfig {
|
||||||
pub tracing: bool,
|
pub tracing: bool,
|
||||||
pub prompt: String,
|
pub prompt: String,
|
||||||
pub model: WhichModel,
|
pub model: Option<WhichModel>,
|
||||||
pub cpu: bool,
|
pub cpu: bool,
|
||||||
pub dtype: Option<String>,
|
pub dtype: Option<String>,
|
||||||
pub model_id: Option<String>,
|
pub model_id: Option<String>,
|
||||||
@@ -229,7 +272,7 @@ impl Default for GemmaInferenceConfig {
|
|||||||
Self {
|
Self {
|
||||||
tracing: false,
|
tracing: false,
|
||||||
prompt: "Hello".to_string(),
|
prompt: "Hello".to_string(),
|
||||||
model: WhichModel::InstructV2_2B,
|
model: Some(WhichModel::InstructV2_2B),
|
||||||
cpu: false,
|
cpu: false,
|
||||||
dtype: None,
|
dtype: None,
|
||||||
model_id: None,
|
model_id: None,
|
||||||
@@ -286,28 +329,30 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
println!("Using dtype: {:?}", dtype);
|
println!("Using dtype: {:?}", dtype);
|
||||||
|
println!("Raw model string: {:?}", cfg.model_id);
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
|
|
||||||
let model_id = cfg.model_id.unwrap_or_else(|| {
|
let model_id = cfg.model_id.unwrap_or_else(|| {
|
||||||
match cfg.model {
|
match cfg.model {
|
||||||
WhichModel::Base2B => "google/gemma-2b",
|
Some(WhichModel::Base2B) => "google/gemma-2b",
|
||||||
WhichModel::Base7B => "google/gemma-7b",
|
Some(WhichModel::Base7B) => "google/gemma-7b",
|
||||||
WhichModel::Instruct2B => "google/gemma-2b-it",
|
Some(WhichModel::Instruct2B) => "google/gemma-2b-it",
|
||||||
WhichModel::Instruct7B => "google/gemma-7b-it",
|
Some(WhichModel::Instruct7B) => "google/gemma-7b-it",
|
||||||
WhichModel::InstructV1_1_2B => "google/gemma-1.1-2b-it",
|
Some(WhichModel::InstructV1_1_2B) => "google/gemma-1.1-2b-it",
|
||||||
WhichModel::InstructV1_1_7B => "google/gemma-1.1-7b-it",
|
Some(WhichModel::InstructV1_1_7B) => "google/gemma-1.1-7b-it",
|
||||||
WhichModel::CodeBase2B => "google/codegemma-2b",
|
Some(WhichModel::CodeBase2B) => "google/codegemma-2b",
|
||||||
WhichModel::CodeBase7B => "google/codegemma-7b",
|
Some(WhichModel::CodeBase7B) => "google/codegemma-7b",
|
||||||
WhichModel::CodeInstruct2B => "google/codegemma-2b-it",
|
Some(WhichModel::CodeInstruct2B) => "google/codegemma-2b-it",
|
||||||
WhichModel::CodeInstruct7B => "google/codegemma-7b-it",
|
Some(WhichModel::CodeInstruct7B) => "google/codegemma-7b-it",
|
||||||
WhichModel::BaseV2_2B => "google/gemma-2-2b",
|
Some(WhichModel::BaseV2_2B) => "google/gemma-2-2b",
|
||||||
WhichModel::InstructV2_2B => "google/gemma-2-2b-it",
|
Some(WhichModel::InstructV2_2B) => "google/gemma-2-2b-it",
|
||||||
WhichModel::BaseV2_9B => "google/gemma-2-9b",
|
Some(WhichModel::BaseV2_9B) => "google/gemma-2-9b",
|
||||||
WhichModel::InstructV2_9B => "google/gemma-2-9b-it",
|
Some(WhichModel::InstructV2_9B) => "google/gemma-2-9b-it",
|
||||||
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
|
Some(WhichModel::BaseV3_1B) => "google/gemma-3-1b-pt",
|
||||||
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
|
Some(WhichModel::InstructV3_1B) => "google/gemma-3-1b-it",
|
||||||
|
None => "google/gemma-2-2b-it", // default fallback
|
||||||
}
|
}
|
||||||
.to_string()
|
.to_string()
|
||||||
});
|
});
|
||||||
@@ -318,7 +363,9 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
|||||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||||
let config_filename = repo.get("config.json")?;
|
let config_filename = repo.get("config.json")?;
|
||||||
let filenames = match cfg.model {
|
let filenames = match cfg.model {
|
||||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
|
||||||
|
vec![repo.get("model.safetensors")?]
|
||||||
|
}
|
||||||
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
_ => hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
};
|
};
|
||||||
println!("Retrieved files in {:?}", start.elapsed());
|
println!("Retrieved files in {:?}", start.elapsed());
|
||||||
@@ -329,29 +376,31 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
|||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
|
||||||
let model: Model = match cfg.model {
|
let model: Model = match cfg.model {
|
||||||
WhichModel::Base2B
|
Some(WhichModel::Base2B)
|
||||||
| WhichModel::Base7B
|
| Some(WhichModel::Base7B)
|
||||||
| WhichModel::Instruct2B
|
| Some(WhichModel::Instruct2B)
|
||||||
| WhichModel::Instruct7B
|
| Some(WhichModel::Instruct7B)
|
||||||
| WhichModel::InstructV1_1_2B
|
| Some(WhichModel::InstructV1_1_2B)
|
||||||
| WhichModel::InstructV1_1_7B
|
| Some(WhichModel::InstructV1_1_7B)
|
||||||
| WhichModel::CodeBase2B
|
| Some(WhichModel::CodeBase2B)
|
||||||
| WhichModel::CodeBase7B
|
| Some(WhichModel::CodeBase7B)
|
||||||
| WhichModel::CodeInstruct2B
|
| Some(WhichModel::CodeInstruct2B)
|
||||||
| WhichModel::CodeInstruct7B => {
|
| Some(WhichModel::CodeInstruct7B) => {
|
||||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||||
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
|
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
|
||||||
Model::V1(model)
|
Model::V1(model)
|
||||||
}
|
}
|
||||||
WhichModel::BaseV2_2B
|
Some(WhichModel::BaseV2_2B)
|
||||||
| WhichModel::InstructV2_2B
|
| Some(WhichModel::InstructV2_2B)
|
||||||
| WhichModel::BaseV2_9B
|
| Some(WhichModel::BaseV2_9B)
|
||||||
| WhichModel::InstructV2_9B => {
|
| Some(WhichModel::InstructV2_9B)
|
||||||
|
| None => {
|
||||||
|
// default to V2 model
|
||||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||||
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
||||||
Model::V2(model)
|
Model::V2(model)
|
||||||
}
|
}
|
||||||
WhichModel::BaseV3_1B | WhichModel::InstructV3_1B => {
|
Some(WhichModel::BaseV3_1B) | Some(WhichModel::InstructV3_1B) => {
|
||||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||||
let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
|
let model = Model3::new(cfg.use_flash_attn, &config, vb)?;
|
||||||
Model::V3(model)
|
Model::V3(model)
|
||||||
@@ -371,7 +420,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
|||||||
);
|
);
|
||||||
|
|
||||||
let prompt = match cfg.model {
|
let prompt = match cfg.model {
|
||||||
WhichModel::InstructV3_1B => {
|
Some(WhichModel::InstructV3_1B) => {
|
||||||
format!(
|
format!(
|
||||||
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
|
"<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n",
|
||||||
cfg.prompt
|
cfg.prompt
|
@@ -67,7 +67,7 @@ pub fn run_cli() -> anyhow::Result<()> {
|
|||||||
let cfg = GemmaInferenceConfig {
|
let cfg = GemmaInferenceConfig {
|
||||||
tracing: args.tracing,
|
tracing: args.tracing,
|
||||||
prompt: args.prompt,
|
prompt: args.prompt,
|
||||||
model: args.model,
|
model: Some(args.model),
|
||||||
cpu: args.cpu,
|
cpu: args.cpu,
|
||||||
dtype: args.dtype,
|
dtype: args.dtype,
|
||||||
model_id: args.model_id,
|
model_id: args.model_id,
|
@@ -6,10 +6,8 @@ mod gemma_api;
|
|||||||
mod gemma_cli;
|
mod gemma_cli;
|
||||||
|
|
||||||
use anyhow::Error;
|
use anyhow::Error;
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
|
|
||||||
use crate::gemma_cli::run_cli;
|
use crate::gemma_cli::run_cli;
|
||||||
use std::io::Write;
|
|
||||||
|
|
||||||
/// just a placeholder, not used for anything
|
/// just a placeholder, not used for anything
|
||||||
fn main() -> std::result::Result<(), Error> {
|
fn main() -> std::result::Result<(), Error> {
|
@@ -64,14 +64,9 @@ version = "0.1.0"
|
|||||||
|
|
||||||
# Required: Kubernetes metadata
|
# Required: Kubernetes metadata
|
||||||
[package.metadata.kube]
|
[package.metadata.kube]
|
||||||
image = "ghcr.io/myorg/my-service:latest"
|
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||||
replicas = 1
|
replicas = 1
|
||||||
port = 8080
|
port = 8080
|
||||||
|
|
||||||
# Optional: Docker Compose metadata (currently not used but parsed)
|
|
||||||
[package.metadata.compose]
|
|
||||||
image = "ghcr.io/myorg/my-service:latest"
|
|
||||||
port = 8080
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Required Fields
|
### Required Fields
|
@@ -1,9 +1,8 @@
|
|||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use clap::{Arg, Command};
|
use clap::{Arg, Command};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::Deserialize;
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::Path;
|
||||||
use walkdir::WalkDir;
|
use walkdir::WalkDir;
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@@ -20,7 +19,6 @@ struct Package {
|
|||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct Metadata {
|
struct Metadata {
|
||||||
kube: Option<KubeMetadata>,
|
kube: Option<KubeMetadata>,
|
||||||
compose: Option<ComposeMetadata>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@@ -30,12 +28,6 @@ struct KubeMetadata {
|
|||||||
port: u16,
|
port: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct ComposeMetadata {
|
|
||||||
image: Option<String>,
|
|
||||||
port: Option<u16>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct ServiceInfo {
|
struct ServiceInfo {
|
||||||
name: String,
|
name: String,
|
||||||
@@ -105,7 +97,9 @@ fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|e| e.ok())
|
.filter_map(|e| e.ok())
|
||||||
{
|
{
|
||||||
if entry.file_name() == "Cargo.toml" && entry.path() != workspace_root.join("Cargo.toml") {
|
if entry.file_name() == "Cargo.toml"
|
||||||
|
&& entry.path() != workspace_root.join("../../../Cargo.toml")
|
||||||
|
{
|
||||||
if let Ok(service_info) = parse_cargo_toml(entry.path()) {
|
if let Ok(service_info) = parse_cargo_toml(entry.path()) {
|
||||||
services.push(service_info);
|
services.push(service_info);
|
||||||
}
|
}
|
||||||
@@ -375,7 +369,7 @@ spec:
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn generate_ingress_template(templates_dir: &Path, services: &[ServiceInfo]) -> Result<()> {
|
fn generate_ingress_template(templates_dir: &Path, _services: &[ServiceInfo]) -> Result<()> {
|
||||||
let ingress_template = r#"{{- if .Values.ingress.enabled -}}
|
let ingress_template = r#"{{- if .Values.ingress.enabled -}}
|
||||||
apiVersion: networking.k8s.io/v1
|
apiVersion: networking.k8s.io/v1
|
||||||
kind: Ingress
|
kind: Ingress
|
@@ -1,6 +1,5 @@
|
|||||||
pub mod llama_api;
|
pub mod llama_api;
|
||||||
|
|
||||||
use clap::ValueEnum;
|
|
||||||
pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
||||||
|
|
||||||
// Re-export constants and types that might be needed
|
// Re-export constants and types that might be needed
|
@@ -57,6 +57,27 @@ pub struct LlamaInferenceConfig {
|
|||||||
pub repeat_last_n: usize,
|
pub repeat_last_n: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl LlamaInferenceConfig {
|
||||||
|
pub fn new(model: WhichModel) -> Self {
|
||||||
|
Self {
|
||||||
|
prompt: String::new(),
|
||||||
|
model,
|
||||||
|
cpu: false,
|
||||||
|
temperature: 1.0,
|
||||||
|
top_p: None,
|
||||||
|
top_k: None,
|
||||||
|
seed: 42,
|
||||||
|
max_tokens: 512,
|
||||||
|
no_kv_cache: false,
|
||||||
|
dtype: None,
|
||||||
|
model_id: None,
|
||||||
|
revision: None,
|
||||||
|
use_flash_attn: true,
|
||||||
|
repeat_penalty: 1.1,
|
||||||
|
repeat_last_n: 64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
impl Default for LlamaInferenceConfig {
|
impl Default for LlamaInferenceConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -81,7 +102,7 @@ impl Default for LlamaInferenceConfig {
|
|||||||
max_tokens: 512,
|
max_tokens: 512,
|
||||||
|
|
||||||
// Performance flags
|
// Performance flags
|
||||||
no_kv_cache: false, // keep cache ON for speed
|
no_kv_cache: false, // keep cache ON for speed
|
||||||
use_flash_attn: false, // great speed boost if supported
|
use_flash_attn: false, // great speed boost if supported
|
||||||
|
|
||||||
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
|
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
|
@@ -6,9 +6,6 @@ mod llama_api;
|
|||||||
mod llama_cli;
|
mod llama_cli;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
|
|
||||||
use std::io::Write;
|
|
||||||
|
|
||||||
use crate::llama_cli::run_cli;
|
use crate::llama_cli::run_cli;
|
||||||
|
|
@@ -1,17 +1,14 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "utils"
|
name = "utils"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
path = "src/lib.rs"
|
path = "src/lib.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = {version = "0.3.2", optional = true }
|
accelerate-src = {version = "0.3.2", optional = true }
|
||||||
candle-nn = {version = "0.9.1" }
|
|
||||||
candle-transformers = {version = "0.9.1" }
|
|
||||||
|
|
||||||
candle-flash-attn = {version = "0.9.1", optional = true }
|
candle-flash-attn = {version = "0.9.1", optional = true }
|
||||||
candle-onnx = {version = "0.9.1", optional = true }
|
candle-onnx = {version = "0.9.1", optional = true }
|
||||||
candle-core="0.9.1"
|
|
||||||
csv = "1.3.0"
|
csv = "1.3.0"
|
||||||
anyhow = "1.0.99"
|
anyhow = "1.0.99"
|
||||||
cudarc = {version = "0.17.3", optional = true }
|
cudarc = {version = "0.17.3", optional = true }
|
||||||
@@ -86,3 +83,14 @@ mimi = ["cpal", "symphonia", "rubato"]
|
|||||||
snac = ["cpal", "symphonia", "rubato"]
|
snac = ["cpal", "symphonia", "rubato"]
|
||||||
depth_anything_v2 = ["palette", "enterpolation"]
|
depth_anything_v2 = ["palette", "enterpolation"]
|
||||||
tekken = ["tekken-rs"]
|
tekken = ["tekken-rs"]
|
||||||
|
|
||||||
|
# Platform-specific candle dependencies
|
||||||
|
[target.'cfg(target_os = "linux")'.dependencies]
|
||||||
|
candle-nn = {version = "0.9.1", default-features = false }
|
||||||
|
candle-transformers = {version = "0.9.1", default-features = false }
|
||||||
|
candle-core = {version = "0.9.1", default-features = false }
|
||||||
|
|
||||||
|
[target.'cfg(not(target_os = "linux"))'.dependencies]
|
||||||
|
candle-nn = {version = "0.9.1" }
|
||||||
|
candle-transformers = {version = "0.9.1" }
|
||||||
|
candle-core = {version = "0.9.1" }
|
@@ -1,5 +1,5 @@
|
|||||||
use candle_transformers::models::mimi::candle;
|
|
||||||
use candle_core::{Device, Result, Tensor};
|
use candle_core::{Device, Result, Tensor};
|
||||||
|
use candle_transformers::models::mimi::candle;
|
||||||
|
|
||||||
pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
|
pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
|
||||||
pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
|
pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
|
@@ -8,8 +8,10 @@ pub mod coco_classes;
|
|||||||
pub mod imagenet;
|
pub mod imagenet;
|
||||||
pub mod token_output_stream;
|
pub mod token_output_stream;
|
||||||
pub mod wav;
|
pub mod wav;
|
||||||
use candle_core::{Device, Tensor, utils::{cuda_is_available, metal_is_available}};
|
use candle_core::{
|
||||||
|
utils::{cuda_is_available, metal_is_available},
|
||||||
|
Device, Tensor,
|
||||||
|
};
|
||||||
|
|
||||||
pub fn device(cpu: bool) -> Result<Device, anyhow::Error> {
|
pub fn device(cpu: bool) -> Result<Device, anyhow::Error> {
|
||||||
if cpu {
|
if cpu {
|
||||||
@@ -122,11 +124,8 @@ pub fn hub_load_safetensors(
|
|||||||
}
|
}
|
||||||
let safetensors_files = safetensors_files
|
let safetensors_files = safetensors_files
|
||||||
.iter()
|
.iter()
|
||||||
.map(|v| {
|
.map(|v| repo.get(v).map_err(std::io::Error::other))
|
||||||
repo.get(v)
|
.collect::<Result<Vec<_>, std::io::Error>>()?;
|
||||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>, std::io::Error, >>()?;
|
|
||||||
Ok(safetensors_files)
|
Ok(safetensors_files)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,7 +135,8 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
|||||||
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
|
) -> Result<Vec<std::path::PathBuf>, anyhow::Error> {
|
||||||
let path = path.as_ref();
|
let path = path.as_ref();
|
||||||
let jsfile = std::fs::File::open(path.join(json_file))?;
|
let jsfile = std::fs::File::open(path.join(json_file))?;
|
||||||
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
let json: serde_json::Value =
|
||||||
|
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||||
let weight_map = match json.get("weight_map") {
|
let weight_map = match json.get("weight_map") {
|
||||||
None => anyhow::bail!("no weight map in {json_file:?}"),
|
None => anyhow::bail!("no weight map in {json_file:?}"),
|
||||||
Some(serde_json::Value::Object(map)) => map,
|
Some(serde_json::Value::Object(map)) => map,
|
@@ -1,8 +1,8 @@
|
|||||||
{
|
{
|
||||||
"name": "predict-otron-9000",
|
"name": "predict-otron-9000",
|
||||||
"workspaces": ["crates/cli/package"],
|
"workspaces": ["integration/cli/package"],
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"# WORKSPACE ALIASES": "#",
|
"# WORKSPACE ALIASES": "#",
|
||||||
"cli": "bun --filter crates/cli/package"
|
"cli": "bun --filter integration/cli/package"
|
||||||
}
|
}
|
||||||
}
|
}
|
Reference in New Issue
Block a user