Skip to main content

max / pom

4.8 KB · 145 lines History Blame Raw
1 //! CORS preflight verification — sends OPTIONS requests and checks Access-Control headers.
2
3 use tracing::instrument;
4
5 use crate::config::CorsCheck;
6 use crate::types::CorsCheckResult;
7
8 /// Send a CORS preflight OPTIONS request and verify the response allows the expected origin.
9 /// Returns one `CorsCheckResult` per `CorsCheck` in the input.
10 #[instrument(skip_all)]
11 pub async fn check_cors(target: &str, checks: &[CorsCheck]) -> Vec<CorsCheckResult> {
12 let client = reqwest::Client::builder()
13 .timeout(std::time::Duration::from_secs(10))
14 .redirect(reqwest::redirect::Policy::none())
15 .build()
16 .unwrap();
17
18 let mut results = Vec::with_capacity(checks.len());
19 for check in checks {
20 results.push(run_preflight(target, &client, check).await);
21 }
22 results
23 }
24
25 async fn run_preflight(
26 target: &str,
27 client: &reqwest::Client,
28 check: &CorsCheck,
29 ) -> CorsCheckResult {
30 let now = chrono::Utc::now().to_rfc3339();
31
32 let response = client
33 .request(reqwest::Method::OPTIONS, &check.url)
34 .header("Origin", &check.origin)
35 .header("Access-Control-Request-Method", &check.method)
36 .send()
37 .await;
38
39 match response {
40 Ok(resp) => {
41 let status = resp.status().as_u16();
42 let allow_origin = resp
43 .headers()
44 .get("access-control-allow-origin")
45 .and_then(|v| v.to_str().ok())
46 .unwrap_or("")
47 .to_string();
48 let allow_methods = resp
49 .headers()
50 .get("access-control-allow-methods")
51 .and_then(|v| v.to_str().ok())
52 .unwrap_or("")
53 .to_string();
54
55 let origin_ok = allow_origin == check.origin || allow_origin == "*";
56 let method_ok = allow_methods
57 .split(',')
58 .any(|m| m.trim().eq_ignore_ascii_case(&check.method));
59
60 let passes = status < 400 && origin_ok && method_ok;
61
62 let error = if passes {
63 None
64 } else {
65 let mut reasons = Vec::new();
66 if status >= 400 {
67 reasons.push(format!("HTTP {status}"));
68 }
69 if !origin_ok {
70 reasons.push(format!(
71 "Access-Control-Allow-Origin: {allow_origin:?} (expected {:?} or \"*\")",
72 check.origin
73 ));
74 }
75 if !method_ok {
76 reasons.push(format!(
77 "Access-Control-Allow-Methods: {allow_methods:?} (expected {:?})",
78 check.method
79 ));
80 }
81 Some(reasons.join("; "))
82 };
83
84 CorsCheckResult {
85 target: target.to_string(),
86 url: check.url.clone(),
87 origin: check.origin.clone(),
88 method: check.method.clone(),
89 passes,
90 checked_at: now,
91 error,
92 }
93 }
94 Err(e) => CorsCheckResult {
95 target: target.to_string(),
96 url: check.url.clone(),
97 origin: check.origin.clone(),
98 method: check.method.clone(),
99 passes: false,
100 checked_at: now,
101 error: Some(format!("preflight request failed: {e}")),
102 },
103 }
104 }
105
106 #[cfg(test)]
107 mod tests {
108 use super::*;
109
110 #[test]
111 fn cors_check_result_serde_roundtrip() {
112 let result = CorsCheckResult {
113 target: "mnw".to_string(),
114 url: "https://s3.example.com/bucket/test".to_string(),
115 origin: "https://makenot.work".to_string(),
116 method: "PUT".to_string(),
117 passes: true,
118 checked_at: "2026-03-28T00:00:00Z".to_string(),
119 error: None,
120 };
121 let json = serde_json::to_string(&result).unwrap();
122 let parsed: CorsCheckResult = serde_json::from_str(&json).unwrap();
123 assert_eq!(parsed.target, "mnw");
124 assert!(parsed.passes);
125 assert!(parsed.error.is_none());
126 }
127
128 #[test]
129 fn cors_check_result_with_error() {
130 let result = CorsCheckResult {
131 target: "mnw".to_string(),
132 url: "https://s3.example.com/bucket/test".to_string(),
133 origin: "https://makenot.work".to_string(),
134 method: "PUT".to_string(),
135 passes: false,
136 checked_at: "2026-03-28T00:00:00Z".to_string(),
137 error: Some("HTTP 403".to_string()),
138 };
139 let json = serde_json::to_string(&result).unwrap();
140 let parsed: CorsCheckResult = serde_json::from_str(&json).unwrap();
141 assert!(!parsed.passes);
142 assert_eq!(parsed.error.as_deref(), Some("HTTP 403"));
143 }
144 }
145