//! Cookie-aware in-process HTTP client for integration tests. use axum::body::Body; use axum::extract::ConnectInfo; use axum::http::{header, Method, Request, StatusCode}; use axum::Router; use http_body_util::BodyExt; use std::collections::HashMap; use std::net::SocketAddr; use tower::ServiceExt; pub struct TestClient { app: Router, cookies: HashMap, csrf_token: Option, } impl TestClient { pub fn new(app: Router) -> Self { TestClient { app, cookies: HashMap::new(), csrf_token: None, } } pub async fn get(&mut self, uri: &str) -> TestResponse { self.send(Method::GET, uri, None, None).await } pub async fn post_form(&mut self, uri: &str, body: &str) -> TestResponse { self.send( Method::POST, uri, Some("application/x-www-form-urlencoded"), Some(body.to_string()), ) .await } /// POST without injecting the CSRF token. Used to test CSRF rejection. pub async fn post_form_no_csrf(&mut self, uri: &str, body: &str) -> TestResponse { self.send_raw( Method::POST, uri, Some("application/x-www-form-urlencoded"), Some(body.to_string()), false, ) .await } /// POST with a specific (wrong) CSRF token. pub async fn post_form_with_token( &mut self, uri: &str, body: &str, token: &str, ) -> TestResponse { self.send_raw_with_token( Method::POST, uri, Some("application/x-www-form-urlencoded"), Some(body.to_string()), token, ) .await } pub async fn post_json(&mut self, uri: &str, body: &str) -> TestResponse { self.send( Method::POST, uri, Some("application/json"), Some(body.to_string()), ) .await } pub async fn post_multipart( &mut self, uri: &str, file_data: &[u8], content_type: &str, filename: &str, ) -> TestResponse { let boundary = "----TestBoundary1234567890"; let mut body = Vec::new(); body.extend_from_slice(format!("--{boundary}\r\n").as_bytes()); body.extend_from_slice( format!( "Content-Disposition: form-data; name=\"file\"; filename=\"{filename}\"\r\n\ Content-Type: {content_type}\r\n\r\n" ) .as_bytes(), ); body.extend_from_slice(file_data); body.extend_from_slice(format!("\r\n--{boundary}--\r\n").as_bytes()); let ct = format!("multipart/form-data; boundary={boundary}"); self.send_bytes(Method::POST, uri, &ct, body).await } pub fn csrf_token(&self) -> Option<&str> { self.csrf_token.as_deref() } async fn send( &mut self, method: Method, uri: &str, content_type: Option<&str>, body: Option, ) -> TestResponse { self.send_raw(method, uri, content_type, body, true).await } async fn send_bytes( &mut self, method: Method, uri: &str, content_type: &str, body: Vec, ) -> TestResponse { let mut builder = Request::builder() .method(&method) .uri(uri) .header(header::CONTENT_TYPE, content_type); if matches!(method, Method::POST | Method::PUT | Method::PATCH | Method::DELETE) && let Some(ref token) = self.csrf_token { builder = builder.header("X-CSRF-Token", token.as_str()); } if !self.cookies.is_empty() { let cookie_header: String = self .cookies .iter() .map(|(k, v)| format!("{}={}", k, v)) .collect::>() .join("; "); builder = builder.header(header::COOKIE, cookie_header); } let mut request = builder.body(Body::from(body)).expect("Failed to build request"); request.extensions_mut().insert(ConnectInfo(SocketAddr::from(([127, 0, 0, 1], 0)))); let response = self.app.clone().oneshot(request).await.expect("Failed to send request"); let status = response.status(); let headers = response.headers().clone(); for value in headers.get_all(header::SET_COOKIE) { if let Ok(cookie_str) = value.to_str() && let Some(nv) = cookie_str.split(';').next() && let Some((name, val)) = nv.split_once('=') { self.cookies.insert(name.trim().to_string(), val.trim().to_string()); } } let body_bytes = response.into_body().collect().await .expect("Failed to read response body").to_bytes(); let text = String::from_utf8_lossy(&body_bytes).to_string(); if let Some(token) = extract_csrf_from_html(&text) { self.csrf_token = Some(token); } TestResponse { status, text, headers } } async fn send_raw( &mut self, method: Method, uri: &str, content_type: Option<&str>, body: Option, inject_csrf: bool, ) -> TestResponse { let body_data = body.unwrap_or_default(); let mut builder = Request::builder().method(&method).uri(uri); if let Some(ct) = content_type { builder = builder.header(header::CONTENT_TYPE, ct); } if inject_csrf && matches!(method, Method::POST | Method::PUT | Method::PATCH | Method::DELETE) && let Some(ref token) = self.csrf_token { builder = builder.header("X-CSRF-Token", token.as_str()); } if !self.cookies.is_empty() { let cookie_header: String = self .cookies .iter() .map(|(k, v)| format!("{}={}", k, v)) .collect::>() .join("; "); builder = builder.header(header::COOKIE, cookie_header); } let mut request = builder.body(Body::from(body_data)).expect("Failed to build request"); // Provide ConnectInfo so SmartIpKeyExtractor works in tests request.extensions_mut().insert(ConnectInfo(SocketAddr::from(([127, 0, 0, 1], 0)))); let response = self .app .clone() .oneshot(request) .await .expect("Failed to send request"); let status = response.status(); let headers = response.headers().clone(); for value in headers.get_all(header::SET_COOKIE) { if let Ok(cookie_str) = value.to_str() && let Some(nv) = cookie_str.split(';').next() && let Some((name, val)) = nv.split_once('=') { self.cookies .insert(name.trim().to_string(), val.trim().to_string()); } } let body_bytes = response .into_body() .collect() .await .expect("Failed to read response body") .to_bytes(); let text = String::from_utf8_lossy(&body_bytes).to_string(); if let Some(token) = extract_csrf_from_html(&text) { self.csrf_token = Some(token); } TestResponse { status, text, headers, } } async fn send_raw_with_token( &mut self, method: Method, uri: &str, content_type: Option<&str>, body: Option, token: &str, ) -> TestResponse { let body_data = body.unwrap_or_default(); let mut builder = Request::builder().method(&method).uri(uri); if let Some(ct) = content_type { builder = builder.header(header::CONTENT_TYPE, ct); } builder = builder.header("X-CSRF-Token", token); if !self.cookies.is_empty() { let cookie_header: String = self .cookies .iter() .map(|(k, v)| format!("{}={}", k, v)) .collect::>() .join("; "); builder = builder.header(header::COOKIE, cookie_header); } let mut request = builder.body(Body::from(body_data)).expect("Failed to build request"); request.extensions_mut().insert(ConnectInfo(SocketAddr::from(([127, 0, 0, 1], 0)))); let response = self .app .clone() .oneshot(request) .await .expect("Failed to send request"); let status = response.status(); let headers = response.headers().clone(); for value in headers.get_all(header::SET_COOKIE) { if let Ok(cookie_str) = value.to_str() && let Some(nv) = cookie_str.split(';').next() && let Some((name, val)) = nv.split_once('=') { self.cookies .insert(name.trim().to_string(), val.trim().to_string()); } } let body_bytes = response .into_body() .collect() .await .expect("Failed to read response body") .to_bytes(); let text = String::from_utf8_lossy(&body_bytes).to_string(); if let Some(token) = extract_csrf_from_html(&text) { self.csrf_token = Some(token); } TestResponse { status, text, headers, } } } #[allow(dead_code)] pub struct TestResponse { pub status: StatusCode, pub text: String, pub headers: axum::http::HeaderMap, } #[allow(dead_code)] impl TestResponse { pub fn json(&self) -> T { serde_json::from_str(&self.text) .unwrap_or_else(|e| panic!("Failed to parse JSON: {}\nBody: {}", e, &self.text)) } pub fn header(&self, name: &str) -> Option<&str> { self.headers.get(name).and_then(|v| v.to_str().ok()) } } fn extract_csrf_from_html(html: &str) -> Option { let marker = "csrf-token\" content=\""; let start = html.find(marker)? + marker.len(); let end = html[start..].find('"')? + start; Some(html[start..end].to_string()) }