Refactor agent function names and streamline imports

Unified the naming convention for agent functions across modules to `agent` for consistency. Adjusted relevant imports and cleaned up unused imports in `webhooks.rs` to improve readability and maintainability.
This commit is contained in:
geoffsee
2025-05-27 13:05:05 -04:00
parent 3256f254ad
commit 77d5a51e76
6 changed files with 31 additions and 31 deletions

View File

@@ -1,7 +1,7 @@
use crate::utils::utils::run_agent; use crate::utils::utils::run_agent;
use tokio::process::Child; use tokio::process::Child;
pub async fn image_generator(stream_id: &str, input: &str) -> Result<Child, String> { pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
tracing::debug!( tracing::debug!(
"Running image generator, \ninput: {}", "Running image generator, \ninput: {}",
input input

View File

@@ -1,4 +1,4 @@
pub mod news; pub(crate) mod news;
pub mod scrape; pub(crate) mod scrape;
pub mod search; pub(crate) mod search;
pub mod image_generator; pub(crate) mod image_generator;

View File

@@ -1,6 +1,6 @@
use crate::utils::utils::run_agent; use crate::utils::utils::run_agent;
use tokio::process::Child; use tokio::process::Child;
pub async fn news_agent(stream_id: &str, input: &str) -> Result<Child, String> { pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/news-search.genai.mts", 10).await run_agent(stream_id, input, "./packages/genaiscript/genaisrc/news-search.genai.mts", 10).await
} }

View File

@@ -1,6 +1,6 @@
use crate::utils::utils::run_agent; use crate::utils::utils::run_agent;
use tokio::process::Child; use tokio::process::Child;
pub async fn scrape_agent(stream_id: &str, input: &str) -> Result<Child, String> { pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-scrape.genai.mts", 10).await run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-scrape.genai.mts", 10).await
} }

View File

@@ -3,7 +3,7 @@ use tracing;
use crate::utils::utils::run_agent; use crate::utils::utils::run_agent;
pub async fn search_agent(stream_id: &str, input: &str) -> Result<Child, String> { pub async fn agent(stream_id: &str, input: &str) -> Result<Child, String> {
run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-search.genai.mts", 10).await run_agent(stream_id, input, "./packages/genaiscript/genaisrc/web-search.genai.mts", 10).await
} }
@@ -11,13 +11,13 @@ pub async fn search_agent(stream_id: &str, input: &str) -> Result<Child, String>
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::fmt::Debug; use std::fmt::Debug;
use crate::agents::search::search_agent; use crate::agents::search::agent;
#[tokio::test] #[tokio::test]
async fn test_search_execution() { async fn test_search_execution() {
let input = "Who won the 2024 presidential election?"; let input = "Who won the 2024 presidential election?";
let mut command = search_agent("test-stream", input).await.unwrap(); let mut command = agent("test-stream", input).await.unwrap();
// command.stdout.take().unwrap().read_to_string(&mut String::new()).await.unwrap(); // command.stdout.take().unwrap().read_to_string(&mut String::new()).await.unwrap();
// Optionally, you can capture and inspect stdout if needed: // Optionally, you can capture and inspect stdout if needed:

View File

@@ -1,7 +1,3 @@
use crate::agents;
use crate::agents::news::news_agent;
use crate::agents::scrape::scrape_agent;
use crate::agents::search::search_agent;
use axum::response::Response; use axum::response::Response;
use axum::{ use axum::{
body::Body, extract::Path, extract::Query, http::StatusCode, response::IntoResponse, Json, body::Body, extract::Path, extract::Query, http::StatusCode, response::IntoResponse, Json,
@@ -18,7 +14,6 @@ use std::time::Duration;
use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command; use tokio::process::Command;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::agents::image_generator::image_generator;
// init sled // init sled
lazy_static! { lazy_static! {
@@ -31,7 +26,6 @@ pub async fn handle_webhooks(Path(agent_id): Path<String>) -> impl IntoResponse
let db = DB.lock().await; let db = DB.lock().await;
match db.get(&agent_id) { match db.get(&agent_id) {
Ok(Some(data)) => { Ok(Some(data)) => {
let mut info: StreamInfo = match serde_json::from_slice(&data) { let mut info: StreamInfo = match serde_json::from_slice(&data) {
Ok(info) => info, Ok(info) => info,
Err(e) => { Err(e) => {
@@ -40,7 +34,6 @@ pub async fn handle_webhooks(Path(agent_id): Path<String>) -> impl IntoResponse
} }
}; };
// Increment the call_count in the database // Increment the call_count in the database
info.call_count += 1; info.call_count += 1;
let updated_info_bytes = match serde_json::to_vec(&info) { let updated_info_bytes = match serde_json::to_vec(&info) {
@@ -54,7 +47,10 @@ pub async fn handle_webhooks(Path(agent_id): Path<String>) -> impl IntoResponse
match db.insert(&agent_id, updated_info_bytes) { match db.insert(&agent_id, updated_info_bytes) {
Ok(_) => { Ok(_) => {
if let Err(e) = db.flush_async().await { if let Err(e) = db.flush_async().await {
tracing::error!("Failed to persist updated call_count to the database: {}", e); tracing::error!(
"Failed to persist updated call_count to the database: {}",
e
);
return StatusCode::INTERNAL_SERVER_ERROR.into_response(); return StatusCode::INTERNAL_SERVER_ERROR.into_response();
} }
} }
@@ -96,10 +92,12 @@ pub async fn handle_webhooks(Path(agent_id): Path<String>) -> impl IntoResponse
); );
let cmd = match resource.as_str() { let cmd = match resource.as_str() {
"web-search" => search_agent(agent_id.as_str(), &*input).await, "web-search" => crate::agents::search::agent(agent_id.as_str(), &*input).await,
"news-search" => news_agent(agent_id.as_str(), &*input).await, "news-search" => crate::agents::news::agent(agent_id.as_str(), &*input).await,
"image-generator" => image_generator(agent_id.as_str(), &*input).await, "image-generator" => {
"web-scrape" => scrape_agent(agent_id.as_str(), &*input).await, crate::agents::image_generator::agent(agent_id.as_str(), &*input).await
}
"web-scrape" => crate::agents::scrape::agent(agent_id.as_str(), &*input).await,
_ => { _ => {
tracing::error!("Unsupported resource type: {}", resource); tracing::error!("Unsupported resource type: {}", resource);
return StatusCode::BAD_REQUEST.into_response(); return StatusCode::BAD_REQUEST.into_response();
@@ -131,7 +129,7 @@ pub async fn handle_webhooks(Path(agent_id): Path<String>) -> impl IntoResponse
.header("Connection", "keep-alive") .header("Connection", "keep-alive")
.header("X-Accel-Buffering", "yes") .header("X-Accel-Buffering", "yes")
.body(Body::from_stream(sse_stream)) .body(Body::from_stream(sse_stream))
.unwrap() .unwrap();
} }
Ok(None) => { Ok(None) => {
tracing::error!("Stream ID not found: {}", agent_id); tracing::error!("Stream ID not found: {}", agent_id);
@@ -183,7 +181,6 @@ struct StreamInfo {
call_count: i32, call_count: i32,
} }
#[derive(Deserialize, Serialize, Debug)] #[derive(Deserialize, Serialize, Debug)]
pub struct WebhookPostRequest { pub struct WebhookPostRequest {
id: String, id: String,
@@ -207,7 +204,7 @@ pub async fn handle_webhooks_post(Json(payload): Json<WebhookPostRequest>) -> im
resource: payload.resource.clone(), resource: payload.resource.clone(),
payload: payload.payload, payload: payload.payload,
parent: payload.parent.clone(), parent: payload.parent.clone(),
call_count: 0 call_count: 0,
}; };
let info_bytes = match serde_json::to_vec(&info) { let info_bytes = match serde_json::to_vec(&info) {
@@ -232,19 +229,22 @@ pub async fn handle_webhooks_post(Json(payload): Json<WebhookPostRequest>) -> im
match db.get(&stream_id) { match db.get(&stream_id) {
Ok(Some(_)) => { Ok(Some(_)) => {
let stream_url = format!("/webhooks/{}", stream_id); let stream_url = format!("/webhooks/{}", stream_id);
tracing::info!("Successfully created and verified stream URL: {}", stream_url); tracing::info!(
"Successfully created and verified stream URL: {}",
stream_url
);
Json(WebhookPostResponse { stream_url }).into_response() Json(WebhookPostResponse { stream_url }).into_response()
}, }
Ok(None) => { Ok(None) => {
tracing::error!("Failed to verify stream creation: {}", stream_id); tracing::error!("Failed to verify stream creation: {}", stream_id);
StatusCode::INTERNAL_SERVER_ERROR.into_response() StatusCode::INTERNAL_SERVER_ERROR.into_response()
}, }
Err(e) => { Err(e) => {
tracing::error!("Error verifying stream creation: {}", e); tracing::error!("Error verifying stream creation: {}", e);
StatusCode::INTERNAL_SERVER_ERROR.into_response() StatusCode::INTERNAL_SERVER_ERROR.into_response()
} }
} }
}, }
Err(e) => { Err(e) => {
tracing::error!("Failed to flush DB: {}", e); tracing::error!("Failed to flush DB: {}", e);
StatusCode::INTERNAL_SERVER_ERROR.into_response() StatusCode::INTERNAL_SERVER_ERROR.into_response()