- Introduced ServerConfig for handling deployment modes and services.

- Added HighAvailability mode for proxying requests to external services.
- Maintained Local mode for embedded services.
- Updated `README.md` and included `SERVER_CONFIG.md` for detailed documentation.
This commit is contained in:
geoffsee
2025-08-28 09:55:39 -04:00
parent c96831d494
commit 45d7cd8819
7 changed files with 823 additions and 29 deletions

View File

@@ -1,4 +1,6 @@
mod middleware;
mod config;
mod proxy;
use axum::{
Router,
@@ -12,9 +14,9 @@ use tower_http::cors::{Any, CorsLayer};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use inference_engine::AppState;
use middleware::{MetricsStore, MetricsLoggerFuture, MetricsLayer};
use config::ServerConfig;
use proxy::create_proxy_router;
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
const DEFAULT_SERVER_PORT: &str = "8080";
#[tokio::main]
async fn main() {
@@ -42,26 +44,49 @@ async fn main() {
// Spawn the metrics logger in a background task
tokio::spawn(metrics_logger);
// Create unified router by merging embeddings and inference routers
let embeddings_router = embeddings_engine::create_embeddings_router();
// Create AppState with correct model configuration
use inference_engine::server::{PipelineArgs, build_pipeline};
use inference_engine::Which;
let mut pipeline_args = PipelineArgs::default();
pipeline_args.model_id = "google/gemma-3-1b-it".to_string();
pipeline_args.which = Which::InstructV3_1B;
let text_generation = build_pipeline(pipeline_args.clone());
let app_state = AppState {
text_generation: std::sync::Arc::new(tokio::sync::Mutex::new(text_generation)),
model_id: "google/gemma-3-1b-it".to_string(),
build_args: pipeline_args,
// Load server configuration from environment variable
let server_config = ServerConfig::from_env();
// Extract the server_host and server_port before potentially moving server_config
let default_host = server_config.server_host.clone();
let default_port = server_config.server_port;
// Create router based on server mode
let service_router = if server_config.clone().is_high_availability() {
tracing::info!("Running in HighAvailability mode - proxying to external services");
tracing::info!(" Inference service URL: {}", server_config.inference_url());
tracing::info!(" Embeddings service URL: {}", server_config.embeddings_url());
// Use proxy router that forwards requests to external services
create_proxy_router(server_config.clone())
} else {
tracing::info!("Running in Local mode - using embedded services");
// Create unified router by merging embeddings and inference routers (existing behavior)
let embeddings_router = embeddings_engine::create_embeddings_router();
// Create AppState with correct model configuration
use inference_engine::server::{PipelineArgs, build_pipeline};
use inference_engine::Which;
let mut pipeline_args = PipelineArgs::default();
pipeline_args.model_id = "google/gemma-3-1b-it".to_string();
pipeline_args.which = Which::InstructV3_1B;
let text_generation = build_pipeline(pipeline_args.clone());
let app_state = AppState {
text_generation: std::sync::Arc::new(tokio::sync::Mutex::new(text_generation)),
model_id: "google/gemma-3-1b-it".to_string(),
build_args: pipeline_args,
};
// Get the inference router directly from the inference engine
let inference_router = inference_engine::create_router(app_state);
// Merge the local routers
Router::new()
.merge(embeddings_router)
.merge(inference_router)
};
// Get the inference router directly from the inference engine
let inference_router = inference_engine::create_router(app_state);
// Create CORS layer
let cors = CorsLayer::new()
@@ -73,20 +98,26 @@ async fn main() {
// Create metrics layer
let metrics_layer = MetricsLayer::new(metrics_store);
// Merge the routers and add middleware layers
// Merge the service router with base routes and add middleware layers
let app = Router::new()
.route("/", get(|| async { "Hello, World!" }))
.route("/", get(|| async { "API ready. This can serve the Leptos web app, but it doesn't." }))
.route("/health", get(|| async { "ok" }))
.merge(embeddings_router)
.merge(inference_router)
.merge(service_router)
.layer(metrics_layer) // Add metrics tracking
.layer(cors)
.layer(TraceLayer::new_for_http());
// Server configuration
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string());
let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string());
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| {
String::from(default_host)
});
let server_port = env::var("SERVER_PORT").map(|v| v.parse::<u16>().unwrap_or(default_port)).unwrap_or_else(|_| {
default_port
});
let server_address = format!("{}:{}", server_host, server_port);
let listener = TcpListener::bind(&server_address).await.unwrap();
tracing::info!("Unified predict-otron-9000 server listening on {}", listener.local_addr().unwrap());