mirror of
https://github.com/geoffsee/predict-otron-9001.git
synced 2025-09-08 22:46:44 +00:00
move cli into crates and stage for release
This commit is contained in:
434
crates/predict-otron-9000/src/ha_mode.rs
Normal file
434
crates/predict-otron-9000/src/ha_mode.rs
Normal file
@@ -0,0 +1,434 @@
|
||||
use axum::{
|
||||
Router,
|
||||
body::Body,
|
||||
extract::{Request, State},
|
||||
http::{HeaderMap, Method, StatusCode, Uri},
|
||||
response::{IntoResponse, Response},
|
||||
routing::{get, post},
|
||||
};
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::config::ServerConfig;
|
||||
|
||||
/// # Generating `SERVER_CONFIG` for TOML using Node.js
|
||||
///
|
||||
/// You can still use the Node.js REPL to build the JSON, but when pasting into
|
||||
/// a `.toml` file you must follow TOML's string rules. Below are the safest patterns.
|
||||
///
|
||||
/// ## 1) Generate the JSON in Node
|
||||
/// ```bash
|
||||
/// node
|
||||
/// ```
|
||||
/// ```javascript
|
||||
/// const myobject = {
|
||||
/// serverMode: "HighAvailability",
|
||||
/// services: {
|
||||
/// inference_url: "http://custom-inference:9000",
|
||||
/// embeddings_url: "http://custom-embeddings:9001"
|
||||
/// }
|
||||
/// };
|
||||
/// const json = JSON.stringify(myobject);
|
||||
/// json
|
||||
/// // -> '{"serverMode":"HighAvailability","services":{"inference_url":"http://custom-inference:9000","embeddings_url":"http://custom-embeddings:9001"}}'
|
||||
/// ```
|
||||
///
|
||||
/// ## 2) Put it into `.toml`
|
||||
///
|
||||
/// ### Option A (recommended): single-quoted TOML *literal* string
|
||||
/// Single quotes in TOML mean "no escaping", so your inner double quotes are safe.
|
||||
/// ```toml
|
||||
/// SERVER_CONFIG = '{"serverMode":"HighAvailability","services":{"inference_url":"http://custom-inference:9000","embeddings_url":"http://custom-embeddings:9001"}}'
|
||||
/// ```
|
||||
///
|
||||
/// ### Option B: double-quoted TOML string (must escape inner quotes)
|
||||
/// If you *must* use double quotes in TOML, escape all `"` inside the JSON.
|
||||
/// You can have Node do this for you:
|
||||
/// ```javascript
|
||||
/// // In Node:
|
||||
/// const jsonForToml = JSON.stringify(myobject).replace(/"/g, '\\"');
|
||||
/// jsonForToml
|
||||
/// // -> \"{\\\"serverMode\\\":\\\"HighAvailability\\\",...}\"
|
||||
/// ```
|
||||
/// Then paste into TOML:
|
||||
/// ```toml
|
||||
/// SERVER_CONFIG = "{\"serverMode\":\"HighAvailability\",\"services\":{\"inference_url\":\"http://custom-inference:9000\",\"embeddings_url\":\"http://custom-embeddings:9001\"}}"
|
||||
/// ```
|
||||
///
|
||||
/// ### Option C: multi-line literal (for pretty JSON)
|
||||
/// If you want pretty-printed JSON in the file, use TOML's triple single quotes:
|
||||
/// ```javascript
|
||||
/// // In Node (pretty with 2 spaces):
|
||||
/// const pretty = JSON.stringify(myobject, null, 2);
|
||||
/// ```
|
||||
/// ```toml
|
||||
/// SERVER_CONFIG = '''{
|
||||
/// "serverMode": "HighAvailability",
|
||||
/// "services": {
|
||||
/// "inference_url": "http://custom-inference:9000",
|
||||
/// "embeddings_url": "http://custom-embeddings:9001"
|
||||
/// }
|
||||
/// }'''
|
||||
/// ```
|
||||
///
|
||||
/// ## 3) Reading it in Rust
|
||||
///
|
||||
/// If `SERVER_CONFIG` is stored as a **string** in TOML (Options A/B/C):
|
||||
/// ```rust
|
||||
/// use serde_json::Value;
|
||||
///
|
||||
/// // Suppose you've already loaded your .toml into a struct or a toml::Value:
|
||||
/// // e.g., struct FileCfg { pub SERVER_CONFIG: String }
|
||||
/// fn parse_server_config(raw: &str) -> anyhow::Result<Value> {
|
||||
/// let v: Value = serde_json::from_str(raw)?;
|
||||
/// Ok(v)
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// ### Alternative: store it as TOML tables and serialize to JSON at runtime
|
||||
/// Instead of a JSON string, you can make the TOML first-class tables:
|
||||
/// ```toml
|
||||
/// [SERVER_CONFIG]
|
||||
/// serverMode = "HighAvailability"
|
||||
///
|
||||
/// [SERVER_CONFIG.services]
|
||||
/// inference_url = "http://custom-inference:9000"
|
||||
/// embeddings_url = "http://custom-embeddings:9001"
|
||||
/// ```
|
||||
/// ```rust
|
||||
/// use serde::{Deserialize, Serialize};
|
||||
/// use serde_json::Value;
|
||||
///
|
||||
/// #[derive(Debug, Serialize, Deserialize)]
|
||||
/// struct Services {
|
||||
/// inference_url: String,
|
||||
/// embeddings_url: String,
|
||||
/// }
|
||||
///
|
||||
/// #[derive(Debug, Serialize, Deserialize)]
|
||||
/// struct ServerConfig {
|
||||
/// serverMode: String,
|
||||
/// services: Services,
|
||||
/// }
|
||||
///
|
||||
/// // After loading the .toml (e.g., via `toml::from_str`):
|
||||
/// // let cfg: ServerConfig = toml::from_str(toml_str)?;
|
||||
/// // Convert to JSON if needed:
|
||||
/// fn to_json(cfg: &ServerConfig) -> serde_json::Result<Value> {
|
||||
/// Ok(serde_json::to_value(cfg)?)
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// ## Gotchas
|
||||
/// - Prefer **single-quoted** TOML strings for raw JSON to avoid escaping.
|
||||
/// - If you use **double-quoted** TOML strings, escape every inner `"` in the JSON.
|
||||
/// - Pretty JSON is fine in TOML using `''' ... '''`, but remember the newlines are part of the string.
|
||||
/// - If you control the consumer, TOML tables (the alternative above) are more ergonomic than embedding JSON.
|
||||
|
||||
/// HTTP client configured for proxying requests
|
||||
#[derive(Clone)]
|
||||
pub struct ProxyClient {
|
||||
client: Client,
|
||||
config: ServerConfig,
|
||||
}
|
||||
|
||||
impl ProxyClient {
|
||||
pub fn new(config: ServerConfig) -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(300)) // 5 minute timeout for long-running inference
|
||||
.build()
|
||||
.expect("Failed to create HTTP client for proxy");
|
||||
|
||||
Self { client, config }
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a router that proxies requests to external services in HighAvailability mode
|
||||
pub fn create_ha_router(config: ServerConfig) -> Router {
|
||||
let proxy_client = ProxyClient::new(config.clone());
|
||||
|
||||
Router::new()
|
||||
.route("/v1/chat/completions", post(proxy_chat_completions))
|
||||
.route("/v1/models", get(proxy_models))
|
||||
.route("/v1/embeddings", post(proxy_embeddings))
|
||||
.with_state(proxy_client)
|
||||
}
|
||||
|
||||
/// Proxy handler for POST /v1/chat/completions
|
||||
async fn proxy_chat_completions(
|
||||
State(proxy_client): State<ProxyClient>,
|
||||
headers: HeaderMap,
|
||||
body: Body,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let target_url = format!(
|
||||
"{}/v1/chat/completions",
|
||||
proxy_client
|
||||
.config
|
||||
.inference_url()
|
||||
.expect("Invalid Configuration")
|
||||
);
|
||||
|
||||
tracing::info!("Proxying chat completions request to: {}", target_url);
|
||||
|
||||
// Extract body as bytes
|
||||
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read request body: {}", e);
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
};
|
||||
|
||||
// Check if this is a streaming request
|
||||
let is_streaming = if let Ok(body_str) = String::from_utf8(body_bytes.to_vec()) {
|
||||
if let Ok(json) = serde_json::from_str::<Value>(&body_str) {
|
||||
json.get("stream")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
// Forward the request
|
||||
let mut req_builder = proxy_client
|
||||
.client
|
||||
.post(&target_url)
|
||||
.body(body_bytes.to_vec());
|
||||
|
||||
// Forward relevant headers
|
||||
for (name, value) in headers.iter() {
|
||||
if should_forward_header(name.as_str()) {
|
||||
req_builder = req_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
let mut resp_builder = Response::builder().status(response.status());
|
||||
|
||||
// Forward response headers
|
||||
for (name, value) in response.headers().iter() {
|
||||
if should_forward_response_header(name.as_str()) {
|
||||
resp_builder = resp_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle streaming vs non-streaming responses
|
||||
if is_streaming {
|
||||
// For streaming, we need to forward the response as-is
|
||||
match response.bytes().await {
|
||||
Ok(body) => resp_builder
|
||||
.header("content-type", "text/plain; charset=utf-8")
|
||||
.header("cache-control", "no-cache")
|
||||
.header("connection", "keep-alive")
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read streaming response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For non-streaming, forward the JSON response
|
||||
match response.bytes().await {
|
||||
Ok(body) => resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to proxy chat completions request: {}", e);
|
||||
Err(StatusCode::BAD_GATEWAY)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy handler for GET /v1/models
|
||||
async fn proxy_models(
|
||||
State(proxy_client): State<ProxyClient>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let target_url = format!(
|
||||
"{}/v1/models",
|
||||
proxy_client
|
||||
.config
|
||||
.inference_url()
|
||||
.expect("Invalid Configuration Detected")
|
||||
);
|
||||
|
||||
tracing::info!("Proxying models request to: {}", target_url);
|
||||
|
||||
let mut req_builder = proxy_client.client.get(&target_url);
|
||||
|
||||
// Forward relevant headers
|
||||
for (name, value) in headers.iter() {
|
||||
if should_forward_header(name.as_str()) {
|
||||
req_builder = req_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
let mut resp_builder = Response::builder().status(response.status());
|
||||
|
||||
// Forward response headers
|
||||
for (name, value) in response.headers().iter() {
|
||||
if should_forward_response_header(name.as_str()) {
|
||||
resp_builder = resp_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match response.bytes().await {
|
||||
Ok(body) => resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read models response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to proxy models request: {}", e);
|
||||
Err(StatusCode::BAD_GATEWAY)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy handler for POST /v1/embeddings
|
||||
async fn proxy_embeddings(
|
||||
State(proxy_client): State<ProxyClient>,
|
||||
headers: HeaderMap,
|
||||
body: Body,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let target_url = format!(
|
||||
"{}/v1/embeddings",
|
||||
proxy_client
|
||||
.config
|
||||
.embeddings_url()
|
||||
.expect("Invalid Configuration Detected")
|
||||
);
|
||||
|
||||
tracing::info!("Proxying embeddings request to: {}", target_url);
|
||||
|
||||
// Extract body as bytes
|
||||
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read request body: {}", e);
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
};
|
||||
|
||||
// Forward the request
|
||||
let mut req_builder = proxy_client
|
||||
.client
|
||||
.post(&target_url)
|
||||
.body(body_bytes.to_vec());
|
||||
|
||||
// Forward relevant headers
|
||||
for (name, value) in headers.iter() {
|
||||
if should_forward_header(name.as_str()) {
|
||||
req_builder = req_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
let mut resp_builder = Response::builder().status(response.status());
|
||||
|
||||
// Forward response headers
|
||||
for (name, value) in response.headers().iter() {
|
||||
if should_forward_response_header(name.as_str()) {
|
||||
resp_builder = resp_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match response.bytes().await {
|
||||
Ok(body) => resp_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to read embeddings response body: {}", e);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to proxy embeddings request: {}", e);
|
||||
Err(StatusCode::BAD_GATEWAY)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine if a request header should be forwarded to the target service
|
||||
fn should_forward_header(header_name: &str) -> bool {
|
||||
match header_name.to_lowercase().as_str() {
|
||||
"content-type" | "content-length" | "authorization" | "user-agent" | "accept" => true,
|
||||
"host" | "connection" | "upgrade" => false, // Don't forward connection-specific headers
|
||||
_ => true, // Forward other headers by default
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine if a response header should be forwarded back to the client
|
||||
fn should_forward_response_header(header_name: &str) -> bool {
|
||||
match header_name.to_lowercase().as_str() {
|
||||
"content-type" | "content-length" | "cache-control" | "connection" => true,
|
||||
"server" | "date" => false, // Don't forward server-specific headers
|
||||
_ => true, // Forward other headers by default
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{ServerMode, Services};
|
||||
|
||||
#[test]
|
||||
fn test_should_forward_header() {
|
||||
assert!(should_forward_header("content-type"));
|
||||
assert!(should_forward_header("authorization"));
|
||||
assert!(!should_forward_header("host"));
|
||||
assert!(!should_forward_header("connection"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_forward_response_header() {
|
||||
assert!(should_forward_response_header("content-type"));
|
||||
assert!(should_forward_response_header("cache-control"));
|
||||
assert!(!should_forward_response_header("server"));
|
||||
assert!(!should_forward_response_header("date"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_proxy_client_creation() {
|
||||
let config = ServerConfig {
|
||||
server_host: "127.0.0.1".to_string(),
|
||||
server_port: 8080,
|
||||
server_mode: ServerMode::HighAvailability,
|
||||
services: Some(Services {
|
||||
inference_url: Some("http://test-inference:8080".to_string()),
|
||||
embeddings_url: Some("http://test-embeddings:8080".to_string()),
|
||||
}),
|
||||
};
|
||||
|
||||
let proxy_client = ProxyClient::new(config);
|
||||
assert_eq!(
|
||||
proxy_client.config.inference_url().unwrap().as_str(),
|
||||
"http://test-inference:8080"
|
||||
);
|
||||
assert_eq!(
|
||||
proxy_client.config.embeddings_url().unwrap().as_str(),
|
||||
"http://test-embeddings:8080"
|
||||
);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user