Remove session-based identification and refactor routing

Eliminated `session_identify.rs` and related session-based logic to streamline the codebase. Refactored webhooks routes to use `agent_id` instead of `stream_id` for improved clarity. Adjusted configuration and dependencies to align with these changes.
This commit is contained in:
geoffsee
2025-05-27 12:48:34 -04:00
parent 07b76723c2
commit 2ea92c2ef1
5 changed files with 19 additions and 159 deletions

View File

@@ -1,4 +1,3 @@
// src/config.rs
pub struct AppConfig { pub struct AppConfig {
pub env_vars: Vec<String>, pub env_vars: Vec<String>,
} }
@@ -6,7 +5,7 @@ pub struct AppConfig {
impl AppConfig { impl AppConfig {
pub fn new() -> Self { pub fn new() -> Self {
// Load .env file if it exists // automatic configuration between local/docker environments
match dotenv::dotenv() { match dotenv::dotenv() {
Ok(_) => tracing::debug!("Loaded .env file successfully"), Ok(_) => tracing::debug!("Loaded .env file successfully"),
Err(e) => tracing::debug!("No .env file found or error loading it: {}", e), Err(e) => tracing::debug!("No .env file found or error loading it: {}", e),
@@ -15,8 +14,6 @@ impl AppConfig {
Self { Self {
env_vars: vec![ env_vars: vec![
"OPENAI_API_KEY".to_string(), "OPENAI_API_KEY".to_string(),
"BING_SEARCH_API_KEY".to_string(),
"TAVILY_API_KEY".to_string(),
"GENAISCRIPT_MODEL_LARGE".to_string(), "GENAISCRIPT_MODEL_LARGE".to_string(),
"GENAISCRIPT_MODEL_SMALL".to_string(), "GENAISCRIPT_MODEL_SMALL".to_string(),
"SEARXNG_API_BASE_URL".to_string(), "SEARXNG_API_BASE_URL".to_string(),

View File

@@ -27,9 +27,9 @@ lazy_static! {
)); ));
} }
pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse { pub async fn handle_webhooks(Path(agent_id): Path<String>) -> impl IntoResponse {
let db = DB.lock().await; let db = DB.lock().await;
match db.get(&stream_id) { match db.get(&agent_id) {
Ok(Some(data)) => { Ok(Some(data)) => {
let mut info: StreamInfo = match serde_json::from_slice(&data) { let mut info: StreamInfo = match serde_json::from_slice(&data) {
@@ -51,7 +51,7 @@ pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse
} }
}; };
match db.insert(&stream_id, updated_info_bytes) { match db.insert(&agent_id, updated_info_bytes) {
Ok(_) => { Ok(_) => {
if let Err(e) = db.flush_async().await { if let Err(e) = db.flush_async().await {
tracing::error!("Failed to persist updated call_count to the database: {}", e); tracing::error!("Failed to persist updated call_count to the database: {}", e);
@@ -64,7 +64,7 @@ pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse
} }
}; };
let info: StreamInfo = match db.get(&stream_id) { let info: StreamInfo = match db.get(&agent_id) {
Ok(Some(updated_data)) => match serde_json::from_slice(&updated_data) { Ok(Some(updated_data)) => match serde_json::from_slice(&updated_data) {
Ok(info) => info, Ok(info) => info,
Err(e) => { Err(e) => {
@@ -73,7 +73,7 @@ pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse
} }
}, },
Ok(None) => { Ok(None) => {
tracing::error!("Stream ID not found after update: {}", stream_id); tracing::error!("Stream ID not found after update: {}", agent_id);
return StatusCode::NOT_FOUND.into_response(); return StatusCode::NOT_FOUND.into_response();
} }
Err(e) => { Err(e) => {
@@ -92,14 +92,14 @@ pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse
tracing::debug!( tracing::debug!(
"Processing webhook - Resource: {}, Stream ID: {}", "Processing webhook - Resource: {}, Stream ID: {}",
resource, resource,
stream_id agent_id
); );
let cmd = match resource.as_str() { let cmd = match resource.as_str() {
"web-search" => search_agent(stream_id.as_str(), &*input).await, "web-search" => search_agent(agent_id.as_str(), &*input).await,
"news-search" => news_agent(stream_id.as_str(), &*input).await, "news-search" => news_agent(agent_id.as_str(), &*input).await,
"image-generator" => image_generator(stream_id.as_str(), &*input).await, "image-generator" => image_generator(agent_id.as_str(), &*input).await,
"web-scrape" => scrape_agent(stream_id.as_str(), &*input).await, "web-scrape" => scrape_agent(agent_id.as_str(), &*input).await,
_ => { _ => {
tracing::error!("Unsupported resource type: {}", resource); tracing::error!("Unsupported resource type: {}", resource);
return StatusCode::BAD_REQUEST.into_response(); return StatusCode::BAD_REQUEST.into_response();
@@ -123,7 +123,7 @@ pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse
}; };
let reader = BufReader::new(stdout); let reader = BufReader::new(stdout);
let sse_stream = reader_to_stream(reader, stream_id.clone()); let sse_stream = reader_to_stream(reader, agent_id.clone());
return Response::builder() return Response::builder()
.header("Content-Type", "text/event-stream") .header("Content-Type", "text/event-stream")
@@ -134,7 +134,7 @@ pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse
.unwrap() .unwrap()
} }
Ok(None) => { Ok(None) => {
tracing::error!("Stream ID not found: {}", stream_id); tracing::error!("Stream ID not found: {}", agent_id);
StatusCode::NOT_FOUND.into_response() StatusCode::NOT_FOUND.into_response()
} }
Err(e) => { Err(e) => {

View File

@@ -10,7 +10,6 @@ mod handlers;
mod agents; mod agents;
mod genaiscript; mod genaiscript;
mod utils; mod utils;
mod session_identify;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
@@ -18,7 +17,7 @@ async fn main() {
init_logging(); init_logging();
// init server configuration // init server configuration
let config = AppConfig::new(); let _ = AppConfig::new();
// Create router with all routes // Create router with all routes
let app = create_router(); let app = create_router();

View File

@@ -1,105 +1,26 @@
use crate::handlers::webhooks::handle_webhooks_post; use crate::handlers::webhooks::handle_webhooks_post;
use crate::handlers::{ use crate::handlers::{error::handle_not_found, ui::serve_ui, webhooks::handle_webhooks};
error::handle_not_found,
ui::serve_ui
,
webhooks::handle_webhooks,
};
use crate::session_identify::session_identify;
use axum::extract::Request;
use axum::response::Response;
use axum::routing::post; use axum::routing::post;
// src/routes.rs
use axum::routing::{get, Router}; use axum::routing::{get, Router};
use http::header::AUTHORIZATION;
use http::StatusCode;
use serde::{Deserialize, Serialize};
use serde_json::Number;
use std::fmt;
use tower_http::trace::{self, TraceLayer}; use tower_http::trace::{self, TraceLayer};
use tracing::Level; use tracing::Level;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CurrentUser {
pub(crate) sub: String,
pub name: String,
pub email: String,
pub exp: Number,
pub id: String,
pub aud: String,
}
impl fmt::Display for CurrentUser {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"CurrentUser {{ id: {}, name: {}, email: {}, sub: {}, aud: {}, exp: {} }}",
self.id, self.name, self.email, self.sub, self.aud, self.exp
)
}
}
pub fn create_router() -> Router { pub fn create_router() -> Router {
Router::new() Router::new()
.route("/", get(serve_ui)) .route("/", get(serve_ui))
// request a stream resource // create an agent
.route("/api/webhooks", post(handle_webhooks_post)) .route("/api/agents", post(handle_webhooks_post))
// consume a stream resource // connect the agent
.route("/webhooks/:stream_id", get(handle_webhooks)) .route("/agents/:agent_id", get(handle_webhooks))
// .route_layer(axum::middleware::from_fn(auth)) // uncomment to implement your own auth
.route("/health", get(health)) .route("/health", get(health))
.layer( .layer(
TraceLayer::new_for_http() TraceLayer::new_for_http()
.make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO)) .make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO))
.on_response(trace::DefaultOnResponse::new().level(Level::INFO)), .on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
) )
// left for smoke testing
// .route("/api/status", get(handle_status))
.fallback(handle_not_found) .fallback(handle_not_found)
} }
async fn health() -> String { async fn health() -> String {
return "ok".to_string(); return "ok".to_string();
} }
async fn auth(mut req: Request, next: axum::middleware::Next) -> Result<Response, StatusCode> {
let session_token_header = req
.headers()
.get(AUTHORIZATION)
.and_then(|header_value| header_value.to_str().ok());
let session_token_parts= session_token_header.expect("No credentials").split(" ").collect::<Vec<&str>>();
let session_token = session_token_parts.get(1);
// log::info!("session_token: {:?}", session_token);
let session_token = session_token.expect("Unauthorized: No credentials supplied");
let result =
if let Some(current_user) = authorize_current_user(&*session_token).await {
// info!("current user: {}", current_user);
// insert the current user into a request extension so the handler can
// extract it
req.extensions_mut().insert(current_user);
Ok(next.run(req).await)
} else {
Err(StatusCode::UNAUTHORIZED)
};
result
}
async fn authorize_current_user(
session_token: &str,
) -> Option<CurrentUser> {
let session_identity = session_identify(session_token)
.await
.unwrap();
// println!("current_user: {:?}", session_identity.user);
Some(serde_json::from_value::<CurrentUser>(session_identity.user).unwrap())
}

View File

@@ -1,57 +0,0 @@
use anyhow::Result;
use serde_json::Value;
use serde_json::json;
use base64::Engine;
use fips204::ml_dsa_44::{PrivateKey, PublicKey};
use fips204::traits::{SerDes, Signer, Verifier};
use crate::utils::base64::B64_ENCODER;
pub struct SessionIdentity {
pub message: String,
pub signature: String,
pub target: String,
pub session_id: String,
pub user: Value
}
// for a production setup, use a 3rd party host to verify the signature
// I removed in this version because the identity server I built is not open source yet
pub async fn session_identify(session_token: &str) -> Result<SessionIdentity> {
let session_data_base64 = session_token.split('.').nth(0).ok_or_else(|| anyhow::anyhow!("Invalid session data format"))?;
// println!("session_data_base64: {}", session_data_base64);
let session_data: Value = serde_json::de::from_slice(&*B64_ENCODER.b64_decode_payload(session_data_base64).map_err(|e| anyhow::anyhow!("Failed to decode session data: {}", e))?).map_err(|e| anyhow::anyhow!("Failed to parse session data: {}", e))?;
// println!("session_data: {:?}", session_data);
let signature_base64 = session_token.split('.').nth(1).ok_or_else(|| anyhow::anyhow!("Invalid session token format"))?;
// println!("signature_base64: {}", signature_base64);
let target = session_data.get("aud")
.and_then(|e| e.as_str())
.ok_or_else(|| anyhow::anyhow!("Session data missing audience"))?;
let target = target.parse::<String>().map_err(|e| anyhow::anyhow!("Failed to parse target to String: {}", e))?;
let session_id = session_data.get("id")
.and_then(|e| e.as_str())
.ok_or_else(|| anyhow::anyhow!("Session data missing id"))?;
let session_id = session_id.parse::<String>().map_err(|e| anyhow::anyhow!("Failed to parse session_id to String: {}", e))?;
// let request_payload: Value = json!({
// "message": session_data_base64,
// "signature": signature_base64,
// "target": target,
// "session_id": session_id,
// });
let result = SessionIdentity {
message: session_data_base64.to_string(),
signature: signature_base64.to_string(),
target,
session_id,
user: session_data.clone()
};
Ok(result)
}