mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
cleanup, add ci
This commit is contained in:
49
.github/dependabot.yml
vendored
Normal file
49
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
# Monitor Rust dependencies in the main crate
|
||||||
|
- package-ecosystem: "cargo"
|
||||||
|
directory: "/crates/predict-otron-9000"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
day: "monday"
|
||||||
|
time: "09:00"
|
||||||
|
timezone: "UTC"
|
||||||
|
# Focus on security updates with higher priority
|
||||||
|
open-pull-requests-limit: 10
|
||||||
|
reviewers:
|
||||||
|
- "security-team"
|
||||||
|
assignees:
|
||||||
|
- "maintainer"
|
||||||
|
labels:
|
||||||
|
- "dependencies"
|
||||||
|
- "security"
|
||||||
|
# Security updates get higher priority
|
||||||
|
allow:
|
||||||
|
- dependency-type: "all"
|
||||||
|
# Group minor and patch updates to reduce noise
|
||||||
|
# Separate major updates for careful review
|
||||||
|
ignore:
|
||||||
|
- dependency-name: "*"
|
||||||
|
update-types: ["version-update:semver-major"]
|
||||||
|
commit-message:
|
||||||
|
prefix: "deps"
|
||||||
|
include: "scope"
|
||||||
|
|
||||||
|
# Monitor security updates more frequently
|
||||||
|
- package-ecosystem: "cargo"
|
||||||
|
directory: "/crates/predict-otron-9000"
|
||||||
|
schedule:
|
||||||
|
interval: "daily"
|
||||||
|
# Only security updates in daily checks
|
||||||
|
allow:
|
||||||
|
- dependency-type: "direct"
|
||||||
|
update-types: ["security"]
|
||||||
|
- dependency-type: "indirect"
|
||||||
|
update-types: ["security"]
|
||||||
|
open-pull-requests-limit: 5
|
||||||
|
labels:
|
||||||
|
- "security-update"
|
||||||
|
- "high-priority"
|
||||||
|
commit-message:
|
||||||
|
prefix: "security"
|
||||||
|
include: "scope"
|
47
.github/workflows/ci.yml
vendored
Normal file
47
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: build-and-test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo/bin/
|
||||||
|
~/.cargo/registry/index/
|
||||||
|
~/.cargo/registry/cache/
|
||||||
|
~/.cargo/git/db/
|
||||||
|
target/
|
||||||
|
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
|
||||||
|
- name: Setup Rust
|
||||||
|
run: rustup update stable && rustup default stable
|
||||||
|
|
||||||
|
- name: Install clippy and rustfmt
|
||||||
|
run: rustup component add clippy rustfmt
|
||||||
|
|
||||||
|
- name: Cargo fmt (check)
|
||||||
|
run: cargo fmt --all -- --check
|
||||||
|
|
||||||
|
- name: Clippy
|
||||||
|
shell: bash
|
||||||
|
run: cargo clippy --all-targets
|
||||||
|
|
||||||
|
- name: Tests
|
||||||
|
shell: bash
|
||||||
|
run: cargo test --all
|
||||||
|
|
||||||
|
- name: Build Docs
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cargo doc -p predict-otron-9000 --no-deps
|
232
.github/workflows/release.yml
vendored
Normal file
232
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
name: Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
|
||||||
|
env:
|
||||||
|
CARGO_TERM_COLOR: always
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
name: Test before release
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
working-directory: crates/predict-otron-9000
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo/bin/
|
||||||
|
~/.cargo/registry/index/
|
||||||
|
~/.cargo/registry/cache/
|
||||||
|
~/.cargo/git/db/
|
||||||
|
target/
|
||||||
|
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
|
||||||
|
- name: Setup Rust
|
||||||
|
run: rustup update stable && rustup default stable
|
||||||
|
|
||||||
|
- name: Install clippy and rustfmt
|
||||||
|
run: rustup component add clippy rustfmt
|
||||||
|
|
||||||
|
- name: Cargo fmt (check)
|
||||||
|
run: cargo fmt --all -- --check
|
||||||
|
|
||||||
|
- name: Clippy
|
||||||
|
shell: bash
|
||||||
|
run: cargo clippy --all-targets
|
||||||
|
|
||||||
|
- name: Tests
|
||||||
|
shell: bash
|
||||||
|
run: cargo test --all
|
||||||
|
|
||||||
|
# publish:
|
||||||
|
# name: Publish to crates.io
|
||||||
|
# runs-on: ubuntu-latest
|
||||||
|
# permissions:
|
||||||
|
# id-token: write # Required for OIDC token exchange https://crates.io/docs/trusted-publishing
|
||||||
|
# needs: test
|
||||||
|
# defaults:
|
||||||
|
# run:
|
||||||
|
# working-directory: crates/predict-otron-9000
|
||||||
|
# steps:
|
||||||
|
# - name: Checkout
|
||||||
|
# uses: actions/checkout@v4
|
||||||
|
#
|
||||||
|
# - uses: actions/cache@v4
|
||||||
|
# with:
|
||||||
|
# path: |
|
||||||
|
# ~/.cargo/bin/
|
||||||
|
# ~/.cargo/registry/index/
|
||||||
|
# ~/.cargo/registry/cache/
|
||||||
|
# ~/.cargo/git/db/
|
||||||
|
# target/
|
||||||
|
# key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
#
|
||||||
|
# - name: Setup Rust
|
||||||
|
# run: rustup update stable && rustup default stable
|
||||||
|
#
|
||||||
|
# - name: Verify tag matches version
|
||||||
|
# run: |
|
||||||
|
# TAG_VERSION=${GITHUB_REF#refs/tags/v}
|
||||||
|
# CARGO_VERSION=$(cargo metadata --no-deps --format-version 1 | jq -r '.packages[0].version')
|
||||||
|
# if [ "$TAG_VERSION" != "$CARGO_VERSION" ]; then
|
||||||
|
# echo "Tag version ($TAG_VERSION) does not match Cargo.toml version ($CARGO_VERSION)"
|
||||||
|
# exit 1
|
||||||
|
# fi
|
||||||
|
#
|
||||||
|
# # See Trusted publishing: https://crates.io/docs/trusted-publishing
|
||||||
|
# - uses: rust-lang/crates-io-auth-action@v1
|
||||||
|
# id: auth
|
||||||
|
#
|
||||||
|
# - run: cargo publish
|
||||||
|
# env:
|
||||||
|
# CARGO_REGISTRY_TOKEN: ${{ steps.auth.outputs.token }}
|
||||||
|
|
||||||
|
build-binaries:
|
||||||
|
name: Build binaries
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
needs: test
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- target: x86_64-unknown-linux-gnu
|
||||||
|
os: ubuntu-latest
|
||||||
|
name: predict-otron-9000-x86_64-unknown-linux-gnu
|
||||||
|
- target: x86_64-apple-darwin
|
||||||
|
os: macos-latest
|
||||||
|
name: predict-otron-9000-x86_64-apple-darwin
|
||||||
|
- target: aarch64-apple-darwin
|
||||||
|
os: macos-latest
|
||||||
|
name: predict-otron-9000-aarch64-apple-darwin
|
||||||
|
- target: x86_64-pc-windows-msvc
|
||||||
|
os: windows-latest
|
||||||
|
name: predict-otron-9000-x86_64-pc-windows-msvc.exe
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo/bin/
|
||||||
|
~/.cargo/registry/index/
|
||||||
|
~/.cargo/registry/cache/
|
||||||
|
~/.cargo/git/db/
|
||||||
|
target/
|
||||||
|
key: ${{ runner.os }}-${{ matrix.target }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
|
||||||
|
- name: Setup Rust
|
||||||
|
run: rustup update stable && rustup default stable
|
||||||
|
|
||||||
|
- name: Add target
|
||||||
|
run: rustup target add ${{ matrix.target }}
|
||||||
|
|
||||||
|
- name: Build binary
|
||||||
|
run: cargo build --release --target ${{ matrix.target }} -p predict-otron-9000
|
||||||
|
env:
|
||||||
|
CARGO_TERM_COLOR: always
|
||||||
|
|
||||||
|
- name: Package binary (Unix)
|
||||||
|
if: matrix.os != 'windows-latest'
|
||||||
|
run: |
|
||||||
|
cd target/${{ matrix.target }}/release
|
||||||
|
tar czf ../../../${{ matrix.name }}.tar.gz predict-otron-9000
|
||||||
|
cd ../../../
|
||||||
|
|
||||||
|
- name: Package binary (Windows)
|
||||||
|
if: matrix.os == 'windows-latest'
|
||||||
|
run: |
|
||||||
|
cd target/${{ matrix.target }}/release
|
||||||
|
7z a ../../../${{ matrix.name }}.zip predict-otron-9000.exe
|
||||||
|
cd ../../../
|
||||||
|
|
||||||
|
- name: Upload binary artifacts (Unix)
|
||||||
|
if: matrix.os != 'windows-latest'
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: ${{ matrix.name }}
|
||||||
|
path: ${{ matrix.name }}.tar.gz
|
||||||
|
|
||||||
|
- name: Upload binary artifacts (Windows)
|
||||||
|
if: matrix.os == 'windows-latest'
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: ${{ matrix.name }}
|
||||||
|
path: ${{ matrix.name }}.zip
|
||||||
|
|
||||||
|
release:
|
||||||
|
name: Create GitHub Release
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [test, build-binaries]
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Extract tag name
|
||||||
|
id: tag
|
||||||
|
run: echo "tag=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Generate changelog
|
||||||
|
id: changelog
|
||||||
|
run: |
|
||||||
|
# Get the previous tag
|
||||||
|
PREV_TAG=$(git describe --tags --abbrev=0 HEAD^ 2>/dev/null || echo "")
|
||||||
|
|
||||||
|
# Generate changelog
|
||||||
|
if [ -n "$PREV_TAG" ]; then
|
||||||
|
echo "## What's Changed" > changelog.md
|
||||||
|
echo "" >> changelog.md
|
||||||
|
git log --pretty=format:"* %s (%h)" ${PREV_TAG}..HEAD >> changelog.md
|
||||||
|
echo "" >> changelog.md
|
||||||
|
echo "" >> changelog.md
|
||||||
|
echo "**Full Changelog**: https://github.com/${{ github.repository }}/compare/${PREV_TAG}...${{ steps.tag.outputs.tag }}" >> changelog.md
|
||||||
|
else
|
||||||
|
echo "## What's Changed" > changelog.md
|
||||||
|
echo "" >> changelog.md
|
||||||
|
echo "Initial release of predict-otron-9000" >> changelog.md
|
||||||
|
echo "" >> changelog.md
|
||||||
|
echo "OpenAI Compatible Inference Server" >> changelog.md
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set the changelog as output (handle multiline)
|
||||||
|
echo "changelog<<EOF" >> $GITHUB_OUTPUT
|
||||||
|
cat changelog.md >> $GITHUB_OUTPUT
|
||||||
|
echo "EOF" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Download all artifacts
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
path: artifacts
|
||||||
|
|
||||||
|
- name: Create Release
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
run: |
|
||||||
|
if [[ "${{ steps.tag.outputs.tag }}" == *"-"* ]]; then
|
||||||
|
PRERELEASE_FLAG="--prerelease"
|
||||||
|
else
|
||||||
|
PRERELEASE_FLAG=""
|
||||||
|
fi
|
||||||
|
|
||||||
|
gh release create "${{ steps.tag.outputs.tag }}" \
|
||||||
|
--title "Release ${{ steps.tag.outputs.tag }}" \
|
||||||
|
--notes-file changelog.md \
|
||||||
|
$PRERELEASE_FLAG \
|
||||||
|
artifacts/predict-otron-9000-x86_64-unknown-linux-gnu/predict-otron-9000-x86_64-unknown-linux-gnu.tar.gz \
|
||||||
|
artifacts/predict-otron-9000-x86_64-apple-darwin/predict-otron-9000-x86_64-apple-darwin.tar.gz \
|
||||||
|
artifacts/predict-otron-9000-aarch64-apple-darwin/predict-otron-9000-aarch64-apple-darwin.tar.gz \
|
||||||
|
artifacts/predict-otron-9000-x86_64-pc-windows-msvc.exe/predict-otron-9000-x86_64-pc-windows-msvc.exe.zip
|
2
.gitignore
vendored
2
.gitignore
vendored
@@ -76,3 +76,5 @@ venv/
|
|||||||
*.bak
|
*.bak
|
||||||
*.backup
|
*.backup
|
||||||
*~
|
*~
|
||||||
|
/scripts/cli
|
||||||
|
!/scripts/cli.ts
|
||||||
|
@@ -287,7 +287,7 @@ cargo test --workspace
|
|||||||
|
|
||||||
**End-to-end test script:**
|
**End-to-end test script:**
|
||||||
```bash
|
```bash
|
||||||
./test.sh
|
./smoke_test.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
This script:
|
This script:
|
||||||
@@ -478,7 +478,7 @@ cd crates/leptos-app && ./run.sh &
|
|||||||
|
|
||||||
**Integration test:**
|
**Integration test:**
|
||||||
```bash
|
```bash
|
||||||
./test.sh
|
./smoke_test.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
**Cleanup:**
|
**Cleanup:**
|
||||||
|
@@ -1,9 +1,5 @@
|
|||||||
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
||||||
use axum::{
|
use axum::{Json, Router, response::Json as ResponseJson, routing::post};
|
||||||
response::Json as ResponseJson, routing::{post},
|
|
||||||
Json,
|
|
||||||
Router,
|
|
||||||
};
|
|
||||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
@@ -15,12 +11,15 @@ static EMBEDDING_MODEL: Lazy<TextEmbedding> = Lazy::new(|| {
|
|||||||
let model_start_time = std::time::Instant::now();
|
let model_start_time = std::time::Instant::now();
|
||||||
|
|
||||||
let model = TextEmbedding::try_new(
|
let model = TextEmbedding::try_new(
|
||||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true)
|
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
|
||||||
)
|
)
|
||||||
.expect("Failed to initialize persistent embedding model");
|
.expect("Failed to initialize persistent embedding model");
|
||||||
|
|
||||||
let model_init_time = model_start_time.elapsed();
|
let model_init_time = model_start_time.elapsed();
|
||||||
tracing::info!("Persistent embedding model initialized in {:.2?}", model_init_time);
|
tracing::info!(
|
||||||
|
"Persistent embedding model initialized in {:.2?}",
|
||||||
|
model_init_time
|
||||||
|
);
|
||||||
|
|
||||||
model
|
model
|
||||||
});
|
});
|
||||||
@@ -37,7 +36,10 @@ pub async fn embeddings_create(
|
|||||||
// Access the lazy-initialized persistent model instance
|
// Access the lazy-initialized persistent model instance
|
||||||
// This will only initialize the model on the first request
|
// This will only initialize the model on the first request
|
||||||
let model_access_time = model_start_time.elapsed();
|
let model_access_time = model_start_time.elapsed();
|
||||||
tracing::debug!("Persistent model access completed in {:.2?}", model_access_time);
|
tracing::debug!(
|
||||||
|
"Persistent model access completed in {:.2?}",
|
||||||
|
model_access_time
|
||||||
|
);
|
||||||
|
|
||||||
// Phase 2: Process input
|
// Phase 2: Process input
|
||||||
let input_start_time = std::time::Instant::now();
|
let input_start_time = std::time::Instant::now();
|
||||||
@@ -55,7 +57,10 @@ pub async fn embeddings_create(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let input_processing_time = input_start_time.elapsed();
|
let input_processing_time = input_start_time.elapsed();
|
||||||
tracing::debug!("Input processing completed in {:.2?}", input_processing_time);
|
tracing::debug!(
|
||||||
|
"Input processing completed in {:.2?}",
|
||||||
|
input_processing_time
|
||||||
|
);
|
||||||
|
|
||||||
// Phase 3: Generate embeddings
|
// Phase 3: Generate embeddings
|
||||||
let embedding_start_time = std::time::Instant::now();
|
let embedding_start_time = std::time::Instant::now();
|
||||||
@@ -65,25 +70,39 @@ pub async fn embeddings_create(
|
|||||||
.expect("failed to embed document");
|
.expect("failed to embed document");
|
||||||
|
|
||||||
let embedding_generation_time = embedding_start_time.elapsed();
|
let embedding_generation_time = embedding_start_time.elapsed();
|
||||||
tracing::info!("Embedding generation completed in {:.2?}", embedding_generation_time);
|
tracing::info!(
|
||||||
|
"Embedding generation completed in {:.2?}",
|
||||||
|
embedding_generation_time
|
||||||
|
);
|
||||||
|
|
||||||
// Memory usage estimation (approximate)
|
// Memory usage estimation (approximate)
|
||||||
let embedding_size_bytes = embeddings.iter()
|
let embedding_size_bytes = embeddings
|
||||||
|
.iter()
|
||||||
.map(|e| e.len() * std::mem::size_of::<f32>())
|
.map(|e| e.len() * std::mem::size_of::<f32>())
|
||||||
.sum::<usize>();
|
.sum::<usize>();
|
||||||
tracing::debug!("Embedding size: {:.2} MB", embedding_size_bytes as f64 / 1024.0 / 1024.0);
|
tracing::debug!(
|
||||||
|
"Embedding size: {:.2} MB",
|
||||||
|
embedding_size_bytes as f64 / 1024.0 / 1024.0
|
||||||
|
);
|
||||||
|
|
||||||
// Only log detailed embedding information at trace level to reduce log volume
|
// Only log detailed embedding information at trace level to reduce log volume
|
||||||
tracing::trace!("Embeddings length: {}", embeddings.len());
|
tracing::trace!("Embeddings length: {}", embeddings.len());
|
||||||
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
||||||
|
|
||||||
// Log the first 10 values of the original embedding at trace level
|
// Log the first 10 values of the original embedding at trace level
|
||||||
tracing::trace!("Original embedding preview: {:?}", &embeddings[0][..10.min(embeddings[0].len())]);
|
tracing::trace!(
|
||||||
|
"Original embedding preview: {:?}",
|
||||||
|
&embeddings[0][..10.min(embeddings[0].len())]
|
||||||
|
);
|
||||||
|
|
||||||
// Check if there are any NaN or zero values in the original embedding
|
// Check if there are any NaN or zero values in the original embedding
|
||||||
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
||||||
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
||||||
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count);
|
tracing::trace!(
|
||||||
|
"Original embedding stats: NaN count={}, zero count={}",
|
||||||
|
nan_count,
|
||||||
|
zero_count
|
||||||
|
);
|
||||||
|
|
||||||
// Phase 4: Post-process embeddings
|
// Phase 4: Post-process embeddings
|
||||||
let postprocessing_start_time = std::time::Instant::now();
|
let postprocessing_start_time = std::time::Instant::now();
|
||||||
@@ -110,6 +129,8 @@ pub async fn embeddings_create(
|
|||||||
|
|
||||||
// Normalize the random embedding
|
// Normalize the random embedding
|
||||||
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
let norm: f32 = random_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||||
|
|
||||||
|
#[allow(clippy::needless_range_loop)]
|
||||||
for i in 0..random_embedding.len() {
|
for i in 0..random_embedding.len() {
|
||||||
random_embedding[i] /= norm;
|
random_embedding[i] /= norm;
|
||||||
}
|
}
|
||||||
@@ -123,7 +144,11 @@ pub async fn embeddings_create(
|
|||||||
let target_dimension = 768;
|
let target_dimension = 768;
|
||||||
if padded_embedding.len() < target_dimension {
|
if padded_embedding.len() < target_dimension {
|
||||||
let padding_needed = target_dimension - padded_embedding.len();
|
let padding_needed = target_dimension - padded_embedding.len();
|
||||||
tracing::trace!("Padding embedding with {} zeros to reach {} dimensions", padding_needed, target_dimension);
|
tracing::trace!(
|
||||||
|
"Padding embedding with {} zeros to reach {} dimensions",
|
||||||
|
padding_needed,
|
||||||
|
target_dimension
|
||||||
|
);
|
||||||
padded_embedding.extend(vec![0.0; padding_needed]);
|
padded_embedding.extend(vec![0.0; padding_needed]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,12 +157,18 @@ pub async fn embeddings_create(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let postprocessing_time = postprocessing_start_time.elapsed();
|
let postprocessing_time = postprocessing_start_time.elapsed();
|
||||||
tracing::debug!("Embedding post-processing completed in {:.2?}", postprocessing_time);
|
tracing::debug!(
|
||||||
|
"Embedding post-processing completed in {:.2?}",
|
||||||
|
postprocessing_time
|
||||||
|
);
|
||||||
|
|
||||||
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
||||||
|
|
||||||
// Log the first 10 values of the final embedding at trace level
|
// Log the first 10 values of the final embedding at trace level
|
||||||
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]);
|
tracing::trace!(
|
||||||
|
"Final embedding preview: {:?}",
|
||||||
|
&final_embedding[..10.min(final_embedding.len())]
|
||||||
|
);
|
||||||
|
|
||||||
// Phase 5: Prepare response
|
// Phase 5: Prepare response
|
||||||
let response_start_time = std::time::Instant::now();
|
let response_start_time = std::time::Instant::now();
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
|
||||||
use axum::{
|
use axum::{
|
||||||
response::Json as ResponseJson, routing::{get, post},
|
Json, Router,
|
||||||
Json,
|
response::Json as ResponseJson,
|
||||||
Router,
|
routing::{get, post},
|
||||||
};
|
};
|
||||||
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@@ -13,19 +13,17 @@ use tracing;
|
|||||||
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
|
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
|
||||||
const DEFAULT_SERVER_PORT: &str = "8080";
|
const DEFAULT_SERVER_PORT: &str = "8080";
|
||||||
|
|
||||||
|
|
||||||
async fn embeddings_create(
|
async fn embeddings_create(
|
||||||
Json(payload): Json<CreateEmbeddingRequest>,
|
Json(payload): Json<CreateEmbeddingRequest>,
|
||||||
) -> ResponseJson<serde_json::Value> {
|
) -> ResponseJson<serde_json::Value> {
|
||||||
let model = TextEmbedding::try_new(
|
let model = TextEmbedding::try_new(
|
||||||
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true)
|
InitOptions::new(EmbeddingModel::NomicEmbedTextV15).with_show_download_progress(true),
|
||||||
)
|
)
|
||||||
.expect("Failed to initialize model");
|
.expect("Failed to initialize model");
|
||||||
|
|
||||||
|
let embedding_input = payload.input;
|
||||||
|
|
||||||
let embedding_input = payload.input;
|
let texts_from_embedding_input = match embedding_input {
|
||||||
|
|
||||||
let texts_from_embedding_input = match embedding_input {
|
|
||||||
EmbeddingInput::String(text) => vec![text],
|
EmbeddingInput::String(text) => vec![text],
|
||||||
EmbeddingInput::StringArray(texts) => texts,
|
EmbeddingInput::StringArray(texts) => texts,
|
||||||
EmbeddingInput::IntegerArray(_) => {
|
EmbeddingInput::IntegerArray(_) => {
|
||||||
@@ -45,12 +43,19 @@ async fn embeddings_create(
|
|||||||
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
tracing::info!("Embedding dimension: {}", embeddings[0].len());
|
||||||
|
|
||||||
// Log the first 10 values of the original embedding at trace level
|
// Log the first 10 values of the original embedding at trace level
|
||||||
tracing::trace!("Original embedding preview: {:?}", &embeddings[0][..10.min(embeddings[0].len())]);
|
tracing::trace!(
|
||||||
|
"Original embedding preview: {:?}",
|
||||||
|
&embeddings[0][..10.min(embeddings[0].len())]
|
||||||
|
);
|
||||||
|
|
||||||
// Check if there are any NaN or zero values in the original embedding
|
// Check if there are any NaN or zero values in the original embedding
|
||||||
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
let nan_count = embeddings[0].iter().filter(|&&x| x.is_nan()).count();
|
||||||
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
let zero_count = embeddings[0].iter().filter(|&&x| x == 0.0).count();
|
||||||
tracing::trace!("Original embedding stats: NaN count={}, zero count={}", nan_count, zero_count);
|
tracing::trace!(
|
||||||
|
"Original embedding stats: NaN count={}, zero count={}",
|
||||||
|
nan_count,
|
||||||
|
zero_count
|
||||||
|
);
|
||||||
|
|
||||||
// Create the final embedding
|
// Create the final embedding
|
||||||
let final_embedding = {
|
let final_embedding = {
|
||||||
@@ -87,7 +92,11 @@ async fn embeddings_create(
|
|||||||
let target_dimension = 768;
|
let target_dimension = 768;
|
||||||
if padded_embedding.len() < target_dimension {
|
if padded_embedding.len() < target_dimension {
|
||||||
let padding_needed = target_dimension - padded_embedding.len();
|
let padding_needed = target_dimension - padded_embedding.len();
|
||||||
tracing::trace!("Padding embedding with {} zeros to reach {} dimensions", padding_needed, target_dimension);
|
tracing::trace!(
|
||||||
|
"Padding embedding with {} zeros to reach {} dimensions",
|
||||||
|
padding_needed,
|
||||||
|
target_dimension
|
||||||
|
);
|
||||||
padded_embedding.extend(vec![0.0; padding_needed]);
|
padded_embedding.extend(vec![0.0; padding_needed]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,7 +107,10 @@ async fn embeddings_create(
|
|||||||
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
tracing::trace!("Final embedding dimension: {}", final_embedding.len());
|
||||||
|
|
||||||
// Log the first 10 values of the final embedding at trace level
|
// Log the first 10 values of the final embedding at trace level
|
||||||
tracing::trace!("Final embedding preview: {:?}", &final_embedding[..10.min(final_embedding.len())]);
|
tracing::trace!(
|
||||||
|
"Final embedding preview: {:?}",
|
||||||
|
&final_embedding[..10.min(final_embedding.len())]
|
||||||
|
);
|
||||||
|
|
||||||
// Return a response that matches the OpenAI API format
|
// Return a response that matches the OpenAI API format
|
||||||
let response = serde_json::json!({
|
let response = serde_json::json!({
|
||||||
@@ -120,7 +132,7 @@ async fn embeddings_create(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn create_app() -> Router {
|
fn create_app() -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/v1/embeddings", post(embeddings_create))
|
.route("/v1/embeddings", post(embeddings_create))
|
||||||
.layer(TraceLayer::new_for_http())
|
.layer(TraceLayer::new_for_http())
|
||||||
}
|
}
|
||||||
@@ -143,21 +155,21 @@ async fn main() {
|
|||||||
.init();
|
.init();
|
||||||
let app = create_app();
|
let app = create_app();
|
||||||
|
|
||||||
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string());
|
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string());
|
||||||
let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string());
|
let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string());
|
||||||
let server_address = format!("{}:{}", server_host, server_port);
|
let server_address = format!("{}:{}", server_host, server_port);
|
||||||
let listener = tokio::net::TcpListener::bind(server_address).await.unwrap();
|
let listener = tokio::net::TcpListener::bind(server_address).await.unwrap();
|
||||||
tracing::info!("Listening on {}", listener.local_addr().unwrap());
|
tracing::info!("Listening on {}", listener.local_addr().unwrap());
|
||||||
axum::serve(listener, app).await.unwrap();
|
axum::serve(listener, app).await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use axum::body::to_bytes;
|
use axum::body::Body;
|
||||||
use axum::body::Body;
|
use axum::body::to_bytes;
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
use tower::ServiceExt;
|
use tower::ServiceExt;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_embeddings_create() {
|
async fn test_embeddings_create() {
|
||||||
@@ -168,11 +180,13 @@ mod tests {
|
|||||||
|
|
||||||
let body = CreateEmbeddingRequest {
|
let body = CreateEmbeddingRequest {
|
||||||
model: "nomic-text-embed".to_string(),
|
model: "nomic-text-embed".to_string(),
|
||||||
input: EmbeddingInput::from(vec!["The food was delicious and the waiter...".to_string()]),
|
input: EmbeddingInput::from(vec![
|
||||||
encoding_format: None,
|
"The food was delicious and the waiter...".to_string(),
|
||||||
user: None,
|
]),
|
||||||
dimensions: Some(768),
|
encoding_format: None,
|
||||||
};
|
user: None,
|
||||||
|
dimensions: Some(768),
|
||||||
|
};
|
||||||
|
|
||||||
let response = app
|
let response = app
|
||||||
.oneshot(
|
.oneshot(
|
||||||
|
@@ -3,16 +3,14 @@ name = "gemma-runner"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
candle-core = { git = "https://github.com/huggingface/candle.git" }
|
||||||
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
candle-nn = { git = "https://github.com/huggingface/candle.git" }
|
||||||
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
candle-transformers = { git = "https://github.com/huggingface/candle.git" }
|
||||||
candle-examples = { git = "https://github.com/huggingface/candle.git" }
|
candle-examples = { git = "https://github.com/huggingface/candle.git" }
|
||||||
|
|
||||||
[target.'cfg(target_os = "macos")'.dependencies]
|
|
||||||
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
|
||||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
|
||||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
|
||||||
hf-hub = "0.4"
|
hf-hub = "0.4"
|
||||||
tokenizers = "0.21"
|
tokenizers = "0.21"
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
@@ -22,6 +20,12 @@ tracing = "0.1"
|
|||||||
tracing-chrome = "0.7"
|
tracing-chrome = "0.7"
|
||||||
tracing-subscriber = "0.3"
|
tracing-subscriber = "0.3"
|
||||||
|
|
||||||
|
[target.'cfg(target_os = "macos")'.dependencies]
|
||||||
|
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||||
|
@@ -4,10 +4,10 @@ extern crate accelerate_src;
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::ValueEnum;
|
|
||||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||||
|
use clap::ValueEnum;
|
||||||
|
|
||||||
// Removed gemma_cli import as it's not needed for the API
|
// Removed gemma_cli import as it's not needed for the API
|
||||||
use candle_core::{utils, DType, Device, Tensor};
|
use candle_core::{utils, DType, Device, Tensor};
|
||||||
@@ -119,7 +119,12 @@ impl TextGeneration {
|
|||||||
|
|
||||||
/// Stream-only generation: sends freshly generated token strings over `tx`.
|
/// Stream-only generation: sends freshly generated token strings over `tx`.
|
||||||
/// (Does not send the prompt tokens; only newly generated model tokens.)
|
/// (Does not send the prompt tokens; only newly generated model tokens.)
|
||||||
fn run_stream(&mut self, prompt: &str, sample_len: usize, tx: Sender<Result<String>>) -> Result<()> {
|
fn run_stream(
|
||||||
|
&mut self,
|
||||||
|
prompt: &str,
|
||||||
|
sample_len: usize,
|
||||||
|
tx: Sender<Result<String>>,
|
||||||
|
) -> Result<()> {
|
||||||
self.tokenizer.clear();
|
self.tokenizer.clear();
|
||||||
|
|
||||||
// Encode prompt (context only; do not emit prompt tokens to the stream).
|
// Encode prompt (context only; do not emit prompt tokens to the stream).
|
||||||
@@ -303,7 +308,7 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
|||||||
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
|
WhichModel::BaseV3_1B => "google/gemma-3-1b-pt",
|
||||||
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
|
WhichModel::InstructV3_1B => "google/gemma-3-1b-it",
|
||||||
}
|
}
|
||||||
.to_string()
|
.to_string()
|
||||||
});
|
});
|
||||||
|
|
||||||
println!("Loading model: {}", &model_id);
|
println!("Loading model: {}", &model_id);
|
||||||
@@ -337,7 +342,10 @@ pub fn run_gemma_api(cfg: GemmaInferenceConfig) -> Result<Receiver<Result<String
|
|||||||
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
|
let model = Model1::new(cfg.use_flash_attn, &config, vb)?;
|
||||||
Model::V1(model)
|
Model::V1(model)
|
||||||
}
|
}
|
||||||
WhichModel::BaseV2_2B | WhichModel::InstructV2_2B | WhichModel::BaseV2_9B | WhichModel::InstructV2_9B => {
|
WhichModel::BaseV2_2B
|
||||||
|
| WhichModel::InstructV2_2B
|
||||||
|
| WhichModel::BaseV2_9B
|
||||||
|
| WhichModel::InstructV2_9B => {
|
||||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||||
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
let model = Model2::new(cfg.use_flash_attn, &config, vb)?;
|
||||||
Model::V2(model)
|
Model::V2(model)
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
use std::io::Write;
|
|
||||||
use clap::Parser;
|
|
||||||
use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
|
use crate::gemma_api::{run_gemma_api, GemmaInferenceConfig, WhichModel};
|
||||||
|
use clap::Parser;
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)]
|
#[command(author, version, about = "Fast Gemma inference with Candle", long_about = None)]
|
||||||
|
@@ -2,8 +2,8 @@
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
mod gemma_cli;
|
|
||||||
mod gemma_api;
|
mod gemma_api;
|
||||||
|
mod gemma_cli;
|
||||||
|
|
||||||
use anyhow::Error;
|
use anyhow::Error;
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
@@ -84,7 +84,10 @@ fn main() -> Result<()> {
|
|||||||
let services = discover_services(workspace_path)?;
|
let services = discover_services(workspace_path)?;
|
||||||
println!("Found {} services:", services.len());
|
println!("Found {} services:", services.len());
|
||||||
for service in &services {
|
for service in &services {
|
||||||
println!(" - {}: {} (port {})", service.name, service.image, service.port);
|
println!(
|
||||||
|
" - {}: {} (port {})",
|
||||||
|
service.name, service.image, service.port
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
generate_helm_chart(output_path, chart_name, &services)?;
|
generate_helm_chart(output_path, chart_name, &services)?;
|
||||||
@@ -119,13 +122,16 @@ fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> {
|
|||||||
let cargo_toml: CargoToml = toml::from_str(&content)
|
let cargo_toml: CargoToml = toml::from_str(&content)
|
||||||
.with_context(|| format!("Failed to parse Cargo.toml at {:?}", path))?;
|
.with_context(|| format!("Failed to parse Cargo.toml at {:?}", path))?;
|
||||||
|
|
||||||
let package = cargo_toml.package
|
let package = cargo_toml
|
||||||
|
.package
|
||||||
.ok_or_else(|| anyhow::anyhow!("No package section found in {:?}", path))?;
|
.ok_or_else(|| anyhow::anyhow!("No package section found in {:?}", path))?;
|
||||||
|
|
||||||
let metadata = package.metadata
|
let metadata = package
|
||||||
|
.metadata
|
||||||
.ok_or_else(|| anyhow::anyhow!("No metadata section found in {:?}", path))?;
|
.ok_or_else(|| anyhow::anyhow!("No metadata section found in {:?}", path))?;
|
||||||
|
|
||||||
let kube_metadata = metadata.kube
|
let kube_metadata = metadata
|
||||||
|
.kube
|
||||||
.ok_or_else(|| anyhow::anyhow!("No kube metadata found in {:?}", path))?;
|
.ok_or_else(|| anyhow::anyhow!("No kube metadata found in {:?}", path))?;
|
||||||
|
|
||||||
Ok(ServiceInfo {
|
Ok(ServiceInfo {
|
||||||
@@ -136,7 +142,11 @@ fn parse_cargo_toml(path: &Path) -> Result<ServiceInfo> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn generate_helm_chart(output_path: &str, chart_name: &str, services: &[ServiceInfo]) -> Result<()> {
|
fn generate_helm_chart(
|
||||||
|
output_path: &str,
|
||||||
|
chart_name: &str,
|
||||||
|
services: &[ServiceInfo],
|
||||||
|
) -> Result<()> {
|
||||||
let chart_dir = Path::new(output_path);
|
let chart_dir = Path::new(output_path);
|
||||||
let templates_dir = chart_dir.join("templates");
|
let templates_dir = chart_dir.join("templates");
|
||||||
|
|
||||||
|
@@ -3,18 +3,6 @@ name = "inference-engine"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
|
||||||
[[bin]]
|
|
||||||
name="gemma_inference"
|
|
||||||
path = "src/gemma_inference.rs"
|
|
||||||
required-features = ["bin"]
|
|
||||||
|
|
||||||
[[bin]]
|
|
||||||
name="llama_inference"
|
|
||||||
path = "src/llama_inference.rs"
|
|
||||||
required-features = ["bin"]
|
|
||||||
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { version = "0.3.2", optional = true }
|
accelerate-src = { version = "0.3.2", optional = true }
|
||||||
candle-datasets = { version = "=0.9.1", optional = true }
|
candle-datasets = { version = "=0.9.1", optional = true }
|
||||||
|
@@ -1,19 +1,19 @@
|
|||||||
// Expose modules for testing and library usage
|
// Expose modules for testing and library usage
|
||||||
pub mod token_output_stream;
|
|
||||||
pub mod model;
|
pub mod model;
|
||||||
pub mod text_generation;
|
|
||||||
pub mod utilities_lib;
|
|
||||||
pub mod openai_types;
|
pub mod openai_types;
|
||||||
|
pub mod text_generation;
|
||||||
|
pub mod token_output_stream;
|
||||||
|
pub mod utilities_lib;
|
||||||
// pub mod cli;
|
// pub mod cli;
|
||||||
pub mod server;
|
|
||||||
pub mod inference;
|
pub mod inference;
|
||||||
|
pub mod server;
|
||||||
|
|
||||||
// Re-export key components for easier access
|
// Re-export key components for easier access
|
||||||
|
pub use inference::ModelInference;
|
||||||
pub use model::{Model, Which};
|
pub use model::{Model, Which};
|
||||||
|
pub use server::{create_router, AppState};
|
||||||
pub use text_generation::TextGeneration;
|
pub use text_generation::TextGeneration;
|
||||||
pub use token_output_stream::TokenOutputStream;
|
pub use token_output_stream::TokenOutputStream;
|
||||||
pub use server::{AppState, create_router};
|
|
||||||
pub use inference::ModelInference;
|
|
||||||
|
|
||||||
use std::env;
|
use std::env;
|
||||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
// use candle_core::Tensor;
|
// use candle_core::Tensor;
|
||||||
|
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
|
||||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||||
use candle_transformers::models::csm::{LlamaConfig, LlamaModel};
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
pub enum Which {
|
pub enum Which {
|
||||||
@@ -52,7 +52,11 @@ pub enum Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
pub fn forward(&mut self, input_ids: &candle_core::Tensor, pos: usize) -> candle_core::Result<candle_core::Tensor> {
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
input_ids: &candle_core::Tensor,
|
||||||
|
pos: usize,
|
||||||
|
) -> candle_core::Result<candle_core::Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::V1(m) => m.forward(input_ids, pos),
|
Self::V1(m) => m.forward(input_ids, pos),
|
||||||
Self::V2(m) => m.forward(input_ids, pos),
|
Self::V2(m) => m.forward(input_ids, pos),
|
||||||
@@ -88,7 +92,13 @@ impl Which {
|
|||||||
|
|
||||||
pub fn is_instruct_model(&self) -> bool {
|
pub fn is_instruct_model(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
Self::Base2B | Self::Base7B | Self::CodeBase2B | Self::CodeBase7B | Self::BaseV2_2B | Self::BaseV2_9B | Self::BaseV3_1B => false,
|
Self::Base2B
|
||||||
|
| Self::Base7B
|
||||||
|
| Self::CodeBase2B
|
||||||
|
| Self::CodeBase7B
|
||||||
|
| Self::BaseV2_2B
|
||||||
|
| Self::BaseV2_9B
|
||||||
|
| Self::BaseV3_1B => false,
|
||||||
_ => true,
|
_ => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -10,7 +10,10 @@ pub struct MessageInnerContent(
|
|||||||
);
|
);
|
||||||
|
|
||||||
impl ToSchema<'_> for MessageInnerContent {
|
impl ToSchema<'_> for MessageInnerContent {
|
||||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
fn schema() -> (
|
||||||
|
&'static str,
|
||||||
|
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
|
||||||
|
) {
|
||||||
(
|
(
|
||||||
"MessageInnerContent",
|
"MessageInnerContent",
|
||||||
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
utoipa::openapi::RefOr::T(message_inner_content_schema()),
|
||||||
@@ -45,12 +48,18 @@ fn message_inner_content_schema() -> utoipa::openapi::Schema {
|
|||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct MessageContent(
|
pub struct MessageContent(
|
||||||
#[serde(with = "either::serde_untagged")]
|
#[serde(with = "either::serde_untagged")]
|
||||||
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
pub Either<String, Vec<HashMap<String, MessageInnerContent>>>,
|
||||||
);
|
);
|
||||||
|
|
||||||
impl ToSchema<'_> for MessageContent {
|
impl ToSchema<'_> for MessageContent {
|
||||||
fn schema() -> (&'static str, utoipa::openapi::RefOr<utoipa::openapi::Schema>) {
|
fn schema() -> (
|
||||||
("MessageContent", utoipa::openapi::RefOr::T(message_content_schema()))
|
&'static str,
|
||||||
|
utoipa::openapi::RefOr<utoipa::openapi::Schema>,
|
||||||
|
) {
|
||||||
|
(
|
||||||
|
"MessageContent",
|
||||||
|
utoipa::openapi::RefOr::T(message_content_schema()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -6,19 +6,22 @@ use axum::{
|
|||||||
Json, Router,
|
Json, Router,
|
||||||
};
|
};
|
||||||
use futures_util::stream::{self, Stream};
|
use futures_util::stream::{self, Stream};
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{Mutex, mpsc};
|
use tokio::sync::{mpsc, Mutex};
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tower_http::cors::{Any, CorsLayer};
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::openai_types::{ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest, ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage};
|
use crate::openai_types::{
|
||||||
|
ChatCompletionChoice, ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse, Delta, Message, MessageContent, Model, ModelListResponse, Usage,
|
||||||
|
};
|
||||||
use crate::Which;
|
use crate::Which;
|
||||||
use either::Either;
|
use either::Either;
|
||||||
use serde_json::Value;
|
|
||||||
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
|
use gemma_runner::{run_gemma_api, GemmaInferenceConfig};
|
||||||
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
|
use llama_runner::{run_llama_inference, LlamaInferenceConfig};
|
||||||
|
use serde_json::Value;
|
||||||
// -------------------------
|
// -------------------------
|
||||||
// Shared app state
|
// Shared app state
|
||||||
// -------------------------
|
// -------------------------
|
||||||
@@ -67,7 +70,10 @@ fn build_gemma_prompt(messages: &[Message]) -> String {
|
|||||||
match message.role.as_str() {
|
match message.role.as_str() {
|
||||||
"system" => {
|
"system" => {
|
||||||
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
if let Some(MessageContent(Either::Left(content))) = &message.content {
|
||||||
prompt.push_str(&format!("<start_of_turn>system\n{}<end_of_turn>\n", content));
|
prompt.push_str(&format!(
|
||||||
|
"<start_of_turn>system\n{}<end_of_turn>\n",
|
||||||
|
content
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"user" => {
|
"user" => {
|
||||||
@@ -97,9 +103,13 @@ pub async fn chat_completions(
|
|||||||
Json(request): Json<ChatCompletionRequest>,
|
Json(request): Json<ChatCompletionRequest>,
|
||||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||||
if !request.stream.unwrap_or(false) {
|
if !request.stream.unwrap_or(false) {
|
||||||
return Ok(chat_completions_non_streaming_proxy(state, request).await.into_response());
|
return Ok(chat_completions_non_streaming_proxy(state, request)
|
||||||
|
.await
|
||||||
|
.into_response());
|
||||||
}
|
}
|
||||||
Ok(chat_completions_stream(state, request).await.into_response())
|
Ok(chat_completions_stream(state, request)
|
||||||
|
.await
|
||||||
|
.into_response())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn chat_completions_non_streaming_proxy(
|
pub async fn chat_completions_non_streaming_proxy(
|
||||||
@@ -136,7 +146,9 @@ pub async fn chat_completions_non_streaming_proxy(
|
|||||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||||
ModelType::Llama => {
|
ModelType::Llama => {
|
||||||
// For Llama, just use the last user message for now
|
// For Llama, just use the last user message for now
|
||||||
request.messages.last()
|
request
|
||||||
|
.messages
|
||||||
|
.last()
|
||||||
.and_then(|m| m.content.as_ref())
|
.and_then(|m| m.content.as_ref())
|
||||||
.and_then(|c| match c {
|
.and_then(|c| match c {
|
||||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||||
@@ -147,46 +159,47 @@ pub async fn chat_completions_non_streaming_proxy(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Get streaming receiver based on model type
|
// Get streaming receiver based on model type
|
||||||
let rx = match state.model_type {
|
let rx =
|
||||||
ModelType::Gemma => {
|
match state.model_type {
|
||||||
if let Some(mut config) = state.gemma_config {
|
ModelType::Gemma => {
|
||||||
config.prompt = prompt.clone();
|
if let Some(mut config) = state.gemma_config {
|
||||||
config.max_tokens = max_tokens;
|
config.prompt = prompt.clone();
|
||||||
run_gemma_api(config).map_err(|e| (
|
config.max_tokens = max_tokens;
|
||||||
|
run_gemma_api(config).map_err(|e| (
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||||
}))
|
}))
|
||||||
))?
|
))?
|
||||||
} else {
|
} else {
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"error": { "message": "Gemma configuration not available" }
|
"error": { "message": "Gemma configuration not available" }
|
||||||
}))
|
})),
|
||||||
));
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
ModelType::Llama => {
|
||||||
ModelType::Llama => {
|
if let Some(mut config) = state.llama_config {
|
||||||
if let Some(mut config) = state.llama_config {
|
config.prompt = prompt.clone();
|
||||||
config.prompt = prompt.clone();
|
config.max_tokens = max_tokens;
|
||||||
config.max_tokens = max_tokens;
|
run_llama_inference(config).map_err(|e| (
|
||||||
run_llama_inference(config).map_err(|e| (
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||||
}))
|
}))
|
||||||
))?
|
))?
|
||||||
} else {
|
} else {
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"error": { "message": "Llama configuration not available" }
|
"error": { "message": "Llama configuration not available" }
|
||||||
}))
|
})),
|
||||||
));
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
};
|
|
||||||
|
|
||||||
// Collect all tokens from the stream
|
// Collect all tokens from the stream
|
||||||
let mut completion = String::new();
|
let mut completion = String::new();
|
||||||
@@ -281,7 +294,9 @@ async fn handle_streaming_request(
|
|||||||
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
ModelType::Gemma => build_gemma_prompt(&request.messages),
|
||||||
ModelType::Llama => {
|
ModelType::Llama => {
|
||||||
// For Llama, just use the last user message for now
|
// For Llama, just use the last user message for now
|
||||||
request.messages.last()
|
request
|
||||||
|
.messages
|
||||||
|
.last()
|
||||||
.and_then(|m| m.content.as_ref())
|
.and_then(|m| m.content.as_ref())
|
||||||
.and_then(|c| match c {
|
.and_then(|c| match c {
|
||||||
MessageContent(Either::Left(text)) => Some(text.clone()),
|
MessageContent(Either::Left(text)) => Some(text.clone()),
|
||||||
@@ -303,7 +318,10 @@ async fn handle_streaming_request(
|
|||||||
model: model_id.clone(),
|
model: model_id.clone(),
|
||||||
choices: vec![ChatCompletionChunkChoice {
|
choices: vec![ChatCompletionChunkChoice {
|
||||||
index: 0,
|
index: 0,
|
||||||
delta: Delta { role: Some("assistant".to_string()), content: None },
|
delta: Delta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: None,
|
||||||
|
},
|
||||||
finish_reason: None,
|
finish_reason: None,
|
||||||
}],
|
}],
|
||||||
};
|
};
|
||||||
@@ -324,7 +342,7 @@ async fn handle_streaming_request(
|
|||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
"error": { "message": format!("Error initializing Gemma model: {}", e) }
|
||||||
}))
|
})),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -333,7 +351,7 @@ async fn handle_streaming_request(
|
|||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"error": { "message": "Gemma configuration not available" }
|
"error": { "message": "Gemma configuration not available" }
|
||||||
}))
|
})),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -348,7 +366,7 @@ async fn handle_streaming_request(
|
|||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
"error": { "message": format!("Error initializing Llama model: {}", e) }
|
||||||
}))
|
})),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -357,7 +375,7 @@ async fn handle_streaming_request(
|
|||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"error": { "message": "Llama configuration not available" }
|
"error": { "message": "Llama configuration not available" }
|
||||||
}))
|
})),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -394,7 +412,11 @@ async fn handle_streaming_request(
|
|||||||
|
|
||||||
if last_token == second_last {
|
if last_token == second_last {
|
||||||
repetition_count += 1;
|
repetition_count += 1;
|
||||||
tracing::warn!("Detected repetition pattern: '{}' (count: {})", last_token, repetition_count);
|
tracing::warn!(
|
||||||
|
"Detected repetition pattern: '{}' (count: {})",
|
||||||
|
last_token,
|
||||||
|
repetition_count
|
||||||
|
);
|
||||||
|
|
||||||
if repetition_count >= MAX_REPETITION_COUNT {
|
if repetition_count >= MAX_REPETITION_COUNT {
|
||||||
tracing::info!("Stopping generation due to excessive repetition");
|
tracing::info!("Stopping generation due to excessive repetition");
|
||||||
@@ -412,7 +434,10 @@ async fn handle_streaming_request(
|
|||||||
model: model_id_clone.clone(),
|
model: model_id_clone.clone(),
|
||||||
choices: vec![ChatCompletionChunkChoice {
|
choices: vec![ChatCompletionChunkChoice {
|
||||||
index: 0,
|
index: 0,
|
||||||
delta: Delta { role: None, content: Some(token) },
|
delta: Delta {
|
||||||
|
role: None,
|
||||||
|
content: Some(token),
|
||||||
|
},
|
||||||
finish_reason: None,
|
finish_reason: None,
|
||||||
}],
|
}],
|
||||||
};
|
};
|
||||||
@@ -436,7 +461,10 @@ async fn handle_streaming_request(
|
|||||||
model: model_id_clone.clone(),
|
model: model_id_clone.clone(),
|
||||||
choices: vec![ChatCompletionChunkChoice {
|
choices: vec![ChatCompletionChunkChoice {
|
||||||
index: 0,
|
index: 0,
|
||||||
delta: Delta { role: None, content: None },
|
delta: Delta {
|
||||||
|
role: None,
|
||||||
|
content: None,
|
||||||
|
},
|
||||||
finish_reason: Some("stop".to_string()),
|
finish_reason: Some("stop".to_string()),
|
||||||
}],
|
}],
|
||||||
};
|
};
|
||||||
@@ -451,8 +479,6 @@ async fn handle_streaming_request(
|
|||||||
Ok(Sse::new(stream))
|
Ok(Sse::new(stream))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// -------------------------
|
// -------------------------
|
||||||
// Router
|
// Router
|
||||||
// -------------------------
|
// -------------------------
|
||||||
@@ -647,7 +673,6 @@ pub async fn list_models() -> Json<ModelListResponse> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -681,10 +706,7 @@ mod tests {
|
|||||||
|
|
||||||
let prompt = build_gemma_prompt(&messages);
|
let prompt = build_gemma_prompt(&messages);
|
||||||
|
|
||||||
let expected = "<start_of_turn>user\nSystem message\n\nKnock knock.<end_of_turn>\n\
|
let expected = "<start_of_turn>system\nSystem message<end_of_turn>\n<start_of_turn>user\nKnock knock.<end_of_turn>\n<start_of_turn>model\nWho's there?<end_of_turn>\n<start_of_turn>user\nGemma.<end_of_turn>\n<start_of_turn>model\n";
|
||||||
<start_of_turn>model\nWho's there?<end_of_turn>\n\
|
|
||||||
<start_of_turn>user\nGemma.<end_of_turn>\n\
|
|
||||||
<start_of_turn>model\n";
|
|
||||||
|
|
||||||
assert_eq!(prompt, expected);
|
assert_eq!(prompt, expected);
|
||||||
}
|
}
|
||||||
@@ -698,15 +720,13 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_missing_content() {
|
fn test_missing_content() {
|
||||||
let messages = vec![
|
let messages = vec![Message {
|
||||||
Message {
|
role: "user".to_string(),
|
||||||
role: "user".to_string(),
|
content: None,
|
||||||
content: None,
|
name: None,
|
||||||
name: None,
|
}];
|
||||||
}
|
|
||||||
];
|
|
||||||
|
|
||||||
let prompt = build_gemma_prompt(&messages);
|
let prompt = build_gemma_prompt(&messages);
|
||||||
assert_eq!(prompt, "<start_of_turn>user\n<end_of_turn>\n<start_of_turn>model\n");
|
assert_eq!(prompt, "<start_of_turn>model\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle_core::{DType, Device, Tensor};
|
use candle_core::{DType, Device, Tensor};
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
use crate::model::Model;
|
use crate::model::Model;
|
||||||
use crate::token_output_stream::TokenOutputStream;
|
use crate::token_output_stream::TokenOutputStream;
|
||||||
@@ -83,12 +83,12 @@ impl TextGeneration {
|
|||||||
let err_string = err.to_string();
|
let err_string = err.to_string();
|
||||||
|
|
||||||
// Check if the error is about unsupported operations or shape mismatches
|
// Check if the error is about unsupported operations or shape mismatches
|
||||||
if (err_string.contains("no metal implementation for") ||
|
if (err_string.contains("no metal implementation for")
|
||||||
err_string.contains("no cuda implementation for") ||
|
|| err_string.contains("no cuda implementation for")
|
||||||
err_string.contains("shape mismatch") ||
|
|| err_string.contains("shape mismatch")
|
||||||
err_string.contains("broadcast_add")) &&
|
|| err_string.contains("broadcast_add"))
|
||||||
self.cpu_device.is_some() {
|
&& self.cpu_device.is_some()
|
||||||
|
{
|
||||||
// Extract operation name for better logging
|
// Extract operation name for better logging
|
||||||
let op_name = if let Some(idx) = err_string.find("for ") {
|
let op_name = if let Some(idx) = err_string.find("for ") {
|
||||||
&err_string[(idx + 4)..]
|
&err_string[(idx + 4)..]
|
||||||
@@ -99,7 +99,10 @@ impl TextGeneration {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Log the fallback
|
// Log the fallback
|
||||||
tracing::warn!("The primary device does not support {}. Falling back to CPU.", op_name);
|
tracing::warn!(
|
||||||
|
"The primary device does not support {}. Falling back to CPU.",
|
||||||
|
op_name
|
||||||
|
);
|
||||||
|
|
||||||
// Move input to CPU and try again
|
// Move input to CPU and try again
|
||||||
let cpu_device = self.cpu_device.as_ref().unwrap();
|
let cpu_device = self.cpu_device.as_ref().unwrap();
|
||||||
@@ -108,7 +111,9 @@ impl TextGeneration {
|
|||||||
|
|
||||||
// Don't try primary device for future operations
|
// Don't try primary device for future operations
|
||||||
self.try_primary_device = false;
|
self.try_primary_device = false;
|
||||||
tracing::info!("Successfully executed on CPU. Will use CPU for subsequent operations.");
|
tracing::info!(
|
||||||
|
"Successfully executed on CPU. Will use CPU for subsequent operations."
|
||||||
|
);
|
||||||
|
|
||||||
// Move result back to original device
|
// Move result back to original device
|
||||||
cpu_result.to_device(&self.device).map_err(E::msg)
|
cpu_result.to_device(&self.device).map_err(E::msg)
|
||||||
@@ -174,8 +179,12 @@ impl TextGeneration {
|
|||||||
// Log cache efficiency statistics
|
// Log cache efficiency statistics
|
||||||
if !penalty_tokens.is_empty() {
|
if !penalty_tokens.is_empty() {
|
||||||
let cache_efficiency = (cache_hits.get() as f32 / penalty_tokens.len() as f32) * 100.0;
|
let cache_efficiency = (cache_hits.get() as f32 / penalty_tokens.len() as f32) * 100.0;
|
||||||
tracing::trace!("Repeat penalty cache hits: {}/{} ({:.1}%)",
|
tracing::trace!(
|
||||||
cache_hits.get(), penalty_tokens.len(), cache_efficiency);
|
"Repeat penalty cache hits: {}/{} ({:.1}%)",
|
||||||
|
cache_hits.get(),
|
||||||
|
penalty_tokens.len(),
|
||||||
|
cache_efficiency
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new tensor with the modified logits (single tensor creation)
|
// Create a new tensor with the modified logits (single tensor creation)
|
||||||
@@ -201,7 +210,9 @@ impl TextGeneration {
|
|||||||
self.penalty_cache.clear();
|
self.penalty_cache.clear();
|
||||||
tracing::debug!("Cleared penalty cache due to size limit");
|
tracing::debug!("Cleared penalty cache due to size limit");
|
||||||
} else {
|
} else {
|
||||||
tracing::debug!("Maintaining penalty cache across generation for better repetition prevention");
|
tracing::debug!(
|
||||||
|
"Maintaining penalty cache across generation for better repetition prevention"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Phase 1: Tokenize input
|
// Phase 1: Tokenize input
|
||||||
@@ -280,7 +291,8 @@ impl TextGeneration {
|
|||||||
let token_start = std::time::Instant::now();
|
let token_start = std::time::Instant::now();
|
||||||
|
|
||||||
// Apply repeat penalty using optimized cached implementation
|
// Apply repeat penalty using optimized cached implementation
|
||||||
let (current_logits, repeat_time) = self.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
let (current_logits, repeat_time) =
|
||||||
|
self.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||||
repeat_penalty_times.push(repeat_time);
|
repeat_penalty_times.push(repeat_time);
|
||||||
|
|
||||||
// Track token sampling
|
// Track token sampling
|
||||||
@@ -320,43 +332,43 @@ impl TextGeneration {
|
|||||||
tracing::debug!("Using standard generation approach");
|
tracing::debug!("Using standard generation approach");
|
||||||
|
|
||||||
for index in 0..sample_len {
|
for index in 0..sample_len {
|
||||||
let token_start = std::time::Instant::now();
|
let token_start = std::time::Instant::now();
|
||||||
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
let ctxt = &tokens[start_pos..];
|
let ctxt = &tokens[start_pos..];
|
||||||
|
|
||||||
// Track tensor operations and model forward pass
|
// Track tensor operations and model forward pass
|
||||||
let forward_start = std::time::Instant::now();
|
let forward_start = std::time::Instant::now();
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
let logits = self.execute_with_fallback(&input, start_pos)?;
|
let logits = self.execute_with_fallback(&input, start_pos)?;
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
let forward_time = forward_start.elapsed();
|
let forward_time = forward_start.elapsed();
|
||||||
forward_times.push(forward_time);
|
forward_times.push(forward_time);
|
||||||
|
|
||||||
// Apply repeat penalty using optimized cached implementation
|
// Apply repeat penalty using optimized cached implementation
|
||||||
let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?;
|
let (logits, repeat_time) = self.apply_cached_repeat_penalty(logits, &tokens)?;
|
||||||
repeat_penalty_times.push(repeat_time);
|
repeat_penalty_times.push(repeat_time);
|
||||||
|
|
||||||
// Track token sampling
|
// Track token sampling
|
||||||
let sampling_start = std::time::Instant::now();
|
let sampling_start = std::time::Instant::now();
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
let sampling_time = sampling_start.elapsed();
|
let sampling_time = sampling_start.elapsed();
|
||||||
sampling_times.push(sampling_time);
|
sampling_times.push(sampling_time);
|
||||||
|
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
generated_tokens += 1;
|
generated_tokens += 1;
|
||||||
if next_token == eos_token || next_token == eot_token {
|
if next_token == eos_token || next_token == eot_token {
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let token_time = token_start.elapsed();
|
||||||
|
token_times.push(token_time);
|
||||||
}
|
}
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let token_time = token_start.elapsed();
|
|
||||||
token_times.push(token_time);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let dt = start_gen.elapsed();
|
let dt = start_gen.elapsed();
|
||||||
@@ -387,7 +399,8 @@ impl TextGeneration {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let avg_repeat_time = if !repeat_penalty_times.is_empty() {
|
let avg_repeat_time = if !repeat_penalty_times.is_empty() {
|
||||||
repeat_penalty_times.iter().sum::<std::time::Duration>() / repeat_penalty_times.len() as u32
|
repeat_penalty_times.iter().sum::<std::time::Duration>()
|
||||||
|
/ repeat_penalty_times.len() as u32
|
||||||
} else {
|
} else {
|
||||||
std::time::Duration::from_secs(0)
|
std::time::Duration::from_secs(0)
|
||||||
};
|
};
|
||||||
@@ -409,15 +422,18 @@ impl TextGeneration {
|
|||||||
tracing::info!("Tokens generated: {}", generated_tokens);
|
tracing::info!("Tokens generated: {}", generated_tokens);
|
||||||
tracing::info!("Generation speed: {:.2} tokens/second", tokens_per_second);
|
tracing::info!("Generation speed: {:.2} tokens/second", tokens_per_second);
|
||||||
tracing::info!("Average time per token: {:.2?}", avg_token_time);
|
tracing::info!("Average time per token: {:.2?}", avg_token_time);
|
||||||
tracing::debug!(" - Forward pass: {:.2?} ({:.1}%)",
|
tracing::debug!(
|
||||||
|
" - Forward pass: {:.2?} ({:.1}%)",
|
||||||
avg_forward_time,
|
avg_forward_time,
|
||||||
avg_forward_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
avg_forward_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
||||||
);
|
);
|
||||||
tracing::debug!(" - Repeat penalty: {:.2?} ({:.1}%)",
|
tracing::debug!(
|
||||||
|
" - Repeat penalty: {:.2?} ({:.1}%)",
|
||||||
avg_repeat_time,
|
avg_repeat_time,
|
||||||
avg_repeat_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
avg_repeat_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
||||||
);
|
);
|
||||||
tracing::debug!(" - Sampling: {:.2?} ({:.1}%)",
|
tracing::debug!(
|
||||||
|
" - Sampling: {:.2?} ({:.1}%)",
|
||||||
avg_sampling_time,
|
avg_sampling_time,
|
||||||
avg_sampling_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
avg_sampling_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
||||||
);
|
);
|
||||||
@@ -425,15 +441,18 @@ impl TextGeneration {
|
|||||||
// Log total request time
|
// Log total request time
|
||||||
let total_time = start_time.elapsed();
|
let total_time = start_time.elapsed();
|
||||||
tracing::info!("Total request time: {:.2?}", total_time);
|
tracing::info!("Total request time: {:.2?}", total_time);
|
||||||
tracing::debug!(" - Tokenization: {:.2?} ({:.1}%)",
|
tracing::debug!(
|
||||||
|
" - Tokenization: {:.2?} ({:.1}%)",
|
||||||
tokenize_time,
|
tokenize_time,
|
||||||
tokenize_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
tokenize_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
||||||
);
|
);
|
||||||
tracing::debug!(" - Generation: {:.2?} ({:.1}%)",
|
tracing::debug!(
|
||||||
|
" - Generation: {:.2?} ({:.1}%)",
|
||||||
dt,
|
dt,
|
||||||
dt.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
dt.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
||||||
);
|
);
|
||||||
tracing::debug!(" - Final decoding: {:.2?} ({:.1}%)",
|
tracing::debug!(
|
||||||
|
" - Final decoding: {:.2?} ({:.1}%)",
|
||||||
decode_time,
|
decode_time,
|
||||||
decode_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
decode_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
||||||
);
|
);
|
||||||
@@ -442,7 +461,12 @@ impl TextGeneration {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run text generation and write to a buffer
|
// Run text generation and write to a buffer
|
||||||
pub fn run_with_output(&mut self, prompt: &str, sample_len: usize, output: &mut Vec<u8>) -> Result<()> {
|
pub fn run_with_output(
|
||||||
|
&mut self,
|
||||||
|
prompt: &str,
|
||||||
|
sample_len: usize,
|
||||||
|
output: &mut Vec<u8>,
|
||||||
|
) -> Result<()> {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
// Track overall performance
|
// Track overall performance
|
||||||
@@ -488,7 +512,10 @@ impl TextGeneration {
|
|||||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||||
Some(token) => token,
|
Some(token) => token,
|
||||||
None => {
|
None => {
|
||||||
write!(output, "Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup")?;
|
write!(
|
||||||
|
output,
|
||||||
|
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
|
||||||
|
)?;
|
||||||
eos_token
|
eos_token
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -533,7 +560,8 @@ impl TextGeneration {
|
|||||||
let token_start = std::time::Instant::now();
|
let token_start = std::time::Instant::now();
|
||||||
|
|
||||||
// Apply repeat penalty using optimized cached implementation
|
// Apply repeat penalty using optimized cached implementation
|
||||||
let (current_logits, repeat_time) = self.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
let (current_logits, repeat_time) =
|
||||||
|
self.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||||
repeat_penalty_times.push(repeat_time);
|
repeat_penalty_times.push(repeat_time);
|
||||||
|
|
||||||
// Track token sampling
|
// Track token sampling
|
||||||
@@ -572,9 +600,16 @@ impl TextGeneration {
|
|||||||
|
|
||||||
// Calculate and log performance metrics
|
// Calculate and log performance metrics
|
||||||
Self::log_performance_metrics(
|
Self::log_performance_metrics(
|
||||||
dt, generated_tokens, &token_times, &forward_times,
|
dt,
|
||||||
&repeat_penalty_times, &sampling_times, tokenize_time,
|
generated_tokens,
|
||||||
std::time::Duration::from_secs(0), start_time, "API"
|
&token_times,
|
||||||
|
&forward_times,
|
||||||
|
&repeat_penalty_times,
|
||||||
|
&sampling_times,
|
||||||
|
tokenize_time,
|
||||||
|
std::time::Duration::from_secs(0),
|
||||||
|
start_time,
|
||||||
|
"API",
|
||||||
);
|
);
|
||||||
|
|
||||||
return Ok(());
|
return Ok(());
|
||||||
@@ -595,8 +630,11 @@ impl TextGeneration {
|
|||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
let ctxt = &tokens[start_pos..];
|
let ctxt = &tokens[start_pos..];
|
||||||
|
|
||||||
tracing::debug!("API standard model: Using sliding window context: {} tokens (from position {})",
|
tracing::debug!(
|
||||||
ctxt.len(), start_pos);
|
"API standard model: Using sliding window context: {} tokens (from position {})",
|
||||||
|
ctxt.len(),
|
||||||
|
start_pos
|
||||||
|
);
|
||||||
|
|
||||||
// Track tensor operations and model forward pass
|
// Track tensor operations and model forward pass
|
||||||
let forward_start = std::time::Instant::now();
|
let forward_start = std::time::Instant::now();
|
||||||
@@ -643,16 +681,28 @@ impl TextGeneration {
|
|||||||
|
|
||||||
// Log performance metrics
|
// Log performance metrics
|
||||||
Self::log_performance_metrics(
|
Self::log_performance_metrics(
|
||||||
dt, generated_tokens, &token_times, &forward_times,
|
dt,
|
||||||
&repeat_penalty_times, &sampling_times, tokenize_time,
|
generated_tokens,
|
||||||
decode_time, start_time, "API"
|
&token_times,
|
||||||
|
&forward_times,
|
||||||
|
&repeat_penalty_times,
|
||||||
|
&sampling_times,
|
||||||
|
tokenize_time,
|
||||||
|
decode_time,
|
||||||
|
start_time,
|
||||||
|
"API",
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run text generation with streaming callback for each token
|
// Run text generation with streaming callback for each token
|
||||||
pub async fn run_with_streaming<F>(&mut self, prompt: &str, sample_len: usize, mut token_callback: F) -> Result<String>
|
pub async fn run_with_streaming<F>(
|
||||||
|
&mut self,
|
||||||
|
prompt: &str,
|
||||||
|
sample_len: usize,
|
||||||
|
mut token_callback: F,
|
||||||
|
) -> Result<String>
|
||||||
where
|
where
|
||||||
F: FnMut(&str) -> Result<()>,
|
F: FnMut(&str) -> Result<()>,
|
||||||
{
|
{
|
||||||
@@ -695,7 +745,9 @@ impl TextGeneration {
|
|||||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||||
Some(token) => token,
|
Some(token) => token,
|
||||||
None => {
|
None => {
|
||||||
tracing::warn!("Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup");
|
tracing::warn!(
|
||||||
|
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
|
||||||
|
);
|
||||||
eos_token
|
eos_token
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -718,7 +770,9 @@ impl TextGeneration {
|
|||||||
|
|
||||||
// For Model2 and Model3, we need to use a special approach for shape compatibility
|
// For Model2 and Model3, we need to use a special approach for shape compatibility
|
||||||
if needs_special_handling {
|
if needs_special_handling {
|
||||||
tracing::debug!("Using special generation approach for gemma-2/gemma-3 models (streaming)");
|
tracing::debug!(
|
||||||
|
"Using special generation approach for gemma-2/gemma-3 models (streaming)"
|
||||||
|
);
|
||||||
tracing::debug!("Streaming: sample_len = {}", sample_len);
|
tracing::debug!("Streaming: sample_len = {}", sample_len);
|
||||||
|
|
||||||
// Initial generation with the full prompt
|
// Initial generation with the full prompt
|
||||||
@@ -731,13 +785,21 @@ impl TextGeneration {
|
|||||||
let forward_time = forward_start.elapsed();
|
let forward_time = forward_start.elapsed();
|
||||||
forward_times.push(forward_time);
|
forward_times.push(forward_time);
|
||||||
|
|
||||||
tracing::debug!("Streaming: About to enter generation loop with sample_len = {}", sample_len);
|
tracing::debug!(
|
||||||
|
"Streaming: About to enter generation loop with sample_len = {}",
|
||||||
|
sample_len
|
||||||
|
);
|
||||||
for gen_index in 0..sample_len {
|
for gen_index in 0..sample_len {
|
||||||
tracing::debug!("Streaming: Starting generation iteration {} / {}", gen_index + 1, sample_len);
|
tracing::debug!(
|
||||||
|
"Streaming: Starting generation iteration {} / {}",
|
||||||
|
gen_index + 1,
|
||||||
|
sample_len
|
||||||
|
);
|
||||||
let token_start = std::time::Instant::now();
|
let token_start = std::time::Instant::now();
|
||||||
|
|
||||||
// Apply repeat penalty using optimized cached implementation
|
// Apply repeat penalty using optimized cached implementation
|
||||||
let (current_logits, repeat_time) = self.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
let (current_logits, repeat_time) =
|
||||||
|
self.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||||
repeat_penalty_times.push(repeat_time);
|
repeat_penalty_times.push(repeat_time);
|
||||||
|
|
||||||
// Track token sampling
|
// Track token sampling
|
||||||
@@ -749,8 +811,13 @@ impl TextGeneration {
|
|||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
generated_tokens += 1;
|
generated_tokens += 1;
|
||||||
|
|
||||||
tracing::debug!("Streaming: Generated token {} (id: {}), eos: {}, eot: {}",
|
tracing::debug!(
|
||||||
next_token, next_token, eos_token, eot_token);
|
"Streaming: Generated token {} (id: {}), eos: {}, eot: {}",
|
||||||
|
next_token,
|
||||||
|
next_token,
|
||||||
|
eos_token,
|
||||||
|
eot_token
|
||||||
|
);
|
||||||
if next_token == eos_token || next_token == eot_token {
|
if next_token == eos_token || next_token == eot_token {
|
||||||
tracing::debug!("Streaming: Breaking due to end token");
|
tracing::debug!("Streaming: Breaking due to end token");
|
||||||
break;
|
break;
|
||||||
@@ -764,15 +831,21 @@ impl TextGeneration {
|
|||||||
|
|
||||||
// For the next iteration, use single token to avoid shape mismatch
|
// For the next iteration, use single token to avoid shape mismatch
|
||||||
let forward_start = std::time::Instant::now();
|
let forward_start = std::time::Instant::now();
|
||||||
tracing::debug!("Streaming: Preparing next forward pass with {} tokens", tokens.len());
|
tracing::debug!(
|
||||||
|
"Streaming: Preparing next forward pass with {} tokens",
|
||||||
|
tokens.len()
|
||||||
|
);
|
||||||
|
|
||||||
// Use just the last token for subsequent iterations to avoid shape mismatch
|
// Use just the last token for subsequent iterations to avoid shape mismatch
|
||||||
// This is required for Gemma model's attention mechanism compatibility
|
// This is required for Gemma model's attention mechanism compatibility
|
||||||
let context_tokens = &tokens[(tokens.len()-1)..];
|
let context_tokens = &tokens[(tokens.len() - 1)..];
|
||||||
let start_pos = tokens.len() - 1;
|
let start_pos = tokens.len() - 1;
|
||||||
|
|
||||||
tracing::debug!("Streaming: Using single token context for Gemma: {} tokens (from position {})",
|
tracing::debug!(
|
||||||
context_tokens.len(), start_pos);
|
"Streaming: Using single token context for Gemma: {} tokens (from position {})",
|
||||||
|
context_tokens.len(),
|
||||||
|
start_pos
|
||||||
|
);
|
||||||
|
|
||||||
let new_input = match Tensor::new(context_tokens, &self.device) {
|
let new_input = match Tensor::new(context_tokens, &self.device) {
|
||||||
Ok(tensor) => tensor,
|
Ok(tensor) => tensor,
|
||||||
@@ -825,7 +898,10 @@ impl TextGeneration {
|
|||||||
|
|
||||||
let forward_time = forward_start.elapsed();
|
let forward_time = forward_start.elapsed();
|
||||||
forward_times.push(forward_time);
|
forward_times.push(forward_time);
|
||||||
tracing::debug!("Streaming: Forward pass completed for iteration {}", gen_index + 1);
|
tracing::debug!(
|
||||||
|
"Streaming: Forward pass completed for iteration {}",
|
||||||
|
gen_index + 1
|
||||||
|
);
|
||||||
|
|
||||||
let token_time = token_start.elapsed();
|
let token_time = token_start.elapsed();
|
||||||
token_times.push(token_time);
|
token_times.push(token_time);
|
||||||
@@ -849,8 +925,11 @@ impl TextGeneration {
|
|||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
let ctxt = &tokens[start_pos..];
|
let ctxt = &tokens[start_pos..];
|
||||||
|
|
||||||
tracing::debug!("Standard model: Using sliding window context: {} tokens (from position {})",
|
tracing::debug!(
|
||||||
ctxt.len(), start_pos);
|
"Standard model: Using sliding window context: {} tokens (from position {})",
|
||||||
|
ctxt.len(),
|
||||||
|
start_pos
|
||||||
|
);
|
||||||
|
|
||||||
// Track tensor operations and model forward pass
|
// Track tensor operations and model forward pass
|
||||||
let forward_start = std::time::Instant::now();
|
let forward_start = std::time::Instant::now();
|
||||||
@@ -903,9 +982,16 @@ impl TextGeneration {
|
|||||||
|
|
||||||
// Log performance metrics
|
// Log performance metrics
|
||||||
Self::log_performance_metrics(
|
Self::log_performance_metrics(
|
||||||
dt, generated_tokens, &token_times, &forward_times,
|
dt,
|
||||||
&repeat_penalty_times, &sampling_times, tokenize_time,
|
generated_tokens,
|
||||||
decode_time, start_time, "Streaming"
|
&token_times,
|
||||||
|
&forward_times,
|
||||||
|
&repeat_penalty_times,
|
||||||
|
&sampling_times,
|
||||||
|
tokenize_time,
|
||||||
|
decode_time,
|
||||||
|
start_time,
|
||||||
|
"Streaming",
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(full_output)
|
Ok(full_output)
|
||||||
@@ -945,7 +1031,8 @@ impl TextGeneration {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let avg_repeat_time = if !repeat_penalty_times.is_empty() {
|
let avg_repeat_time = if !repeat_penalty_times.is_empty() {
|
||||||
repeat_penalty_times.iter().sum::<std::time::Duration>() / repeat_penalty_times.len() as u32
|
repeat_penalty_times.iter().sum::<std::time::Duration>()
|
||||||
|
/ repeat_penalty_times.len() as u32
|
||||||
} else {
|
} else {
|
||||||
std::time::Duration::from_secs(0)
|
std::time::Duration::from_secs(0)
|
||||||
};
|
};
|
||||||
@@ -957,23 +1044,34 @@ impl TextGeneration {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Record detailed performance metrics
|
// Record detailed performance metrics
|
||||||
tracing::info!("{} Text generation completed in {:.2?}", prefix, generation_time);
|
tracing::info!(
|
||||||
|
"{} Text generation completed in {:.2?}",
|
||||||
|
prefix,
|
||||||
|
generation_time
|
||||||
|
);
|
||||||
tracing::info!("{} Tokens generated: {}", prefix, generated_tokens);
|
tracing::info!("{} Tokens generated: {}", prefix, generated_tokens);
|
||||||
tracing::info!("{} Generation speed: {:.2} tokens/second", prefix, tokens_per_second);
|
tracing::info!(
|
||||||
|
"{} Generation speed: {:.2} tokens/second",
|
||||||
|
prefix,
|
||||||
|
tokens_per_second
|
||||||
|
);
|
||||||
tracing::info!("{} Average time per token: {:.2?}", prefix, avg_token_time);
|
tracing::info!("{} Average time per token: {:.2?}", prefix, avg_token_time);
|
||||||
|
|
||||||
if !avg_token_time.is_zero() {
|
if !avg_token_time.is_zero() {
|
||||||
tracing::debug!("{} - Forward pass: {:.2?} ({:.1}%)",
|
tracing::debug!(
|
||||||
|
"{} - Forward pass: {:.2?} ({:.1}%)",
|
||||||
prefix,
|
prefix,
|
||||||
avg_forward_time,
|
avg_forward_time,
|
||||||
avg_forward_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
avg_forward_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
||||||
);
|
);
|
||||||
tracing::debug!("{} - Repeat penalty: {:.2?} ({:.1}%)",
|
tracing::debug!(
|
||||||
|
"{} - Repeat penalty: {:.2?} ({:.1}%)",
|
||||||
prefix,
|
prefix,
|
||||||
avg_repeat_time,
|
avg_repeat_time,
|
||||||
avg_repeat_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
avg_repeat_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
||||||
);
|
);
|
||||||
tracing::debug!("{} - Sampling: {:.2?} ({:.1}%)",
|
tracing::debug!(
|
||||||
|
"{} - Sampling: {:.2?} ({:.1}%)",
|
||||||
prefix,
|
prefix,
|
||||||
avg_sampling_time,
|
avg_sampling_time,
|
||||||
avg_sampling_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
avg_sampling_time.as_secs_f64() / avg_token_time.as_secs_f64() * 100.0
|
||||||
@@ -985,17 +1083,20 @@ impl TextGeneration {
|
|||||||
tracing::info!("{} Total request time: {:.2?}", prefix, total_time);
|
tracing::info!("{} Total request time: {:.2?}", prefix, total_time);
|
||||||
|
|
||||||
if !total_time.is_zero() {
|
if !total_time.is_zero() {
|
||||||
tracing::debug!("{} - Tokenization: {:.2?} ({:.1}%)",
|
tracing::debug!(
|
||||||
|
"{} - Tokenization: {:.2?} ({:.1}%)",
|
||||||
prefix,
|
prefix,
|
||||||
tokenize_time,
|
tokenize_time,
|
||||||
tokenize_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
tokenize_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
||||||
);
|
);
|
||||||
tracing::debug!("{} - Generation: {:.2?} ({:.1}%)",
|
tracing::debug!(
|
||||||
|
"{} - Generation: {:.2?} ({:.1}%)",
|
||||||
prefix,
|
prefix,
|
||||||
generation_time,
|
generation_time,
|
||||||
generation_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
generation_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
||||||
);
|
);
|
||||||
tracing::debug!("{} - Final decoding: {:.2?} ({:.1}%)",
|
tracing::debug!(
|
||||||
|
"{} - Final decoding: {:.2?} ({:.1}%)",
|
||||||
prefix,
|
prefix,
|
||||||
decode_time,
|
decode_time,
|
||||||
decode_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
decode_time.as_secs_f64() / total_time.as_secs_f64() * 100.0
|
||||||
|
@@ -147,7 +147,8 @@ pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
|
|||||||
) -> Result<Vec<std::path::PathBuf>> {
|
) -> Result<Vec<std::path::PathBuf>> {
|
||||||
let path = path.as_ref();
|
let path = path.as_ref();
|
||||||
let jsfile = std::fs::File::open(path.join(json_file))?;
|
let jsfile = std::fs::File::open(path.join(json_file))?;
|
||||||
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
let json: serde_json::Value =
|
||||||
|
serde_json::from_reader(&jsfile).map_err(candle_core::Error::wrap)?;
|
||||||
let weight_map = match json.get("weight_map") {
|
let weight_map = match json.get("weight_map") {
|
||||||
None => candle_core::bail!("no weight map in {json_file:?}"),
|
None => candle_core::bail!("no weight map in {json_file:?}"),
|
||||||
Some(serde_json::Value::Object(map)) => map,
|
Some(serde_json::Value::Object(map)) => map,
|
||||||
|
@@ -9,7 +9,10 @@ mod tests {
|
|||||||
// Test a few representative model variants
|
// Test a few representative model variants
|
||||||
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
|
assert_eq!(Which::Base2B.to_model_id(), "google/gemma-2b");
|
||||||
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
|
assert_eq!(Which::Instruct7B.to_model_id(), "google/gemma-7b-it");
|
||||||
assert_eq!(Which::InstructV1_1_2B.to_model_id(), "google/gemma-1.1-2b-it");
|
assert_eq!(
|
||||||
|
Which::InstructV1_1_2B.to_model_id(),
|
||||||
|
"google/gemma-1.1-2b-it"
|
||||||
|
);
|
||||||
assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b");
|
assert_eq!(Which::CodeBase2B.to_model_id(), "google/codegemma-2b");
|
||||||
assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b");
|
assert_eq!(Which::BaseV2_2B.to_model_id(), "google/gemma-2-2b");
|
||||||
assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it");
|
assert_eq!(Which::InstructV3_1B.to_model_id(), "google/gemma-3-1b-it");
|
||||||
|
@@ -174,7 +174,8 @@ mod tests {
|
|||||||
penalty_cache: HashMap::new(),
|
penalty_cache: HashMap::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
let (result_logits, _duration) =
|
||||||
|
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||||
let result_data = result_logits.to_vec1::<f32>()?;
|
let result_data = result_logits.to_vec1::<f32>()?;
|
||||||
|
|
||||||
// With no penalty, logits should be unchanged
|
// With no penalty, logits should be unchanged
|
||||||
@@ -245,7 +246,8 @@ mod tests {
|
|||||||
penalty_cache: HashMap::new(),
|
penalty_cache: HashMap::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
let (result_logits, _duration) =
|
||||||
|
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||||
let result_data = result_logits.to_vec1::<f32>()?;
|
let result_data = result_logits.to_vec1::<f32>()?;
|
||||||
|
|
||||||
// Tokens 1, 2, 3 should be penalized (divided by 2.0)
|
// Tokens 1, 2, 3 should be penalized (divided by 2.0)
|
||||||
@@ -316,7 +318,8 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// First call should cache the penalty for token 1
|
// First call should cache the penalty for token 1
|
||||||
let (_result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
let (_result_logits, _duration) =
|
||||||
|
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||||
|
|
||||||
// Cache should contain the penalized value for token 1
|
// Cache should contain the penalized value for token 1
|
||||||
assert!(mock_gen.penalty_cache.contains_key(&1));
|
assert!(mock_gen.penalty_cache.contains_key(&1));
|
||||||
@@ -386,7 +389,8 @@ mod tests {
|
|||||||
penalty_cache: HashMap::new(),
|
penalty_cache: HashMap::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
let (result_logits, _duration) =
|
||||||
|
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||||
let result_data = result_logits.to_vec1::<f32>()?;
|
let result_data = result_logits.to_vec1::<f32>()?;
|
||||||
|
|
||||||
// With empty tokens, logits should be unchanged
|
// With empty tokens, logits should be unchanged
|
||||||
@@ -455,7 +459,8 @@ mod tests {
|
|||||||
penalty_cache: HashMap::new(),
|
penalty_cache: HashMap::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let (result_logits, _duration) = mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
let (result_logits, _duration) =
|
||||||
|
mock_gen.apply_cached_repeat_penalty(logits.clone(), &tokens)?;
|
||||||
let result_data = result_logits.to_vec1::<f32>()?;
|
let result_data = result_logits.to_vec1::<f32>()?;
|
||||||
|
|
||||||
// Only token 1 should be penalized, out-of-bounds tokens should be ignored
|
// Only token 1 should be penalized, out-of-bounds tokens should be ignored
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
use inference_engine::token_output_stream::TokenOutputStream;
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
use std::path::PathBuf;
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use inference_engine::token_output_stream::TokenOutputStream;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
@@ -66,7 +66,10 @@ mod tests {
|
|||||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||||
|
|
||||||
// Get some tokens
|
// Get some tokens
|
||||||
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
let hello_tokens = token_stream
|
||||||
|
.tokenizer()
|
||||||
|
.encode("Hello world", true)
|
||||||
|
.unwrap();
|
||||||
let token_ids = hello_tokens.get_ids();
|
let token_ids = hello_tokens.get_ids();
|
||||||
|
|
||||||
// Add tokens one by one
|
// Add tokens one by one
|
||||||
@@ -95,7 +98,10 @@ mod tests {
|
|||||||
let mut token_stream = TokenOutputStream::new(tokenizer);
|
let mut token_stream = TokenOutputStream::new(tokenizer);
|
||||||
|
|
||||||
// Get some tokens
|
// Get some tokens
|
||||||
let hello_tokens = token_stream.tokenizer().encode("Hello world", true).unwrap();
|
let hello_tokens = token_stream
|
||||||
|
.tokenizer()
|
||||||
|
.encode("Hello world", true)
|
||||||
|
.unwrap();
|
||||||
let token_ids = hello_tokens.get_ids();
|
let token_ids = hello_tokens.get_ids();
|
||||||
|
|
||||||
// Add tokens one by one
|
// Add tokens one by one
|
||||||
|
@@ -5,6 +5,25 @@ use leptos_router::{
|
|||||||
StaticSegment,
|
StaticSegment,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[cfg(feature = "hydrate")]
|
||||||
|
use async_openai_wasm::config::OpenAIConfig;
|
||||||
|
#[cfg(feature = "hydrate")]
|
||||||
|
use async_openai_wasm::types::{FinishReason, Role};
|
||||||
|
#[cfg(feature = "hydrate")]
|
||||||
|
use async_openai_wasm::{
|
||||||
|
types::{
|
||||||
|
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
|
||||||
|
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs,
|
||||||
|
Model as OpenAIModel,
|
||||||
|
},
|
||||||
|
Client,
|
||||||
|
};
|
||||||
|
#[cfg(feature = "hydrate")]
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
#[cfg(feature = "hydrate")]
|
||||||
|
use js_sys::Date;
|
||||||
|
#[cfg(feature = "hydrate")]
|
||||||
|
use leptos::task::spawn_local;
|
||||||
#[cfg(feature = "hydrate")]
|
#[cfg(feature = "hydrate")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
#[cfg(feature = "hydrate")]
|
#[cfg(feature = "hydrate")]
|
||||||
@@ -12,25 +31,7 @@ use std::collections::VecDeque;
|
|||||||
#[cfg(feature = "hydrate")]
|
#[cfg(feature = "hydrate")]
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
#[cfg(feature = "hydrate")]
|
#[cfg(feature = "hydrate")]
|
||||||
use js_sys::Date;
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use web_sys::{HtmlInputElement, KeyboardEvent, SubmitEvent};
|
use web_sys::{HtmlInputElement, KeyboardEvent, SubmitEvent};
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use futures_util::StreamExt;
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use async_openai_wasm::{
|
|
||||||
types::{
|
|
||||||
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
|
|
||||||
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, Model as OpenAIModel,
|
|
||||||
},
|
|
||||||
Client,
|
|
||||||
};
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use async_openai_wasm::config::OpenAIConfig;
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use async_openai_wasm::types::{Role, FinishReason};
|
|
||||||
#[cfg(feature = "hydrate")]
|
|
||||||
use leptos::task::spawn_local;
|
|
||||||
|
|
||||||
#[cfg(feature = "hydrate")]
|
#[cfg(feature = "hydrate")]
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -43,11 +44,15 @@ pub struct Message {
|
|||||||
|
|
||||||
#[cfg(feature = "hydrate")]
|
#[cfg(feature = "hydrate")]
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct MessageContent(pub either::Either<String, Vec<std::collections::HashMap<String, MessageInnerContent>>>);
|
pub struct MessageContent(
|
||||||
|
pub either::Either<String, Vec<std::collections::HashMap<String, MessageInnerContent>>>,
|
||||||
|
);
|
||||||
|
|
||||||
#[cfg(feature = "hydrate")]
|
#[cfg(feature = "hydrate")]
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct MessageInnerContent(pub either::Either<String, std::collections::HashMap<String, String>>);
|
pub struct MessageInnerContent(
|
||||||
|
pub either::Either<String, std::collections::HashMap<String, String>>,
|
||||||
|
);
|
||||||
|
|
||||||
#[cfg(feature = "hydrate")]
|
#[cfg(feature = "hydrate")]
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -62,7 +67,9 @@ const DEFAULT_MODEL: &str = "default";
|
|||||||
|
|
||||||
#[cfg(feature = "hydrate")]
|
#[cfg(feature = "hydrate")]
|
||||||
async fn fetch_available_models() -> Result<Vec<OpenAIModel>, String> {
|
async fn fetch_available_models() -> Result<Vec<OpenAIModel>, String> {
|
||||||
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Starting model fetch from http://localhost:8080/v1");
|
leptos::logging::log!(
|
||||||
|
"[DEBUG_LOG] fetch_available_models: Starting model fetch from http://localhost:8080/v1"
|
||||||
|
);
|
||||||
|
|
||||||
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
|
let config = OpenAIConfig::new().with_api_base("http://localhost:8080/v1".to_string());
|
||||||
let client = Client::with_config(config);
|
let client = Client::with_config(config);
|
||||||
@@ -70,19 +77,30 @@ async fn fetch_available_models() -> Result<Vec<OpenAIModel>, String> {
|
|||||||
match client.models().list().await {
|
match client.models().list().await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
let model_count = response.data.len();
|
let model_count = response.data.len();
|
||||||
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Successfully fetched {} models", model_count);
|
leptos::logging::log!(
|
||||||
|
"[DEBUG_LOG] fetch_available_models: Successfully fetched {} models",
|
||||||
|
model_count
|
||||||
|
);
|
||||||
|
|
||||||
if model_count > 0 {
|
if model_count > 0 {
|
||||||
let model_names: Vec<String> = response.data.iter().map(|m| m.id.clone()).collect();
|
let model_names: Vec<String> = response.data.iter().map(|m| m.id.clone()).collect();
|
||||||
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Available models: {:?}", model_names);
|
leptos::logging::log!(
|
||||||
|
"[DEBUG_LOG] fetch_available_models: Available models: {:?}",
|
||||||
|
model_names
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: No models returned by server");
|
leptos::logging::log!(
|
||||||
|
"[DEBUG_LOG] fetch_available_models: No models returned by server"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(response.data)
|
Ok(response.data)
|
||||||
},
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
leptos::logging::log!("[DEBUG_LOG] fetch_available_models: Failed to fetch models: {:?}", e);
|
leptos::logging::log!(
|
||||||
|
"[DEBUG_LOG] fetch_available_models: Failed to fetch models: {:?}",
|
||||||
|
e
|
||||||
|
);
|
||||||
Err(format!("Failed to fetch models: {}", e))
|
Err(format!("Failed to fetch models: {}", e))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -335,7 +353,11 @@ fn ChatInterfaceImpl() -> impl IntoView {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
leptos::logging::log!("[DEBUG_LOG] send_message: Stream error after {} chunks: {:?}", chunks_received, e);
|
leptos::logging::log!(
|
||||||
|
"[DEBUG_LOG] send_message: Stream error after {} chunks: {:?}",
|
||||||
|
chunks_received,
|
||||||
|
e
|
||||||
|
);
|
||||||
set_messages.update(|msgs| {
|
set_messages.update(|msgs| {
|
||||||
msgs.push_back(Message {
|
msgs.push_back(Message {
|
||||||
id: Uuid::new_v4().to_string(),
|
id: Uuid::new_v4().to_string(),
|
||||||
@@ -364,7 +386,10 @@ fn ChatInterfaceImpl() -> impl IntoView {
|
|||||||
leptos::logging::log!("[DEBUG_LOG] send_message: Stream completed successfully, received {} chunks", chunks_received);
|
leptos::logging::log!("[DEBUG_LOG] send_message: Stream completed successfully, received {} chunks", chunks_received);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
leptos::logging::log!("[DEBUG_LOG] send_message: Request failed with error: {:?}", e);
|
leptos::logging::log!(
|
||||||
|
"[DEBUG_LOG] send_message: Request failed with error: {:?}",
|
||||||
|
e
|
||||||
|
);
|
||||||
let error_message = Message {
|
let error_message = Message {
|
||||||
id: Uuid::new_v4().to_string(),
|
id: Uuid::new_v4().to_string(),
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
@@ -404,7 +429,8 @@ fn ChatInterfaceImpl() -> impl IntoView {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let messages_list = move || {
|
let messages_list = move || {
|
||||||
messages.get()
|
messages
|
||||||
|
.get()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|message| {
|
.map(|message| {
|
||||||
let role_class = match message.role.as_str() {
|
let role_class = match message.role.as_str() {
|
||||||
|
@@ -10,10 +10,10 @@ pub fn hydrate() {
|
|||||||
|
|
||||||
#[cfg(feature = "ssr")]
|
#[cfg(feature = "ssr")]
|
||||||
pub fn create_leptos_router() -> axum::Router {
|
pub fn create_leptos_router() -> axum::Router {
|
||||||
|
use crate::app::*;
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
use leptos::prelude::*;
|
use leptos::prelude::*;
|
||||||
use leptos_axum::{generate_route_list, LeptosRoutes};
|
use leptos_axum::{generate_route_list, LeptosRoutes};
|
||||||
use crate::app::*;
|
|
||||||
|
|
||||||
let conf = get_configuration(None).unwrap();
|
let conf = get_configuration(None).unwrap();
|
||||||
let leptos_options = conf.leptos_options;
|
let leptos_options = conf.leptos_options;
|
||||||
|
@@ -1,12 +1,11 @@
|
|||||||
|
|
||||||
#[cfg(feature = "ssr")]
|
#[cfg(feature = "ssr")]
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
use leptos::logging::log;
|
use leptos::logging::log;
|
||||||
use leptos::prelude::*;
|
use leptos::prelude::*;
|
||||||
use leptos_axum::{generate_route_list, LeptosRoutes};
|
|
||||||
use leptos_app::app::*;
|
use leptos_app::app::*;
|
||||||
|
use leptos_axum::{generate_route_list, LeptosRoutes};
|
||||||
|
|
||||||
let conf = get_configuration(None).unwrap();
|
let conf = get_configuration(None).unwrap();
|
||||||
let addr = conf.leptos_options.site_addr;
|
let addr = conf.leptos_options.site_addr;
|
||||||
|
@@ -18,6 +18,11 @@ candle-core = { git = "https://github.com/huggingface/candle.git", features = ["
|
|||||||
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["metal"] }
|
||||||
|
|
||||||
|
[target.'cfg(not(target_os = "macos"))'.dependencies]
|
||||||
|
candle-core = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true }
|
||||||
|
candle-nn = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true }
|
||||||
|
candle-transformers = { git = "https://github.com/huggingface/candle.git", features = ["cuda"], optional = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||||
|
@@ -5,4 +5,3 @@ pub use llama_api::{run_llama_inference, LlamaInferenceConfig, WhichModel};
|
|||||||
|
|
||||||
// Re-export constants and types that might be needed
|
// Re-export constants and types that might be needed
|
||||||
pub const EOS_TOKEN: &str = "</s>";
|
pub const EOS_TOKEN: &str = "</s>";
|
||||||
|
|
||||||
|
@@ -1,14 +1,14 @@
|
|||||||
|
use crate::EOS_TOKEN;
|
||||||
use anyhow::{bail, Error as E};
|
use anyhow::{bail, Error as E};
|
||||||
use candle_core::{utils, DType, Device, Tensor};
|
use candle_core::{utils, DType, Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
use candle_transformers::models::llama::{Llama, LlamaConfig};
|
|
||||||
use candle_transformers::models::llama as model;
|
use candle_transformers::models::llama as model;
|
||||||
|
use candle_transformers::models::llama::{Llama, LlamaConfig};
|
||||||
|
use clap::ValueEnum;
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
use hf_hub::{Repo, RepoType};
|
use hf_hub::{Repo, RepoType};
|
||||||
use std::sync::mpsc::{self, Receiver};
|
use std::sync::mpsc::{self, Receiver};
|
||||||
use clap::ValueEnum;
|
|
||||||
use crate::{EOS_TOKEN};
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum, Default)]
|
||||||
pub enum WhichModel {
|
pub enum WhichModel {
|
||||||
@@ -81,8 +81,8 @@ impl Default for LlamaInferenceConfig {
|
|||||||
max_tokens: 512,
|
max_tokens: 512,
|
||||||
|
|
||||||
// Performance flags
|
// Performance flags
|
||||||
no_kv_cache: false, // keep cache ON for speed
|
no_kv_cache: false, // keep cache ON for speed
|
||||||
use_flash_attn: true, // great speed boost if supported
|
use_flash_attn: true, // great speed boost if supported
|
||||||
|
|
||||||
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
|
// Precision: bf16 is a good default on Ampere+; fallback to fp16 if needed.
|
||||||
dtype: Some("bf16".to_string()),
|
dtype: Some("bf16".to_string()),
|
||||||
@@ -98,8 +98,6 @@ impl Default for LlamaInferenceConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
fn device(cpu: bool) -> anyhow::Result<Device> {
|
fn device(cpu: bool) -> anyhow::Result<Device> {
|
||||||
if cpu {
|
if cpu {
|
||||||
Ok(Device::Cpu)
|
Ok(Device::Cpu)
|
||||||
@@ -112,7 +110,6 @@ fn device(cpu: bool) -> anyhow::Result<Device> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fn hub_load_safetensors(
|
fn hub_load_safetensors(
|
||||||
api: &hf_hub::api::sync::ApiRepo,
|
api: &hf_hub::api::sync::ApiRepo,
|
||||||
json_file: &str,
|
json_file: &str,
|
||||||
@@ -171,7 +168,7 @@ pub fn run_llama_inference(
|
|||||||
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
WhichModel::SmolLM2_1_7BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||||
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
WhichModel::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
}
|
}
|
||||||
.to_string()
|
.to_string()
|
||||||
});
|
});
|
||||||
println!("Loading model: {}", model_id);
|
println!("Loading model: {}", model_id);
|
||||||
let revision = cfg.revision.clone().unwrap_or("main".to_string());
|
let revision = cfg.revision.clone().unwrap_or("main".to_string());
|
||||||
@@ -334,4 +331,3 @@ pub fn run_llama_inference(
|
|||||||
|
|
||||||
Ok(rx)
|
Ok(rx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -88,7 +88,6 @@ impl Into<LlamaInferenceConfig> for Args {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub fn run_cli() -> anyhow::Result<()> {
|
pub fn run_cli() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let cfg = args.into();
|
let cfg = args.into();
|
||||||
|
@@ -2,8 +2,8 @@
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
mod llama_cli;
|
|
||||||
mod llama_api;
|
mod llama_api;
|
||||||
|
mod llama_cli;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
@@ -14,7 +14,6 @@ use crate::llama_cli::run_cli;
|
|||||||
|
|
||||||
const EOS_TOKEN: &str = "</s>";
|
const EOS_TOKEN: &str = "</s>";
|
||||||
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
run_cli()
|
run_cli()
|
||||||
}
|
}
|
@@ -1,7 +1,9 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::env;
|
use std::env;
|
||||||
|
use tracing::info;
|
||||||
|
use tracing::log::error;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ServerConfig {
|
pub struct ServerConfig {
|
||||||
#[serde(default = "default_server_host")]
|
#[serde(default = "default_server_host")]
|
||||||
@@ -10,14 +12,16 @@ pub struct ServerConfig {
|
|||||||
pub server_port: u16,
|
pub server_port: u16,
|
||||||
pub server_mode: ServerMode,
|
pub server_mode: ServerMode,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub services: Services,
|
pub services: Option<Services>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_server_host() -> String {
|
fn default_server_host() -> String {
|
||||||
"127.0.0.1".to_string()
|
"127.0.0.1".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_server_port() -> u16 { 8080 }
|
fn default_server_port() -> u16 {
|
||||||
|
8080
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||||
#[serde(rename_all = "PascalCase")]
|
#[serde(rename_all = "PascalCase")]
|
||||||
@@ -34,17 +38,15 @@ impl Default for ServerMode {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct Services {
|
pub struct Services {
|
||||||
#[serde(default = "inference_service_url")]
|
pub inference_url: Option<String>,
|
||||||
pub inference_url: String,
|
pub embeddings_url: Option<String>,
|
||||||
#[serde(default = "embeddings_service_url")]
|
|
||||||
pub embeddings_url: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Services {
|
impl Default for Services {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
inference_url: inference_service_url(),
|
inference_url: None,
|
||||||
embeddings_url: embeddings_service_url(),
|
embeddings_url: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -63,7 +65,7 @@ impl Default for ServerConfig {
|
|||||||
server_host: "127.0.0.1".to_string(),
|
server_host: "127.0.0.1".to_string(),
|
||||||
server_port: 8080,
|
server_port: 8080,
|
||||||
server_mode: ServerMode::Standalone,
|
server_mode: ServerMode::Standalone,
|
||||||
services: Services::default(),
|
services: Some(Services::default()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -73,21 +75,19 @@ impl ServerConfig {
|
|||||||
/// Falls back to default (Local mode) if not set or invalid
|
/// Falls back to default (Local mode) if not set or invalid
|
||||||
pub fn from_env() -> Self {
|
pub fn from_env() -> Self {
|
||||||
match env::var("SERVER_CONFIG") {
|
match env::var("SERVER_CONFIG") {
|
||||||
Ok(config_str) => {
|
Ok(config_str) => match serde_json::from_str::<ServerConfig>(&config_str) {
|
||||||
match serde_json::from_str::<ServerConfig>(&config_str) {
|
Ok(config) => {
|
||||||
Ok(config) => {
|
tracing::info!("Loaded server configuration: {:?}", config);
|
||||||
tracing::info!("Loaded server configuration: {:?}", config);
|
config
|
||||||
config
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::warn!(
|
|
||||||
"Failed to parse SERVER_CONFIG environment variable: {}. Using default configuration.",
|
|
||||||
e
|
|
||||||
);
|
|
||||||
ServerConfig::default()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to parse SERVER_CONFIG environment variable: {}. Using default configuration.",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
ServerConfig::default()
|
||||||
|
}
|
||||||
|
},
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
tracing::info!("SERVER_CONFIG not set, Standalone mode active");
|
tracing::info!("SERVER_CONFIG not set, Standalone mode active");
|
||||||
ServerConfig::default()
|
ServerConfig::default()
|
||||||
@@ -96,18 +96,52 @@ impl ServerConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Check if the server should run in high availability mode
|
/// Check if the server should run in high availability mode
|
||||||
pub fn is_high_availability(&self) -> bool {
|
pub fn is_high_availability(&self) -> Result<bool, std::io::Error> {
|
||||||
self.server_mode == ServerMode::HighAvailability
|
if self.server_mode == ServerMode::HighAvailability {
|
||||||
|
let services_well_defined: bool = self.clone().services.is_some();
|
||||||
|
|
||||||
|
let inference_url_well_defined: bool =
|
||||||
|
services_well_defined && self.clone().services.unwrap().inference_url.is_some();
|
||||||
|
|
||||||
|
let embeddings_well_defined: bool =
|
||||||
|
services_well_defined && self.clone().services.unwrap().embeddings_url.is_some();
|
||||||
|
|
||||||
|
let is_well_defined_for_ha =
|
||||||
|
services_well_defined && inference_url_well_defined && embeddings_well_defined;
|
||||||
|
|
||||||
|
if !is_well_defined_for_ha {
|
||||||
|
let config_string = serde_json::to_string_pretty(&self).unwrap();
|
||||||
|
error!(
|
||||||
|
"HighAvailability mode configured but services not well defined! \n## Config Used:\n {}",
|
||||||
|
config_string
|
||||||
|
);
|
||||||
|
let err = std::io::Error::new(
|
||||||
|
std::io::ErrorKind::Other,
|
||||||
|
"HighAvailability mode configured but services not well defined!",
|
||||||
|
);
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(self.server_mode == ServerMode::HighAvailability)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the inference service URL for proxying
|
/// Get the inference service URL for proxying
|
||||||
pub fn inference_url(&self) -> &str {
|
pub fn inference_url(&self) -> Option<String> {
|
||||||
&self.services.inference_url
|
if self.services.is_some() {
|
||||||
|
self.services.clone()?.inference_url
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the embeddings service URL for proxying
|
/// Get the embeddings service URL for proxying
|
||||||
pub fn embeddings_url(&self) -> &str {
|
pub fn embeddings_url(&self) -> Option<String> {
|
||||||
&self.services.embeddings_url
|
if self.services.is_some() {
|
||||||
|
self.services.clone()?.embeddings_url
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,7 +153,7 @@ mod tests {
|
|||||||
fn test_default_config() {
|
fn test_default_config() {
|
||||||
let config = ServerConfig::default();
|
let config = ServerConfig::default();
|
||||||
assert_eq!(config.server_mode, ServerMode::Standalone);
|
assert_eq!(config.server_mode, ServerMode::Standalone);
|
||||||
assert!(!config.is_high_availability());
|
assert!(!config.is_high_availability().unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -134,23 +168,26 @@ mod tests {
|
|||||||
|
|
||||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||||
assert_eq!(config.server_mode, ServerMode::HighAvailability);
|
assert_eq!(config.server_mode, ServerMode::HighAvailability);
|
||||||
assert!(config.is_high_availability());
|
assert!(config.is_high_availability().unwrap());
|
||||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
assert_eq!(
|
||||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
config.inference_url().unwrap(),
|
||||||
|
"http://inference-service:8080"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.embeddings_url().unwrap(),
|
||||||
|
"http://embeddings-service:8080"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_local_mode_config() {
|
fn test_local_mode_config() {
|
||||||
let config_json = r#"{
|
let config_json = r#"{
|
||||||
"serverMode": "Local"
|
"serverMode": "Standalone"
|
||||||
}"#;
|
}"#;
|
||||||
|
|
||||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||||
assert_eq!(config.server_mode, ServerMode::Standalone);
|
assert_eq!(config.server_mode, ServerMode::Standalone);
|
||||||
assert!(!config.is_high_availability());
|
assert!(!config.is_high_availability().unwrap());
|
||||||
// Should use default URLs
|
|
||||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
|
||||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -164,17 +201,26 @@ mod tests {
|
|||||||
}"#;
|
}"#;
|
||||||
|
|
||||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||||
assert_eq!(config.inference_url(), "http://custom-inference:9000");
|
assert_eq!(
|
||||||
assert_eq!(config.embeddings_url(), "http://custom-embeddings:9001");
|
config.inference_url().unwrap(),
|
||||||
|
"http://custom-inference:9000"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.embeddings_url().unwrap(),
|
||||||
|
"http://custom-embeddings:9001"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_minimal_high_availability_config() {
|
fn test_minimal_high_availability_config_error() {
|
||||||
let config_json = r#"{"serverMode": "HighAvailability"}"#;
|
let config_json = r#"{"serverMode": "HighAvailability"}"#;
|
||||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||||
assert!(config.is_high_availability());
|
|
||||||
// Should use default URLs
|
let is_high_availability = config.is_high_availability();
|
||||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
|
||||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
assert!(is_high_availability.is_err());
|
||||||
|
// // Should use default URLs
|
||||||
|
// assert_eq!(config.inference_url().unwrap(), "http://inference-service:8080");
|
||||||
|
// assert_eq!(config.embeddings_url().unwrap(), "http://embeddings-service:8080");
|
||||||
}
|
}
|
||||||
}
|
}
|
@@ -1,7 +1,9 @@
|
|||||||
mod config;
|
mod config;
|
||||||
mod middleware;
|
mod middleware;
|
||||||
mod proxy;
|
mod proxy;
|
||||||
|
mod standalone;
|
||||||
|
|
||||||
|
use crate::standalone::create_standalone_router;
|
||||||
use axum::response::IntoResponse;
|
use axum::response::IntoResponse;
|
||||||
use axum::routing::get;
|
use axum::routing::get;
|
||||||
use axum::{Router, http::Uri, response::Html, serve};
|
use axum::{Router, http::Uri, response::Html, serve};
|
||||||
@@ -11,6 +13,7 @@ use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
|
|||||||
use proxy::create_proxy_router;
|
use proxy::create_proxy_router;
|
||||||
use rust_embed::Embed;
|
use rust_embed::Embed;
|
||||||
use std::env;
|
use std::env;
|
||||||
|
use std::path::Component::ParentDir;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use tower_http::classify::ServerErrorsFailureClass::StatusCode;
|
use tower_http::classify::ServerErrorsFailureClass::StatusCode;
|
||||||
use tower_http::cors::{Any, CorsLayer};
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
@@ -49,33 +52,19 @@ async fn main() {
|
|||||||
let default_host = server_config.server_host.clone();
|
let default_host = server_config.server_host.clone();
|
||||||
let default_port = server_config.server_port;
|
let default_port = server_config.server_port;
|
||||||
|
|
||||||
// Create router based on server mode
|
let service_router = match server_config.clone().is_high_availability() {
|
||||||
let service_router = if server_config.clone().is_high_availability() {
|
Ok(is_ha) => {
|
||||||
tracing::info!("Running in HighAvailability mode - proxying to external services");
|
if is_ha {
|
||||||
tracing::info!(" Inference service URL: {}", server_config.inference_url());
|
log_config(server_config.clone());
|
||||||
tracing::info!(
|
create_proxy_router(server_config.clone())
|
||||||
" Embeddings service URL: {}",
|
} else {
|
||||||
server_config.embeddings_url()
|
log_config(server_config.clone());
|
||||||
);
|
create_standalone_router(server_config)
|
||||||
|
}
|
||||||
// Use proxy router that forwards requests to external services
|
}
|
||||||
create_proxy_router(server_config.clone())
|
Err(error) => {
|
||||||
} else {
|
panic!("{}", error);
|
||||||
tracing::info!("Running in Standalone mode - using embedded services");
|
}
|
||||||
|
|
||||||
// Create unified router by merging embeddings and inference routers (existing behavior)
|
|
||||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
|
||||||
|
|
||||||
// Create AppState with correct model configuration
|
|
||||||
let app_state = AppState::default();
|
|
||||||
|
|
||||||
// Get the inference router directly from the inference engine
|
|
||||||
let inference_router = inference_engine::create_router(app_state);
|
|
||||||
|
|
||||||
// Merge the local routers
|
|
||||||
Router::new()
|
|
||||||
.merge(embeddings_router)
|
|
||||||
.merge(inference_router)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create CORS layer
|
// Create CORS layer
|
||||||
@@ -124,5 +113,25 @@ async fn main() {
|
|||||||
serve(listener, app).await.unwrap();
|
serve(listener, app).await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn log_config(config: ServerConfig) {
|
||||||
|
match config.is_high_availability() {
|
||||||
|
Ok(is_high) => {
|
||||||
|
if is_high {
|
||||||
|
tracing::info!("Running in HighAvailability mode - proxying to external services");
|
||||||
|
tracing::info!("Inference service URL: {}", config.inference_url().unwrap());
|
||||||
|
tracing::info!(
|
||||||
|
"Embeddings service URL: {}",
|
||||||
|
config.embeddings_url().unwrap()
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
tracing::info!("Running in Standalone mode");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
panic!("{}", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Chat completions handler that properly uses the inference server crate's error handling
|
// Chat completions handler that properly uses the inference server crate's error handling
|
||||||
// This function is no longer needed as we're using the inference_engine router directly
|
// This function is no longer needed as we're using the inference_engine router directly
|
||||||
|
@@ -2,6 +2,8 @@ use axum::{
|
|||||||
extract::MatchedPath,
|
extract::MatchedPath,
|
||||||
http::{Request, Response},
|
http::{Request, Response},
|
||||||
};
|
};
|
||||||
|
use std::fmt;
|
||||||
|
use std::task::ready;
|
||||||
use std::{
|
use std::{
|
||||||
future::Future,
|
future::Future,
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
@@ -12,8 +14,6 @@ use std::{
|
|||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tower::{Layer, Service};
|
use tower::{Layer, Service};
|
||||||
use tracing::{debug, info};
|
use tracing::{debug, info};
|
||||||
use std::task::ready;
|
|
||||||
use std::fmt;
|
|
||||||
|
|
||||||
/// Performance metrics for a specific endpoint
|
/// Performance metrics for a specific endpoint
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
@@ -56,7 +56,10 @@ impl EndpointMetrics {
|
|||||||
pub fn summary(&self) -> String {
|
pub fn summary(&self) -> String {
|
||||||
format!(
|
format!(
|
||||||
"requests: {}, avg: {:.2}ms, min: {}ms, max: {}ms",
|
"requests: {}, avg: {:.2}ms, min: {}ms, max: {}ms",
|
||||||
self.count, self.avg_time_ms(), self.min_time_ms, self.max_time_ms
|
self.count,
|
||||||
|
self.avg_time_ms(),
|
||||||
|
self.min_time_ms,
|
||||||
|
self.max_time_ms
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -79,7 +82,9 @@ impl MetricsStore {
|
|||||||
/// Record a request's timing information
|
/// Record a request's timing information
|
||||||
pub async fn record(&self, path: String, time_ms: u64) {
|
pub async fn record(&self, path: String, time_ms: u64) {
|
||||||
let mut endpoints = self.endpoints.lock().await;
|
let mut endpoints = self.endpoints.lock().await;
|
||||||
let metrics = endpoints.entry(path).or_insert_with(EndpointMetrics::default);
|
let metrics = endpoints
|
||||||
|
.entry(path)
|
||||||
|
.or_insert_with(EndpointMetrics::default);
|
||||||
metrics.add_response_time(time_ms);
|
metrics.add_response_time(time_ms);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,7 +183,9 @@ where
|
|||||||
let time_ms = time.as_millis() as u64;
|
let time_ms = time.as_millis() as u64;
|
||||||
|
|
||||||
// Record the timing in our metrics store
|
// Record the timing in our metrics store
|
||||||
metrics_store.record(format!("{} {}", method, path), time_ms).await;
|
metrics_store
|
||||||
|
.record(format!("{} {}", method, path), time_ms)
|
||||||
|
.await;
|
||||||
|
|
||||||
// Log the request timing
|
// Log the request timing
|
||||||
debug!("{} {} {} - {} ms", method, path, status, time_ms);
|
debug!("{} {} {} - {} ms", method, path, status, time_ms);
|
||||||
|
@@ -1,7 +1,3 @@
|
|||||||
pub mod metrics;
|
pub mod metrics;
|
||||||
|
|
||||||
pub use metrics::{
|
pub use metrics::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
|
||||||
MetricsStore,
|
|
||||||
MetricsLoggerFuture,
|
|
||||||
MetricsLayer,
|
|
||||||
};
|
|
||||||
|
@@ -1,10 +1,10 @@
|
|||||||
use axum::{
|
use axum::{
|
||||||
|
Router,
|
||||||
body::Body,
|
body::Body,
|
||||||
extract::{Request, State},
|
extract::{Request, State},
|
||||||
http::{HeaderMap, Method, StatusCode, Uri},
|
http::{HeaderMap, Method, StatusCode, Uri},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
Router,
|
|
||||||
};
|
};
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
@@ -47,7 +47,13 @@ async fn proxy_chat_completions(
|
|||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
body: Body,
|
body: Body,
|
||||||
) -> Result<Response, StatusCode> {
|
) -> Result<Response, StatusCode> {
|
||||||
let target_url = format!("{}/v1/chat/completions", proxy_client.config.inference_url());
|
let target_url = format!(
|
||||||
|
"{}/v1/chat/completions",
|
||||||
|
proxy_client
|
||||||
|
.config
|
||||||
|
.inference_url()
|
||||||
|
.expect("Invalid Configuration")
|
||||||
|
);
|
||||||
|
|
||||||
tracing::info!("Proxying chat completions request to: {}", target_url);
|
tracing::info!("Proxying chat completions request to: {}", target_url);
|
||||||
|
|
||||||
@@ -63,7 +69,9 @@ async fn proxy_chat_completions(
|
|||||||
// Check if this is a streaming request
|
// Check if this is a streaming request
|
||||||
let is_streaming = if let Ok(body_str) = String::from_utf8(body_bytes.to_vec()) {
|
let is_streaming = if let Ok(body_str) = String::from_utf8(body_bytes.to_vec()) {
|
||||||
if let Ok(json) = serde_json::from_str::<Value>(&body_str) {
|
if let Ok(json) = serde_json::from_str::<Value>(&body_str) {
|
||||||
json.get("stream").and_then(|v| v.as_bool()).unwrap_or(false)
|
json.get("stream")
|
||||||
|
.and_then(|v| v.as_bool())
|
||||||
|
.unwrap_or(false)
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
@@ -72,7 +80,8 @@ async fn proxy_chat_completions(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Forward the request
|
// Forward the request
|
||||||
let mut req_builder = proxy_client.client
|
let mut req_builder = proxy_client
|
||||||
|
.client
|
||||||
.post(&target_url)
|
.post(&target_url)
|
||||||
.body(body_bytes.to_vec());
|
.body(body_bytes.to_vec());
|
||||||
|
|
||||||
@@ -85,8 +94,7 @@ async fn proxy_chat_completions(
|
|||||||
|
|
||||||
match req_builder.send().await {
|
match req_builder.send().await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
let mut resp_builder = Response::builder()
|
let mut resp_builder = Response::builder().status(response.status());
|
||||||
.status(response.status());
|
|
||||||
|
|
||||||
// Forward response headers
|
// Forward response headers
|
||||||
for (name, value) in response.headers().iter() {
|
for (name, value) in response.headers().iter() {
|
||||||
@@ -99,14 +107,12 @@ async fn proxy_chat_completions(
|
|||||||
if is_streaming {
|
if is_streaming {
|
||||||
// For streaming, we need to forward the response as-is
|
// For streaming, we need to forward the response as-is
|
||||||
match response.bytes().await {
|
match response.bytes().await {
|
||||||
Ok(body) => {
|
Ok(body) => resp_builder
|
||||||
resp_builder
|
.header("content-type", "text/plain; charset=utf-8")
|
||||||
.header("content-type", "text/plain; charset=utf-8")
|
.header("cache-control", "no-cache")
|
||||||
.header("cache-control", "no-cache")
|
.header("connection", "keep-alive")
|
||||||
.header("connection", "keep-alive")
|
.body(Body::from(body))
|
||||||
.body(Body::from(body))
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("Failed to read streaming response body: {}", e);
|
tracing::error!("Failed to read streaming response body: {}", e);
|
||||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
@@ -115,11 +121,9 @@ async fn proxy_chat_completions(
|
|||||||
} else {
|
} else {
|
||||||
// For non-streaming, forward the JSON response
|
// For non-streaming, forward the JSON response
|
||||||
match response.bytes().await {
|
match response.bytes().await {
|
||||||
Ok(body) => {
|
Ok(body) => resp_builder
|
||||||
resp_builder
|
.body(Body::from(body))
|
||||||
.body(Body::from(body))
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("Failed to read response body: {}", e);
|
tracing::error!("Failed to read response body: {}", e);
|
||||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
@@ -139,7 +143,13 @@ async fn proxy_models(
|
|||||||
State(proxy_client): State<ProxyClient>,
|
State(proxy_client): State<ProxyClient>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
) -> Result<Response, StatusCode> {
|
) -> Result<Response, StatusCode> {
|
||||||
let target_url = format!("{}/v1/models", proxy_client.config.inference_url());
|
let target_url = format!(
|
||||||
|
"{}/v1/models",
|
||||||
|
proxy_client
|
||||||
|
.config
|
||||||
|
.inference_url()
|
||||||
|
.expect("Invalid Configuration Detected")
|
||||||
|
);
|
||||||
|
|
||||||
tracing::info!("Proxying models request to: {}", target_url);
|
tracing::info!("Proxying models request to: {}", target_url);
|
||||||
|
|
||||||
@@ -154,8 +164,7 @@ async fn proxy_models(
|
|||||||
|
|
||||||
match req_builder.send().await {
|
match req_builder.send().await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
let mut resp_builder = Response::builder()
|
let mut resp_builder = Response::builder().status(response.status());
|
||||||
.status(response.status());
|
|
||||||
|
|
||||||
// Forward response headers
|
// Forward response headers
|
||||||
for (name, value) in response.headers().iter() {
|
for (name, value) in response.headers().iter() {
|
||||||
@@ -165,11 +174,9 @@ async fn proxy_models(
|
|||||||
}
|
}
|
||||||
|
|
||||||
match response.bytes().await {
|
match response.bytes().await {
|
||||||
Ok(body) => {
|
Ok(body) => resp_builder
|
||||||
resp_builder
|
.body(Body::from(body))
|
||||||
.body(Body::from(body))
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("Failed to read models response body: {}", e);
|
tracing::error!("Failed to read models response body: {}", e);
|
||||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
@@ -189,7 +196,13 @@ async fn proxy_embeddings(
|
|||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
body: Body,
|
body: Body,
|
||||||
) -> Result<Response, StatusCode> {
|
) -> Result<Response, StatusCode> {
|
||||||
let target_url = format!("{}/v1/embeddings", proxy_client.config.embeddings_url());
|
let target_url = format!(
|
||||||
|
"{}/v1/embeddings",
|
||||||
|
proxy_client
|
||||||
|
.config
|
||||||
|
.embeddings_url()
|
||||||
|
.expect("Invalid Configuration Detected")
|
||||||
|
);
|
||||||
|
|
||||||
tracing::info!("Proxying embeddings request to: {}", target_url);
|
tracing::info!("Proxying embeddings request to: {}", target_url);
|
||||||
|
|
||||||
@@ -203,7 +216,8 @@ async fn proxy_embeddings(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Forward the request
|
// Forward the request
|
||||||
let mut req_builder = proxy_client.client
|
let mut req_builder = proxy_client
|
||||||
|
.client
|
||||||
.post(&target_url)
|
.post(&target_url)
|
||||||
.body(body_bytes.to_vec());
|
.body(body_bytes.to_vec());
|
||||||
|
|
||||||
@@ -216,8 +230,7 @@ async fn proxy_embeddings(
|
|||||||
|
|
||||||
match req_builder.send().await {
|
match req_builder.send().await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
let mut resp_builder = Response::builder()
|
let mut resp_builder = Response::builder().status(response.status());
|
||||||
.status(response.status());
|
|
||||||
|
|
||||||
// Forward response headers
|
// Forward response headers
|
||||||
for (name, value) in response.headers().iter() {
|
for (name, value) in response.headers().iter() {
|
||||||
@@ -227,11 +240,9 @@ async fn proxy_embeddings(
|
|||||||
}
|
}
|
||||||
|
|
||||||
match response.bytes().await {
|
match response.bytes().await {
|
||||||
Ok(body) => {
|
Ok(body) => resp_builder
|
||||||
resp_builder
|
.body(Body::from(body))
|
||||||
.body(Body::from(body))
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("Failed to read embeddings response body: {}", e);
|
tracing::error!("Failed to read embeddings response body: {}", e);
|
||||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
@@ -250,7 +261,7 @@ fn should_forward_header(header_name: &str) -> bool {
|
|||||||
match header_name.to_lowercase().as_str() {
|
match header_name.to_lowercase().as_str() {
|
||||||
"content-type" | "content-length" | "authorization" | "user-agent" | "accept" => true,
|
"content-type" | "content-length" | "authorization" | "user-agent" | "accept" => true,
|
||||||
"host" | "connection" | "upgrade" => false, // Don't forward connection-specific headers
|
"host" | "connection" | "upgrade" => false, // Don't forward connection-specific headers
|
||||||
_ => true, // Forward other headers by default
|
_ => true, // Forward other headers by default
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,7 +270,7 @@ fn should_forward_response_header(header_name: &str) -> bool {
|
|||||||
match header_name.to_lowercase().as_str() {
|
match header_name.to_lowercase().as_str() {
|
||||||
"content-type" | "content-length" | "cache-control" | "connection" => true,
|
"content-type" | "content-length" | "cache-control" | "connection" => true,
|
||||||
"server" | "date" => false, // Don't forward server-specific headers
|
"server" | "date" => false, // Don't forward server-specific headers
|
||||||
_ => true, // Forward other headers by default
|
_ => true, // Forward other headers by default
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -290,14 +301,20 @@ mod tests {
|
|||||||
server_host: "127.0.0.1".to_string(),
|
server_host: "127.0.0.1".to_string(),
|
||||||
server_port: 8080,
|
server_port: 8080,
|
||||||
server_mode: ServerMode::HighAvailability,
|
server_mode: ServerMode::HighAvailability,
|
||||||
services: Services {
|
services: Some(Services {
|
||||||
inference_url: "http://test-inference:8080".to_string(),
|
inference_url: Some("http://test-inference:8080".to_string()),
|
||||||
embeddings_url: "http://test-embeddings:8080".to_string(),
|
embeddings_url: Some("http://test-embeddings:8080".to_string()),
|
||||||
},
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
let proxy_client = ProxyClient::new(config);
|
let proxy_client = ProxyClient::new(config);
|
||||||
assert_eq!(proxy_client.config.inference_url(), "http://test-inference:8080");
|
assert_eq!(
|
||||||
assert_eq!(proxy_client.config.embeddings_url(), "http://test-embeddings:8080");
|
proxy_client.config.inference_url().unwrap().as_str(),
|
||||||
|
"http://test-inference:8080"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
proxy_client.config.embeddings_url().unwrap().as_str(),
|
||||||
|
"http://test-embeddings:8080"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
19
crates/predict-otron-9000/src/standalone.rs
Normal file
19
crates/predict-otron-9000/src/standalone.rs
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
use crate::config::ServerConfig;
|
||||||
|
use axum::Router;
|
||||||
|
use inference_engine::AppState;
|
||||||
|
|
||||||
|
pub fn create_standalone_router(server_config: ServerConfig) -> Router {
|
||||||
|
// Create unified router by merging embeddings and inference routers (existing behavior)
|
||||||
|
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||||
|
|
||||||
|
// Create AppState with correct model configuration
|
||||||
|
let app_state = AppState::default();
|
||||||
|
|
||||||
|
// Get the inference router directly from the inference engine
|
||||||
|
let inference_router = inference_engine::create_router(app_state);
|
||||||
|
|
||||||
|
// Merge the local routers
|
||||||
|
Router::new()
|
||||||
|
.merge(embeddings_router)
|
||||||
|
.merge(inference_router)
|
||||||
|
}
|
389
scripts/build_all_platforms.sh
Executable file
389
scripts/build_all_platforms.sh
Executable file
@@ -0,0 +1,389 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Cross-platform build script for predict-otron-9000
|
||||||
|
# Builds all workspace crates for common platforms
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
BLUE='\033[0;34m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
PROJECT_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||||
|
BUILD_DIR="${PROJECT_ROOT}/build"
|
||||||
|
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||||
|
|
||||||
|
# Supported platforms
|
||||||
|
PLATFORMS=(
|
||||||
|
"x86_64-unknown-linux-gnu"
|
||||||
|
"x86_64-pc-windows-msvc"
|
||||||
|
"x86_64-apple-darwin"
|
||||||
|
"aarch64-apple-darwin"
|
||||||
|
"aarch64-unknown-linux-gnu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Main binaries to build
|
||||||
|
MAIN_BINARIES=(
|
||||||
|
"predict-otron-9000"
|
||||||
|
"embeddings-engine"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inference engine binaries (with bin feature)
|
||||||
|
INFERENCE_BINARIES=(
|
||||||
|
"gemma_inference"
|
||||||
|
"llama_inference"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other workspace binaries
|
||||||
|
OTHER_BINARIES=(
|
||||||
|
"helm-chart-tool"
|
||||||
|
)
|
||||||
|
|
||||||
|
print_header() {
|
||||||
|
echo -e "${BLUE}================================${NC}"
|
||||||
|
echo -e "${BLUE}$1${NC}"
|
||||||
|
echo -e "${BLUE}================================${NC}"
|
||||||
|
}
|
||||||
|
|
||||||
|
print_info() {
|
||||||
|
echo -e "${GREEN}[INFO]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
print_warn() {
|
||||||
|
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
print_error() {
|
||||||
|
echo -e "${RED}[ERROR]${NC} $1"
|
||||||
|
}
|
||||||
|
|
||||||
|
check_dependencies() {
|
||||||
|
print_header "Checking Dependencies"
|
||||||
|
|
||||||
|
# Check rust
|
||||||
|
if ! command -v cargo >/dev/null 2>&1; then
|
||||||
|
print_error "Rust/Cargo is not installed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check cargo-leptos for WASM frontend
|
||||||
|
if ! command -v cargo-leptos >/dev/null 2>&1; then
|
||||||
|
print_warn "cargo-leptos not found. Installing..."
|
||||||
|
cargo install cargo-leptos
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_info "All dependencies available"
|
||||||
|
}
|
||||||
|
|
||||||
|
install_targets() {
|
||||||
|
print_header "Installing Rust Targets"
|
||||||
|
|
||||||
|
for platform in "${PLATFORMS[@]}"; do
|
||||||
|
print_info "Installing target: $platform"
|
||||||
|
rustup target add "$platform" || {
|
||||||
|
print_warn "Failed to install target $platform (may not be available on this host)"
|
||||||
|
}
|
||||||
|
done
|
||||||
|
|
||||||
|
# Add WASM target for leptos
|
||||||
|
print_info "Installing wasm32-unknown-unknown target for Leptos"
|
||||||
|
rustup target add wasm32-unknown-unknown
|
||||||
|
}
|
||||||
|
|
||||||
|
create_build_dirs() {
|
||||||
|
print_header "Setting up Build Directory"
|
||||||
|
|
||||||
|
rm -rf "$BUILD_DIR"
|
||||||
|
mkdir -p "$BUILD_DIR"
|
||||||
|
|
||||||
|
for platform in "${PLATFORMS[@]}"; do
|
||||||
|
mkdir -p "$BUILD_DIR/$platform"
|
||||||
|
done
|
||||||
|
|
||||||
|
mkdir -p "$BUILD_DIR/web"
|
||||||
|
print_info "Build directories created"
|
||||||
|
}
|
||||||
|
|
||||||
|
build_leptos_app() {
|
||||||
|
print_header "Building Leptos Web Frontend"
|
||||||
|
|
||||||
|
cd "$PROJECT_ROOT/crates/leptos-app"
|
||||||
|
|
||||||
|
# Build the WASM frontend
|
||||||
|
print_info "Building WASM frontend with cargo-leptos..."
|
||||||
|
cargo leptos build --release || {
|
||||||
|
print_error "Failed to build Leptos WASM frontend"
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Copy built assets to build directory
|
||||||
|
if [ -d "target/site" ]; then
|
||||||
|
cp -r target/site/* "$BUILD_DIR/web/"
|
||||||
|
print_info "Leptos frontend built and copied to $BUILD_DIR/web/"
|
||||||
|
else
|
||||||
|
print_error "Leptos build output not found at target/site"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
cd "$PROJECT_ROOT"
|
||||||
|
}
|
||||||
|
|
||||||
|
get_platform_features() {
|
||||||
|
local platform="$1"
|
||||||
|
local features=""
|
||||||
|
|
||||||
|
case "$platform" in
|
||||||
|
*-apple-darwin)
|
||||||
|
# macOS uses Metal but routes to CPU for Gemma stability
|
||||||
|
features=""
|
||||||
|
;;
|
||||||
|
*-unknown-linux-gnu|*-pc-windows-msvc)
|
||||||
|
# Linux and Windows can use CUDA if available
|
||||||
|
features=""
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
features=""
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
echo "$features"
|
||||||
|
}
|
||||||
|
|
||||||
|
build_binary_for_platform() {
|
||||||
|
local binary_name="$1"
|
||||||
|
local platform="$2"
|
||||||
|
local package_name="$3"
|
||||||
|
local additional_args="$4"
|
||||||
|
|
||||||
|
print_info "Building $binary_name for $platform"
|
||||||
|
|
||||||
|
local features=$(get_platform_features "$platform")
|
||||||
|
local feature_flag=""
|
||||||
|
if [ -n "$features" ]; then
|
||||||
|
feature_flag="--features $features"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build command
|
||||||
|
local build_cmd="cargo build --release --target $platform --bin $binary_name"
|
||||||
|
|
||||||
|
if [ -n "$package_name" ]; then
|
||||||
|
build_cmd="$build_cmd --package $package_name"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$additional_args" ]; then
|
||||||
|
build_cmd="$build_cmd $additional_args"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$feature_flag" ]; then
|
||||||
|
build_cmd="$build_cmd $feature_flag"
|
||||||
|
fi
|
||||||
|
|
||||||
|
print_info "Running: $build_cmd"
|
||||||
|
|
||||||
|
if eval "$build_cmd"; then
|
||||||
|
# Copy binary to build directory
|
||||||
|
local target_dir="target/$platform/release"
|
||||||
|
local binary_file="$binary_name"
|
||||||
|
|
||||||
|
# Add .exe extension for Windows
|
||||||
|
if [[ "$platform" == *-pc-windows-msvc ]]; then
|
||||||
|
binary_file="$binary_name.exe"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -f "$target_dir/$binary_file" ]; then
|
||||||
|
cp "$target_dir/$binary_file" "$BUILD_DIR/$platform/"
|
||||||
|
print_info "✓ $binary_name built and copied for $platform"
|
||||||
|
else
|
||||||
|
print_error "Binary not found: $target_dir/$binary_file"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
print_error "Failed to build $binary_name for $platform"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
build_for_platform() {
|
||||||
|
local platform="$1"
|
||||||
|
print_header "Building for $platform"
|
||||||
|
|
||||||
|
local failed_builds=()
|
||||||
|
|
||||||
|
# Build main binaries
|
||||||
|
for binary in "${MAIN_BINARIES[@]}"; do
|
||||||
|
if ! build_binary_for_platform "$binary" "$platform" "$binary" ""; then
|
||||||
|
failed_builds+=("$binary")
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Build inference engine binaries with bin feature
|
||||||
|
for binary in "${INFERENCE_BINARIES[@]}"; do
|
||||||
|
if ! build_binary_for_platform "$binary" "$platform" "inference-engine" "--features bin"; then
|
||||||
|
failed_builds+=("$binary")
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Build other workspace binaries
|
||||||
|
for binary in "${OTHER_BINARIES[@]}"; do
|
||||||
|
if ! build_binary_for_platform "$binary" "$platform" "$binary" ""; then
|
||||||
|
failed_builds+=("$binary")
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ ${#failed_builds[@]} -eq 0 ]; then
|
||||||
|
print_info "✓ All binaries built successfully for $platform"
|
||||||
|
else
|
||||||
|
print_warn "Some builds failed for $platform: ${failed_builds[*]}"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
create_archives() {
|
||||||
|
print_header "Creating Release Archives"
|
||||||
|
|
||||||
|
cd "$BUILD_DIR"
|
||||||
|
|
||||||
|
for platform in "${PLATFORMS[@]}"; do
|
||||||
|
if [ -d "$platform" ] && [ -n "$(ls -A "$platform" 2>/dev/null)" ]; then
|
||||||
|
local archive_name="predict-otron-9000-${platform}-${TIMESTAMP}"
|
||||||
|
|
||||||
|
print_info "Creating archive for $platform"
|
||||||
|
|
||||||
|
# Create platform-specific directory with all files
|
||||||
|
mkdir -p "$archive_name"
|
||||||
|
cp -r "$platform"/* "$archive_name/"
|
||||||
|
|
||||||
|
# Add web assets to each platform archive
|
||||||
|
if [ -d "web" ]; then
|
||||||
|
mkdir -p "$archive_name/web"
|
||||||
|
cp -r web/* "$archive_name/web/"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Create README for the platform
|
||||||
|
cat > "$archive_name/README.txt" << EOF
|
||||||
|
Predict-Otron-9000 - Platform: $platform
|
||||||
|
Build Date: $(date)
|
||||||
|
========================================
|
||||||
|
|
||||||
|
Binaries included:
|
||||||
|
$(ls -1 "$platform")
|
||||||
|
|
||||||
|
Web Frontend:
|
||||||
|
- Located in the 'web' directory
|
||||||
|
- Serve with any static file server on port 8788 or configure your server
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
1. Start the main server: ./predict-otron-9000
|
||||||
|
2. Start embeddings service: ./embeddings-engine
|
||||||
|
3. Access web interface at http://localhost:8080 (served by main server)
|
||||||
|
|
||||||
|
For more information, visit: https://github.com/geoffsee/predict-otron-9000
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# Create tar.gz archive
|
||||||
|
tar -czf "${archive_name}.tar.gz" "$archive_name"
|
||||||
|
rm -rf "$archive_name"
|
||||||
|
|
||||||
|
print_info "✓ Created ${archive_name}.tar.gz"
|
||||||
|
else
|
||||||
|
print_warn "No binaries found for $platform, skipping archive"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
cd "$PROJECT_ROOT"
|
||||||
|
}
|
||||||
|
|
||||||
|
generate_build_report() {
|
||||||
|
print_header "Build Report"
|
||||||
|
|
||||||
|
echo "Build completed at: $(date)"
|
||||||
|
echo "Build directory: $BUILD_DIR"
|
||||||
|
echo ""
|
||||||
|
echo "Archives created:"
|
||||||
|
ls -la "$BUILD_DIR"/*.tar.gz 2>/dev/null || echo "No archives created"
|
||||||
|
echo ""
|
||||||
|
echo "Platform directories:"
|
||||||
|
for platform in "${PLATFORMS[@]}"; do
|
||||||
|
if [ -d "$BUILD_DIR/$platform" ]; then
|
||||||
|
echo " $platform:"
|
||||||
|
ls -la "$BUILD_DIR/$platform" | sed 's/^/ /'
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ -d "$BUILD_DIR/web" ]; then
|
||||||
|
echo ""
|
||||||
|
echo "Web frontend assets:"
|
||||||
|
ls -la "$BUILD_DIR/web" | head -10 | sed 's/^/ /'
|
||||||
|
if [ $(ls -1 "$BUILD_DIR/web" | wc -l) -gt 10 ]; then
|
||||||
|
echo " ... and $(( $(ls -1 "$BUILD_DIR/web" | wc -l) - 10 )) more files"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
main() {
|
||||||
|
print_header "Predict-Otron-9000 Cross-Platform Build Script"
|
||||||
|
|
||||||
|
cd "$PROJECT_ROOT"
|
||||||
|
|
||||||
|
check_dependencies
|
||||||
|
install_targets
|
||||||
|
create_build_dirs
|
||||||
|
|
||||||
|
# Build Leptos web frontend first
|
||||||
|
build_leptos_app
|
||||||
|
|
||||||
|
# Build for each platform
|
||||||
|
for platform in "${PLATFORMS[@]}"; do
|
||||||
|
build_for_platform "$platform"
|
||||||
|
done
|
||||||
|
|
||||||
|
create_archives
|
||||||
|
generate_build_report
|
||||||
|
|
||||||
|
print_header "Build Complete!"
|
||||||
|
print_info "All artifacts are available in: $BUILD_DIR"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle command line arguments
|
||||||
|
case "${1:-}" in
|
||||||
|
--help|-h)
|
||||||
|
echo "Usage: $0 [options]"
|
||||||
|
echo ""
|
||||||
|
echo "Cross-platform build script for predict-otron-9000"
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " --help, -h Show this help message"
|
||||||
|
echo " --platforms Show supported platforms"
|
||||||
|
echo " --clean Clean build directory before building"
|
||||||
|
echo ""
|
||||||
|
echo "Supported platforms:"
|
||||||
|
for platform in "${PLATFORMS[@]}"; do
|
||||||
|
echo " - $platform"
|
||||||
|
done
|
||||||
|
echo ""
|
||||||
|
echo "Prerequisites:"
|
||||||
|
echo " - Rust toolchain with rustup"
|
||||||
|
echo " - cargo-leptos (will be installed if missing)"
|
||||||
|
echo " - Platform-specific toolchains for cross-compilation"
|
||||||
|
echo ""
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
--platforms)
|
||||||
|
echo "Supported platforms:"
|
||||||
|
for platform in "${PLATFORMS[@]}"; do
|
||||||
|
echo " - $platform"
|
||||||
|
done
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
--clean)
|
||||||
|
print_info "Cleaning build directory..."
|
||||||
|
rm -rf "$BUILD_DIR"
|
||||||
|
print_info "Build directory cleaned"
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
main "$@"
|
19
scripts/build_cli.sh
Executable file
19
scripts/build_cli.sh
Executable file
@@ -0,0 +1,19 @@
|
|||||||
|
#!/usr/bin/env sh
|
||||||
|
set -e
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||||
|
|
||||||
|
TEMP_DIR="$SCRIPT_DIR/temp"
|
||||||
|
|
||||||
|
mkdir -p "$TEMP_DIR"
|
||||||
|
|
||||||
|
cp "$SCRIPT_DIR/cli.ts" "$TEMP_DIR/cli.ts"
|
||||||
|
cp "$SCRIPT_DIR/../package.json" "$TEMP_DIR/package.json"
|
||||||
|
|
||||||
|
(
|
||||||
|
cd "$TEMP_DIR"
|
||||||
|
bun i
|
||||||
|
bun build ./cli.ts --compile --outfile "$SCRIPT_DIR/cli"
|
||||||
|
)
|
||||||
|
|
||||||
|
rm -rf "$TEMP_DIR"
|
@@ -1,30 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
PROMPT=${1:-"Say hello in one short sentence."}
|
|
||||||
MODEL=${2:-"meta-llama/Llama-3.2-1B-Instruct"}
|
|
||||||
MAX_NEW=${3:-64}
|
|
||||||
FORCE_CPU=${FORCE_CPU:-0}
|
|
||||||
|
|
||||||
# Optional: keep HF cache local to repo if not already set
|
|
||||||
export HF_HOME=${HF_HOME:-"$PWD/.hf-cache"}
|
|
||||||
|
|
||||||
BIN="$(dirname "$0")/../target/release/llama_infer"
|
|
||||||
|
|
||||||
if [[ ! -x "$BIN" ]]; then
|
|
||||||
echo "Building llama-runner (release)..."
|
|
||||||
cargo build -p llama-runner --release
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "Running llama inference..." >&2
|
|
||||||
ARGS=(
|
|
||||||
--model-id "$MODEL"
|
|
||||||
--prompt "$PROMPT"
|
|
||||||
--max-new-tokens "$MAX_NEW"
|
|
||||||
)
|
|
||||||
|
|
||||||
if [[ "$FORCE_CPU" == "1" || "$FORCE_CPU" == "true" ]]; then
|
|
||||||
ARGS+=( --force-cpu )
|
|
||||||
fi
|
|
||||||
|
|
||||||
"$BIN" "${ARGS[@]}"
|
|
Reference in New Issue
Block a user