mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
cleanup, add ci
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use tracing::info;
|
||||
use tracing::log::error;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ServerConfig {
|
||||
#[serde(default = "default_server_host")]
|
||||
@@ -10,14 +12,16 @@ pub struct ServerConfig {
|
||||
pub server_port: u16,
|
||||
pub server_mode: ServerMode,
|
||||
#[serde(default)]
|
||||
pub services: Services,
|
||||
pub services: Option<Services>,
|
||||
}
|
||||
|
||||
fn default_server_host() -> String {
|
||||
"127.0.0.1".to_string()
|
||||
}
|
||||
|
||||
fn default_server_port() -> u16 { 8080 }
|
||||
fn default_server_port() -> u16 {
|
||||
8080
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
@@ -34,17 +38,15 @@ impl Default for ServerMode {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Services {
|
||||
#[serde(default = "inference_service_url")]
|
||||
pub inference_url: String,
|
||||
#[serde(default = "embeddings_service_url")]
|
||||
pub embeddings_url: String,
|
||||
pub inference_url: Option<String>,
|
||||
pub embeddings_url: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for Services {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
inference_url: inference_service_url(),
|
||||
embeddings_url: embeddings_service_url(),
|
||||
inference_url: None,
|
||||
embeddings_url: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -63,7 +65,7 @@ impl Default for ServerConfig {
|
||||
server_host: "127.0.0.1".to_string(),
|
||||
server_port: 8080,
|
||||
server_mode: ServerMode::Standalone,
|
||||
services: Services::default(),
|
||||
services: Some(Services::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -73,21 +75,19 @@ impl ServerConfig {
|
||||
/// Falls back to default (Local mode) if not set or invalid
|
||||
pub fn from_env() -> Self {
|
||||
match env::var("SERVER_CONFIG") {
|
||||
Ok(config_str) => {
|
||||
match serde_json::from_str::<ServerConfig>(&config_str) {
|
||||
Ok(config) => {
|
||||
tracing::info!("Loaded server configuration: {:?}", config);
|
||||
config
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to parse SERVER_CONFIG environment variable: {}. Using default configuration.",
|
||||
e
|
||||
);
|
||||
ServerConfig::default()
|
||||
}
|
||||
Ok(config_str) => match serde_json::from_str::<ServerConfig>(&config_str) {
|
||||
Ok(config) => {
|
||||
tracing::info!("Loaded server configuration: {:?}", config);
|
||||
config
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to parse SERVER_CONFIG environment variable: {}. Using default configuration.",
|
||||
e
|
||||
);
|
||||
ServerConfig::default()
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
tracing::info!("SERVER_CONFIG not set, Standalone mode active");
|
||||
ServerConfig::default()
|
||||
@@ -96,18 +96,52 @@ impl ServerConfig {
|
||||
}
|
||||
|
||||
/// Check if the server should run in high availability mode
|
||||
pub fn is_high_availability(&self) -> bool {
|
||||
self.server_mode == ServerMode::HighAvailability
|
||||
pub fn is_high_availability(&self) -> Result<bool, std::io::Error> {
|
||||
if self.server_mode == ServerMode::HighAvailability {
|
||||
let services_well_defined: bool = self.clone().services.is_some();
|
||||
|
||||
let inference_url_well_defined: bool =
|
||||
services_well_defined && self.clone().services.unwrap().inference_url.is_some();
|
||||
|
||||
let embeddings_well_defined: bool =
|
||||
services_well_defined && self.clone().services.unwrap().embeddings_url.is_some();
|
||||
|
||||
let is_well_defined_for_ha =
|
||||
services_well_defined && inference_url_well_defined && embeddings_well_defined;
|
||||
|
||||
if !is_well_defined_for_ha {
|
||||
let config_string = serde_json::to_string_pretty(&self).unwrap();
|
||||
error!(
|
||||
"HighAvailability mode configured but services not well defined! \n## Config Used:\n {}",
|
||||
config_string
|
||||
);
|
||||
let err = std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"HighAvailability mode configured but services not well defined!",
|
||||
);
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(self.server_mode == ServerMode::HighAvailability)
|
||||
}
|
||||
|
||||
/// Get the inference service URL for proxying
|
||||
pub fn inference_url(&self) -> &str {
|
||||
&self.services.inference_url
|
||||
pub fn inference_url(&self) -> Option<String> {
|
||||
if self.services.is_some() {
|
||||
self.services.clone()?.inference_url
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the embeddings service URL for proxying
|
||||
pub fn embeddings_url(&self) -> &str {
|
||||
&self.services.embeddings_url
|
||||
pub fn embeddings_url(&self) -> Option<String> {
|
||||
if self.services.is_some() {
|
||||
self.services.clone()?.embeddings_url
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,7 +153,7 @@ mod tests {
|
||||
fn test_default_config() {
|
||||
let config = ServerConfig::default();
|
||||
assert_eq!(config.server_mode, ServerMode::Standalone);
|
||||
assert!(!config.is_high_availability());
|
||||
assert!(!config.is_high_availability().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -134,23 +168,26 @@ mod tests {
|
||||
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert_eq!(config.server_mode, ServerMode::HighAvailability);
|
||||
assert!(config.is_high_availability());
|
||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
||||
assert!(config.is_high_availability().unwrap());
|
||||
assert_eq!(
|
||||
config.inference_url().unwrap(),
|
||||
"http://inference-service:8080"
|
||||
);
|
||||
assert_eq!(
|
||||
config.embeddings_url().unwrap(),
|
||||
"http://embeddings-service:8080"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_local_mode_config() {
|
||||
let config_json = r#"{
|
||||
"serverMode": "Local"
|
||||
"serverMode": "Standalone"
|
||||
}"#;
|
||||
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert_eq!(config.server_mode, ServerMode::Standalone);
|
||||
assert!(!config.is_high_availability());
|
||||
// Should use default URLs
|
||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
||||
assert!(!config.is_high_availability().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -164,17 +201,26 @@ mod tests {
|
||||
}"#;
|
||||
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert_eq!(config.inference_url(), "http://custom-inference:9000");
|
||||
assert_eq!(config.embeddings_url(), "http://custom-embeddings:9001");
|
||||
assert_eq!(
|
||||
config.inference_url().unwrap(),
|
||||
"http://custom-inference:9000"
|
||||
);
|
||||
assert_eq!(
|
||||
config.embeddings_url().unwrap(),
|
||||
"http://custom-embeddings:9001"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_minimal_high_availability_config() {
|
||||
fn test_minimal_high_availability_config_error() {
|
||||
let config_json = r#"{"serverMode": "HighAvailability"}"#;
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert!(config.is_high_availability());
|
||||
// Should use default URLs
|
||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
||||
|
||||
let is_high_availability = config.is_high_availability();
|
||||
|
||||
assert!(is_high_availability.is_err());
|
||||
// // Should use default URLs
|
||||
// assert_eq!(config.inference_url().unwrap(), "http://inference-service:8080");
|
||||
// assert_eq!(config.embeddings_url().unwrap(), "http://embeddings-service:8080");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,7 +1,9 @@
|
||||
mod config;
|
||||
mod middleware;
|
||||
mod proxy;
|
||||
mod standalone;
|
||||
|
||||
use crate::standalone::create_standalone_router;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::get;
|
||||
use axum::{Router, http::Uri, response::Html, serve};
|
||||
@@ -11,6 +13,7 @@ use middleware::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
|
||||
use proxy::create_proxy_router;
|
||||
use rust_embed::Embed;
|
||||
use std::env;
|
||||
use std::path::Component::ParentDir;
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::classify::ServerErrorsFailureClass::StatusCode;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
@@ -49,33 +52,19 @@ async fn main() {
|
||||
let default_host = server_config.server_host.clone();
|
||||
let default_port = server_config.server_port;
|
||||
|
||||
// Create router based on server mode
|
||||
let service_router = if server_config.clone().is_high_availability() {
|
||||
tracing::info!("Running in HighAvailability mode - proxying to external services");
|
||||
tracing::info!(" Inference service URL: {}", server_config.inference_url());
|
||||
tracing::info!(
|
||||
" Embeddings service URL: {}",
|
||||
server_config.embeddings_url()
|
||||
);
|
||||
|
||||
// Use proxy router that forwards requests to external services
|
||||
create_proxy_router(server_config.clone())
|
||||
} else {
|
||||
tracing::info!("Running in Standalone mode - using embedded services");
|
||||
|
||||
// Create unified router by merging embeddings and inference routers (existing behavior)
|
||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||
|
||||
// Create AppState with correct model configuration
|
||||
let app_state = AppState::default();
|
||||
|
||||
// Get the inference router directly from the inference engine
|
||||
let inference_router = inference_engine::create_router(app_state);
|
||||
|
||||
// Merge the local routers
|
||||
Router::new()
|
||||
.merge(embeddings_router)
|
||||
.merge(inference_router)
|
||||
let service_router = match server_config.clone().is_high_availability() {
|
||||
Ok(is_ha) => {
|
||||
if is_ha {
|
||||
log_config(server_config.clone());
|
||||
create_proxy_router(server_config.clone())
|
||||
} else {
|
||||
log_config(server_config.clone());
|
||||
create_standalone_router(server_config)
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
panic!("{}", error);
|
||||
}
|
||||
};
|
||||
|
||||
// Create CORS layer
|
||||
@@ -124,5 +113,25 @@ async fn main() {
|
||||
serve(listener, app).await.unwrap();
|
||||
}
|
||||
|
||||
fn log_config(config: ServerConfig) {
|
||||
match config.is_high_availability() {
|
||||
Ok(is_high) => {
|
||||
if is_high {
|
||||
tracing::info!("Running in HighAvailability mode - proxying to external services");
|
||||
tracing::info!("Inference service URL: {}", config.inference_url().unwrap());
|
||||
tracing::info!(
|
||||
"Embeddings service URL: {}",
|
||||
config.embeddings_url().unwrap()
|
||||
);
|
||||
} else {
|
||||
tracing::info!("Running in Standalone mode");
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
panic!("{}", error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Chat completions handler that properly uses the inference server crate's error handling
|
||||
// This function is no longer needed as we're using the inference_engine router directly
|
||||
|
@@ -2,6 +2,8 @@ use axum::{
|
||||
extract::MatchedPath,
|
||||
http::{Request, Response},
|
||||
};
|
||||
use std::fmt;
|
||||
use std::task::ready;
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
@@ -12,8 +14,6 @@ use std::{
|
||||
use tokio::sync::Mutex;
|
||||
use tower::{Layer, Service};
|
||||
use tracing::{debug, info};
|
||||
use std::task::ready;
|
||||
use std::fmt;
|
||||
|
||||
/// Performance metrics for a specific endpoint
|
||||
#[derive(Debug, Clone, Default)]
|
||||
@@ -33,16 +33,16 @@ impl EndpointMetrics {
|
||||
pub fn add_response_time(&mut self, time_ms: u64) {
|
||||
self.count += 1;
|
||||
self.total_time_ms += time_ms;
|
||||
|
||||
|
||||
if self.min_time_ms == 0 || time_ms < self.min_time_ms {
|
||||
self.min_time_ms = time_ms;
|
||||
}
|
||||
|
||||
|
||||
if time_ms > self.max_time_ms {
|
||||
self.max_time_ms = time_ms;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Get the average response time in milliseconds
|
||||
pub fn avg_time_ms(&self) -> f64 {
|
||||
if self.count == 0 {
|
||||
@@ -51,12 +51,15 @@ impl EndpointMetrics {
|
||||
self.total_time_ms as f64 / self.count as f64
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Get a human-readable summary of the metrics
|
||||
pub fn summary(&self) -> String {
|
||||
format!(
|
||||
"requests: {}, avg: {:.2}ms, min: {}ms, max: {}ms",
|
||||
self.count, self.avg_time_ms(), self.min_time_ms, self.max_time_ms
|
||||
self.count,
|
||||
self.avg_time_ms(),
|
||||
self.min_time_ms,
|
||||
self.max_time_ms
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -75,14 +78,16 @@ impl MetricsStore {
|
||||
endpoints: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Record a request's timing information
|
||||
pub async fn record(&self, path: String, time_ms: u64) {
|
||||
let mut endpoints = self.endpoints.lock().await;
|
||||
let metrics = endpoints.entry(path).or_insert_with(EndpointMetrics::default);
|
||||
let metrics = endpoints
|
||||
.entry(path)
|
||||
.or_insert_with(EndpointMetrics::default);
|
||||
metrics.add_response_time(time_ms);
|
||||
}
|
||||
|
||||
|
||||
/// Get metrics for all endpoints
|
||||
pub async fn get_all(&self) -> Vec<(String, EndpointMetrics)> {
|
||||
let endpoints = self.endpoints.lock().await;
|
||||
@@ -91,12 +96,12 @@ impl MetricsStore {
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
|
||||
/// Log a summary of all metrics
|
||||
pub async fn log_summary(&self) {
|
||||
let metrics = self.get_all().await;
|
||||
info!("Performance metrics summary:");
|
||||
|
||||
|
||||
for (path, metric) in metrics {
|
||||
info!(" {}: {}", path, metric.summary());
|
||||
}
|
||||
@@ -163,26 +168,28 @@ where
|
||||
} else {
|
||||
req.uri().path().to_string()
|
||||
};
|
||||
|
||||
|
||||
let method = req.method().clone();
|
||||
let start = Instant::now();
|
||||
let metrics_store = self.metrics_store.clone();
|
||||
|
||||
|
||||
let future = self.inner.call(req);
|
||||
|
||||
|
||||
Box::pin(async move {
|
||||
let response = future.await?;
|
||||
|
||||
|
||||
let time = start.elapsed();
|
||||
let status = response.status();
|
||||
let time_ms = time.as_millis() as u64;
|
||||
|
||||
|
||||
// Record the timing in our metrics store
|
||||
metrics_store.record(format!("{} {}", method, path), time_ms).await;
|
||||
|
||||
metrics_store
|
||||
.record(format!("{} {}", method, path), time_ms)
|
||||
.await;
|
||||
|
||||
// Log the request timing
|
||||
debug!("{} {} {} - {} ms", method, path, status, time_ms);
|
||||
|
||||
|
||||
Ok(response)
|
||||
})
|
||||
}
|
||||
@@ -214,7 +221,7 @@ impl Future for MetricsLoggerFuture {
|
||||
metrics_store.log_summary().await;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,7 +1,3 @@
|
||||
pub mod metrics;
|
||||
|
||||
pub use metrics::{
|
||||
MetricsStore,
|
||||
MetricsLoggerFuture,
|
||||
MetricsLayer,
|
||||
};
|
||||
pub use metrics::{MetricsLayer, MetricsLoggerFuture, MetricsStore};
|
||||
|
@@ -1,10 +1,10 @@
|
||||
use axum::{
|
||||
Router,
|
||||
body::Body,
|
||||
extract::{Request, State},
|
||||
http::{HeaderMap, Method, StatusCode, Uri},
|
||||
response::{IntoResponse, Response},
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
@@ -47,10 +47,16 @@ async fn proxy_chat_completions(
|
||||
headers: HeaderMap,
|
||||
body: Body,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let target_url = format!("{}/v1/chat/completions", proxy_client.config.inference_url());
|
||||
|
||||
let target_url = format!(
|
||||
"{}/v1/chat/completions",
|
||||
proxy_client
|
||||
.config
|
||||
.inference_url()
|
||||
.expect("Invalid Configuration")
|
||||
);
|
||||
|
||||
tracing::info!("Proxying chat completions request to: {}", target_url);
|
||||
|
||||
|
||||
// Extract body as bytes
|
||||
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||
Ok(bytes) => bytes,
|
||||
@@ -63,7 +69,9 @@ async fn proxy_chat_completions(
|
||||
// Check if this is a streaming request
|
||||
let is_streaming = if let Ok(body_str) = String::from_utf8(body_bytes.to_vec()) {
|
||||
if let Ok(json) = serde_json::from_str::<Value>(&body_str) {
|
||||
json.get("stream").and_then(|v| v.as_bool()).unwrap_or(false)
|
||||
json.get("stream")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
@@ -72,7 +80,8 @@ async fn proxy_chat_completions(
|
||||
};
|
||||
|
||||
// Forward the request
|
||||
let mut req_builder = proxy_client.client
|
||||
let mut req_builder = proxy_client
|
||||
.client
|
||||
.post(&target_url)
|
||||
.body(body_bytes.to_vec());
|
||||
|
||||
@@ -85,8 +94,7 @@ async fn proxy_chat_completions(
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
let mut resp_builder = Response::builder()
|
||||
.status(response.status());
|
||||
let mut resp_builder = Response::builder().status(response.status());
|
||||
|
||||
// Forward response headers
|
||||
for (name, value) in response.headers().iter() {
|
||||
@@ -99,14 +107,12 @@ async fn proxy_chat_completions(
|
||||
if is_streaming {
|
||||
// For streaming, we need to forward the response as-is
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.header("content-type", "text/plain; charset=utf-8")
|
||||
.header("cache-control", "no-cache")
|
||||
.header("connection", "keep-alive")
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Ok(body) => resp_builder
|
||||
.header("content-type", "text/plain; charset=utf-8")
|
||||
.header("cache-control", "no-cache")
|
||||
.header("connection", "keep-alive")
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read streaming response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
@@ -115,11 +121,9 @@ async fn proxy_chat_completions(
|
||||
} else {
|
||||
// For non-streaming, forward the JSON response
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Ok(body) => resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
@@ -139,10 +143,16 @@ async fn proxy_models(
|
||||
State(proxy_client): State<ProxyClient>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let target_url = format!("{}/v1/models", proxy_client.config.inference_url());
|
||||
|
||||
let target_url = format!(
|
||||
"{}/v1/models",
|
||||
proxy_client
|
||||
.config
|
||||
.inference_url()
|
||||
.expect("Invalid Configuration Detected")
|
||||
);
|
||||
|
||||
tracing::info!("Proxying models request to: {}", target_url);
|
||||
|
||||
|
||||
let mut req_builder = proxy_client.client.get(&target_url);
|
||||
|
||||
// Forward relevant headers
|
||||
@@ -154,8 +164,7 @@ async fn proxy_models(
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
let mut resp_builder = Response::builder()
|
||||
.status(response.status());
|
||||
let mut resp_builder = Response::builder().status(response.status());
|
||||
|
||||
// Forward response headers
|
||||
for (name, value) in response.headers().iter() {
|
||||
@@ -165,11 +174,9 @@ async fn proxy_models(
|
||||
}
|
||||
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Ok(body) => resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read models response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
@@ -189,10 +196,16 @@ async fn proxy_embeddings(
|
||||
headers: HeaderMap,
|
||||
body: Body,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let target_url = format!("{}/v1/embeddings", proxy_client.config.embeddings_url());
|
||||
|
||||
let target_url = format!(
|
||||
"{}/v1/embeddings",
|
||||
proxy_client
|
||||
.config
|
||||
.embeddings_url()
|
||||
.expect("Invalid Configuration Detected")
|
||||
);
|
||||
|
||||
tracing::info!("Proxying embeddings request to: {}", target_url);
|
||||
|
||||
|
||||
// Extract body as bytes
|
||||
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||
Ok(bytes) => bytes,
|
||||
@@ -203,7 +216,8 @@ async fn proxy_embeddings(
|
||||
};
|
||||
|
||||
// Forward the request
|
||||
let mut req_builder = proxy_client.client
|
||||
let mut req_builder = proxy_client
|
||||
.client
|
||||
.post(&target_url)
|
||||
.body(body_bytes.to_vec());
|
||||
|
||||
@@ -216,8 +230,7 @@ async fn proxy_embeddings(
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
let mut resp_builder = Response::builder()
|
||||
.status(response.status());
|
||||
let mut resp_builder = Response::builder().status(response.status());
|
||||
|
||||
// Forward response headers
|
||||
for (name, value) in response.headers().iter() {
|
||||
@@ -227,11 +240,9 @@ async fn proxy_embeddings(
|
||||
}
|
||||
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Ok(body) => resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read embeddings response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
@@ -250,7 +261,7 @@ fn should_forward_header(header_name: &str) -> bool {
|
||||
match header_name.to_lowercase().as_str() {
|
||||
"content-type" | "content-length" | "authorization" | "user-agent" | "accept" => true,
|
||||
"host" | "connection" | "upgrade" => false, // Don't forward connection-specific headers
|
||||
_ => true, // Forward other headers by default
|
||||
_ => true, // Forward other headers by default
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,7 +270,7 @@ fn should_forward_response_header(header_name: &str) -> bool {
|
||||
match header_name.to_lowercase().as_str() {
|
||||
"content-type" | "content-length" | "cache-control" | "connection" => true,
|
||||
"server" | "date" => false, // Don't forward server-specific headers
|
||||
_ => true, // Forward other headers by default
|
||||
_ => true, // Forward other headers by default
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,14 +301,20 @@ mod tests {
|
||||
server_host: "127.0.0.1".to_string(),
|
||||
server_port: 8080,
|
||||
server_mode: ServerMode::HighAvailability,
|
||||
services: Services {
|
||||
inference_url: "http://test-inference:8080".to_string(),
|
||||
embeddings_url: "http://test-embeddings:8080".to_string(),
|
||||
},
|
||||
services: Some(Services {
|
||||
inference_url: Some("http://test-inference:8080".to_string()),
|
||||
embeddings_url: Some("http://test-embeddings:8080".to_string()),
|
||||
}),
|
||||
};
|
||||
|
||||
let proxy_client = ProxyClient::new(config);
|
||||
assert_eq!(proxy_client.config.inference_url(), "http://test-inference:8080");
|
||||
assert_eq!(proxy_client.config.embeddings_url(), "http://test-embeddings:8080");
|
||||
assert_eq!(
|
||||
proxy_client.config.inference_url().unwrap().as_str(),
|
||||
"http://test-inference:8080"
|
||||
);
|
||||
assert_eq!(
|
||||
proxy_client.config.embeddings_url().unwrap().as_str(),
|
||||
"http://test-embeddings:8080"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
19
crates/predict-otron-9000/src/standalone.rs
Normal file
19
crates/predict-otron-9000/src/standalone.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
use crate::config::ServerConfig;
|
||||
use axum::Router;
|
||||
use inference_engine::AppState;
|
||||
|
||||
pub fn create_standalone_router(server_config: ServerConfig) -> Router {
|
||||
// Create unified router by merging embeddings and inference routers (existing behavior)
|
||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||
|
||||
// Create AppState with correct model configuration
|
||||
let app_state = AppState::default();
|
||||
|
||||
// Get the inference router directly from the inference engine
|
||||
let inference_router = inference_engine::create_router(app_state);
|
||||
|
||||
// Merge the local routers
|
||||
Router::new()
|
||||
.merge(embeddings_router)
|
||||
.merge(inference_router)
|
||||
}
|
Reference in New Issue
Block a user