add basic router tests

This commit is contained in:
geoffsee
2025-06-05 14:06:00 -04:00
committed by Geoff Seemueller
parent 72583e5f5b
commit 22ef371c5b
3 changed files with 73 additions and 11 deletions

1
Cargo.lock generated
View File

@@ -746,6 +746,7 @@ dependencies = [
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tokio-util", "tokio-util",
"tower",
"tower-http", "tower-http",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",

View File

@@ -28,6 +28,7 @@ bytes = "1.8.0"
lazy_static = "1.5.0" lazy_static = "1.5.0"
sled = "0.34.7" sled = "0.34.7"
tower-http = { version = "0.6.2", features = ["trace", "cors"] } tower-http = { version = "0.6.2", features = ["trace", "cors"] }
tower = "0.5.2"
anyhow = "1.0.97" anyhow = "1.0.97"
base64 = "0.22.1" base64 = "0.22.1"
fips204 = "0.4.6" fips204 = "0.4.6"

View File

@@ -6,26 +6,18 @@ use tracing::Level;
use rmcp::transport::streamable_http_server::{ use rmcp::transport::streamable_http_server::{
StreamableHttpService, session::local::LocalSessionManager, StreamableHttpService, session::local::LocalSessionManager,
}; };
use crate::counter::Counter;
use crate::agents::Agents; use crate::agents::Agents;
pub fn create_router() -> Router { pub fn create_router() -> Router {
let counter_service = StreamableHttpService::new( let mcp_service = StreamableHttpService::new(
Counter::new,
LocalSessionManager::default().into(),
Default::default(),
);
let agents_service = StreamableHttpService::new(
Agents::new, Agents::new,
LocalSessionManager::default().into(), LocalSessionManager::default().into(),
Default::default(), Default::default(),
); );
Router::new() Router::new()
.nest_service("/mcp/counter", counter_service) .nest_service("/mcp", mcp_service)
.nest_service("/mcp/agents", agents_service)
.route("/", get(serve_ui)) .route("/", get(serve_ui))
.route("/health", get(health)) .route("/health", get(health))
.layer( .layer(
@@ -40,3 +32,71 @@ pub fn create_router() -> Router {
async fn health() -> String { async fn health() -> String {
return "ok".to_string(); return "ok".to_string();
} }
#[cfg(test)]
mod tests {
use super::*;
use axum::body::{Body, Bytes};
use axum::http::{Request, StatusCode};
use axum::response::Response;
use tower::ServiceExt;
#[tokio::test]
async fn test_health_endpoint() {
// Call the health function directly
let response = health().await;
assert_eq!(response, "ok".to_string());
}
#[tokio::test]
async fn test_health_route() {
// Create the router
let app = create_router();
// Create a request to the health endpoint
let request = Request::builder()
.uri("/health")
.method("GET")
.body(Body::empty())
.unwrap();
// Process the request
let response = app.oneshot(request).await.unwrap();
// Check the response status
assert_eq!(response.status(), StatusCode::OK);
// Check the response body
let body = response_body_bytes(response).await;
assert_eq!(&body[..], b"ok");
}
#[tokio::test]
async fn test_not_found_route() {
// Create the router
let app = create_router();
// Create a request to a non-existent endpoint
let request = Request::builder()
.uri("/non-existent")
.method("GET")
.body(Body::empty())
.unwrap();
// Process the request
let response = app.oneshot(request).await.unwrap();
// Check the response status
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
// Helper function to extract bytes from a response body
async fn response_body_bytes(response: Response) -> Bytes {
let body = response.into_body();
// Use a reasonable size limit for the body (16MB)
let bytes = axum::body::to_bytes(body, 16 * 1024 * 1024)
.await
.expect("Failed to read response body");
bytes
}
}