Refactor apply_cached_repeat_penalty for optimized caching and reuse, add extensive unit tests, and integrate special handling for gemma-specific models.

Removed `test_request.sh`, deprecated functionality, and unused imports; introduced a new CLI tool (`cli.ts`) for testing inference engine and adjusted handling of non-streaming/streaming chat completions.

- Add CPU fallback support for text generation when primary device is unsupported
- Introduce `execute_with_fallback` method to handle device compatibility and shape mismatch errors
- Extend unit tests to reproduce tensor shape mismatch errors specific to model configurations
- Increase HTTP timeout limits in `curl_chat_stream.sh` script for reliable API testing

chat completion endpoint functions with gemma3 (no streaming)

Add benchmarking guide with HTML reporting, Leptos chat crate, and middleware for metrics tracking
This commit is contained in:
geoffsee
2025-08-26 01:30:26 -04:00
parent 7dd23213c9
commit 8338750beb
64 changed files with 14997 additions and 220 deletions

View File

@@ -3,6 +3,10 @@ name = "predict-otron-9000"
version = "0.1.0"
edition = "2024"
[[bin]]
name = "predict-otron-9000"
path = "src/main.rs"
[dependencies]
# Axum web framework
axum = "0.8.4"

View File

@@ -1,12 +1,19 @@
use axum::{Router, serve, http::StatusCode};
mod middleware;
use axum::{
Router,
serve,
};
use std::env;
use axum::routing::get;
use tokio::net::TcpListener;
use tower::Service;
use tower_http::trace::TraceLayer;
use tower_http::cors::{Any, CorsLayer};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use inference_engine::AppState;
use middleware::{MetricsStore, MetricsLoggerFuture, MetricsLayer};
const DEFAULT_SERVER_HOST: &str = "0.0.0.0";
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
const DEFAULT_SERVER_PORT: &str = "8080";
#[tokio::main]
@@ -25,23 +32,53 @@ async fn main() {
.with(tracing_subscriber::fmt::layer())
.init();
// Initialize metrics store for performance tracking
let metrics_store = MetricsStore::new();
// Create a metrics logger that will periodically log metrics (every 60 seconds)
let metrics_logger = MetricsLoggerFuture::new(metrics_store.clone(), 60);
// Spawn the metrics logger in a background task
tokio::spawn(metrics_logger);
// Create unified router by merging embeddings and inference routers
let embeddings_router = embeddings_engine::create_embeddings_router();
// Create AppState with correct model configuration
use inference_engine::server::{PipelineArgs, build_pipeline};
use inference_engine::Which;
let mut pipeline_args = PipelineArgs::default();
pipeline_args.model_id = "google/gemma-3-1b-it".to_string();
pipeline_args.which = Which::InstructV3_1B;
let text_generation = build_pipeline(pipeline_args);
let app_state = AppState {
text_generation: std::sync::Arc::new(tokio::sync::Mutex::new(text_generation)),
model_id: "google/gemma-3-1b-it".to_string(),
};
// Get the inference router directly from the inference engine
let inference_router = inference_engine::create_inference_router();
let inference_router = inference_engine::create_router(app_state);
// Create CORS layer
let cors = CorsLayer::new()
.allow_headers(Any)
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
// Create metrics layer
let metrics_layer = MetricsLayer::new(metrics_store);
// Merge the routers
// Merge the routers and add middleware layers
let app = Router::new()
.route("/", get(|| async { "Hello, World!" }))
.route("/health", get(|| async { "ok" }))
.merge(embeddings_router)
.merge(inference_router)
.layer(metrics_layer) // Add metrics tracking
.layer(cors)
.layer(TraceLayer::new_for_http());
@@ -52,6 +89,7 @@ async fn main() {
let listener = TcpListener::bind(&server_address).await.unwrap();
tracing::info!("Unified predict-otron-9000 server listening on {}", listener.local_addr().unwrap());
tracing::info!("Performance metrics tracking enabled - summary logs every 60 seconds");
tracing::info!("Available endpoints:");
tracing::info!(" GET / - Root endpoint from embeddings-engine");
tracing::info!(" POST /v1/embeddings - Text embeddings");
@@ -60,5 +98,7 @@ async fn main() {
serve(listener, app).await.unwrap();
}
// Chat completions handler that properly uses the inference server crate's error handling
// This function is no longer needed as we're using the inference_engine router directly

View File

@@ -0,0 +1,220 @@
use axum::{
extract::MatchedPath,
http::{Request, Response},
};
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Instant,
};
use tokio::sync::Mutex;
use tower::{Layer, Service};
use tracing::{debug, info};
use std::task::ready;
use std::fmt;
/// Performance metrics for a specific endpoint
#[derive(Debug, Clone, Default)]
pub struct EndpointMetrics {
/// Total number of requests
pub count: usize,
/// Total response time in milliseconds
pub total_time_ms: u64,
/// Minimum response time in milliseconds
pub min_time_ms: u64,
/// Maximum response time in milliseconds
pub max_time_ms: u64,
}
impl EndpointMetrics {
/// Add a new response time to the metrics
pub fn add_response_time(&mut self, time_ms: u64) {
self.count += 1;
self.total_time_ms += time_ms;
if self.min_time_ms == 0 || time_ms < self.min_time_ms {
self.min_time_ms = time_ms;
}
if time_ms > self.max_time_ms {
self.max_time_ms = time_ms;
}
}
/// Get the average response time in milliseconds
pub fn avg_time_ms(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.total_time_ms as f64 / self.count as f64
}
}
/// Get a human-readable summary of the metrics
pub fn summary(&self) -> String {
format!(
"requests: {}, avg: {:.2}ms, min: {}ms, max: {}ms",
self.count, self.avg_time_ms(), self.min_time_ms, self.max_time_ms
)
}
}
/// Global metrics storage
#[derive(Debug, Clone, Default)]
pub struct MetricsStore {
/// Metrics per endpoint
endpoints: Arc<Mutex<std::collections::HashMap<String, EndpointMetrics>>>,
}
impl MetricsStore {
/// Create a new metrics store
pub fn new() -> Self {
Self {
endpoints: Arc::new(Mutex::new(std::collections::HashMap::new())),
}
}
/// Record a request's timing information
pub async fn record(&self, path: String, time_ms: u64) {
let mut endpoints = self.endpoints.lock().await;
let metrics = endpoints.entry(path).or_insert_with(EndpointMetrics::default);
metrics.add_response_time(time_ms);
}
/// Get metrics for all endpoints
pub async fn get_all(&self) -> Vec<(String, EndpointMetrics)> {
let endpoints = self.endpoints.lock().await;
endpoints
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
/// Log a summary of all metrics
pub async fn log_summary(&self) {
let metrics = self.get_all().await;
info!("Performance metrics summary:");
for (path, metric) in metrics {
info!(" {}: {}", path, metric.summary());
}
}
}
// Define a Layer for metrics tracking
#[derive(Debug, Clone)]
pub struct MetricsLayer {
metrics_store: MetricsStore,
}
impl MetricsLayer {
pub fn new(metrics_store: MetricsStore) -> Self {
Self { metrics_store }
}
}
impl<S> Layer<S> for MetricsLayer {
type Service = MetricsService<S>;
fn layer(&self, service: S) -> Self::Service {
MetricsService {
inner: service,
metrics_store: self.metrics_store.clone(),
}
}
}
// Define a Service for metrics tracking
#[derive(Clone)]
pub struct MetricsService<S> {
inner: S,
metrics_store: MetricsStore,
}
impl<S> fmt::Debug for MetricsService<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MetricsService")
.field("metrics_store", &self.metrics_store)
.finish()
}
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for MetricsService<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
ReqBody: Send + 'static,
ResBody: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.inner.poll_ready(cx))?;
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let path = if let Some(matched_path) = req.extensions().get::<MatchedPath>() {
matched_path.as_str().to_string()
} else {
req.uri().path().to_string()
};
let method = req.method().clone();
let start = Instant::now();
let metrics_store = self.metrics_store.clone();
let future = self.inner.call(req);
Box::pin(async move {
let response = future.await?;
let time = start.elapsed();
let status = response.status();
let time_ms = time.as_millis() as u64;
// Record the timing in our metrics store
metrics_store.record(format!("{} {}", method, path), time_ms).await;
// Log the request timing
debug!("{} {} {} - {} ms", method, path, status, time_ms);
Ok(response)
})
}
}
/// Future that periodically logs metrics summaries
pub struct MetricsLoggerFuture {
metrics_store: MetricsStore,
interval: tokio::time::Interval,
}
impl MetricsLoggerFuture {
pub fn new(metrics_store: MetricsStore, interval_secs: u64) -> Self {
let interval = tokio::time::interval(tokio::time::Duration::from_secs(interval_secs));
Self {
metrics_store,
interval,
}
}
}
impl Future for MetricsLoggerFuture {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.interval.poll_tick(cx).is_ready() {
let metrics_store = self.metrics_store.clone();
tokio::spawn(async move {
metrics_store.log_summary().await;
});
}
Poll::Pending
}
}

View File

@@ -0,0 +1,7 @@
pub mod metrics;
pub use metrics::{
MetricsStore,
MetricsLoggerFuture,
MetricsLayer,
};