mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
- Introduced ServerConfig
for handling deployment modes and services.
- Added HighAvailability mode for proxying requests to external services. - Maintained Local mode for embedded services. - Updated `README.md` and included `SERVER_CONFIG.md` for detailed documentation.
This commit is contained in:
@@ -18,6 +18,8 @@ serde_json = "1.0.140"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
uuid = { version = "1.7.0", features = ["v4"] }
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
rust-embed = "8.7.2"
|
||||
|
||||
# Dependencies for embeddings functionality
|
||||
embeddings-engine = { path = "../embeddings-engine" }
|
||||
@@ -36,4 +38,5 @@ port = 8080
|
||||
[package.metadata.kube]
|
||||
image = "ghcr.io/geoffsee/predict-otron-9000:latest"
|
||||
replicas = 1
|
||||
port = 8080
|
||||
port = 8080
|
||||
env = { SERVER_CONFIG = "" }
|
180
crates/predict-otron-9000/src/config.rs
Normal file
180
crates/predict-otron-9000/src/config.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ServerConfig {
|
||||
#[serde(default = "default_server_host")]
|
||||
pub server_host: String,
|
||||
#[serde(default = "default_server_port")]
|
||||
pub server_port: u16,
|
||||
pub server_mode: ServerMode,
|
||||
#[serde(default)]
|
||||
pub services: Services,
|
||||
}
|
||||
|
||||
fn default_server_host() -> String {
|
||||
"127.0.0.1".to_string()
|
||||
}
|
||||
|
||||
fn default_server_port() -> u16 { 8080 }
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
pub enum ServerMode {
|
||||
Standalone,
|
||||
HighAvailability,
|
||||
}
|
||||
|
||||
impl Default for ServerMode {
|
||||
fn default() -> Self {
|
||||
Self::Standalone
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Services {
|
||||
#[serde(default = "inference_service_url")]
|
||||
pub inference_url: String,
|
||||
#[serde(default = "embeddings_service_url")]
|
||||
pub embeddings_url: String,
|
||||
}
|
||||
|
||||
impl Default for Services {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
inference_url: inference_service_url(),
|
||||
embeddings_url: embeddings_service_url(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn inference_service_url() -> String {
|
||||
"http://inference-service:8080".to_string()
|
||||
}
|
||||
|
||||
fn embeddings_service_url() -> String {
|
||||
"http://embeddings-service:8080".to_string()
|
||||
}
|
||||
|
||||
impl Default for ServerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
server_host: "127.0.0.1".to_string(),
|
||||
server_port: 8080,
|
||||
server_mode: ServerMode::Standalone,
|
||||
services: Services::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerConfig {
|
||||
/// Load configuration from SERVER_CONFIG environment variable
|
||||
/// Falls back to default (Local mode) if not set or invalid
|
||||
pub fn from_env() -> Self {
|
||||
match env::var("SERVER_CONFIG") {
|
||||
Ok(config_str) => {
|
||||
match serde_json::from_str::<ServerConfig>(&config_str) {
|
||||
Ok(config) => {
|
||||
tracing::info!("Loaded server configuration: {:?}", config);
|
||||
config
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to parse SERVER_CONFIG environment variable: {}. Using default configuration.",
|
||||
e
|
||||
);
|
||||
ServerConfig::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::info!("SERVER_CONFIG not set, using default Local mode");
|
||||
ServerConfig::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the server should run in high availability mode
|
||||
pub fn is_high_availability(&self) -> bool {
|
||||
self.server_mode == ServerMode::HighAvailability
|
||||
}
|
||||
|
||||
/// Get the inference service URL for proxying
|
||||
pub fn inference_url(&self) -> &str {
|
||||
&self.services.inference_url
|
||||
}
|
||||
|
||||
/// Get the embeddings service URL for proxying
|
||||
pub fn embeddings_url(&self) -> &str {
|
||||
&self.services.embeddings_url
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = ServerConfig::default();
|
||||
assert_eq!(config.server_mode, ServerMode::Standalone);
|
||||
assert!(!config.is_high_availability());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_high_availability_config() {
|
||||
let config_json = r#"{
|
||||
"serverMode": "HighAvailability",
|
||||
"services": {
|
||||
"inference_url": "http://inference-service:8080",
|
||||
"embeddings_url": "http://embeddings-service:8080"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert_eq!(config.server_mode, ServerMode::HighAvailability);
|
||||
assert!(config.is_high_availability());
|
||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_local_mode_config() {
|
||||
let config_json = r#"{
|
||||
"serverMode": "Local"
|
||||
}"#;
|
||||
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert_eq!(config.server_mode, ServerMode::Standalone);
|
||||
assert!(!config.is_high_availability());
|
||||
// Should use default URLs
|
||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_urls() {
|
||||
let config_json = r#"{
|
||||
"serverMode": "HighAvailability",
|
||||
"services": {
|
||||
"inference_url": "http://custom-inference:9000",
|
||||
"embeddings_url": "http://custom-embeddings:9001"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert_eq!(config.inference_url(), "http://custom-inference:9000");
|
||||
assert_eq!(config.embeddings_url(), "http://custom-embeddings:9001");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_minimal_high_availability_config() {
|
||||
let config_json = r#"{"serverMode": "HighAvailability"}"#;
|
||||
let config: ServerConfig = serde_json::from_str(config_json).unwrap();
|
||||
assert!(config.is_high_availability());
|
||||
// Should use default URLs
|
||||
assert_eq!(config.inference_url(), "http://inference-service:8080");
|
||||
assert_eq!(config.embeddings_url(), "http://embeddings-service:8080");
|
||||
}
|
||||
}
|
@@ -1,4 +1,6 @@
|
||||
mod middleware;
|
||||
mod config;
|
||||
mod proxy;
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
@@ -12,9 +14,9 @@ use tower_http::cors::{Any, CorsLayer};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
use inference_engine::AppState;
|
||||
use middleware::{MetricsStore, MetricsLoggerFuture, MetricsLayer};
|
||||
use config::ServerConfig;
|
||||
use proxy::create_proxy_router;
|
||||
|
||||
const DEFAULT_SERVER_HOST: &str = "127.0.0.1";
|
||||
const DEFAULT_SERVER_PORT: &str = "8080";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
@@ -42,26 +44,49 @@ async fn main() {
|
||||
// Spawn the metrics logger in a background task
|
||||
tokio::spawn(metrics_logger);
|
||||
|
||||
// Create unified router by merging embeddings and inference routers
|
||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||
|
||||
|
||||
// Create AppState with correct model configuration
|
||||
use inference_engine::server::{PipelineArgs, build_pipeline};
|
||||
use inference_engine::Which;
|
||||
let mut pipeline_args = PipelineArgs::default();
|
||||
pipeline_args.model_id = "google/gemma-3-1b-it".to_string();
|
||||
pipeline_args.which = Which::InstructV3_1B;
|
||||
|
||||
let text_generation = build_pipeline(pipeline_args.clone());
|
||||
let app_state = AppState {
|
||||
text_generation: std::sync::Arc::new(tokio::sync::Mutex::new(text_generation)),
|
||||
model_id: "google/gemma-3-1b-it".to_string(),
|
||||
build_args: pipeline_args,
|
||||
// Load server configuration from environment variable
|
||||
let server_config = ServerConfig::from_env();
|
||||
|
||||
// Extract the server_host and server_port before potentially moving server_config
|
||||
let default_host = server_config.server_host.clone();
|
||||
let default_port = server_config.server_port;
|
||||
|
||||
// Create router based on server mode
|
||||
let service_router = if server_config.clone().is_high_availability() {
|
||||
tracing::info!("Running in HighAvailability mode - proxying to external services");
|
||||
tracing::info!(" Inference service URL: {}", server_config.inference_url());
|
||||
tracing::info!(" Embeddings service URL: {}", server_config.embeddings_url());
|
||||
|
||||
// Use proxy router that forwards requests to external services
|
||||
create_proxy_router(server_config.clone())
|
||||
} else {
|
||||
tracing::info!("Running in Local mode - using embedded services");
|
||||
|
||||
// Create unified router by merging embeddings and inference routers (existing behavior)
|
||||
let embeddings_router = embeddings_engine::create_embeddings_router();
|
||||
|
||||
// Create AppState with correct model configuration
|
||||
use inference_engine::server::{PipelineArgs, build_pipeline};
|
||||
use inference_engine::Which;
|
||||
let mut pipeline_args = PipelineArgs::default();
|
||||
pipeline_args.model_id = "google/gemma-3-1b-it".to_string();
|
||||
pipeline_args.which = Which::InstructV3_1B;
|
||||
|
||||
let text_generation = build_pipeline(pipeline_args.clone());
|
||||
let app_state = AppState {
|
||||
text_generation: std::sync::Arc::new(tokio::sync::Mutex::new(text_generation)),
|
||||
model_id: "google/gemma-3-1b-it".to_string(),
|
||||
build_args: pipeline_args,
|
||||
};
|
||||
|
||||
// Get the inference router directly from the inference engine
|
||||
let inference_router = inference_engine::create_router(app_state);
|
||||
|
||||
// Merge the local routers
|
||||
Router::new()
|
||||
.merge(embeddings_router)
|
||||
.merge(inference_router)
|
||||
};
|
||||
|
||||
// Get the inference router directly from the inference engine
|
||||
let inference_router = inference_engine::create_router(app_state);
|
||||
|
||||
// Create CORS layer
|
||||
let cors = CorsLayer::new()
|
||||
@@ -73,20 +98,26 @@ async fn main() {
|
||||
// Create metrics layer
|
||||
let metrics_layer = MetricsLayer::new(metrics_store);
|
||||
|
||||
// Merge the routers and add middleware layers
|
||||
// Merge the service router with base routes and add middleware layers
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async { "Hello, World!" }))
|
||||
.route("/", get(|| async { "API ready. This can serve the Leptos web app, but it doesn't." }))
|
||||
.route("/health", get(|| async { "ok" }))
|
||||
.merge(embeddings_router)
|
||||
.merge(inference_router)
|
||||
.merge(service_router)
|
||||
.layer(metrics_layer) // Add metrics tracking
|
||||
.layer(cors)
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
// Server configuration
|
||||
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| DEFAULT_SERVER_HOST.to_string());
|
||||
let server_port = env::var("SERVER_PORT").unwrap_or_else(|_| DEFAULT_SERVER_PORT.to_string());
|
||||
let server_host = env::var("SERVER_HOST").unwrap_or_else(|_| {
|
||||
String::from(default_host)
|
||||
});
|
||||
|
||||
let server_port = env::var("SERVER_PORT").map(|v| v.parse::<u16>().unwrap_or(default_port)).unwrap_or_else(|_| {
|
||||
default_port
|
||||
});
|
||||
|
||||
let server_address = format!("{}:{}", server_host, server_port);
|
||||
|
||||
|
||||
let listener = TcpListener::bind(&server_address).await.unwrap();
|
||||
tracing::info!("Unified predict-otron-9000 server listening on {}", listener.local_addr().unwrap());
|
||||
|
303
crates/predict-otron-9000/src/proxy.rs
Normal file
303
crates/predict-otron-9000/src/proxy.rs
Normal file
@@ -0,0 +1,303 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Request, State},
|
||||
http::{HeaderMap, Method, StatusCode, Uri},
|
||||
response::{IntoResponse, Response},
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::config::ServerConfig;
|
||||
|
||||
/// HTTP client configured for proxying requests
|
||||
#[derive(Clone)]
|
||||
pub struct ProxyClient {
|
||||
client: Client,
|
||||
config: ServerConfig,
|
||||
}
|
||||
|
||||
impl ProxyClient {
|
||||
pub fn new(config: ServerConfig) -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(300)) // 5 minute timeout for long-running inference
|
||||
.build()
|
||||
.expect("Failed to create HTTP client for proxy");
|
||||
|
||||
Self { client, config }
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a router that proxies requests to external services in HighAvailability mode
|
||||
pub fn create_proxy_router(config: ServerConfig) -> Router {
|
||||
let proxy_client = ProxyClient::new(config.clone());
|
||||
|
||||
Router::new()
|
||||
.route("/v1/chat/completions", post(proxy_chat_completions))
|
||||
.route("/v1/models", get(proxy_models))
|
||||
.route("/v1/embeddings", post(proxy_embeddings))
|
||||
.with_state(proxy_client)
|
||||
}
|
||||
|
||||
/// Proxy handler for POST /v1/chat/completions
|
||||
async fn proxy_chat_completions(
|
||||
State(proxy_client): State<ProxyClient>,
|
||||
headers: HeaderMap,
|
||||
body: Body,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let target_url = format!("{}/v1/chat/completions", proxy_client.config.inference_url());
|
||||
|
||||
tracing::info!("Proxying chat completions request to: {}", target_url);
|
||||
|
||||
// Extract body as bytes
|
||||
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read request body: {}", e);
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
};
|
||||
|
||||
// Check if this is a streaming request
|
||||
let is_streaming = if let Ok(body_str) = String::from_utf8(body_bytes.to_vec()) {
|
||||
if let Ok(json) = serde_json::from_str::<Value>(&body_str) {
|
||||
json.get("stream").and_then(|v| v.as_bool()).unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
// Forward the request
|
||||
let mut req_builder = proxy_client.client
|
||||
.post(&target_url)
|
||||
.body(body_bytes.to_vec());
|
||||
|
||||
// Forward relevant headers
|
||||
for (name, value) in headers.iter() {
|
||||
if should_forward_header(name.as_str()) {
|
||||
req_builder = req_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
let mut resp_builder = Response::builder()
|
||||
.status(response.status());
|
||||
|
||||
// Forward response headers
|
||||
for (name, value) in response.headers().iter() {
|
||||
if should_forward_response_header(name.as_str()) {
|
||||
resp_builder = resp_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle streaming vs non-streaming responses
|
||||
if is_streaming {
|
||||
// For streaming, we need to forward the response as-is
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.header("content-type", "text/plain; charset=utf-8")
|
||||
.header("cache-control", "no-cache")
|
||||
.header("connection", "keep-alive")
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read streaming response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For non-streaming, forward the JSON response
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to proxy chat completions request: {}", e);
|
||||
Err(StatusCode::BAD_GATEWAY)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy handler for GET /v1/models
|
||||
async fn proxy_models(
|
||||
State(proxy_client): State<ProxyClient>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let target_url = format!("{}/v1/models", proxy_client.config.inference_url());
|
||||
|
||||
tracing::info!("Proxying models request to: {}", target_url);
|
||||
|
||||
let mut req_builder = proxy_client.client.get(&target_url);
|
||||
|
||||
// Forward relevant headers
|
||||
for (name, value) in headers.iter() {
|
||||
if should_forward_header(name.as_str()) {
|
||||
req_builder = req_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
let mut resp_builder = Response::builder()
|
||||
.status(response.status());
|
||||
|
||||
// Forward response headers
|
||||
for (name, value) in response.headers().iter() {
|
||||
if should_forward_response_header(name.as_str()) {
|
||||
resp_builder = resp_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read models response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to proxy models request: {}", e);
|
||||
Err(StatusCode::BAD_GATEWAY)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy handler for POST /v1/embeddings
|
||||
async fn proxy_embeddings(
|
||||
State(proxy_client): State<ProxyClient>,
|
||||
headers: HeaderMap,
|
||||
body: Body,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let target_url = format!("{}/v1/embeddings", proxy_client.config.embeddings_url());
|
||||
|
||||
tracing::info!("Proxying embeddings request to: {}", target_url);
|
||||
|
||||
// Extract body as bytes
|
||||
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read request body: {}", e);
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
};
|
||||
|
||||
// Forward the request
|
||||
let mut req_builder = proxy_client.client
|
||||
.post(&target_url)
|
||||
.body(body_bytes.to_vec());
|
||||
|
||||
// Forward relevant headers
|
||||
for (name, value) in headers.iter() {
|
||||
if should_forward_header(name.as_str()) {
|
||||
req_builder = req_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
let mut resp_builder = Response::builder()
|
||||
.status(response.status());
|
||||
|
||||
// Forward response headers
|
||||
for (name, value) in response.headers().iter() {
|
||||
if should_forward_response_header(name.as_str()) {
|
||||
resp_builder = resp_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match response.bytes().await {
|
||||
Ok(body) => {
|
||||
resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read embeddings response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to proxy embeddings request: {}", e);
|
||||
Err(StatusCode::BAD_GATEWAY)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine if a request header should be forwarded to the target service
|
||||
fn should_forward_header(header_name: &str) -> bool {
|
||||
match header_name.to_lowercase().as_str() {
|
||||
"content-type" | "content-length" | "authorization" | "user-agent" | "accept" => true,
|
||||
"host" | "connection" | "upgrade" => false, // Don't forward connection-specific headers
|
||||
_ => true, // Forward other headers by default
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine if a response header should be forwarded back to the client
|
||||
fn should_forward_response_header(header_name: &str) -> bool {
|
||||
match header_name.to_lowercase().as_str() {
|
||||
"content-type" | "content-length" | "cache-control" | "connection" => true,
|
||||
"server" | "date" => false, // Don't forward server-specific headers
|
||||
_ => true, // Forward other headers by default
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{ServerMode, Services};
|
||||
|
||||
#[test]
|
||||
fn test_should_forward_header() {
|
||||
assert!(should_forward_header("content-type"));
|
||||
assert!(should_forward_header("authorization"));
|
||||
assert!(!should_forward_header("host"));
|
||||
assert!(!should_forward_header("connection"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_forward_response_header() {
|
||||
assert!(should_forward_response_header("content-type"));
|
||||
assert!(should_forward_response_header("cache-control"));
|
||||
assert!(!should_forward_response_header("server"));
|
||||
assert!(!should_forward_response_header("date"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_proxy_client_creation() {
|
||||
let config = ServerConfig {
|
||||
server_host: "127.0.0.1".to_string(),
|
||||
server_port: 8080,
|
||||
server_mode: ServerMode::HighAvailability,
|
||||
services: Services {
|
||||
inference_url: "http://test-inference:8080".to_string(),
|
||||
embeddings_url: "http://test-embeddings:8080".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let proxy_client = ProxyClient::new(config);
|
||||
assert_eq!(proxy_client.config.inference_url(), "http://test-inference:8080");
|
||||
assert_eq!(proxy_client.config.embeddings_url(), "http://test-embeddings:8080");
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user