Remove session-based identification and refactor routing
Eliminated `session_identify.rs` and related session-based logic to streamline the codebase. Refactored webhooks routes to use `agent_id` instead of `stream_id` for improved clarity. Adjusted configuration and dependencies to align with these changes.
This commit is contained in:
@@ -1,4 +1,3 @@
|
|||||||
// src/config.rs
|
|
||||||
pub struct AppConfig {
|
pub struct AppConfig {
|
||||||
pub env_vars: Vec<String>,
|
pub env_vars: Vec<String>,
|
||||||
}
|
}
|
||||||
@@ -6,7 +5,7 @@ pub struct AppConfig {
|
|||||||
|
|
||||||
impl AppConfig {
|
impl AppConfig {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
// Load .env file if it exists
|
// automatic configuration between local/docker environments
|
||||||
match dotenv::dotenv() {
|
match dotenv::dotenv() {
|
||||||
Ok(_) => tracing::debug!("Loaded .env file successfully"),
|
Ok(_) => tracing::debug!("Loaded .env file successfully"),
|
||||||
Err(e) => tracing::debug!("No .env file found or error loading it: {}", e),
|
Err(e) => tracing::debug!("No .env file found or error loading it: {}", e),
|
||||||
@@ -15,8 +14,6 @@ impl AppConfig {
|
|||||||
Self {
|
Self {
|
||||||
env_vars: vec![
|
env_vars: vec![
|
||||||
"OPENAI_API_KEY".to_string(),
|
"OPENAI_API_KEY".to_string(),
|
||||||
"BING_SEARCH_API_KEY".to_string(),
|
|
||||||
"TAVILY_API_KEY".to_string(),
|
|
||||||
"GENAISCRIPT_MODEL_LARGE".to_string(),
|
"GENAISCRIPT_MODEL_LARGE".to_string(),
|
||||||
"GENAISCRIPT_MODEL_SMALL".to_string(),
|
"GENAISCRIPT_MODEL_SMALL".to_string(),
|
||||||
"SEARXNG_API_BASE_URL".to_string(),
|
"SEARXNG_API_BASE_URL".to_string(),
|
||||||
|
@@ -27,9 +27,9 @@ lazy_static! {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse {
|
pub async fn handle_webhooks(Path(agent_id): Path<String>) -> impl IntoResponse {
|
||||||
let db = DB.lock().await;
|
let db = DB.lock().await;
|
||||||
match db.get(&stream_id) {
|
match db.get(&agent_id) {
|
||||||
Ok(Some(data)) => {
|
Ok(Some(data)) => {
|
||||||
|
|
||||||
let mut info: StreamInfo = match serde_json::from_slice(&data) {
|
let mut info: StreamInfo = match serde_json::from_slice(&data) {
|
||||||
@@ -51,7 +51,7 @@ pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match db.insert(&stream_id, updated_info_bytes) {
|
match db.insert(&agent_id, updated_info_bytes) {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
if let Err(e) = db.flush_async().await {
|
if let Err(e) = db.flush_async().await {
|
||||||
tracing::error!("Failed to persist updated call_count to the database: {}", e);
|
tracing::error!("Failed to persist updated call_count to the database: {}", e);
|
||||||
@@ -64,7 +64,7 @@ pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let info: StreamInfo = match db.get(&stream_id) {
|
let info: StreamInfo = match db.get(&agent_id) {
|
||||||
Ok(Some(updated_data)) => match serde_json::from_slice(&updated_data) {
|
Ok(Some(updated_data)) => match serde_json::from_slice(&updated_data) {
|
||||||
Ok(info) => info,
|
Ok(info) => info,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -73,7 +73,7 @@ pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
tracing::error!("Stream ID not found after update: {}", stream_id);
|
tracing::error!("Stream ID not found after update: {}", agent_id);
|
||||||
return StatusCode::NOT_FOUND.into_response();
|
return StatusCode::NOT_FOUND.into_response();
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -92,14 +92,14 @@ pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse
|
|||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
"Processing webhook - Resource: {}, Stream ID: {}",
|
"Processing webhook - Resource: {}, Stream ID: {}",
|
||||||
resource,
|
resource,
|
||||||
stream_id
|
agent_id
|
||||||
);
|
);
|
||||||
|
|
||||||
let cmd = match resource.as_str() {
|
let cmd = match resource.as_str() {
|
||||||
"web-search" => search_agent(stream_id.as_str(), &*input).await,
|
"web-search" => search_agent(agent_id.as_str(), &*input).await,
|
||||||
"news-search" => news_agent(stream_id.as_str(), &*input).await,
|
"news-search" => news_agent(agent_id.as_str(), &*input).await,
|
||||||
"image-generator" => image_generator(stream_id.as_str(), &*input).await,
|
"image-generator" => image_generator(agent_id.as_str(), &*input).await,
|
||||||
"web-scrape" => scrape_agent(stream_id.as_str(), &*input).await,
|
"web-scrape" => scrape_agent(agent_id.as_str(), &*input).await,
|
||||||
_ => {
|
_ => {
|
||||||
tracing::error!("Unsupported resource type: {}", resource);
|
tracing::error!("Unsupported resource type: {}", resource);
|
||||||
return StatusCode::BAD_REQUEST.into_response();
|
return StatusCode::BAD_REQUEST.into_response();
|
||||||
@@ -123,7 +123,7 @@ pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse
|
|||||||
};
|
};
|
||||||
|
|
||||||
let reader = BufReader::new(stdout);
|
let reader = BufReader::new(stdout);
|
||||||
let sse_stream = reader_to_stream(reader, stream_id.clone());
|
let sse_stream = reader_to_stream(reader, agent_id.clone());
|
||||||
|
|
||||||
return Response::builder()
|
return Response::builder()
|
||||||
.header("Content-Type", "text/event-stream")
|
.header("Content-Type", "text/event-stream")
|
||||||
@@ -134,7 +134,7 @@ pub async fn handle_webhooks(Path(stream_id): Path<String>) -> impl IntoResponse
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
tracing::error!("Stream ID not found: {}", stream_id);
|
tracing::error!("Stream ID not found: {}", agent_id);
|
||||||
StatusCode::NOT_FOUND.into_response()
|
StatusCode::NOT_FOUND.into_response()
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
@@ -10,7 +10,6 @@ mod handlers;
|
|||||||
mod agents;
|
mod agents;
|
||||||
mod genaiscript;
|
mod genaiscript;
|
||||||
mod utils;
|
mod utils;
|
||||||
mod session_identify;
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
@@ -18,7 +17,7 @@ async fn main() {
|
|||||||
init_logging();
|
init_logging();
|
||||||
|
|
||||||
// init server configuration
|
// init server configuration
|
||||||
let config = AppConfig::new();
|
let _ = AppConfig::new();
|
||||||
|
|
||||||
// Create router with all routes
|
// Create router with all routes
|
||||||
let app = create_router();
|
let app = create_router();
|
||||||
|
@@ -1,105 +1,26 @@
|
|||||||
use crate::handlers::webhooks::handle_webhooks_post;
|
use crate::handlers::webhooks::handle_webhooks_post;
|
||||||
use crate::handlers::{
|
use crate::handlers::{error::handle_not_found, ui::serve_ui, webhooks::handle_webhooks};
|
||||||
error::handle_not_found,
|
|
||||||
ui::serve_ui
|
|
||||||
,
|
|
||||||
webhooks::handle_webhooks,
|
|
||||||
};
|
|
||||||
use crate::session_identify::session_identify;
|
|
||||||
use axum::extract::Request;
|
|
||||||
use axum::response::Response;
|
|
||||||
use axum::routing::post;
|
use axum::routing::post;
|
||||||
// src/routes.rs
|
|
||||||
use axum::routing::{get, Router};
|
use axum::routing::{get, Router};
|
||||||
use http::header::AUTHORIZATION;
|
|
||||||
use http::StatusCode;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::Number;
|
|
||||||
use std::fmt;
|
|
||||||
use tower_http::trace::{self, TraceLayer};
|
use tower_http::trace::{self, TraceLayer};
|
||||||
use tracing::Level;
|
use tracing::Level;
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
||||||
pub struct CurrentUser {
|
|
||||||
pub(crate) sub: String,
|
|
||||||
pub name: String,
|
|
||||||
pub email: String,
|
|
||||||
pub exp: Number,
|
|
||||||
pub id: String,
|
|
||||||
pub aud: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for CurrentUser {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
write!(
|
|
||||||
f,
|
|
||||||
"CurrentUser {{ id: {}, name: {}, email: {}, sub: {}, aud: {}, exp: {} }}",
|
|
||||||
self.id, self.name, self.email, self.sub, self.aud, self.exp
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn create_router() -> Router {
|
pub fn create_router() -> Router {
|
||||||
|
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/", get(serve_ui))
|
.route("/", get(serve_ui))
|
||||||
// request a stream resource
|
// create an agent
|
||||||
.route("/api/webhooks", post(handle_webhooks_post))
|
.route("/api/agents", post(handle_webhooks_post))
|
||||||
// consume a stream resource
|
// connect the agent
|
||||||
.route("/webhooks/:stream_id", get(handle_webhooks))
|
.route("/agents/:agent_id", get(handle_webhooks))
|
||||||
// .route_layer(axum::middleware::from_fn(auth)) // uncomment to implement your own auth
|
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.layer(
|
.layer(
|
||||||
TraceLayer::new_for_http()
|
TraceLayer::new_for_http()
|
||||||
.make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO))
|
.make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO))
|
||||||
.on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
|
.on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
|
||||||
)
|
)
|
||||||
// left for smoke testing
|
|
||||||
// .route("/api/status", get(handle_status))
|
|
||||||
.fallback(handle_not_found)
|
.fallback(handle_not_found)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health() -> String {
|
async fn health() -> String {
|
||||||
return "ok".to_string();
|
return "ok".to_string();
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn auth(mut req: Request, next: axum::middleware::Next) -> Result<Response, StatusCode> {
|
|
||||||
let session_token_header = req
|
|
||||||
.headers()
|
|
||||||
.get(AUTHORIZATION)
|
|
||||||
.and_then(|header_value| header_value.to_str().ok());
|
|
||||||
|
|
||||||
let session_token_parts= session_token_header.expect("No credentials").split(" ").collect::<Vec<&str>>();
|
|
||||||
|
|
||||||
let session_token = session_token_parts.get(1);
|
|
||||||
|
|
||||||
|
|
||||||
// log::info!("session_token: {:?}", session_token);
|
|
||||||
|
|
||||||
let session_token = session_token.expect("Unauthorized: No credentials supplied");
|
|
||||||
|
|
||||||
let result =
|
|
||||||
if let Some(current_user) = authorize_current_user(&*session_token).await {
|
|
||||||
// info!("current user: {}", current_user);
|
|
||||||
// insert the current user into a request extension so the handler can
|
|
||||||
// extract it
|
|
||||||
req.extensions_mut().insert(current_user);
|
|
||||||
Ok(next.run(req).await)
|
|
||||||
} else {
|
|
||||||
Err(StatusCode::UNAUTHORIZED)
|
|
||||||
};
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async fn authorize_current_user(
|
|
||||||
session_token: &str,
|
|
||||||
) -> Option<CurrentUser> {
|
|
||||||
let session_identity = session_identify(session_token)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// println!("current_user: {:?}", session_identity.user);
|
|
||||||
|
|
||||||
Some(serde_json::from_value::<CurrentUser>(session_identity.user).unwrap())
|
|
||||||
}
|
|
@@ -1,57 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use serde_json::Value;
|
|
||||||
use serde_json::json;
|
|
||||||
use base64::Engine;
|
|
||||||
use fips204::ml_dsa_44::{PrivateKey, PublicKey};
|
|
||||||
use fips204::traits::{SerDes, Signer, Verifier};
|
|
||||||
use crate::utils::base64::B64_ENCODER;
|
|
||||||
|
|
||||||
pub struct SessionIdentity {
|
|
||||||
pub message: String,
|
|
||||||
pub signature: String,
|
|
||||||
pub target: String,
|
|
||||||
pub session_id: String,
|
|
||||||
pub user: Value
|
|
||||||
}
|
|
||||||
|
|
||||||
// for a production setup, use a 3rd party host to verify the signature
|
|
||||||
// I removed in this version because the identity server I built is not open source yet
|
|
||||||
pub async fn session_identify(session_token: &str) -> Result<SessionIdentity> {
|
|
||||||
let session_data_base64 = session_token.split('.').nth(0).ok_or_else(|| anyhow::anyhow!("Invalid session data format"))?;
|
|
||||||
// println!("session_data_base64: {}", session_data_base64);
|
|
||||||
let session_data: Value = serde_json::de::from_slice(&*B64_ENCODER.b64_decode_payload(session_data_base64).map_err(|e| anyhow::anyhow!("Failed to decode session data: {}", e))?).map_err(|e| anyhow::anyhow!("Failed to parse session data: {}", e))?;
|
|
||||||
// println!("session_data: {:?}", session_data);
|
|
||||||
|
|
||||||
|
|
||||||
let signature_base64 = session_token.split('.').nth(1).ok_or_else(|| anyhow::anyhow!("Invalid session token format"))?;
|
|
||||||
// println!("signature_base64: {}", signature_base64);
|
|
||||||
|
|
||||||
let target = session_data.get("aud")
|
|
||||||
.and_then(|e| e.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Session data missing audience"))?;
|
|
||||||
|
|
||||||
let target = target.parse::<String>().map_err(|e| anyhow::anyhow!("Failed to parse target to String: {}", e))?;
|
|
||||||
|
|
||||||
let session_id = session_data.get("id")
|
|
||||||
.and_then(|e| e.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Session data missing id"))?;
|
|
||||||
|
|
||||||
let session_id = session_id.parse::<String>().map_err(|e| anyhow::anyhow!("Failed to parse session_id to String: {}", e))?;
|
|
||||||
|
|
||||||
// let request_payload: Value = json!({
|
|
||||||
// "message": session_data_base64,
|
|
||||||
// "signature": signature_base64,
|
|
||||||
// "target": target,
|
|
||||||
// "session_id": session_id,
|
|
||||||
// });
|
|
||||||
|
|
||||||
let result = SessionIdentity {
|
|
||||||
message: session_data_base64.to_string(),
|
|
||||||
signature: signature_base64.to_string(),
|
|
||||||
target,
|
|
||||||
session_id,
|
|
||||||
user: session_data.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
Reference in New Issue
Block a user