mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
Refactor apply_cached_repeat_penalty
for optimized caching and reuse, add extensive unit tests, and integrate special handling for gemma-specific models.
Removed `test_request.sh`, deprecated functionality, and unused imports; introduced a new CLI tool (`cli.ts`) for testing inference engine and adjusted handling of non-streaming/streaming chat completions. - Add CPU fallback support for text generation when primary device is unsupported - Introduce `execute_with_fallback` method to handle device compatibility and shape mismatch errors - Extend unit tests to reproduce tensor shape mismatch errors specific to model configurations - Increase HTTP timeout limits in `curl_chat_stream.sh` script for reliable API testing chat completion endpoint functions with gemma3 (no streaming) Add benchmarking guide with HTML reporting, Leptos chat crate, and middleware for metrics tracking
This commit is contained in:
@@ -3,6 +3,10 @@ name = "predict-otron-9000"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[[bin]]
|
||||
name = "predict-otron-9000"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
# Axum web framework
|
||||
axum = "0.8.4"
|
||||
|
@@ -1,12 +1,19 @@
|
||||
use axum::{Router, serve, http::StatusCode};
|
||||
mod middleware;
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
serve,
|
||||
};
|
||||
use std::env;
|
||||
use axum::routing::get;
|
||||
use tokio::net::TcpListener;
|
||||
use tower::Service;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
use inference_engine::AppState;
|
||||
use middleware::{MetricsStore, MetricsLoggerFuture, MetricsLayer};
|
||||
|
||||
const DEFAULT_SERVER_HOST: &str = "0.0.0.0";
|
||||
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
|
||||
const DEFAULT_SERVER_PORT: &str = "8080";
|
||||
|
||||
#[tokio::main]
|
||||
@@ -25,23 +32,53 @@ async fn main() {
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
|
||||
// Initialize metrics store for performance tracking
|
||||
let metrics_store = MetricsStore::new();
|
||||
|
||||
// Create a metrics logger that will periodically log metrics (every 60 seconds)
|
||||
let metrics_logger = MetricsLoggerFuture::new(metrics_store.clone(), 60);
|
||||
|
||||
// Spawn the metrics logger in a background task
|
||||
tokio::spawn(metrics_logger);
|
||||
|
||||
// Create unified router by merging embeddings and inference routers
|
||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||
|
||||
|
||||
// Create AppState with correct model configuration
|
||||
use inference_engine::server::{PipelineArgs, build_pipeline};
|
||||
use inference_engine::Which;
|
||||
let mut pipeline_args = PipelineArgs::default();
|
||||
pipeline_args.model_id = "google/gemma-3-1b-it".to_string();
|
||||
pipeline_args.which = Which::InstructV3_1B;
|
||||
|
||||
let text_generation = build_pipeline(pipeline_args);
|
||||
let app_state = AppState {
|
||||
text_generation: std::sync::Arc::new(tokio::sync::Mutex::new(text_generation)),
|
||||
model_id: "google/gemma-3-1b-it".to_string(),
|
||||
};
|
||||
|
||||
// Get the inference router directly from the inference engine
|
||||
let inference_router = inference_engine::create_inference_router();
|
||||
|
||||
let inference_router = inference_engine::create_router(app_state);
|
||||
|
||||
// Create CORS layer
|
||||
let cors = CorsLayer::new()
|
||||
.allow_headers(Any)
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
// Create metrics layer
|
||||
let metrics_layer = MetricsLayer::new(metrics_store);
|
||||
|
||||
// Merge the routers
|
||||
// Merge the routers and add middleware layers
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async { "Hello, World!" }))
|
||||
.route("/health", get(|| async { "ok" }))
|
||||
.merge(embeddings_router)
|
||||
.merge(inference_router)
|
||||
.layer(metrics_layer) // Add metrics tracking
|
||||
.layer(cors)
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
@@ -52,6 +89,7 @@ async fn main() {
|
||||
|
||||
let listener = TcpListener::bind(&server_address).await.unwrap();
|
||||
tracing::info!("Unified predict-otron-9000 server listening on {}", listener.local_addr().unwrap());
|
||||
tracing::info!("Performance metrics tracking enabled - summary logs every 60 seconds");
|
||||
tracing::info!("Available endpoints:");
|
||||
tracing::info!(" GET / - Root endpoint from embeddings-engine");
|
||||
tracing::info!(" POST /v1/embeddings - Text embeddings");
|
||||
@@ -60,5 +98,7 @@ async fn main() {
|
||||
serve(listener, app).await.unwrap();
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Chat completions handler that properly uses the inference server crate's error handling
|
||||
// This function is no longer needed as we're using the inference_engine router directly
|
||||
|
220
crates/predict-otron-9000/src/middleware/metrics.rs
Normal file
220
crates/predict-otron-9000/src/middleware/metrics.rs
Normal file
@@ -0,0 +1,220 @@
|
||||
use axum::{
|
||||
extract::MatchedPath,
|
||||
http::{Request, Response},
|
||||
};
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
time::Instant,
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
use tower::{Layer, Service};
|
||||
use tracing::{debug, info};
|
||||
use std::task::ready;
|
||||
use std::fmt;
|
||||
|
||||
/// Performance metrics for a specific endpoint
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EndpointMetrics {
|
||||
/// Total number of requests
|
||||
pub count: usize,
|
||||
/// Total response time in milliseconds
|
||||
pub total_time_ms: u64,
|
||||
/// Minimum response time in milliseconds
|
||||
pub min_time_ms: u64,
|
||||
/// Maximum response time in milliseconds
|
||||
pub max_time_ms: u64,
|
||||
}
|
||||
|
||||
impl EndpointMetrics {
|
||||
/// Add a new response time to the metrics
|
||||
pub fn add_response_time(&mut self, time_ms: u64) {
|
||||
self.count += 1;
|
||||
self.total_time_ms += time_ms;
|
||||
|
||||
if self.min_time_ms == 0 || time_ms < self.min_time_ms {
|
||||
self.min_time_ms = time_ms;
|
||||
}
|
||||
|
||||
if time_ms > self.max_time_ms {
|
||||
self.max_time_ms = time_ms;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the average response time in milliseconds
|
||||
pub fn avg_time_ms(&self) -> f64 {
|
||||
if self.count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.total_time_ms as f64 / self.count as f64
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a human-readable summary of the metrics
|
||||
pub fn summary(&self) -> String {
|
||||
format!(
|
||||
"requests: {}, avg: {:.2}ms, min: {}ms, max: {}ms",
|
||||
self.count, self.avg_time_ms(), self.min_time_ms, self.max_time_ms
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Global metrics storage
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct MetricsStore {
|
||||
/// Metrics per endpoint
|
||||
endpoints: Arc<Mutex<std::collections::HashMap<String, EndpointMetrics>>>,
|
||||
}
|
||||
|
||||
impl MetricsStore {
|
||||
/// Create a new metrics store
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
endpoints: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a request's timing information
|
||||
pub async fn record(&self, path: String, time_ms: u64) {
|
||||
let mut endpoints = self.endpoints.lock().await;
|
||||
let metrics = endpoints.entry(path).or_insert_with(EndpointMetrics::default);
|
||||
metrics.add_response_time(time_ms);
|
||||
}
|
||||
|
||||
/// Get metrics for all endpoints
|
||||
pub async fn get_all(&self) -> Vec<(String, EndpointMetrics)> {
|
||||
let endpoints = self.endpoints.lock().await;
|
||||
endpoints
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Log a summary of all metrics
|
||||
pub async fn log_summary(&self) {
|
||||
let metrics = self.get_all().await;
|
||||
info!("Performance metrics summary:");
|
||||
|
||||
for (path, metric) in metrics {
|
||||
info!(" {}: {}", path, metric.summary());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Define a Layer for metrics tracking
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetricsLayer {
|
||||
metrics_store: MetricsStore,
|
||||
}
|
||||
|
||||
impl MetricsLayer {
|
||||
pub fn new(metrics_store: MetricsStore) -> Self {
|
||||
Self { metrics_store }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for MetricsLayer {
|
||||
type Service = MetricsService<S>;
|
||||
|
||||
fn layer(&self, service: S) -> Self::Service {
|
||||
MetricsService {
|
||||
inner: service,
|
||||
metrics_store: self.metrics_store.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Define a Service for metrics tracking
|
||||
#[derive(Clone)]
|
||||
pub struct MetricsService<S> {
|
||||
inner: S,
|
||||
metrics_store: MetricsStore,
|
||||
}
|
||||
|
||||
impl<S> fmt::Debug for MetricsService<S> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("MetricsService")
|
||||
.field("metrics_store", &self.metrics_store)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for MetricsService<S>
|
||||
where
|
||||
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
ReqBody: Send + 'static,
|
||||
ResBody: Send + 'static,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
ready!(self.inner.poll_ready(cx))?;
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
|
||||
let path = if let Some(matched_path) = req.extensions().get::<MatchedPath>() {
|
||||
matched_path.as_str().to_string()
|
||||
} else {
|
||||
req.uri().path().to_string()
|
||||
};
|
||||
|
||||
let method = req.method().clone();
|
||||
let start = Instant::now();
|
||||
let metrics_store = self.metrics_store.clone();
|
||||
|
||||
let future = self.inner.call(req);
|
||||
|
||||
Box::pin(async move {
|
||||
let response = future.await?;
|
||||
|
||||
let time = start.elapsed();
|
||||
let status = response.status();
|
||||
let time_ms = time.as_millis() as u64;
|
||||
|
||||
// Record the timing in our metrics store
|
||||
metrics_store.record(format!("{} {}", method, path), time_ms).await;
|
||||
|
||||
// Log the request timing
|
||||
debug!("{} {} {} - {} ms", method, path, status, time_ms);
|
||||
|
||||
Ok(response)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Future that periodically logs metrics summaries
|
||||
pub struct MetricsLoggerFuture {
|
||||
metrics_store: MetricsStore,
|
||||
interval: tokio::time::Interval,
|
||||
}
|
||||
|
||||
impl MetricsLoggerFuture {
|
||||
pub fn new(metrics_store: MetricsStore, interval_secs: u64) -> Self {
|
||||
let interval = tokio::time::interval(tokio::time::Duration::from_secs(interval_secs));
|
||||
Self {
|
||||
metrics_store,
|
||||
interval,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for MetricsLoggerFuture {
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
if self.interval.poll_tick(cx).is_ready() {
|
||||
let metrics_store = self.metrics_store.clone();
|
||||
tokio::spawn(async move {
|
||||
metrics_store.log_summary().await;
|
||||
});
|
||||
}
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
7
crates/predict-otron-9000/src/middleware/mod.rs
Normal file
7
crates/predict-otron-9000/src/middleware/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
pub mod metrics;
|
||||
|
||||
pub use metrics::{
|
||||
MetricsStore,
|
||||
MetricsLoggerFuture,
|
||||
MetricsLayer,
|
||||
};
|
Reference in New Issue
Block a user