//! CORS preflight verification — sends OPTIONS requests and checks Access-Control headers. use tracing::instrument; use crate::config::CorsCheck; use crate::types::CorsCheckResult; /// Send a CORS preflight OPTIONS request and verify the response allows the expected origin. /// Returns one `CorsCheckResult` per `CorsCheck` in the input. #[instrument(skip_all)] pub async fn check_cors(target: &str, checks: &[CorsCheck]) -> Vec { let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(10)) .redirect(reqwest::redirect::Policy::none()) .build() .unwrap(); let mut results = Vec::with_capacity(checks.len()); for check in checks { results.push(run_preflight(target, &client, check).await); } results } async fn run_preflight( target: &str, client: &reqwest::Client, check: &CorsCheck, ) -> CorsCheckResult { let now = chrono::Utc::now().to_rfc3339(); let response = client .request(reqwest::Method::OPTIONS, &check.url) .header("Origin", &check.origin) .header("Access-Control-Request-Method", &check.method) .send() .await; match response { Ok(resp) => { let status = resp.status().as_u16(); let allow_origin = resp .headers() .get("access-control-allow-origin") .and_then(|v| v.to_str().ok()) .unwrap_or("") .to_string(); let allow_methods = resp .headers() .get("access-control-allow-methods") .and_then(|v| v.to_str().ok()) .unwrap_or("") .to_string(); let (passes, error) = evaluate_preflight( status, &allow_origin, &allow_methods, &check.origin, &check.method, ); CorsCheckResult { target: target.to_string(), url: check.url.clone(), origin: check.origin.clone(), method: check.method.clone(), passes, checked_at: now, error, } } Err(e) => CorsCheckResult { target: target.to_string(), url: check.url.clone(), origin: check.origin.clone(), method: check.method.clone(), passes: false, checked_at: now, error: Some(format!("preflight request failed: {e}")), }, } } /// Evaluate CORS preflight response headers against expected values. /// Returns `(passes, error_message)`. fn evaluate_preflight( status: u16, allow_origin: &str, allow_methods: &str, expected_origin: &str, expected_method: &str, ) -> (bool, Option) { let origin_ok = allow_origin == expected_origin || allow_origin == "*"; let method_ok = allow_methods .split(',') .any(|m| m.trim().eq_ignore_ascii_case(expected_method)); let passes = status < 400 && origin_ok && method_ok; if passes { (true, None) } else { let mut reasons = Vec::new(); if status >= 400 { reasons.push(format!("HTTP {status}")); } if !origin_ok { reasons.push(format!( "Access-Control-Allow-Origin: {allow_origin:?} (expected {expected_origin:?} or \"*\")", )); } if !method_ok { reasons.push(format!( "Access-Control-Allow-Methods: {allow_methods:?} (expected {expected_method:?})", )); } (false, Some(reasons.join("; "))) } } #[cfg(test)] mod tests { use super::*; #[test] fn preflight_exact_origin_match() { let (passes, error) = evaluate_preflight(200, "https://makenot.work", "PUT", "https://makenot.work", "PUT"); assert!(passes); assert!(error.is_none()); } #[test] fn preflight_wildcard_origin() { let (passes, error) = evaluate_preflight(200, "*", "GET", "https://makenot.work", "GET"); assert!(passes); assert!(error.is_none()); } #[test] fn preflight_origin_mismatch() { let (passes, error) = evaluate_preflight(200, "https://other.com", "PUT", "https://makenot.work", "PUT"); assert!(!passes); let msg = error.unwrap(); assert!(msg.contains("Access-Control-Allow-Origin")); assert!(msg.contains("https://other.com")); } #[test] fn preflight_method_case_insensitive() { let (passes, _) = evaluate_preflight(200, "*", "put", "https://x.com", "PUT"); assert!(passes); } #[test] fn preflight_method_comma_separated() { let (passes, _) = evaluate_preflight(200, "*", "GET, PUT, DELETE", "https://x.com", "PUT"); assert!(passes); } #[test] fn preflight_method_with_whitespace() { let (passes, _) = evaluate_preflight(200, "*", "GET , PUT , DELETE", "https://x.com", "PUT"); assert!(passes); } #[test] fn preflight_method_mismatch() { let (passes, error) = evaluate_preflight(200, "*", "GET, POST", "https://x.com", "PUT"); assert!(!passes); let msg = error.unwrap(); assert!(msg.contains("Access-Control-Allow-Methods")); } #[test] fn preflight_status_400_fails() { let (passes, error) = evaluate_preflight(403, "https://makenot.work", "PUT", "https://makenot.work", "PUT"); assert!(!passes); assert!(error.unwrap().contains("HTTP 403")); } #[test] fn preflight_multiple_failures() { let (passes, error) = evaluate_preflight(500, "https://wrong.com", "GET", "https://makenot.work", "PUT"); assert!(!passes); let msg = error.unwrap(); assert!(msg.contains("HTTP 500")); assert!(msg.contains("Access-Control-Allow-Origin")); assert!(msg.contains("Access-Control-Allow-Methods")); assert!(msg.contains("; ")); } #[test] fn preflight_missing_headers() { let (passes, error) = evaluate_preflight(200, "", "", "https://makenot.work", "PUT"); assert!(!passes); let msg = error.unwrap(); assert!(msg.contains("Access-Control-Allow-Origin")); assert!(msg.contains("Access-Control-Allow-Methods")); } #[test] fn cors_check_result_serde_roundtrip() { let result = CorsCheckResult { target: "mnw".to_string(), url: "https://s3.example.com/bucket/test".to_string(), origin: "https://makenot.work".to_string(), method: "PUT".to_string(), passes: true, checked_at: "2026-03-28T00:00:00Z".to_string(), error: None, }; let json = serde_json::to_string(&result).unwrap(); let parsed: CorsCheckResult = serde_json::from_str(&json).unwrap(); assert_eq!(parsed.target, "mnw"); assert!(parsed.passes); assert!(parsed.error.is_none()); } #[test] fn cors_check_result_with_error() { let result = CorsCheckResult { target: "mnw".to_string(), url: "https://s3.example.com/bucket/test".to_string(), origin: "https://makenot.work".to_string(), method: "PUT".to_string(), passes: false, checked_at: "2026-03-28T00:00:00Z".to_string(), error: Some("HTTP 403".to_string()), }; let json = serde_json::to_string(&result).unwrap(); let parsed: CorsCheckResult = serde_json::from_str(&json).unwrap(); assert!(!parsed.passes); assert_eq!(parsed.error.as_deref(), Some("HTTP 403")); } }