cleanup, add ci

This commit is contained in:
geoffsee
2025-08-31 10:31:07 -04:00
parent 419e1c2ea7
commit f5d2a85f2e
42 changed files with 1740 additions and 705 deletions

49
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,49 @@
version: 2
updates:
# Monitor Rust dependencies in the main crate
- package-ecosystem: "cargo"
directory: "/crates/predict-otron-9000"
schedule:
interval: "weekly"
day: "monday"
time: "09:00"
timezone: "UTC"
# Focus on security updates with higher priority
open-pull-requests-limit: 10
reviewers:
- "security-team"
assignees:
- "maintainer"
labels:
- "dependencies"
- "security"
# Security updates get higher priority
allow:
- dependency-type: "all"
# Group minor and patch updates to reduce noise
# Separate major updates for careful review
ignore:
- dependency-name: "*"
update-types: ["version-update:semver-major"]
commit-message:
prefix: "deps"
include: "scope"
# Monitor security updates more frequently
- package-ecosystem: "cargo"
directory: "/crates/predict-otron-9000"
schedule:
interval: "daily"
# Only security updates in daily checks
allow:
- dependency-type: "direct"
update-types: ["security"]
- dependency-type: "indirect"
update-types: ["security"]
open-pull-requests-limit: 5
labels:
- "security-update"
- "high-priority"
commit-message:
prefix: "security"
include: "scope"

47
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,47 @@
name: CI
on:
push:
pull_request:
jobs:
build:
name: build-and-test
runs-on: ubuntu-latest
strategy:
fail-fast: false
steps:
- name: Checkout
uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Setup Rust
run: rustup update stable && rustup default stable
- name: Install clippy and rustfmt
run: rustup component add clippy rustfmt
- name: Cargo fmt (check)
run: cargo fmt --all -- --check
- name: Clippy
shell: bash
run: cargo clippy --all-targets
- name: Tests
shell: bash
run: cargo test --all
- name: Build Docs
shell: bash
run: |
cargo doc -p predict-otron-9000 --no-deps

232
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,232 @@
name: Release
on:
push:
tags:
- 'v*'
env:
CARGO_TERM_COLOR: always
jobs:
test:
name: Test before release
runs-on: ubuntu-latest
defaults:
run:
working-directory: crates/predict-otron-9000
strategy:
fail-fast: false
steps:
- name: Checkout
uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Setup Rust
run: rustup update stable && rustup default stable
- name: Install clippy and rustfmt
run: rustup component add clippy rustfmt
- name: Cargo fmt (check)
run: cargo fmt --all -- --check
- name: Clippy
shell: bash
run: cargo clippy --all-targets
- name: Tests
shell: bash
run: cargo test --all
# publish:
# name: Publish to crates.io
# runs-on: ubuntu-latest
# permissions:
# id-token: write # Required for OIDC token exchange https://crates.io/docs/trusted-publishing
# needs: test
# defaults:
# run:
# working-directory: crates/predict-otron-9000
# steps:
# - name: Checkout
# uses: actions/checkout@v4
#
# - uses: actions/cache@v4
# with:
# path: |
# ~/.cargo/bin/
# ~/.cargo/registry/index/
# ~/.cargo/registry/cache/
# ~/.cargo/git/db/
# target/
# key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
#
# - name: Setup Rust
# run: rustup update stable && rustup default stable
#
# - name: Verify tag matches version
# run: |
# TAG_VERSION=${GITHUB_REF#refs/tags/v}
# CARGO_VERSION=$(cargo metadata --no-deps --format-version 1 | jq -r '.packages[0].version')
# if [ "$TAG_VERSION" != "$CARGO_VERSION" ]; then
# echo "Tag version ($TAG_VERSION) does not match Cargo.toml version ($CARGO_VERSION)"
# exit 1
# fi
#
# # See Trusted publishing: https://crates.io/docs/trusted-publishing
# - uses: rust-lang/crates-io-auth-action@v1
# id: auth
#
# - run: cargo publish
# env:
# CARGO_REGISTRY_TOKEN: ${{ steps.auth.outputs.token }}
build-binaries:
name: Build binaries
runs-on: ${{ matrix.os }}
needs: test
strategy:
fail-fast: false
matrix:
include:
- target: x86_64-unknown-linux-gnu
os: ubuntu-latest
name: predict-otron-9000-x86_64-unknown-linux-gnu
- target: x86_64-apple-darwin
os: macos-latest
name: predict-otron-9000-x86_64-apple-darwin
- target: aarch64-apple-darwin
os: macos-latest
name: predict-otron-9000-aarch64-apple-darwin
- target: x86_64-pc-windows-msvc
os: windows-latest
name: predict-otron-9000-x86_64-pc-windows-msvc.exe
steps:
- name: Checkout
uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
key: ${{ runner.os }}-${{ matrix.target }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Setup Rust
run: rustup update stable && rustup default stable
- name: Add target
run: rustup target add ${{ matrix.target }}
- name: Build binary
run: cargo build --release --target ${{ matrix.target }} -p predict-otron-9000
env:
CARGO_TERM_COLOR: always
- name: Package binary (Unix)
if: matrix.os != 'windows-latest'
run: |
cd target/${{ matrix.target }}/release
tar czf ../../../${{ matrix.name }}.tar.gz predict-otron-9000
cd ../../../
- name: Package binary (Windows)
if: matrix.os == 'windows-latest'
run: |
cd target/${{ matrix.target }}/release
7z a ../../../${{ matrix.name }}.zip predict-otron-9000.exe
cd ../../../
- name: Upload binary artifacts (Unix)
if: matrix.os != 'windows-latest'
uses: actions/upload-artifact@v4
with:
name: ${{ matrix.name }}
path: ${{ matrix.name }}.tar.gz
- name: Upload binary artifacts (Windows)
if: matrix.os == 'windows-latest'
uses: actions/upload-artifact@v4
with:
name: ${{ matrix.name }}
path: ${{ matrix.name }}.zip
release:
name: Create GitHub Release
runs-on: ubuntu-latest
needs: [test, build-binaries]
permissions:
contents: write
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Extract tag name
id: tag
run: echo "tag=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
- name: Generate changelog
id: changelog
run: |
# Get the previous tag
PREV_TAG=$(git describe --tags --abbrev=0 HEAD^ 2>/dev/null || echo "")
# Generate changelog
if [ -n "$PREV_TAG" ]; then
echo "## What's Changed" > changelog.md
echo "" >> changelog.md
git log --pretty=format:"* %s (%h)" ${PREV_TAG}..HEAD >> changelog.md
echo "" >> changelog.md
echo "" >> changelog.md
echo "**Full Changelog**: https://github.com/${{ github.repository }}/compare/${PREV_TAG}...${{ steps.tag.outputs.tag }}" >> changelog.md
else
echo "## What's Changed" > changelog.md
echo "" >> changelog.md
echo "Initial release of predict-otron-9000" >> changelog.md
echo "" >> changelog.md
echo "OpenAI Compatible Inference Server" >> changelog.md
fi
# Set the changelog as output (handle multiline)
echo "changelog<<EOF" >> $GITHUB_OUTPUT
cat changelog.md >> $GITHUB_OUTPUT
echo "EOF" >> $GITHUB_OUTPUT
- name: Download all artifacts
uses: actions/download-artifact@v4
with:
path: artifacts
- name: Create Release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
if [[ "${{ steps.tag.outputs.tag }}" == *"-"* ]]; then
PRERELEASE_FLAG="--prerelease"
else
PRERELEASE_FLAG=""
fi
gh release create "${{ steps.tag.outputs.tag }}" \
--title "Release ${{ steps.tag.outputs.tag }}" \
--notes-file changelog.md \
$PRERELEASE_FLAG \
artifacts/predict-otron-9000-x86_64-unknown-linux-gnu/predict-otron-9000-x86_64-unknown-linux-gnu.tar.gz \
artifacts/predict-otron-9000-x86_64-apple-darwin/predict-otron-9000-x86_64-apple-darwin.tar.gz \
artifacts/predict-otron-9000-aarch64-apple-darwin/predict-otron-9000-aarch64-apple-darwin.tar.gz \
artifacts/predict-otron-9000-x86_64-pc-windows-msvc.exe/predict-otron-9000-x86_64-pc-windows-msvc.exe.zip

2
.gitignore vendored
View File

@@ -76,3 +76,5 @@ venv/
*.bak *.bak
*.backup *.backup
*~ *~
/scripts/cli
!/scripts/cli.ts

View File

@@ -287,7 +287,7 @@ cargo test --workspace
**End-to-end test script:** **End-to-end test script:**
```bash ```bash
./test.sh ./smoke_test.sh
``` ```
This script: This script:
@@ -478,7 +478,7 @@ cd crates/leptos-app && ./run.sh &
**Integration test:** **Integration test:**
```bash ```bash
./test.sh ./smoke_test.sh
``` ```
**Cleanup:** **Cleanup:**

View File

@@ -1,9 +1,5 @@
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput}; use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
use axum::{ use axum::{Json, Router, response::Json as ResponseJson, routing::post};
response::Json as ResponseJson, routing::{post},
Json,
Router,
};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
@@ -13,15 +9,18 @@ use tracing;
static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| { static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
tracing::info!("Initializing persistent embedding model (singleton)"); tracing::info!("Initializing persistent embedding model (singleton)");
let model_start_time = std::time::Instant::now(); let model_start_time = std::time::Instant::now();
let model = TextEmbedding::try_new( let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true) InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
) )
.expect("Failed to initialize persistent embedding model"); .expect("Failed to initialize persistent embedding model");
let model_init_time = model_start_time.elapsed(); let model_init_time = model_start_time.elapsed();
tracing::info!("Persistent embedding model initialized in {:.2?}", model_init_time); tracing::info!(
"Persistent embedding model initialized in {:.2?}",
model_init_time
);
model model
}); });
@@ -30,18 +29,21 @@ pub async fn embeddings_create(
) -> ResponseJson<serde_json::Value> { ) -> ResponseJson<serde_json::Value> {
// Start timing the entire process // Start timing the entire process
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
// Phase 1: Access persistent model instance // Phase 1: Access persistent model instance
let model_start_time = std::time::Instant::now(); let model_start_time = std::time::Instant::now();
// Access the lazy-initialized persistent model instance // Access the lazy-initialized persistent model instance
// This will only initialize the model on the first request // This will only initialize the model on the first request
let model_access_time = model_start_time.elapsed(); let model_access_time = model_start_time.elapsed();
tracing::debug!("Persistent model access completed in {:.2?}", model_access_time); tracing::debug!(
"Persistent model access completed in {:.2?}",
model_access_time
);
// Phase 2: Process input // Phase 2: Process input
let input_start_time = std::time::Instant::now(); let input_start_time = std::time::Instant::now();
let embedding_input = payload.input; let embedding_input = payload.input;
let texts_from_embedding_input = match embedding_input { let texts_from_embedding_input = match embedding_input {
EmbeddingInput::String(text) => vec![text], EmbeddingInput::String(text) => vec![text],
@@ -53,41 +55,58 @@ pub async fn embeddings_create(
panic!("Array of integer arrays not supported for text embeddings"); panic!("Array of integer arrays not supported for text embeddings");
} }
}; };
let input_processing_time = input_start_time.elapsed(); let input_processing_time = input_start_time.elapsed();
tracing::debug!("Input processing completed in {:.2?}", input_processing_time); tracing::debug!(
"Input processing completed in {:.2?}",
input_processing_time
);
// Phase 3: Generate embeddings // Phase 3: Generate embeddings
let embedding_start_time = std::time::Instant::now(); let embedding_start_time = std::time::Instant::now();
let embeddings = EMBEDDING_MODEL let embeddings = EMBEDDING_MODEL
.embed(texts_from_embedding_input, None) .embed(texts_from_embedding_input, None)
.expect("failed to embed document"); .expect("failed to embed document");
let embedding_generation_time = embedding_start_time.elapsed(); let embedding_generation_time = embedding_start_time.elapsed();
tracing::info!("Embedding generation completed in {:.2?}", embedding_generation_time); tracing::info!(
"Embedding generation completed in {:.2?}",
embedding_generation_time
);
// Memory usage estimation (approximate) // Memory usage estimation (approximate)
let embedding_size_bytes = embeddings.iter() let embedding_size_bytes = embeddings
.iter()
.map(|e| e.len() * std::mem::size_of::<f32>()) .map(|e| e.len() * std::mem::size_of::<f32>())
.sum::<usize>(); .sum::<usize>();
tracing::debug!("Embedding size: {:.2} MB", embedding_size_bytes as f64 / 1024.0 / 1024.0); tracing::debug!(
"Embedding size: {:.2} MB",
embedding_size_bytes as f64 / 1024.0 / 1024.0
);
// Only log detailed embedding information at trace level to reduce log volume // Only log detailed embedding information at trace level to reduce log volume
tracing::trace!("Embeddings length: {}", embeddings.len()); tracing::trace!("Embeddings length: {}", embeddings.len());
tracing::info!("Embedding dimension: {}", embeddings[0].len()); tracing::info!("Embedding dimension: {}", embeddings[0].len());
// Log the first 10 values of the original embedding at trace level // Log the first 10 values of the original embedding at trace level
tracing::trace!("Original embedding preview: {:?}", &embeddings[0][..10.min(embeddings[0].len())]); tracing::trace!(
"Original embedding preview: {:?}",
&embeddings[0][..10.min(embeddings[0].len())]
);
// Check if there are any NaN or zero values in the original embedding // Check if there are any NaN or zero values in the original embedding
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count(); let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count(); let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count); tracing::trace!(
"Original embedding stats: NaN count={}, zero count={}",
nan_count,
zero_count
);
// Phase 4: Post-process embeddings // Phase 4: Post-process embeddings
let postprocessing_start_time = std::time::Instant::now(); let postprocessing_start_time = std::time::Instant::now();
// Create the final embedding // Create the final embedding
let final_embedding = { let final_embedding = {
// Check if the embedding is all zeros // Check if the embedding is all zeros
@@ -110,6 +129,8 @@ pub async fn embeddings_create(
// Normalize the random embedding // Normalize the random embedding
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt(); let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
#[allow(clippy::needless_range_loop)]
for i in 0..random_embedding.len() { for i in 0..random_embedding.len() {
random_embedding[i] /= norm; random_embedding[i] /= norm;
} }
@@ -123,25 +144,35 @@ pub async fn embeddings_create(
let target_dimension = 768; let target_dimension = 768;
if padded_embedding.len() < target_dimension { if padded_embedding.len() < target_dimension {
let padding_needed = target_dimension - padded_embedding.len(); let padding_needed = target_dimension - padded_embedding.len();
tracing::trace!("Padding embedding with {} zeros to reach {} dimensions", padding_needed, target_dimension); tracing::trace!(
"Padding embedding with {} zeros to reach {} dimensions",
padding_needed,
target_dimension
);
padded_embedding.extend(vec![0.0; padding_needed]); padded_embedding.extend(vec![0.0; padding_needed]);
} }
padded_embedding padded_embedding
} }
}; };
let postprocessing_time = postprocessing_start_time.elapsed(); let postprocessing_time = postprocessing_start_time.elapsed();
tracing::debug!("Embedding post-processing completed in {:.2?}", postprocessing_time); tracing::debug!(
"Embedding post-processing completed in {:.2?}",
postprocessing_time
);
tracing::trace!("Final embedding dimension: {}", final_embedding.len()); tracing::trace!("Final embedding dimension: {}", final_embedding.len());
// Log the first 10 values of the final embedding at trace level // Log the first 10 values of the final embedding at trace level
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]); tracing::trace!(
"Final embedding preview: {:?}",
&final_embedding[..10.min(final_embedding.len())]
);
// Phase 5: Prepare response // Phase 5: Prepare response
let response_start_time = std::time::Instant::now(); let response_start_time = std::time::Instant::now();
// Return a response that matches the OpenAI API format // Return a response that matches the OpenAI API format
let response = serde_json::json!({ let response = serde_json::json!({
"object": "list", "object": "list",
@@ -158,10 +189,10 @@ pub async fn embeddings_create(
"total_tokens": 0 "total_tokens": 0
} }
}); });
let response_time = response_start_time.elapsed(); let response_time = response_start_time.elapsed();
tracing::debug!("Response preparation completed in {:.2?}", response_time); tracing::debug!("Response preparation completed in {:.2?}", response_time);
// Log total time and breakdown // Log total time and breakdown
let total_time = start_time.elapsed(); let total_time = start_time.elapsed();
tracing::info!( tracing::info!(
@@ -171,7 +202,7 @@ pub async fn embeddings_create(
embedding_generation_time, embedding_generation_time,
postprocessing_time postprocessing_time
); );
ResponseJson(response) ResponseJson(response)
} }
@@ -179,4 +210,4 @@ pub fn create_embeddings_router() -> Router {
Router::new() Router::new()
.route("/v1/embeddings", post(embeddings_create)) .route("/v1/embeddings", post(embeddings_create))
.layer(TraceLayer::new_for_http()) .layer(TraceLayer::new_for_http())
} }

View File

@@ -1,8 +1,8 @@
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput}; use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
use axum::{ use axum::{
response::Json as ResponseJson, routing::{get, post}, Json, Router,
Json, response::Json as ResponseJson,
Router, routing::{get, post},
}; };
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -13,19 +13,17 @@ use tracing;
const DEFAULT_SERVER_HOST: &str = "127.0.0.1"; const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
const DEFAULT_SERVER_PORT: &str = "8080"; const DEFAULT_SERVER_PORT: &str = "8080";
async fn embeddings_create( async fn embeddings_create(
Json(payload): Json<CreateEmbeddingRequest>, Json(payload): Json<CreateEmbeddingRequest>,
) -> ResponseJson<serde_json::Value> { ) -> ResponseJson<serde_json::Value> {
let model = TextEmbedding::try_new( let model = TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true) InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
) )
.expect("Failed to initialize model"); .expect("Failed to initialize model");
let embedding_input = payload.input;
let embedding_input = payload.input; let texts_from_embedding_input = match embedding_input {
let texts_from_embedding_input = match embedding_input {
EmbeddingInput::String(text) => vec![text], EmbeddingInput::String(text) => vec![text],
EmbeddingInput::StringArray(texts) => texts, EmbeddingInput::StringArray(texts) => texts,
EmbeddingInput::IntegerArray(_) => { EmbeddingInput::IntegerArray(_) => {
@@ -45,12 +43,19 @@ async fn embeddings_create(
tracing::info!("Embedding dimension: {}", embeddings[0].len()); tracing::info!("Embedding dimension: {}", embeddings[0].len());
// Log the first 10 values of the original embedding at trace level // Log the first 10 values of the original embedding at trace level
tracing::trace!("Original embedding preview: {:?}", &embeddings[0][..10.min(embeddings[0].len())]); tracing::trace!(
"Original embedding preview: {:?}",
&embeddings[0][..10.min(embeddings[0].len())]
);
// Check if there are any NaN or zero values in the original embedding // Check if there are any NaN or zero values in the original embedding
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count(); let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count(); let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count); tracing::trace!(
"Original embedding stats: NaN count={}, zero count={}",
nan_count,
zero_count
);
// Create the final embedding // Create the final embedding
let final_embedding = { let final_embedding = {
@@ -87,7 +92,11 @@ async fn embeddings_create(
let target_dimension = 768; let target_dimension = 768;
if padded_embedding.len() < target_dimension { if padded_embedding.len() < target_dimension {
let padding_needed = target_dimension - padded_embedding.len(); let padding_needed = target_dimension - padded_embedding.len();
tracing::trace!("Padding embedding with {} zeros to reach {} dimensions", padding_needed, target_dimension); tracing::trace!(
"Padding embedding with {} zeros to reach {} dimensions",
padding_needed,
target_dimension
);
padded_embedding.extend(vec![0.0; padding_needed]); padded_embedding.extend(vec![0.0; padding_needed]);
} }
@@ -98,7 +107,10 @@ async fn embeddings_create(
tracing::trace!("Final embedding dimension: {}", final_embedding.len()); tracing::trace!("Final embedding dimension: {}", final_embedding.len());
// Log the first 10 values of the final embedding at trace level // Log the first 10 values of the final embedding at trace level
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]); tracing::trace!(
"Final embedding preview: {:?}",
&final_embedding[..10.min(final_embedding.len())]
);
// Return a response that matches the OpenAI API format // Return a response that matches the OpenAI API format
let response = serde_json::json!({ let response = serde_json::json!({
@@ -120,7 +132,7 @@ async fn embeddings_create(
} }
fn create_app() -> Router { fn create_app() -> Router {
Router::new() Router::new()
.route("/v1/embeddings", post(embeddings_create)) .route("/v1/embeddings", post(embeddings_create))
.layer(TraceLayer::new_for_http()) .layer(TraceLayer::new_for_http())
} }
@@ -143,21 +155,21 @@ async fn main() {
.init(); .init();
let app = create_app(); let app = create_app();
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string()); let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string());
let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string()); let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string());
let server_address = format!("{}:{}", server_host, server_port); let server_address = format!("{}:{}", server_host, server_port);
let listener = tokio::net::TcpListener::bind(server_address).await.unwrap(); let listener = tokio::net::TcpListener::bind(server_address).await.unwrap();
tracing::info!("Listening on {}", listener.local_addr().unwrap()); tracing::info!("Listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap(); axum::serve(listener, app).await.unwrap();
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use axum::body::to_bytes; use axum::body::Body;
use axum::body::Body; use axum::body::to_bytes;
use axum::http::StatusCode; use axum::http::StatusCode;
use tower::ServiceExt; use tower::ServiceExt;
#[tokio::test] #[tokio::test]
async fn test_embeddings_create() { async fn test_embeddings_create() {
@@ -168,11 +180,13 @@ mod tests {
let body = CreateEmbeddingRequest { let body = CreateEmbeddingRequest {
model: "nomic-text-embed".to_string(), model: "nomic-text-embed".to_string(),
input: EmbeddingInput::from(vec!["The food was delicious and the waiter...".to_string()]), input: EmbeddingInput::from(vec![
encoding_format: None, "The food was delicious and the waiter...".to_string(),
user: None, ]),
dimensions: Some(768), encoding_format: None,
}; user: None,
dimensions: Some(768),
};
let response = app let response = app
.oneshot( .oneshot(

View File

@@ -3,16 +3,14 @@ name = "gemma-runner"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git" } candle-core = { git = "https://github.com/huggingface/candle.git" }
candle-nn = { git = "https://github.com/huggingface/candle.git" } candle-nn = { git = "https://github.com/huggingface/candle.git" }
candle-transformers = { git = "https://github.com/huggingface/candle.git" } candle-transformers = { git = "https://github.com/huggingface/candle.git" }
candle-examples = { git = "https://github.com/huggingface/candle.git" } candle-examples = { git = "https://github.com/huggingface/candle.git" }
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
hf-hub = "0.4" hf-hub = "0.4"
tokenizers = "0.21" tokenizers = "0.21"
anyhow = "1.0" anyhow = "1.0"
@@ -22,6 +20,12 @@ tracing = "0.1"
tracing-chrome = "0.7" tracing-chrome = "0.7"
tracing-subscriber = "0.3" tracing-subscriber = "0.3"
[target.'cfg(target_os = "macos")'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
[features] [features]
default = [] default = []
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]

View File

@@ -4,10 +4,10 @@ extern crate accelerate_src;
extern crate intel_mkl_src; extern crate intel_mkl_src;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use clap::ValueEnum;
use candle_transformers::models::gemma::{Config as Config1, Model as Model1}; use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2}; use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3}; use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use clap::ValueEnum;
// Removed gemma_cli import as it's not needed for the API // Removed gemma_cli import as it's not needed for the API
use candle_core::{utils, DType, Device, Tensor}; use candle_core::{utils, DType, Device, Tensor};
@@ -119,7 +119,12 @@ impl TextGeneration {
/// Stream-only generation: sends freshly generated token strings over `tx`. /// Stream-only generation: sends freshly generated token strings over `tx`.
/// (Does not send the prompt tokens; only newly generated model tokens.) /// (Does not send the prompt tokens; only newly generated model tokens.)
fn run_stream(&mut self, prompt: &str, sample_len: usize, tx: Sender<Result<String>>) -> Result<()> { fn run_stream(
&mut self,
prompt: &str,
sample_len: usize,
tx: Sender<Result<String>>,
) -> Result<()> {
self.tokenizer.clear(); self.tokenizer.clear();
// Encode prompt (context only; do not emit prompt tokens to the stream). // Encode prompt (context only; do not emit prompt tokens to the stream).
@@ -303,7 +308,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt", WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
WhichModel::InstructV3_1B => "google/gemma-3-1b-it", WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
} }
.to_string() .to_string()
}); });
println!("Loading model: {}", &model_id); println!("Loading model: {}", &model_id);
@@ -337,7 +342,10 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
let model = Model1::new(cfg.use_flash_attn, &config, vb)?; let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
Model::V1(model) Model::V1(model)
} }
WhichModel::BaseV2_2B | WhichModel::InstructV2_2B | WhichModel::BaseV2_9B | WhichModel::InstructV2_9B => { WhichModel::BaseV2_2B
| WhichModel::InstructV2_2B
| WhichModel::BaseV2_9B
| WhichModel::InstructV2_9B => {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?; let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(cfg.use_flash_attn, &config, vb)?; let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
Model::V2(model) Model::V2(model)

View File

@@ -1,6 +1,6 @@
use std::io::Write;
use clap::Parser;
use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel}; use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
use clap::Parser;
use std::io::Write;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)] #[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)]
@@ -94,4 +94,4 @@ pub fn run_cli() -> anyhow::Result<()> {
} }
} }
Ok(()) Ok(())
} }

View File

@@ -2,8 +2,8 @@
extern crate accelerate_src; extern crate accelerate_src;
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;
mod gemma_cli;
mod gemma_api; mod gemma_api;
mod gemma_cli;
use anyhow::Error; use anyhow::Error;
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
@@ -14,4 +14,4 @@ use std::io::Write;
/// just a placeholder, not used for anything /// just a placeholder, not used for anything
fn main() -> std::result::Result<(), Error> { fn main() -> std::result::Result<(), Error> {
run_cli() run_cli()
} }

View File

@@ -84,7 +84,10 @@ fn main() -> Result<()> {
let services = discover_services(workspace_path)?; let services = discover_services(workspace_path)?;
println!("Found {} services:", services.len()); println!("Found {} services:", services.len());
for service in &services { for service in &services {
println!(" - {}: {} (port {})", service.name, service.image, service.port); println!(
" - {}: {} (port {})",
service.name, service.image, service.port
);
} }
generate_helm_chart(output_path, chart_name, &services)?; generate_helm_chart(output_path, chart_name, &services)?;
@@ -115,17 +118,20 @@ fn discover_services(workspace_path: &str) -> Result<Vec<ServiceInfo>> {
fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> { fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> {
let content = fs::read_to_string(path) let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read Cargo.toml at {:?}", path))?; .with_context(|| format!("Failed to read Cargo.toml at {:?}", path))?;
let cargo_toml: CargoToml = toml::from_str(&content) let cargo_toml: CargoToml = toml::from_str(&content)
.with_context(|| format!("Failed to parse Cargo.toml at {:?}", path))?; .with_context(|| format!("Failed to parse Cargo.toml at {:?}", path))?;
let package = cargo_toml.package let package = cargo_toml
.package
.ok_or_else(|| anyhow::anyhow!("No package section found in {:?}", path))?; .ok_or_else(|| anyhow::anyhow!("No package section found in {:?}", path))?;
let metadata = package.metadata let metadata = package
.metadata
.ok_or_else(|| anyhow::anyhow!("No metadata section found in {:?}", path))?; .ok_or_else(|| anyhow::anyhow!("No metadata section found in {:?}", path))?;
let kube_metadata = metadata.kube let kube_metadata = metadata
.kube
.ok_or_else(|| anyhow::anyhow!("No kube metadata found in {:?}", path))?; .ok_or_else(|| anyhow::anyhow!("No kube metadata found in {:?}", path))?;
Ok(ServiceInfo { Ok(ServiceInfo {
@@ -136,7 +142,11 @@ fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> {
}) })
} }
fn generate_helm_chart(output_path: &str, chart_name: &str, services: &[ServiceInfo]) -> Result<()> { fn generate_helm_chart(
output_path: &str,
chart_name: &str,
services: &[ServiceInfo],
) -> Result<()> {
let chart_dir = Path::new(output_path); let chart_dir = Path::new(output_path);
let templates_dir = chart_dir.join("templates"); let templates_dir = chart_dir.join("templates");
@@ -512,4 +522,4 @@ fn generate_helmignore(chart_dir: &Path) -> Result<()> {
fs::write(chart_dir.join(".helmignore"), helmignore_content)?; fs::write(chart_dir.join(".helmignore"), helmignore_content)?;
Ok(()) Ok(())
} }

View File

@@ -3,18 +3,6 @@ name = "inference-engine"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
[[bin]]
name="gemma_inference"
path = "src/gemma_inference.rs"
required-features = ["bin"]
[[bin]]
name="llama_inference"
path = "src/llama_inference.rs"
required-features = ["bin"]
[dependencies] [dependencies]
accelerate-src = { version = "0.3.2", optional = true } accelerate-src = { version = "0.3.2", optional = true }
candle-datasets = { version = "=0.9.1", optional = true } candle-datasets = { version = "=0.9.1", optional = true }

View File

@@ -30,4 +30,4 @@ pub trait ModelInference {
} }
/// Factory function type for creating model inference implementations /// Factory function type for creating model inference implementations
pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>; pub type ModelInferenceFactory = fn() -> Result<Box<dyn ModelInference>>;

View File

@@ -1,19 +1,19 @@
// Expose modules for testing and library usage // Expose modules for testing and library usage
pub mod token_output_stream;
pub mod model; pub mod model;
pub mod text_generation;
pub mod utilities_lib;
pub mod openai_types; pub mod openai_types;
pub mod text_generation;
pub mod token_output_stream;
pub mod utilities_lib;
// pub mod cli; // pub mod cli;
pub mod server;
pub mod inference; pub mod inference;
pub mod server;
// Re-export key components for easier access // Re-export key components for easier access
pub use inference::ModelInference;
pub use model::{Model, Which}; pub use model::{Model, Which};
pub use server::{create_router, AppState};
pub use text_generation::TextGeneration; pub use text_generation::TextGeneration;
pub use token_output_stream::TokenOutputStream; pub use token_output_stream::TokenOutputStream;
pub use server::{AppState, create_router};
pub use inference::ModelInference;
use std::env; use std::env;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

View File

@@ -1,8 +1,8 @@
// use candle_core::Tensor; // use candle_core::Tensor;
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1}; use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2}; use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3}; use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] #[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum Which { pub enum Which {
@@ -52,7 +52,11 @@ pub enum Model {
} }
impl Model { impl Model {
pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> { pub fn forward(
&mut self,
input_ids: &candle_core::Tensor,
pos: usize,
) -> candle_core::Result<candle_core::Tensor> {
match self { match self {
Self::V1(m) => m.forward(input_ids, pos), Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos), Self::V2(m) => m.forward(input_ids, pos),
@@ -88,7 +92,13 @@ impl Which {
pub fn is_instruct_model(&self) -> bool { pub fn is_instruct_model(&self) -> bool {
match self { match self {
Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false, Self::Base2B
| Self::Base7B
| Self::CodeBase2B
| Self::CodeBase7B
| Self::BaseV2_2B
| Self::BaseV2_9B
| Self::BaseV3_1B => false,
_ => true, _ => true,
} }
} }
@@ -100,4 +110,4 @@ impl Which {
pub fn is_llama_model(&self) -> bool { pub fn is_llama_model(&self) -> bool {
matches!(self, Self::LlamaInstruct3_2_1B | Self::LlamaInstruct3_2_3B) matches!(self, Self::LlamaInstruct3_2_1B | Self::LlamaInstruct3_2_3B)
} }
} }

View File

@@ -10,7 +10,10 @@ pub struct MessageInnerContent(
); );
impl ToSchema<'_> for MessageInnerContent { impl ToSchema<'_> for MessageInnerContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) { fn schema() -> (
&'static str,
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
) {
( (
"MessageInnerContent", "MessageInnerContent",
utoipa::openapi::RefOr::T(message_inner_content_schema()), utoipa::openapi::RefOr::T(message_inner_content_schema()),
@@ -45,12 +48,18 @@ fn message_inner_content_schema() -> utoipa::openapi::Schema {
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MessageContent( pub struct MessageContent(
#[serde(with = "either::serde_untagged")] #[serde(with = "either::serde_untagged")]
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>, pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
); );
impl ToSchema<'_> for MessageContent { impl ToSchema<'_> for MessageContent {
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) { fn schema() -> (
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema())) &'static str,
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
) {
(
"MessageContent",
utoipa::openapi::RefOr::T(message_content_schema()),
)
} }
} }
@@ -213,4 +222,4 @@ pub struct ModelListResponse {
pub object: String, pub object: String,
/// Array of available models /// Array of available models
pub data: Vec<Model>, pub data: Vec<Model>,
} }

View File

@@ -6,19 +6,22 @@ use axum::{
Json, Router, Json, Router,
}; };
use futures_util::stream::{self, Stream}; use futures_util::stream::{self, Stream};
use tokio_stream::wrappers::UnboundedReceiverStream;
use std::convert::Infallible; use std::convert::Infallible;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{Mutex, mpsc}; use tokio::sync::{mpsc, Mutex};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid; use uuid::Uuid;
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage}; use crate::openai_types::{
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest,
ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage,
};
use crate::Which; use crate::Which;
use either::Either; use either::Either;
use serde_json::Value;
use gemma_runner::{run_gemma_api, GemmaInferenceConfig}; use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
use llama_runner::{run_llama_inference, LlamaInferenceConfig}; use llama_runner::{run_llama_inference, LlamaInferenceConfig};
use serde_json::Value;
// ------------------------- // -------------------------
// Shared app state // Shared app state
// ------------------------- // -------------------------
@@ -62,12 +65,15 @@ fn normalize_model_id(model_id: &str) -> String {
fn build_gemma_prompt(messages: &[Message]) -> String { fn build_gemma_prompt(messages: &[Message]) -> String {
let mut prompt = String::new(); let mut prompt = String::new();
for message in messages { for message in messages {
match message.role.as_str() { match message.role.as_str() {
"system" => { "system" => {
if let Some(MessageContent(Either::Left(content))) = &message.content { if let Some(MessageContent(Either::Left(content))) = &message.content {
prompt.push_str(&format!("<start_of_turn>system\n{}<end_of_turn>\n", content)); prompt.push_str(&format!(
"<start_of_turn>system\n{}<end_of_turn>\n",
content
));
} }
} }
"user" => { "user" => {
@@ -83,7 +89,7 @@ fn build_gemma_prompt(messages: &[Message]) -> String {
_ => {} _ => {}
} }
} }
prompt.push_str("<start_of_turn>model\n"); prompt.push_str("<start_of_turn>model\n");
prompt prompt
} }
@@ -97,9 +103,13 @@ pub async fn chat_completions(
Json(request): Json<ChatCompletionRequest>, Json(request): Json<ChatCompletionRequest>,
) -> Result<impl IntoResponse, (StatusCode, String)> { ) -> Result<impl IntoResponse, (StatusCode, String)> {
if !request.stream.unwrap_or(false) { if !request.stream.unwrap_or(false) {
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response()); return Ok(chat_completions_non_streaming_proxy(state, request)
.await
.into_response());
} }
Ok(chat_completions_stream(state, request).await.into_response()) Ok(chat_completions_stream(state, request)
.await
.into_response())
} }
pub async fn chat_completions_non_streaming_proxy( pub async fn chat_completions_non_streaming_proxy(
@@ -136,7 +146,9 @@ pub async fn chat_completions_non_streaming_proxy(
ModelType::Gemma => build_gemma_prompt(&request.messages), ModelType::Gemma => build_gemma_prompt(&request.messages),
ModelType::Llama => { ModelType::Llama => {
// For Llama, just use the last user message for now // For Llama, just use the last user message for now
request.messages.last() request
.messages
.last()
.and_then(|m| m.content.as_ref()) .and_then(|m| m.content.as_ref())
.and_then(|c| match c { .and_then(|c| match c {
MessageContent(Either::Left(text)) => Some(text.clone()), MessageContent(Either::Left(text)) => Some(text.clone()),
@@ -147,46 +159,47 @@ pub async fn chat_completions_non_streaming_proxy(
}; };
// Get streaming receiver based on model type // Get streaming receiver based on model type
let rx = match state.model_type { let rx =
ModelType::Gemma => { match state.model_type {
if let Some(mut config) = state.gemma_config { ModelType::Gemma => {
config.prompt = prompt.clone(); if let Some(mut config) = state.gemma_config {
config.max_tokens = max_tokens; config.prompt = prompt.clone();
run_gemma_api(config).map_err(|e| ( config.max_tokens = max_tokens;
run_gemma_api(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) } "error": { "message": format!("Error initializing Gemma model: {}", e) }
})) }))
))? ))?
} else { } else {
return Err(( return Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": "Gemma configuration not available" } "error": { "message": "Gemma configuration not available" }
})) })),
)); ));
}
} }
} ModelType::Llama => {
ModelType::Llama => { if let Some(mut config) = state.llama_config {
if let Some(mut config) = state.llama_config { config.prompt = prompt.clone();
config.prompt = prompt.clone(); config.max_tokens = max_tokens;
config.max_tokens = max_tokens; run_llama_inference(config).map_err(|e| (
run_llama_inference(config).map_err(|e| (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) } "error": { "message": format!("Error initializing Llama model: {}", e) }
})) }))
))? ))?
} else { } else {
return Err(( return Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": "Llama configuration not available" } "error": { "message": "Llama configuration not available" }
})) })),
)); ));
}
} }
} };
};
// Collect all tokens from the stream // Collect all tokens from the stream
let mut completion = String::new(); let mut completion = String::new();
@@ -281,7 +294,9 @@ async fn handle_streaming_request(
ModelType::Gemma => build_gemma_prompt(&request.messages), ModelType::Gemma => build_gemma_prompt(&request.messages),
ModelType::Llama => { ModelType::Llama => {
// For Llama, just use the last user message for now // For Llama, just use the last user message for now
request.messages.last() request
.messages
.last()
.and_then(|m| m.content.as_ref()) .and_then(|m| m.content.as_ref())
.and_then(|c| match c { .and_then(|c| match c {
MessageContent(Either::Left(text)) => Some(text.clone()), MessageContent(Either::Left(text)) => Some(text.clone()),
@@ -303,7 +318,10 @@ async fn handle_streaming_request(
model: model_id.clone(), model: model_id.clone(),
choices: vec![ChatCompletionChunkChoice { choices: vec![ChatCompletionChunkChoice {
index: 0, index: 0,
delta: Delta { role: Some("assistant".to_string()), content: None }, delta: Delta {
role: Some("assistant".to_string()),
content: None,
},
finish_reason: None, finish_reason: None,
}], }],
}; };
@@ -324,7 +342,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": format!("Error initializing Gemma model: {}", e) } "error": { "message": format!("Error initializing Gemma model: {}", e) }
})) })),
)); ));
} }
} }
@@ -333,7 +351,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": "Gemma configuration not available" } "error": { "message": "Gemma configuration not available" }
})) })),
)); ));
} }
} }
@@ -348,7 +366,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": format!("Error initializing Llama model: {}", e) } "error": { "message": format!("Error initializing Llama model: {}", e) }
})) })),
)); ));
} }
} }
@@ -357,7 +375,7 @@ async fn handle_streaming_request(
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ Json(serde_json::json!({
"error": { "message": "Llama configuration not available" } "error": { "message": "Llama configuration not available" }
})) })),
)); ));
} }
} }
@@ -386,16 +404,20 @@ async fn handle_streaming_request(
if recent_tokens.len() > REPETITION_WINDOW { if recent_tokens.len() > REPETITION_WINDOW {
recent_tokens.remove(0); recent_tokens.remove(0);
} }
// Check for repetitive patterns // Check for repetitive patterns
if recent_tokens.len() >= 4 { if recent_tokens.len() >= 4 {
let last_token = &recent_tokens[recent_tokens.len() - 1]; let last_token = &recent_tokens[recent_tokens.len() - 1];
let second_last = &recent_tokens[recent_tokens.len() - 2]; let second_last = &recent_tokens[recent_tokens.len() - 2];
if last_token == second_last { if last_token == second_last {
repetition_count += 1; repetition_count += 1;
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count); tracing::warn!(
"Detected repetition pattern: '{}' (count: {})",
last_token,
repetition_count
);
if repetition_count >= MAX_REPETITION_COUNT { if repetition_count >= MAX_REPETITION_COUNT {
tracing::info!("Stopping generation due to excessive repetition"); tracing::info!("Stopping generation due to excessive repetition");
break; break;
@@ -412,11 +434,14 @@ async fn handle_streaming_request(
model: model_id_clone.clone(), model: model_id_clone.clone(),
choices: vec![ChatCompletionChunkChoice { choices: vec![ChatCompletionChunkChoice {
index: 0, index: 0,
delta: Delta { role: None, content: Some(token) }, delta: Delta {
role: None,
content: Some(token),
},
finish_reason: None, finish_reason: None,
}], }],
}; };
if let Ok(json) = serde_json::to_string(&chunk) { if let Ok(json) = serde_json::to_string(&chunk) {
let _ = tx.send(Ok(Event::default().data(json))); let _ = tx.send(Ok(Event::default().data(json)));
} }
@@ -436,7 +461,10 @@ async fn handle_streaming_request(
model: model_id_clone.clone(), model: model_id_clone.clone(),
choices: vec![ChatCompletionChunkChoice { choices: vec![ChatCompletionChunkChoice {
index: 0, index: 0,
delta: Delta { role: None, content: None }, delta: Delta {
role: None,
content: None,
},
finish_reason: Some("stop".to_string()), finish_reason: Some("stop".to_string()),
}], }],
}; };
@@ -451,8 +479,6 @@ async fn handle_streaming_request(
Ok(Sse::new(stream)) Ok(Sse::new(stream))
} }
// ------------------------- // -------------------------
// Router // Router
// ------------------------- // -------------------------
@@ -647,7 +673,6 @@ pub async fn list_models() -> Json<ModelListResponse> {
}) })
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@@ -681,10 +706,7 @@ mod tests {
let prompt = build_gemma_prompt(&messages); let prompt = build_gemma_prompt(&messages);
let expected = "<start_of_turn>user\nSystem message\n\nKnock knock.<end_of_turn>\n\ let expected = "<start_of_turn>system\nSystem message<end_of_turn>\n<start_of_turn>user\nKnock knock.<end_of_turn>\n<start_of_turn>model\nWho's there?<end_of_turn>\n<start_of_turn>user\nGemma.<end_of_turn>\n<start_of_turn>model\n";
<start_of_turn>model\nWho's there?<end_of_turn>\n\
<start_of_turn>user\nGemma.<end_of_turn>\n\
<start_of_turn>model\n";
assert_eq!(prompt, expected); assert_eq!(prompt, expected);
} }
@@ -698,15 +720,13 @@ mod tests {
#[test] #[test]
fn test_missing_content() { fn test_missing_content() {
let messages = vec![ let messages = vec![Message {
Message { role: "user".to_string(),
role: "user".to_string(), content: None,
content: None, name: None,
name: None, }];
}
];
let prompt = build_gemma_prompt(&messages); let prompt = build_gemma_prompt(&messages);
assert_eq!(prompt, "<start_of_turn>user\n<end_of_turn>\n<start_of_turn>model\n"); assert_eq!(prompt, "<start_of_turn>model\n");
} }
} }

File diff suppressed because it is too large Load Diff

View File

@@ -84,4 +84,4 @@ impl TokenOutputStream {
self.prev_index = 0; self.prev_index = 0;
self.current_index = 0; self.current_index = 0;
} }
} }

View File

@@ -147,7 +147,8 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
) -> Result<Vec<std::path::PathBuf>> { ) -> Result<Vec<std::path::PathBuf>> {
let path = path.as_ref(); let path = path.as_ref();
let jsfile = std::fs::File::open(path.join(json_file))?; let jsfile = std::fs::File::open(path.join(json_file))?;
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?; let json: serde_json::Value =
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
let weight_map = match json.get("weight_map") { let weight_map = match json.get("weight_map") {
None => candle_core::bail!("no weight map in {json_file:?}"), None => candle_core::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map, Some(serde_json::Value::Object(map)) => map,
@@ -164,4 +165,4 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
.map(|v| path.join(v)) .map(|v| path.join(v))
.collect(); .collect();
Ok(safetensors_files) Ok(safetensors_files)
} }

View File

@@ -9,7 +9,10 @@ mod tests {
// Test a few representative model variants // Test a few representative model variants
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b"); assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it"); assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
assert_eq!(Which::InstructV1_1_2B.to_model_id(), "google/gemma-1.1-2b-it"); assert_eq!(
Which::InstructV1_1_2B.to_model_id(),
"google/gemma-1.1-2b-it"
);
assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b"); assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b");
assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b"); assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b");
assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it"); assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it");
@@ -64,4 +67,4 @@ mod tests {
// Note: Testing the Model enum's forward method would require creating actual model instances, // Note: Testing the Model enum's forward method would require creating actual model instances,
// which is complex and would require loading model weights. This is better suited for // which is complex and would require loading model weights. This is better suited for
// integration tests or mocking the models. // integration tests or mocking the models.
} }

View File

@@ -106,7 +106,7 @@ mod tests {
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?; let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 3u32]; let tokens = vec![1u32, 2u32, 3u32];
// Create a mock TextGeneration instance // Create a mock TextGeneration instance
// Since we can't easily create a full TextGeneration instance without a model, // Since we can't easily create a full TextGeneration instance without a model,
// we'll test the logic by creating a simple struct with the necessary fields // we'll test the logic by creating a simple struct with the necessary fields
@@ -115,7 +115,7 @@ mod tests {
repeat_last_n: usize, repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>, penalty_cache: HashMap<usize, f32>,
} }
impl MockTextGeneration { impl MockTextGeneration {
fn apply_cached_repeat_penalty( fn apply_cached_repeat_penalty(
&mut self, &mut self,
@@ -167,16 +167,17 @@ mod tests {
Ok((result, elapsed)) Ok((result, elapsed))
} }
} }
let mut mock_gen = MockTextGeneration { let mut mock_gen = MockTextGeneration {
repeat_penalty: 1.0, // No penalty repeat_penalty: 1.0, // No penalty
repeat_last_n: 3, repeat_last_n: 3,
penalty_cache: HashMap::new(), penalty_cache: HashMap::new(),
}; };
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?; let (result_logits, _duration) =
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?; let result_data = result_logits.to_vec1::<f32>()?;
// With no penalty, logits should be unchanged // With no penalty, logits should be unchanged
assert_eq!(result_data, logits_data); assert_eq!(result_data, logits_data);
Ok(()) Ok(())
@@ -189,13 +190,13 @@ mod tests {
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?; let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 3u32]; let tokens = vec![1u32, 2u32, 3u32];
struct MockTextGeneration { struct MockTextGeneration {
repeat_penalty: f32, repeat_penalty: f32,
repeat_last_n: usize, repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>, penalty_cache: HashMap<usize, f32>,
} }
impl MockTextGeneration { impl MockTextGeneration {
fn apply_cached_repeat_penalty( fn apply_cached_repeat_penalty(
&mut self, &mut self,
@@ -238,16 +239,17 @@ mod tests {
Ok((result, elapsed)) Ok((result, elapsed))
} }
} }
let mut mock_gen = MockTextGeneration { let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0, // Apply penalty repeat_penalty: 2.0, // Apply penalty
repeat_last_n: 3, repeat_last_n: 3,
penalty_cache: HashMap::new(), penalty_cache: HashMap::new(),
}; };
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?; let (result_logits, _duration) =
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?; let result_data = result_logits.to_vec1::<f32>()?;
// Tokens 1, 2, 3 should be penalized (divided by 2.0) // Tokens 1, 2, 3 should be penalized (divided by 2.0)
let expected = vec![1.0f32, 1.0, 1.5, 2.0, 5.0]; // [1.0, 2.0/2.0, 3.0/2.0, 4.0/2.0, 5.0] let expected = vec![1.0f32, 1.0, 1.5, 2.0, 5.0]; // [1.0, 2.0/2.0, 3.0/2.0, 4.0/2.0, 5.0]
assert_eq!(result_data, expected); assert_eq!(result_data, expected);
@@ -261,13 +263,13 @@ mod tests {
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?; let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 1u32, 1u32]; // Repeated token should use cache let tokens = vec![1u32, 1u32, 1u32]; // Repeated token should use cache
struct MockTextGeneration { struct MockTextGeneration {
repeat_penalty: f32, repeat_penalty: f32,
repeat_last_n: usize, repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>, penalty_cache: HashMap<usize, f32>,
} }
impl MockTextGeneration { impl MockTextGeneration {
fn apply_cached_repeat_penalty( fn apply_cached_repeat_penalty(
&mut self, &mut self,
@@ -308,20 +310,21 @@ mod tests {
Ok((result, elapsed)) Ok((result, elapsed))
} }
} }
let mut mock_gen = MockTextGeneration { let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0, repeat_penalty: 2.0,
repeat_last_n: 3, repeat_last_n: 3,
penalty_cache: HashMap::new(), penalty_cache: HashMap::new(),
}; };
// First call should cache the penalty for token 1 // First call should cache the penalty for token 1
let (_result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?; let (_result_logits, _duration) =
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
// Cache should contain the penalized value for token 1 // Cache should contain the penalized value for token 1
assert!(mock_gen.penalty_cache.contains_key(&1)); assert!(mock_gen.penalty_cache.contains_key(&1));
assert_eq!(mock_gen.penalty_cache.get(&1), Some(&1.0)); // 2.0 / 2.0 = 1.0 assert_eq!(mock_gen.penalty_cache.get(&1), Some(&1.0)); // 2.0 / 2.0 = 1.0
Ok(()) Ok(())
} }
@@ -332,13 +335,13 @@ mod tests {
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?; let logits = Tensor::new(&logits_data[..], &device)?;
let tokens: Vec<u32> = vec![]; // Empty tokens let tokens: Vec<u32> = vec![]; // Empty tokens
struct MockTextGeneration { struct MockTextGeneration {
repeat_penalty: f32, repeat_penalty: f32,
repeat_last_n: usize, repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>, penalty_cache: HashMap<usize, f32>,
} }
impl MockTextGeneration { impl MockTextGeneration {
fn apply_cached_repeat_penalty( fn apply_cached_repeat_penalty(
&mut self, &mut self,
@@ -379,16 +382,17 @@ mod tests {
Ok((result, elapsed)) Ok((result, elapsed))
} }
} }
let mut mock_gen = MockTextGeneration { let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0, repeat_penalty: 2.0,
repeat_last_n: 3, repeat_last_n: 3,
penalty_cache: HashMap::new(), penalty_cache: HashMap::new(),
}; };
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?; let (result_logits, _duration) =
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?; let result_data = result_logits.to_vec1::<f32>()?;
// With empty tokens, logits should be unchanged // With empty tokens, logits should be unchanged
assert_eq!(result_data, logits_data); assert_eq!(result_data, logits_data);
Ok(()) Ok(())
@@ -401,13 +405,13 @@ mod tests {
let logits_data = vec![1.0f32, 2.0, 3.0]; let logits_data = vec![1.0f32, 2.0, 3.0];
let logits = Tensor::new(&logits_data[..], &device)?; let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 5u32, 10u32]; // Token 5 and 10 are out of bounds let tokens = vec![1u32, 5u32, 10u32]; // Token 5 and 10 are out of bounds
struct MockTextGeneration { struct MockTextGeneration {
repeat_penalty: f32, repeat_penalty: f32,
repeat_last_n: usize, repeat_last_n: usize,
penalty_cache: HashMap<usize, f32>, penalty_cache: HashMap<usize, f32>,
} }
impl MockTextGeneration { impl MockTextGeneration {
fn apply_cached_repeat_penalty( fn apply_cached_repeat_penalty(
&mut self, &mut self,
@@ -448,16 +452,17 @@ mod tests {
Ok((result, elapsed)) Ok((result, elapsed))
} }
} }
let mut mock_gen = MockTextGeneration { let mut mock_gen = MockTextGeneration {
repeat_penalty: 2.0, repeat_penalty: 2.0,
repeat_last_n: 3, repeat_last_n: 3,
penalty_cache: HashMap::new(), penalty_cache: HashMap::new(),
}; };
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?; let (result_logits, _duration) =
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
let result_data = result_logits.to_vec1::<f32>()?; let result_data = result_logits.to_vec1::<f32>()?;
// Only token 1 should be penalized, out-of-bounds tokens should be ignored // Only token 1 should be penalized, out-of-bounds tokens should be ignored
let expected = vec![1.0f32, 1.0, 3.0]; // [1.0, 2.0/2.0, 3.0] let expected = vec![1.0f32, 1.0, 3.0]; // [1.0, 2.0/2.0, 3.0]
assert_eq!(result_data, expected); assert_eq!(result_data, expected);
@@ -471,52 +476,52 @@ mod tests {
// Since creating a real TextGeneration instance requires a Model which needs model weights, // Since creating a real TextGeneration instance requires a Model which needs model weights,
// we'll create a test that demonstrates the method is now public and can be accessed. // we'll create a test that demonstrates the method is now public and can be accessed.
// The comprehensive functionality testing is already covered by the mock tests above. // The comprehensive functionality testing is already covered by the mock tests above.
// Test data setup // Test data setup
let device = Device::Cpu; let device = Device::Cpu;
let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; let logits_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let logits = Tensor::new(&logits_data[..], &device)?; let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 3u32]; let tokens = vec![1u32, 2u32, 3u32];
// Test that we can create the necessary components // Test that we can create the necessary components
let tokenizer = create_test_tokenizer()?; let tokenizer = create_test_tokenizer()?;
// The method is now public as confirmed by making it pub fn apply_cached_repeat_penalty // The method is now public as confirmed by making it pub fn apply_cached_repeat_penalty
// This test verifies the method signature and that it's accessible from external code // This test verifies the method signature and that it's accessible from external code
// We could create a TextGeneration instance if we had a way to mock the Model, // We could create a TextGeneration instance if we had a way to mock the Model,
// but for now we confirm that the existing mock tests cover the functionality // but for now we confirm that the existing mock tests cover the functionality
// and the method is properly exposed as public // and the method is properly exposed as public
println!("apply_cached_repeat_penalty method is now public and accessible for testing"); println!("apply_cached_repeat_penalty method is now public and accessible for testing");
assert!(true); assert!(true);
Ok(()) Ok(())
} }
// Integration test that demonstrates the method usage pattern // Integration test that demonstrates the method usage pattern
#[test] #[test]
fn test_apply_cached_repeat_penalty_usage_pattern() -> Result<()> { fn test_apply_cached_repeat_penalty_usage_pattern() -> Result<()> {
// This test demonstrates how the apply_cached_repeat_penalty method would be used // This test demonstrates how the apply_cached_repeat_penalty method would be used
// in practice, even though we can't create a full TextGeneration instance in unit tests // in practice, even though we can't create a full TextGeneration instance in unit tests
let device = Device::Cpu; let device = Device::Cpu;
let logits_data = vec![1.5f32, 2.5, 3.5, 4.5, 5.5]; let logits_data = vec![1.5f32, 2.5, 3.5, 4.5, 5.5];
let logits = Tensor::new(&logits_data[..], &device)?; let logits = Tensor::new(&logits_data[..], &device)?;
let tokens = vec![1u32, 2u32, 1u32, 3u32]; // Repeated token 1 to test caching let tokens = vec![1u32, 2u32, 1u32, 3u32]; // Repeated token 1 to test caching
// Test parameters that would be used with TextGeneration // Test parameters that would be used with TextGeneration
let repeat_penalty = 1.2f32; let repeat_penalty = 1.2f32;
let repeat_last_n = 3usize; let repeat_last_n = 3usize;
let mut penalty_cache: HashMap<usize, f32> = HashMap::new(); let mut penalty_cache: HashMap<usize, f32> = HashMap::new();
// Simulate the method's logic to verify it works as expected // Simulate the method's logic to verify it works as expected
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
if repeat_penalty != 1.0 { if repeat_penalty != 1.0 {
let start_at = tokens.len().saturating_sub(repeat_last_n); let start_at = tokens.len().saturating_sub(repeat_last_n);
let penalty_tokens = &tokens[start_at..]; let penalty_tokens = &tokens[start_at..];
let mut logits_vec = logits.to_vec1::<f32>()?; let mut logits_vec = logits.to_vec1::<f32>()?;
for &token_id in penalty_tokens { for &token_id in penalty_tokens {
let token_id = token_id as usize; let token_id = token_id as usize;
if token_id < logits_vec.len() { if token_id < logits_vec.len() {
@@ -531,14 +536,14 @@ mod tests {
} }
} }
} }
let _duration = start_time.elapsed(); let _duration = start_time.elapsed();
// Verify that tokens were processed correctly // Verify that tokens were processed correctly
assert!(penalty_cache.contains_key(&1)); // Token 1 should be cached assert!(penalty_cache.contains_key(&1)); // Token 1 should be cached
assert!(penalty_cache.contains_key(&2)); // Token 2 should be cached assert!(penalty_cache.contains_key(&2)); // Token 2 should be cached
assert!(penalty_cache.contains_key(&3)); // Token 3 should be cached assert!(penalty_cache.contains_key(&3)); // Token 3 should be cached
println!("Successfully demonstrated apply_cached_repeat_penalty usage pattern"); println!("Successfully demonstrated apply_cached_repeat_penalty usage pattern");
Ok(()) Ok(())
} }

View File

@@ -1,7 +1,7 @@
use inference_engine::token_output_stream::TokenOutputStream;
use tokenizers::Tokenizer;
use std::path::PathBuf;
use anyhow::Result; use anyhow::Result;
use inference_engine::token_output_stream::TokenOutputStream;
use std::path::PathBuf;
use tokenizers::Tokenizer;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@@ -19,7 +19,7 @@ mod tests {
fn test_new_token_output_stream() -> Result<()> { fn test_new_token_output_stream() -> Result<()> {
let tokenizer = create_test_tokenizer()?; let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer); let token_stream = TokenOutputStream::new(tokenizer);
// Check that the token stream was created successfully // Check that the token stream was created successfully
assert!(token_stream.tokenizer().get_vocab(true).len() > 0); assert!(token_stream.tokenizer().get_vocab(true).len() > 0);
Ok(()) Ok(())
@@ -29,18 +29,18 @@ mod tests {
fn test_clear() -> Result<()> { fn test_clear() -> Result<()> {
let tokenizer = create_test_tokenizer()?; let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer); let mut token_stream = TokenOutputStream::new(tokenizer);
// Add a token // Add a token
let token_id = token_stream.get_token("<eos>").unwrap(); let token_id = token_stream.get_token("<eos>").unwrap();
token_stream.next_token(token_id)?; token_stream.next_token(token_id)?;
// Clear the stream // Clear the stream
token_stream.clear(); token_stream.clear();
// Check that the stream is empty by trying to decode all // Check that the stream is empty by trying to decode all
let decoded = token_stream.decode_all()?; let decoded = token_stream.decode_all()?;
assert_eq!(decoded, ""); assert_eq!(decoded, "");
Ok(()) Ok(())
} }
@@ -48,15 +48,15 @@ mod tests {
fn test_get_token() -> Result<()> { fn test_get_token() -> Result<()> {
let tokenizer = create_test_tokenizer()?; let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer); let token_stream = TokenOutputStream::new(tokenizer);
// Get a token that should exist // Get a token that should exist
let eos_token = token_stream.get_token("<eos>"); let eos_token = token_stream.get_token("<eos>");
assert!(eos_token.is_some()); assert!(eos_token.is_some());
// Get a token that shouldn't exist // Get a token that shouldn't exist
let nonexistent_token = token_stream.get_token("<this_token_does_not_exist>"); let nonexistent_token = token_stream.get_token("<this_token_does_not_exist>");
assert!(nonexistent_token.is_none()); assert!(nonexistent_token.is_none());
Ok(()) Ok(())
} }
@@ -64,11 +64,14 @@ mod tests {
fn test_next_token_and_decode() -> Result<()> { fn test_next_token_and_decode() -> Result<()> {
let tokenizer = create_test_tokenizer()?; let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer); let mut token_stream = TokenOutputStream::new(tokenizer);
// Get some tokens // Get some tokens
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap(); let hello_tokens = token_stream
.tokenizer()
.encode("Hello world", true)
.unwrap();
let token_ids = hello_tokens.get_ids(); let token_ids = hello_tokens.get_ids();
// Add tokens one by one // Add tokens one by one
let mut output = String::new(); let mut output = String::new();
for &token_id in token_ids { for &token_id in token_ids {
@@ -76,16 +79,16 @@ mod tests {
output.push_str(&text); output.push_str(&text);
} }
} }
// Get any remaining text // Get any remaining text
if let Some(rest) = token_stream.decode_rest()? { if let Some(rest) = token_stream.decode_rest()? {
output.push_str(&rest); output.push_str(&rest);
} }
// Check the output // Check the output
assert!(!output.is_empty()); assert!(!output.is_empty());
assert_eq!(output.trim(), "Hello world"); assert_eq!(output.trim(), "Hello world");
Ok(()) Ok(())
} }
@@ -93,22 +96,25 @@ mod tests {
fn test_decode_all() -> Result<()> { fn test_decode_all() -> Result<()> {
let tokenizer = create_test_tokenizer()?; let tokenizer = create_test_tokenizer()?;
let mut token_stream = TokenOutputStream::new(tokenizer); let mut token_stream = TokenOutputStream::new(tokenizer);
// Get some tokens // Get some tokens
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap(); let hello_tokens = token_stream
.tokenizer()
.encode("Hello world", true)
.unwrap();
let token_ids = hello_tokens.get_ids(); let token_ids = hello_tokens.get_ids();
// Add tokens one by one // Add tokens one by one
for &token_id in token_ids { for &token_id in token_ids {
token_stream.next_token(token_id)?; token_stream.next_token(token_id)?;
} }
// Decode all // Decode all
let decoded = token_stream.decode_all()?; let decoded = token_stream.decode_all()?;
// Check the output // Check the output
assert_eq!(decoded.trim(), "Hello world"); assert_eq!(decoded.trim(), "Hello world");
Ok(()) Ok(())
} }
@@ -116,14 +122,14 @@ mod tests {
fn test_into_inner() -> Result<()> { fn test_into_inner() -> Result<()> {
let tokenizer = create_test_tokenizer()?; let tokenizer = create_test_tokenizer()?;
let token_stream = TokenOutputStream::new(tokenizer); let token_stream = TokenOutputStream::new(tokenizer);
// Get the inner tokenizer // Get the inner tokenizer
let inner_tokenizer = token_stream.into_inner(); let inner_tokenizer = token_stream.into_inner();
// Check that the inner tokenizer works // Check that the inner tokenizer works
let encoded = inner_tokenizer.encode("Test", true).unwrap(); let encoded = inner_tokenizer.encode("Test", true).unwrap();
assert!(encoded.get_ids().len() > 0); assert!(encoded.get_ids().len() > 0);
Ok(()) Ok(())
} }
} }

View File

@@ -5,6 +5,25 @@ use leptos_router::{
StaticSegment, StaticSegment,
}; };
#[cfg(feature = "hydrate")]
use async_openai_wasm::config::OpenAIConfig;
#[cfg(feature = "hydrate")]
use async_openai_wasm::types::{FinishReason, Role};
#[cfg(feature = "hydrate")]
use async_openai_wasm::{
types::{
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs,
Model as OpenAIModel,
},
Client,
};
#[cfg(feature = "hydrate")]
use futures_util::StreamExt;
#[cfg(feature = "hydrate")]
use js_sys::Date;
#[cfg(feature = "hydrate")]
use leptos::task::spawn_local;
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
@@ -12,25 +31,7 @@ use std::collections::VecDeque;
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
use uuid::Uuid; use uuid::Uuid;
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
use js_sys::Date;
#[cfg(feature = "hydrate")]
use web_sys::{HtmlInputElement, KeyboardEvent, SubmitEvent}; use web_sys::{HtmlInputElement, KeyboardEvent, SubmitEvent};
#[cfg(feature = "hydrate")]
use futures_util::StreamExt;
#[cfg(feature = "hydrate")]
use async_openai_wasm::{
types::{
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, Model as OpenAIModel,
},
Client,
};
#[cfg(feature = "hydrate")]
use async_openai_wasm::config::OpenAIConfig;
#[cfg(feature = "hydrate")]
use async_openai_wasm::types::{Role, FinishReason};
#[cfg(feature = "hydrate")]
use leptos::task::spawn_local;
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -43,11 +44,15 @@ pub struct Message {
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageContent(pub either::Either<String, Vec<std::collections::HashMap<String, MessageInnerContent>>>); pub struct MessageContent(
pub either::Either<String, Vec<std::collections::HashMap<String, MessageInnerContent>>>,
);
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageInnerContent(pub either::Either<String, std::collections::HashMap<String, String>>); pub struct MessageInnerContent(
pub either::Either<String, std::collections::HashMap<String, String>>,
);
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -62,27 +67,40 @@ const DEFAULT_MODEL: &str = "default";
#[cfg(feature = "hydrate")] #[cfg(feature = "hydrate")]
async fn fetch_available_models() -> Result<Vec<OpenAIModel>, String> { async fn fetch_available_models() -> Result<Vec<OpenAIModel>, String> {
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Starting model fetch from http://localhost:8080/v1"); leptos::logging::log!(
"[DEBUG_LOG] fetch_available_models: Starting model fetch from http://localhost:8080/v1"
);
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string()); let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
let client = Client::with_config(config); let client = Client::with_config(config);
match client.models().list().await { match client.models().list().await {
Ok(response) => { Ok(response) => {
let model_count = response.data.len(); let model_count = response.data.len();
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Successfully fetched {} models", model_count); leptos::logging::log!(
"[DEBUG_LOG] fetch_available_models: Successfully fetched {} models",
model_count
);
if model_count > 0 { if model_count > 0 {
let model_names: Vec<String> = response.data.iter().map(|m| m.id.clone()).collect(); let model_names: Vec<String> = response.data.iter().map(|m| m.id.clone()).collect();
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Available models: {:?}", model_names); leptos::logging::log!(
"[DEBUG_LOG] fetch_available_models: Available models: {:?}",
model_names
);
} else { } else {
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: No models returned by server"); leptos::logging::log!(
"[DEBUG_LOG] fetch_available_models: No models returned by server"
);
} }
Ok(response.data) Ok(response.data)
}, }
Err(e) => { Err(e) => {
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Failed to fetch models: {:?}", e); leptos::logging::log!(
"[DEBUG_LOG] fetch_available_models: Failed to fetch models: {:?}",
e
);
Err(format!("Failed to fetch models: {}", e)) Err(format!("Failed to fetch models: {}", e))
} }
} }
@@ -150,7 +168,7 @@ fn ChatInterface() -> impl IntoView {
{ {
ChatInterfaceImpl() ChatInterfaceImpl()
} }
#[cfg(not(feature = "hydrate"))] #[cfg(not(feature = "hydrate"))]
{ {
view! { view! {
@@ -252,7 +270,7 @@ fn ChatInterfaceImpl() -> impl IntoView {
let current_model = selected_model.get_untracked(); let current_model = selected_model.get_untracked();
let total_messages = chat_messages.len(); let total_messages = chat_messages.len();
leptos::logging::log!("[DEBUG_LOG] send_message: Preparing request - model: '{}', history_count: {}, total_messages: {}", leptos::logging::log!("[DEBUG_LOG] send_message: Preparing request - model: '{}', history_count: {}, total_messages: {}",
current_model, history_count, total_messages); current_model, history_count, total_messages);
@@ -267,17 +285,17 @@ fn ChatInterfaceImpl() -> impl IntoView {
// Send request // Send request
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string()); let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
let client = Client::with_config(config); let client = Client::with_config(config);
leptos::logging::log!("[DEBUG_LOG] send_message: Sending request to http://localhost:8080/v1 with model: '{}'", current_model); leptos::logging::log!("[DEBUG_LOG] send_message: Sending request to http://localhost:8080/v1 with model: '{}'", current_model);
match client.chat().create_stream(request).await { match client.chat().create_stream(request).await {
Ok(mut stream) => { Ok(mut stream) => {
leptos::logging::log!("[DEBUG_LOG] send_message: Successfully created stream"); leptos::logging::log!("[DEBUG_LOG] send_message: Successfully created stream");
let mut assistant_created = false; let mut assistant_created = false;
let mut content_appended = false; let mut content_appended = false;
let mut chunks_received = 0; let mut chunks_received = 0;
while let Some(next) = stream.next().await { while let Some(next) = stream.next().await {
match next { match next {
Ok(chunk) => { Ok(chunk) => {
@@ -335,7 +353,11 @@ fn ChatInterfaceImpl() -> impl IntoView {
} }
} }
Err(e) => { Err(e) => {
leptos::logging::log!("[DEBUG_LOG] send_message: Stream error after {} chunks: {:?}", chunks_received, e); leptos::logging::log!(
"[DEBUG_LOG] send_message: Stream error after {} chunks: {:?}",
chunks_received,
e
);
set_messages.update(|msgs| { set_messages.update(|msgs| {
msgs.push_back(Message { msgs.push_back(Message {
id: Uuid::new_v4().to_string(), id: Uuid::new_v4().to_string(),
@@ -364,7 +386,10 @@ fn ChatInterfaceImpl() -> impl IntoView {
leptos::logging::log!("[DEBUG_LOG] send_message: Stream completed successfully, received {} chunks", chunks_received); leptos::logging::log!("[DEBUG_LOG] send_message: Stream completed successfully, received {} chunks", chunks_received);
} }
Err(e) => { Err(e) => {
leptos::logging::log!("[DEBUG_LOG] send_message: Request failed with error: {:?}", e); leptos::logging::log!(
"[DEBUG_LOG] send_message: Request failed with error: {:?}",
e
);
let error_message = Message { let error_message = Message {
id: Uuid::new_v4().to_string(), id: Uuid::new_v4().to_string(),
role: "system".to_string(), role: "system".to_string(),
@@ -404,7 +429,8 @@ fn ChatInterfaceImpl() -> impl IntoView {
}; };
let messages_list = move || { let messages_list = move || {
messages.get() messages
.get()
.into_iter() .into_iter()
.map(|message| { .map(|message| {
let role_class = match message.role.as_str() { let role_class = match message.role.as_str() {
@@ -439,7 +465,7 @@ fn ChatInterfaceImpl() -> impl IntoView {
<h1>"Chat Interface"</h1> <h1>"Chat Interface"</h1>
<div class="model-selector"> <div class="model-selector">
<label for="model-select">"Model: "</label> <label for="model-select">"Model: "</label>
<select <select
id="model-select" id="model-select"
on:change=on_model_change on:change=on_model_change
prop:value=selected_model prop:value=selected_model

View File

@@ -10,10 +10,10 @@ pub fn hydrate() {
#[cfg(feature = "ssr")] #[cfg(feature = "ssr")]
pub fn create_leptos_router() -> axum::Router { pub fn create_leptos_router() -> axum::Router {
use crate::app::*;
use axum::Router; use axum::Router;
use leptos::prelude::*; use leptos::prelude::*;
use leptos_axum::{generate_route_list, LeptosRoutes}; use leptos_axum::{generate_route_list, LeptosRoutes};
use crate::app::*;
let conf = get_configuration(None).unwrap(); let conf = get_configuration(None).unwrap();
let leptos_options = conf.leptos_options; let leptos_options = conf.leptos_options;

View File

@@ -1,12 +1,11 @@
#[cfg(feature = "ssr")] #[cfg(feature = "ssr")]
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
use axum::Router; use axum::Router;
use leptos::logging::log; use leptos::logging::log;
use leptos::prelude::*; use leptos::prelude::*;
use leptos_axum::{generate_route_list, LeptosRoutes};
use leptos_app::app::*; use leptos_app::app::*;
use leptos_axum::{generate_route_list, LeptosRoutes};
let conf = get_configuration(None).unwrap(); let conf = get_configuration(None).unwrap();
let addr = conf.leptos_options.site_addr; let addr = conf.leptos_options.site_addr;

View File

@@ -18,6 +18,11 @@ candle-core = { git = "https://github.com/huggingface/candle.git", features = ["
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] } candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] } candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
[target.'cfg(not(target_os = "macos"))'.dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true }
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true }
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true }
[features] [features]
default = [] default = []
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]

View File

@@ -5,4 +5,3 @@ pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
// Re-export constants and types that might be needed // Re-export constants and types that might be needed
pub const EOS_TOKEN: &str = "</s>"; pub const EOS_TOKEN: &str = "</s>";

View File

@@ -1,14 +1,14 @@
use crate::EOS_TOKEN;
use anyhow::{bail, Error as E}; use anyhow::{bail, Error as E};
use candle_core::{utils, DType, Device, Tensor}; use candle_core::{utils, DType, Device, Tensor};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling}; use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama::{Llama, LlamaConfig};
use candle_transformers::models::llama as model; use candle_transformers::models::llama as model;
use candle_transformers::models::llama::{Llama, LlamaConfig};
use clap::ValueEnum;
use hf_hub::api::sync::Api; use hf_hub::api::sync::Api;
use hf_hub::{Repo, RepoType}; use hf_hub::{Repo, RepoType};
use std::sync::mpsc::{self, Receiver}; use std::sync::mpsc::{self, Receiver};
use clap::ValueEnum;
use crate::{EOS_TOKEN};
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)] #[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
pub enum WhichModel { pub enum WhichModel {
@@ -81,8 +81,8 @@ impl Default for LlamaInferenceConfig {
max_tokens: 512, max_tokens: 512,
// Performance flags // Performance flags
no_kv_cache: false, // keep cache ON for speed no_kv_cache: false, // keep cache ON for speed
use_flash_attn: true, // great speed boost if supported use_flash_attn: true, // great speed boost if supported
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed. // Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
dtype: Some("bf16".to_string()), dtype: Some("bf16".to_string()),
@@ -98,8 +98,6 @@ impl Default for LlamaInferenceConfig {
} }
} }
fn device(cpu: bool) -> anyhow::Result<Device> { fn device(cpu: bool) -> anyhow::Result<Device> {
if cpu { if cpu {
Ok(Device::Cpu) Ok(Device::Cpu)
@@ -112,7 +110,6 @@ fn device(cpu: bool) -> anyhow::Result<Device> {
} }
} }
fn hub_load_safetensors( fn hub_load_safetensors(
api: &hf_hub::api::sync::ApiRepo, api: &hf_hub::api::sync::ApiRepo,
json_file: &str, json_file: &str,
@@ -171,7 +168,7 @@ pub fn run_llama_inference(
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct", WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0", WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
} }
.to_string() .to_string()
}); });
println!("Loading model: {}", model_id); println!("Loading model: {}", model_id);
let revision = cfg.revision.clone().unwrap_or("main".to_string()); let revision = cfg.revision.clone().unwrap_or("main".to_string());
@@ -334,4 +331,3 @@ pub fn run_llama_inference(
Ok(rx) Ok(rx)
} }

View File

@@ -88,7 +88,6 @@ impl Into<LlamaInferenceConfig> for Args {
} }
} }
pub fn run_cli() -> anyhow::Result<()> { pub fn run_cli() -> anyhow::Result<()> {
let args = Args::parse(); let args = Args::parse();
let cfg = args.into(); let cfg = args.into();
@@ -106,4 +105,4 @@ pub fn run_cli() -> anyhow::Result<()> {
} }
} }
Ok(()) Ok(())
} }

View File

@@ -2,8 +2,8 @@
extern crate accelerate_src; extern crate accelerate_src;
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;
mod llama_cli;
mod llama_api; mod llama_api;
mod llama_cli;
use anyhow::Result; use anyhow::Result;
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
@@ -14,7 +14,6 @@ use crate::llama_cli::run_cli;
const EOS_TOKEN: &str = "</s>"; const EOS_TOKEN: &str = "</s>";
fn main() -> Result<()> { fn main() -> Result<()> {
run_cli() run_cli()
} }

View File

@@ -1,7 +1,9 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::env; use std::env;
use tracing::info;
use tracing::log::error;
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ServerConfig { pub struct ServerConfig {
#[serde(default = "default_server_host")] #[serde(default = "default_server_host")]
@@ -10,14 +12,16 @@ pub struct ServerConfig {
pub server_port: u16, pub server_port: u16,
pub server_mode: ServerMode, pub server_mode: ServerMode,
#[serde(default)] #[serde(default)]
pub services: Services, pub services: Option<Services>,
} }
fn default_server_host() -> String { fn default_server_host() -> String {
"127.0.0.1".to_string() "127.0.0.1".to_string()
} }
fn default_server_port() -> u16 { 8080 } fn default_server_port() -> u16 {
8080
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "PascalCase")] #[serde(rename_all = "PascalCase")]
@@ -34,17 +38,15 @@ impl Default for ServerMode {
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Services { pub struct Services {
#[serde(default = "inference_service_url")] pub inference_url: Option<String>,
pub inference_url: String, pub embeddings_url: Option<String>,
#[serde(default = "embeddings_service_url")]
pub embeddings_url: String,
} }
impl Default for Services { impl Default for Services {
fn default() -> Self { fn default() -> Self {
Self { Self {
inference_url: inference_service_url(), inference_url: None,
embeddings_url: embeddings_service_url(), embeddings_url: None,
} }
} }
} }
@@ -63,7 +65,7 @@ impl Default for ServerConfig {
server_host: "127.0.0.1".to_string(), server_host: "127.0.0.1".to_string(),
server_port: 8080, server_port: 8080,
server_mode: ServerMode::Standalone, server_mode: ServerMode::Standalone,
services: Services::default(), services: Some(Services::default()),
} }
} }
} }
@@ -73,21 +75,19 @@ impl ServerConfig {
/// Falls back to default (Local mode) if not set or invalid /// Falls back to default (Local mode) if not set or invalid
pub fn from_env() -> Self { pub fn from_env() -> Self {
match env::var("SERVER_CONFIG") { match env::var("SERVER_CONFIG") {
Ok(config_str) => { Ok(config_str) => match serde_json::from_str::<ServerConfig>(&config_str) {
match serde_json::from_str::<ServerConfig>(&config_str) { Ok(config) => {
Ok(config) => { tracing::info!("Loaded server configuration: {:?}", config);
tracing::info!("Loaded server configuration: {:?}", config); config
config
}
Err(e) => {
tracing::warn!(
"Failed to parse SERVER_CONFIG environment variable: {}. Using default configuration.",
e
);
ServerConfig::default()
}
} }
} Err(e) => {
tracing::warn!(
"Failed to parse SERVER_CONFIG environment variable: {}. Using default configuration.",
e
);
ServerConfig::default()
}
},
Err(_) => { Err(_) => {
tracing::info!("SERVER_CONFIG not set, Standalone mode active"); tracing::info!("SERVER_CONFIG not set, Standalone mode active");
ServerConfig::default() ServerConfig::default()
@@ -96,18 +96,52 @@ impl ServerConfig {
} }
/// Check if the server should run in high availability mode /// Check if the server should run in high availability mode
pub fn is_high_availability(&self) -> bool { pub fn is_high_availability(&self) -> Result<bool, std::io::Error> {
self.server_mode == ServerMode::HighAvailability if self.server_mode == ServerMode::HighAvailability {
let services_well_defined: bool = self.clone().services.is_some();
let inference_url_well_defined: bool =
services_well_defined && self.clone().services.unwrap().inference_url.is_some();
let embeddings_well_defined: bool =
services_well_defined && self.clone().services.unwrap().embeddings_url.is_some();
let is_well_defined_for_ha =
services_well_defined && inference_url_well_defined && embeddings_well_defined;
if !is_well_defined_for_ha {
let config_string = serde_json::to_string_pretty(&self).unwrap();
error!(
"HighAvailability mode configured but services not well defined! \n## Config Used:\n {}",
config_string
);
let err = std::io::Error::new(
std::io::ErrorKind::Other,
"HighAvailability mode configured but services not well defined!",
);
return Err(err);
}
}
Ok(self.server_mode == ServerMode::HighAvailability)
} }
/// Get the inference service URL for proxying /// Get the inference service URL for proxying
pub fn inference_url(&self) -> &str { pub fn inference_url(&self) -> Option<String> {
&self.services.inference_url if self.services.is_some() {
self.services.clone()?.inference_url
} else {
None
}
} }
/// Get the embeddings service URL for proxying /// Get the embeddings service URL for proxying
pub fn embeddings_url(&self) -> &str { pub fn embeddings_url(&self) -> Option<String> {
&self.services.embeddings_url if self.services.is_some() {
self.services.clone()?.embeddings_url
} else {
None
}
} }
} }
@@ -119,7 +153,7 @@ mod tests {
fn test_default_config() { fn test_default_config() {
let config = ServerConfig::default(); let config = ServerConfig::default();
assert_eq!(config.server_mode, ServerMode::Standalone); assert_eq!(config.server_mode, ServerMode::Standalone);
assert!(!config.is_high_availability()); assert!(!config.is_high_availability().unwrap());
} }
#[test] #[test]
@@ -134,23 +168,26 @@ mod tests {
let config: ServerConfig = serde_json::from_str(config_json).unwrap(); let config: ServerConfig = serde_json::from_str(config_json).unwrap();
assert_eq!(config.server_mode, ServerMode::HighAvailability); assert_eq!(config.server_mode, ServerMode::HighAvailability);
assert!(config.is_high_availability()); assert!(config.is_high_availability().unwrap());
assert_eq!(config.inference_url(), "http://inference-service:8080"); assert_eq!(
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080"); config.inference_url().unwrap(),
"http://inference-service:8080"
);
assert_eq!(
config.embeddings_url().unwrap(),
"http://embeddings-service:8080"
);
} }
#[test] #[test]
fn test_local_mode_config() { fn test_local_mode_config() {
let config_json = r#"{ let config_json = r#"{
"serverMode": "Local" "serverMode": "Standalone"
}"#; }"#;
let config: ServerConfig = serde_json::from_str(config_json).unwrap(); let config: ServerConfig = serde_json::from_str(config_json).unwrap();
assert_eq!(config.server_mode, ServerMode::Standalone); assert_eq!(config.server_mode, ServerMode::Standalone);
assert!(!config.is_high_availability()); assert!(!config.is_high_availability().unwrap());
// Should use default URLs
assert_eq!(config.inference_url(), "http://inference-service:8080");
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
} }
#[test] #[test]
@@ -164,17 +201,26 @@ mod tests {
}"#; }"#;
let config: ServerConfig = serde_json::from_str(config_json).unwrap(); let config: ServerConfig = serde_json::from_str(config_json).unwrap();
assert_eq!(config.inference_url(), "http://custom-inference:9000"); assert_eq!(
assert_eq!(config.embeddings_url(), "http://custom-embeddings:9001"); config.inference_url().unwrap(),
"http://custom-inference:9000"
);
assert_eq!(
config.embeddings_url().unwrap(),
"http://custom-embeddings:9001"
);
} }
#[test] #[test]
fn test_minimal_high_availability_config() { fn test_minimal_high_availability_config_error() {
let config_json = r#"{"serverMode": "HighAvailability"}"#; let config_json = r#"{"serverMode": "HighAvailability"}"#;
let config: ServerConfig = serde_json::from_str(config_json).unwrap(); let config: ServerConfig = serde_json::from_str(config_json).unwrap();
assert!(config.is_high_availability());
// Should use default URLs let is_high_availability = config.is_high_availability();
assert_eq!(config.inference_url(), "http://inference-service:8080");
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080"); assert!(is_high_availability.is_err());
// // Should use default URLs
// assert_eq!(config.inference_url().unwrap(), "http://inference-service:8080");
// assert_eq!(config.embeddings_url().unwrap(), "http://embeddings-service:8080");
} }
} }

View File

@@ -1,7 +1,9 @@
mod config; mod config;
mod middleware; mod middleware;
mod proxy; mod proxy;
mod standalone;
use crate::standalone::create_standalone_router;
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::get; use axum::routing::get;
use axum::{Router, http::Uri, response::Html, serve}; use axum::{Router, http::Uri, response::Html, serve};
@@ -11,6 +13,7 @@ use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
use proxy::create_proxy_router; use proxy::create_proxy_router;
use rust_embed::Embed; use rust_embed::Embed;
use std::env; use std::env;
use std::path::Component::ParentDir;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tower_http::classify::ServerErrorsFailureClass::StatusCode; use tower_http::classify::ServerErrorsFailureClass::StatusCode;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
@@ -49,33 +52,19 @@ async fn main() {
let default_host = server_config.server_host.clone(); let default_host = server_config.server_host.clone();
let default_port = server_config.server_port; let default_port = server_config.server_port;
// Create router based on server mode let service_router = match server_config.clone().is_high_availability() {
let service_router = if server_config.clone().is_high_availability() { Ok(is_ha) => {
tracing::info!("Running in HighAvailability mode - proxying to external services"); if is_ha {
tracing::info!(" Inference service URL: {}", server_config.inference_url()); log_config(server_config.clone());
tracing::info!( create_proxy_router(server_config.clone())
" Embeddings service URL: {}", } else {
server_config.embeddings_url() log_config(server_config.clone());
); create_standalone_router(server_config)
}
// Use proxy router that forwards requests to external services }
create_proxy_router(server_config.clone()) Err(error) => {
} else { panic!("{}", error);
tracing::info!("Running in Standalone mode - using embedded services"); }
// Create unified router by merging embeddings and inference routers (existing behavior)
let embeddings_router = embeddings_engine::create_embeddings_router();
// Create AppState with correct model configuration
let app_state = AppState::default();
// Get the inference router directly from the inference engine
let inference_router = inference_engine::create_router(app_state);
// Merge the local routers
Router::new()
.merge(embeddings_router)
.merge(inference_router)
}; };
// Create CORS layer // Create CORS layer
@@ -124,5 +113,25 @@ async fn main() {
serve(listener, app).await.unwrap(); serve(listener, app).await.unwrap();
} }
fn log_config(config: ServerConfig) {
match config.is_high_availability() {
Ok(is_high) => {
if is_high {
tracing::info!("Running in HighAvailability mode - proxying to external services");
tracing::info!("Inference service URL: {}", config.inference_url().unwrap());
tracing::info!(
"Embeddings service URL: {}",
config.embeddings_url().unwrap()
);
} else {
tracing::info!("Running in Standalone mode");
}
}
Err(error) => {
panic!("{}", error);
}
}
}
// Chat completions handler that properly uses the inference server crate's error handling // Chat completions handler that properly uses the inference server crate's error handling
// This function is no longer needed as we're using the inference_engine router directly // This function is no longer needed as we're using the inference_engine router directly

View File

@@ -2,6 +2,8 @@ use axum::{
extract::MatchedPath, extract::MatchedPath,
http::{Request, Response}, http::{Request, Response},
}; };
use std::fmt;
use std::task::ready;
use std::{ use std::{
future::Future, future::Future,
pin::Pin, pin::Pin,
@@ -12,8 +14,6 @@ use std::{
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tower::{Layer, Service}; use tower::{Layer, Service};
use tracing::{debug, info}; use tracing::{debug, info};
use std::task::ready;
use std::fmt;
/// Performance metrics for a specific endpoint /// Performance metrics for a specific endpoint
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
@@ -33,16 +33,16 @@ impl EndpointMetrics {
pub fn add_response_time(&mut self, time_ms: u64) { pub fn add_response_time(&mut self, time_ms: u64) {
self.count += 1; self.count += 1;
self.total_time_ms += time_ms; self.total_time_ms += time_ms;
if self.min_time_ms == 0 || time_ms < self.min_time_ms { if self.min_time_ms == 0 || time_ms < self.min_time_ms {
self.min_time_ms = time_ms; self.min_time_ms = time_ms;
} }
if time_ms > self.max_time_ms { if time_ms > self.max_time_ms {
self.max_time_ms = time_ms; self.max_time_ms = time_ms;
} }
} }
/// Get the average response time in milliseconds /// Get the average response time in milliseconds
pub fn avg_time_ms(&self) -> f64 { pub fn avg_time_ms(&self) -> f64 {
if self.count == 0 { if self.count == 0 {
@@ -51,12 +51,15 @@ impl EndpointMetrics {
self.total_time_ms as f64 / self.count as f64 self.total_time_ms as f64 / self.count as f64
} }
} }
/// Get a human-readable summary of the metrics /// Get a human-readable summary of the metrics
pub fn summary(&self) -> String { pub fn summary(&self) -> String {
format!( format!(
"requests: {}, avg: {:.2}ms, min: {}ms, max: {}ms", "requests: {}, avg: {:.2}ms, min: {}ms, max: {}ms",
self.count, self.avg_time_ms(), self.min_time_ms, self.max_time_ms self.count,
self.avg_time_ms(),
self.min_time_ms,
self.max_time_ms
) )
} }
} }
@@ -75,14 +78,16 @@ impl MetricsStore {
endpoints: Arc::new(Mutex::new(std::collections::HashMap::new())), endpoints: Arc::new(Mutex::new(std::collections::HashMap::new())),
} }
} }
/// Record a request's timing information /// Record a request's timing information
pub async fn record(&self, path: String, time_ms: u64) { pub async fn record(&self, path: String, time_ms: u64) {
let mut endpoints = self.endpoints.lock().await; let mut endpoints = self.endpoints.lock().await;
let metrics = endpoints.entry(path).or_insert_with(EndpointMetrics::default); let metrics = endpoints
.entry(path)
.or_insert_with(EndpointMetrics::default);
metrics.add_response_time(time_ms); metrics.add_response_time(time_ms);
} }
/// Get metrics for all endpoints /// Get metrics for all endpoints
pub async fn get_all(&self) -> Vec<(String, EndpointMetrics)> { pub async fn get_all(&self) -> Vec<(String, EndpointMetrics)> {
let endpoints = self.endpoints.lock().await; let endpoints = self.endpoints.lock().await;
@@ -91,12 +96,12 @@ impl MetricsStore {
.map(|(k, v)| (k.clone(), v.clone())) .map(|(k, v)| (k.clone(), v.clone()))
.collect() .collect()
} }
/// Log a summary of all metrics /// Log a summary of all metrics
pub async fn log_summary(&self) { pub async fn log_summary(&self) {
let metrics = self.get_all().await; let metrics = self.get_all().await;
info!("Performance metrics summary:"); info!("Performance metrics summary:");
for (path, metric) in metrics { for (path, metric) in metrics {
info!(" {}: {}", path, metric.summary()); info!(" {}: {}", path, metric.summary());
} }
@@ -163,26 +168,28 @@ where
} else { } else {
req.uri().path().to_string() req.uri().path().to_string()
}; };
let method = req.method().clone(); let method = req.method().clone();
let start = Instant::now(); let start = Instant::now();
let metrics_store = self.metrics_store.clone(); let metrics_store = self.metrics_store.clone();
let future = self.inner.call(req); let future = self.inner.call(req);
Box::pin(async move { Box::pin(async move {
let response = future.await?; let response = future.await?;
let time = start.elapsed(); let time = start.elapsed();
let status = response.status(); let status = response.status();
let time_ms = time.as_millis() as u64; let time_ms = time.as_millis() as u64;
// Record the timing in our metrics store // Record the timing in our metrics store
metrics_store.record(format!("{} {}", method, path), time_ms).await; metrics_store
.record(format!("{} {}", method, path), time_ms)
.await;
// Log the request timing // Log the request timing
debug!("{} {} {} - {} ms", method, path, status, time_ms); debug!("{} {} {} - {} ms", method, path, status, time_ms);
Ok(response) Ok(response)
}) })
} }
@@ -214,7 +221,7 @@ impl Future for MetricsLoggerFuture {
metrics_store.log_summary().await; metrics_store.log_summary().await;
}); });
} }
Poll::Pending Poll::Pending
} }
} }

View File

@@ -1,7 +1,3 @@
pub mod metrics; pub mod metrics;
pub use metrics::{ pub use metrics::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
MetricsStore,
MetricsLoggerFuture,
MetricsLayer,
};

View File

@@ -1,10 +1,10 @@
use axum::{ use axum::{
Router,
body::Body, body::Body,
extract::{Request, State}, extract::{Request, State},
http::{HeaderMap, Method, StatusCode, Uri}, http::{HeaderMap, Method, StatusCode, Uri},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
routing::{get, post}, routing::{get, post},
Router,
}; };
use reqwest::Client; use reqwest::Client;
use serde_json::Value; use serde_json::Value;
@@ -47,10 +47,16 @@ async fn proxy_chat_completions(
headers: HeaderMap, headers: HeaderMap,
body: Body, body: Body,
) -> Result<Response, StatusCode> { ) -> Result<Response, StatusCode> {
let target_url = format!("{}/v1/chat/completions", proxy_client.config.inference_url()); let target_url = format!(
"{}/v1/chat/completions",
proxy_client
.config
.inference_url()
.expect("Invalid Configuration")
);
tracing::info!("Proxying chat completions request to: {}", target_url); tracing::info!("Proxying chat completions request to: {}", target_url);
// Extract body as bytes // Extract body as bytes
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await { let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes, Ok(bytes) => bytes,
@@ -63,7 +69,9 @@ async fn proxy_chat_completions(
// Check if this is a streaming request // Check if this is a streaming request
let is_streaming = if let Ok(body_str) = String::from_utf8(body_bytes.to_vec()) { let is_streaming = if let Ok(body_str) = String::from_utf8(body_bytes.to_vec()) {
if let Ok(json) = serde_json::from_str::<Value>(&body_str) { if let Ok(json) = serde_json::from_str::<Value>(&body_str) {
json.get("stream").and_then(|v| v.as_bool()).unwrap_or(false) json.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false)
} else { } else {
false false
} }
@@ -72,7 +80,8 @@ async fn proxy_chat_completions(
}; };
// Forward the request // Forward the request
let mut req_builder = proxy_client.client let mut req_builder = proxy_client
.client
.post(&target_url) .post(&target_url)
.body(body_bytes.to_vec()); .body(body_bytes.to_vec());
@@ -85,8 +94,7 @@ async fn proxy_chat_completions(
match req_builder.send().await { match req_builder.send().await {
Ok(response) => { Ok(response) => {
let mut resp_builder = Response::builder() let mut resp_builder = Response::builder().status(response.status());
.status(response.status());
// Forward response headers // Forward response headers
for (name, value) in response.headers().iter() { for (name, value) in response.headers().iter() {
@@ -99,14 +107,12 @@ async fn proxy_chat_completions(
if is_streaming { if is_streaming {
// For streaming, we need to forward the response as-is // For streaming, we need to forward the response as-is
match response.bytes().await { match response.bytes().await {
Ok(body) => { Ok(body) => resp_builder
resp_builder .header("content-type", "text/plain; charset=utf-8")
.header("content-type", "text/plain; charset=utf-8") .header("cache-control", "no-cache")
.header("cache-control", "no-cache") .header("connection", "keep-alive")
.header("connection", "keep-alive") .body(Body::from(body))
.body(Body::from(body)) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
Err(e) => { Err(e) => {
tracing::error!("Failed to read streaming response body: {}", e); tracing::error!("Failed to read streaming response body: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR) Err(StatusCode::INTERNAL_SERVER_ERROR)
@@ -115,11 +121,9 @@ async fn proxy_chat_completions(
} else { } else {
// For non-streaming, forward the JSON response // For non-streaming, forward the JSON response
match response.bytes().await { match response.bytes().await {
Ok(body) => { Ok(body) => resp_builder
resp_builder .body(Body::from(body))
.body(Body::from(body)) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
Err(e) => { Err(e) => {
tracing::error!("Failed to read response body: {}", e); tracing::error!("Failed to read response body: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR) Err(StatusCode::INTERNAL_SERVER_ERROR)
@@ -139,10 +143,16 @@ async fn proxy_models(
State(proxy_client): State<ProxyClient>, State(proxy_client): State<ProxyClient>,
headers: HeaderMap, headers: HeaderMap,
) -> Result<Response, StatusCode> { ) -> Result<Response, StatusCode> {
let target_url = format!("{}/v1/models", proxy_client.config.inference_url()); let target_url = format!(
"{}/v1/models",
proxy_client
.config
.inference_url()
.expect("Invalid Configuration Detected")
);
tracing::info!("Proxying models request to: {}", target_url); tracing::info!("Proxying models request to: {}", target_url);
let mut req_builder = proxy_client.client.get(&target_url); let mut req_builder = proxy_client.client.get(&target_url);
// Forward relevant headers // Forward relevant headers
@@ -154,8 +164,7 @@ async fn proxy_models(
match req_builder.send().await { match req_builder.send().await {
Ok(response) => { Ok(response) => {
let mut resp_builder = Response::builder() let mut resp_builder = Response::builder().status(response.status());
.status(response.status());
// Forward response headers // Forward response headers
for (name, value) in response.headers().iter() { for (name, value) in response.headers().iter() {
@@ -165,11 +174,9 @@ async fn proxy_models(
} }
match response.bytes().await { match response.bytes().await {
Ok(body) => { Ok(body) => resp_builder
resp_builder .body(Body::from(body))
.body(Body::from(body)) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
Err(e) => { Err(e) => {
tracing::error!("Failed to read models response body: {}", e); tracing::error!("Failed to read models response body: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR) Err(StatusCode::INTERNAL_SERVER_ERROR)
@@ -189,10 +196,16 @@ async fn proxy_embeddings(
headers: HeaderMap, headers: HeaderMap,
body: Body, body: Body,
) -> Result<Response, StatusCode> { ) -> Result<Response, StatusCode> {
let target_url = format!("{}/v1/embeddings", proxy_client.config.embeddings_url()); let target_url = format!(
"{}/v1/embeddings",
proxy_client
.config
.embeddings_url()
.expect("Invalid Configuration Detected")
);
tracing::info!("Proxying embeddings request to: {}", target_url); tracing::info!("Proxying embeddings request to: {}", target_url);
// Extract body as bytes // Extract body as bytes
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await { let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes, Ok(bytes) => bytes,
@@ -203,7 +216,8 @@ async fn proxy_embeddings(
}; };
// Forward the request // Forward the request
let mut req_builder = proxy_client.client let mut req_builder = proxy_client
.client
.post(&target_url) .post(&target_url)
.body(body_bytes.to_vec()); .body(body_bytes.to_vec());
@@ -216,8 +230,7 @@ async fn proxy_embeddings(
match req_builder.send().await { match req_builder.send().await {
Ok(response) => { Ok(response) => {
let mut resp_builder = Response::builder() let mut resp_builder = Response::builder().status(response.status());
.status(response.status());
// Forward response headers // Forward response headers
for (name, value) in response.headers().iter() { for (name, value) in response.headers().iter() {
@@ -227,11 +240,9 @@ async fn proxy_embeddings(
} }
match response.bytes().await { match response.bytes().await {
Ok(body) => { Ok(body) => resp_builder
resp_builder .body(Body::from(body))
.body(Body::from(body)) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
Err(e) => { Err(e) => {
tracing::error!("Failed to read embeddings response body: {}", e); tracing::error!("Failed to read embeddings response body: {}", e);
Err(StatusCode::INTERNAL_SERVER_ERROR) Err(StatusCode::INTERNAL_SERVER_ERROR)
@@ -250,7 +261,7 @@ fn should_forward_header(header_name: &str) -> bool {
match header_name.to_lowercase().as_str() { match header_name.to_lowercase().as_str() {
"content-type" | "content-length" | "authorization" | "user-agent" | "accept" => true, "content-type" | "content-length" | "authorization" | "user-agent" | "accept" => true,
"host" | "connection" | "upgrade" => false, // Don't forward connection-specific headers "host" | "connection" | "upgrade" => false, // Don't forward connection-specific headers
_ => true, // Forward other headers by default _ => true, // Forward other headers by default
} }
} }
@@ -259,7 +270,7 @@ fn should_forward_response_header(header_name: &str) -> bool {
match header_name.to_lowercase().as_str() { match header_name.to_lowercase().as_str() {
"content-type" | "content-length" | "cache-control" | "connection" => true, "content-type" | "content-length" | "cache-control" | "connection" => true,
"server" | "date" => false, // Don't forward server-specific headers "server" | "date" => false, // Don't forward server-specific headers
_ => true, // Forward other headers by default _ => true, // Forward other headers by default
} }
} }
@@ -290,14 +301,20 @@ mod tests {
server_host: "127.0.0.1".to_string(), server_host: "127.0.0.1".to_string(),
server_port: 8080, server_port: 8080,
server_mode: ServerMode::HighAvailability, server_mode: ServerMode::HighAvailability,
services: Services { services: Some(Services {
inference_url: "http://test-inference:8080".to_string(), inference_url: Some("http://test-inference:8080".to_string()),
embeddings_url: "http://test-embeddings:8080".to_string(), embeddings_url: Some("http://test-embeddings:8080".to_string()),
}, }),
}; };
let proxy_client = ProxyClient::new(config); let proxy_client = ProxyClient::new(config);
assert_eq!(proxy_client.config.inference_url(), "http://test-inference:8080"); assert_eq!(
assert_eq!(proxy_client.config.embeddings_url(), "http://test-embeddings:8080"); proxy_client.config.inference_url().unwrap().as_str(),
"http://test-inference:8080"
);
assert_eq!(
proxy_client.config.embeddings_url().unwrap().as_str(),
"http://test-embeddings:8080"
);
} }
} }

View File

@@ -0,0 +1,19 @@
use crate::config::ServerConfig;
use axum::Router;
use inference_engine::AppState;
pub fn create_standalone_router(server_config: ServerConfig) -> Router {
// Create unified router by merging embeddings and inference routers (existing behavior)
let embeddings_router = embeddings_engine::create_embeddings_router();
// Create AppState with correct model configuration
let app_state = AppState::default();
// Get the inference router directly from the inference engine
let inference_router = inference_engine::create_router(app_state);
// Merge the local routers
Router::new()
.merge(embeddings_router)
.merge(inference_router)
}

389
scripts/build_all_platforms.sh Executable file
View File

@@ -0,0 +1,389 @@
#!/bin/bash
# Cross-platform build script for predict-otron-9000
# Builds all workspace crates for common platforms
set -euo pipefail
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# Configuration
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
BUILD_DIR="${PROJECT_ROOT}/build"
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
# Supported platforms
PLATFORMS=(
"x86_64-unknown-linux-gnu"
"x86_64-pc-windows-msvc"
"x86_64-apple-darwin"
"aarch64-apple-darwin"
"aarch64-unknown-linux-gnu"
)
# Main binaries to build
MAIN_BINARIES=(
"predict-otron-9000"
"embeddings-engine"
)
# Inference engine binaries (with bin feature)
INFERENCE_BINARIES=(
"gemma_inference"
"llama_inference"
)
# Other workspace binaries
OTHER_BINARIES=(
"helm-chart-tool"
)
print_header() {
echo -e "${BLUE}================================${NC}"
echo -e "${BLUE}$1${NC}"
echo -e "${BLUE}================================${NC}"
}
print_info() {
echo -e "${GREEN}[INFO]${NC} $1"
}
print_warn() {
echo -e "${YELLOW}[WARN]${NC} $1"
}
print_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
check_dependencies() {
print_header "Checking Dependencies"
# Check rust
if ! command -v cargo >/dev/null 2>&1; then
print_error "Rust/Cargo is not installed"
exit 1
fi
# Check cargo-leptos for WASM frontend
if ! command -v cargo-leptos >/dev/null 2>&1; then
print_warn "cargo-leptos not found. Installing..."
cargo install cargo-leptos
fi
print_info "All dependencies available"
}
install_targets() {
print_header "Installing Rust Targets"
for platform in "${PLATFORMS[@]}"; do
print_info "Installing target: $platform"
rustup target add "$platform" || {
print_warn "Failed to install target $platform (may not be available on this host)"
}
done
# Add WASM target for leptos
print_info "Installing wasm32-unknown-unknown target for Leptos"
rustup target add wasm32-unknown-unknown
}
create_build_dirs() {
print_header "Setting up Build Directory"
rm -rf "$BUILD_DIR"
mkdir -p "$BUILD_DIR"
for platform in "${PLATFORMS[@]}"; do
mkdir -p "$BUILD_DIR/$platform"
done
mkdir -p "$BUILD_DIR/web"
print_info "Build directories created"
}
build_leptos_app() {
print_header "Building Leptos Web Frontend"
cd "$PROJECT_ROOT/crates/leptos-app"
# Build the WASM frontend
print_info "Building WASM frontend with cargo-leptos..."
cargo leptos build --release || {
print_error "Failed to build Leptos WASM frontend"
return 1
}
# Copy built assets to build directory
if [ -d "target/site" ]; then
cp -r target/site/* "$BUILD_DIR/web/"
print_info "Leptos frontend built and copied to $BUILD_DIR/web/"
else
print_error "Leptos build output not found at target/site"
return 1
fi
cd "$PROJECT_ROOT"
}
get_platform_features() {
local platform="$1"
local features=""
case "$platform" in
*-apple-darwin)
# macOS uses Metal but routes to CPU for Gemma stability
features=""
;;
*-unknown-linux-gnu|*-pc-windows-msvc)
# Linux and Windows can use CUDA if available
features=""
;;
*)
features=""
;;
esac
echo "$features"
}
build_binary_for_platform() {
local binary_name="$1"
local platform="$2"
local package_name="$3"
local additional_args="$4"
print_info "Building $binary_name for $platform"
local features=$(get_platform_features "$platform")
local feature_flag=""
if [ -n "$features" ]; then
feature_flag="--features $features"
fi
# Build command
local build_cmd="cargo build --release --target $platform --bin $binary_name"
if [ -n "$package_name" ]; then
build_cmd="$build_cmd --package $package_name"
fi
if [ -n "$additional_args" ]; then
build_cmd="$build_cmd $additional_args"
fi
if [ -n "$feature_flag" ]; then
build_cmd="$build_cmd $feature_flag"
fi
print_info "Running: $build_cmd"
if eval "$build_cmd"; then
# Copy binary to build directory
local target_dir="target/$platform/release"
local binary_file="$binary_name"
# Add .exe extension for Windows
if [[ "$platform" == *-pc-windows-msvc ]]; then
binary_file="$binary_name.exe"
fi
if [ -f "$target_dir/$binary_file" ]; then
cp "$target_dir/$binary_file" "$BUILD_DIR/$platform/"
print_info "$binary_name built and copied for $platform"
else
print_error "Binary not found: $target_dir/$binary_file"
return 1
fi
else
print_error "Failed to build $binary_name for $platform"
return 1
fi
}
build_for_platform() {
local platform="$1"
print_header "Building for $platform"
local failed_builds=()
# Build main binaries
for binary in "${MAIN_BINARIES[@]}"; do
if ! build_binary_for_platform "$binary" "$platform" "$binary" ""; then
failed_builds+=("$binary")
fi
done
# Build inference engine binaries with bin feature
for binary in "${INFERENCE_BINARIES[@]}"; do
if ! build_binary_for_platform "$binary" "$platform" "inference-engine" "--features bin"; then
failed_builds+=("$binary")
fi
done
# Build other workspace binaries
for binary in "${OTHER_BINARIES[@]}"; do
if ! build_binary_for_platform "$binary" "$platform" "$binary" ""; then
failed_builds+=("$binary")
fi
done
if [ ${#failed_builds[@]} -eq 0 ]; then
print_info "✓ All binaries built successfully for $platform"
else
print_warn "Some builds failed for $platform: ${failed_builds[*]}"
fi
}
create_archives() {
print_header "Creating Release Archives"
cd "$BUILD_DIR"
for platform in "${PLATFORMS[@]}"; do
if [ -d "$platform" ] && [ -n "$(ls -A "$platform" 2>/dev/null)" ]; then
local archive_name="predict-otron-9000-${platform}-${TIMESTAMP}"
print_info "Creating archive for $platform"
# Create platform-specific directory with all files
mkdir -p "$archive_name"
cp -r "$platform"/* "$archive_name/"
# Add web assets to each platform archive
if [ -d "web" ]; then
mkdir -p "$archive_name/web"
cp -r web/* "$archive_name/web/"
fi
# Create README for the platform
cat > "$archive_name/README.txt" << EOF
Predict-Otron-9000 - Platform: $platform
Build Date: $(date)
========================================
Binaries included:
$(ls -1 "$platform")
Web Frontend:
- Located in the 'web' directory
- Serve with any static file server on port 8788 or configure your server
Usage:
1. Start the main server: ./predict-otron-9000
2. Start embeddings service: ./embeddings-engine
3. Access web interface at http://localhost:8080 (served by main server)
For more information, visit: https://github.com/geoffsee/predict-otron-9000
EOF
# Create tar.gz archive
tar -czf "${archive_name}.tar.gz" "$archive_name"
rm -rf "$archive_name"
print_info "✓ Created ${archive_name}.tar.gz"
else
print_warn "No binaries found for $platform, skipping archive"
fi
done
cd "$PROJECT_ROOT"
}
generate_build_report() {
print_header "Build Report"
echo "Build completed at: $(date)"
echo "Build directory: $BUILD_DIR"
echo ""
echo "Archives created:"
ls -la "$BUILD_DIR"/*.tar.gz 2>/dev/null || echo "No archives created"
echo ""
echo "Platform directories:"
for platform in "${PLATFORMS[@]}"; do
if [ -d "$BUILD_DIR/$platform" ]; then
echo " $platform:"
ls -la "$BUILD_DIR/$platform" | sed 's/^/ /'
fi
done
if [ -d "$BUILD_DIR/web" ]; then
echo ""
echo "Web frontend assets:"
ls -la "$BUILD_DIR/web" | head -10 | sed 's/^/ /'
if [ $(ls -1 "$BUILD_DIR/web" | wc -l) -gt 10 ]; then
echo " ... and $(( $(ls -1 "$BUILD_DIR/web" | wc -l) - 10 )) more files"
fi
fi
}
main() {
print_header "Predict-Otron-9000 Cross-Platform Build Script"
cd "$PROJECT_ROOT"
check_dependencies
install_targets
create_build_dirs
# Build Leptos web frontend first
build_leptos_app
# Build for each platform
for platform in "${PLATFORMS[@]}"; do
build_for_platform "$platform"
done
create_archives
generate_build_report
print_header "Build Complete!"
print_info "All artifacts are available in: $BUILD_DIR"
}
# Handle command line arguments
case "${1:-}" in
--help|-h)
echo "Usage: $0 [options]"
echo ""
echo "Cross-platform build script for predict-otron-9000"
echo ""
echo "Options:"
echo " --help, -h Show this help message"
echo " --platforms Show supported platforms"
echo " --clean Clean build directory before building"
echo ""
echo "Supported platforms:"
for platform in "${PLATFORMS[@]}"; do
echo " - $platform"
done
echo ""
echo "Prerequisites:"
echo " - Rust toolchain with rustup"
echo " - cargo-leptos (will be installed if missing)"
echo " - Platform-specific toolchains for cross-compilation"
echo ""
exit 0
;;
--platforms)
echo "Supported platforms:"
for platform in "${PLATFORMS[@]}"; do
echo " - $platform"
done
exit 0
;;
--clean)
print_info "Cleaning build directory..."
rm -rf "$BUILD_DIR"
print_info "Build directory cleaned"
;;
esac
main "$@"

19
scripts/build_cli.sh Executable file
View File

@@ -0,0 +1,19 @@
#!/usr/bin/env sh
set -e
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
TEMP_DIR="$SCRIPT_DIR/temp"
mkdir -p "$TEMP_DIR"
cp "$SCRIPT_DIR/cli.ts" "$TEMP_DIR/cli.ts"
cp "$SCRIPT_DIR/../package.json" "$TEMP_DIR/package.json"
(
cd "$TEMP_DIR"
bun i
bun build ./cli.ts --compile --outfile "$SCRIPT_DIR/cli"
)
rm -rf "$TEMP_DIR"

View File

@@ -1,30 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
PROMPT=${1:-"Say hello in one short sentence."}
MODEL=${2:-"meta-llama/Llama-3.2-1B-Instruct"}
MAX_NEW=${3:-64}
FORCE_CPU=${FORCE_CPU:-0}
# Optional: keep HF cache local to repo if not already set
export HF_HOME=${HF_HOME:-"$PWD/.hf-cache"}
BIN="$(dirname "$0")/../target/release/llama_infer"
if [[ ! -x "$BIN" ]]; then
echo "Building llama-runner (release)..."
cargo build -p llama-runner --release
fi
echo "Running llama inference..." >&2
ARGS=(
--model-id "$MODEL"
--prompt "$PROMPT"
--max-new-tokens "$MAX_NEW"
)
if [[ "$FORCE_CPU" == "1" || "$FORCE_CPU" == "true" ]]; then
ARGS+=( --force-cpu )
fi
"$BIN" "${ARGS[@]}"