Skip to main content

max / makenotwork

7.6 KB · 237 lines History Blame Raw
1 //! CORS preflight verification — sends OPTIONS requests and checks Access-Control headers.
2
3 use tracing::instrument;
4
5 use crate::config::CorsCheck;
6 use crate::types::CorsCheckResult;
7
8 /// Send a CORS preflight OPTIONS request and verify the response allows the expected origin.
9 /// Returns one `CorsCheckResult` per `CorsCheck` in the input.
10 #[instrument(skip_all)]
11 pub async fn check_cors(target: &str, checks: &[CorsCheck]) -> Vec<CorsCheckResult> {
12 let client = reqwest::Client::builder()
13 .timeout(std::time::Duration::from_secs(10))
14 .redirect(reqwest::redirect::Policy::none())
15 .build()
16 .unwrap();
17
18 let mut results = Vec::with_capacity(checks.len());
19 for check in checks {
20 results.push(run_preflight(target, &client, check).await);
21 }
22 results
23 }
24
25 async fn run_preflight(
26 target: &str,
27 client: &reqwest::Client,
28 check: &CorsCheck,
29 ) -> CorsCheckResult {
30 let now = chrono::Utc::now().to_rfc3339();
31
32 let response = client
33 .request(reqwest::Method::OPTIONS, &check.url)
34 .header("Origin", &check.origin)
35 .header("Access-Control-Request-Method", &check.method)
36 .send()
37 .await;
38
39 match response {
40 Ok(resp) => {
41 let status = resp.status().as_u16();
42 let allow_origin = resp
43 .headers()
44 .get("access-control-allow-origin")
45 .and_then(|v| v.to_str().ok())
46 .unwrap_or("")
47 .to_string();
48 let allow_methods = resp
49 .headers()
50 .get("access-control-allow-methods")
51 .and_then(|v| v.to_str().ok())
52 .unwrap_or("")
53 .to_string();
54
55 let (passes, error) = evaluate_preflight(
56 status,
57 &allow_origin,
58 &allow_methods,
59 &check.origin,
60 &check.method,
61 );
62
63 CorsCheckResult {
64 target: target.to_string(),
65 url: check.url.clone(),
66 origin: check.origin.clone(),
67 method: check.method.clone(),
68 passes,
69 checked_at: now,
70 error,
71 }
72 }
73 Err(e) => CorsCheckResult {
74 target: target.to_string(),
75 url: check.url.clone(),
76 origin: check.origin.clone(),
77 method: check.method.clone(),
78 passes: false,
79 checked_at: now,
80 error: Some(format!("preflight request failed: {e}")),
81 },
82 }
83 }
84
85 /// Evaluate CORS preflight response headers against expected values.
86 /// Returns `(passes, error_message)`.
87 fn evaluate_preflight(
88 status: u16,
89 allow_origin: &str,
90 allow_methods: &str,
91 expected_origin: &str,
92 expected_method: &str,
93 ) -> (bool, Option<String>) {
94 let origin_ok = allow_origin == expected_origin || allow_origin == "*";
95 let method_ok = allow_methods
96 .split(',')
97 .any(|m| m.trim().eq_ignore_ascii_case(expected_method));
98
99 let passes = status < 400 && origin_ok && method_ok;
100
101 if passes {
102 (true, None)
103 } else {
104 let mut reasons = Vec::new();
105 if status >= 400 {
106 reasons.push(format!("HTTP {status}"));
107 }
108 if !origin_ok {
109 reasons.push(format!(
110 "Access-Control-Allow-Origin: {allow_origin:?} (expected {expected_origin:?} or \"*\")",
111 ));
112 }
113 if !method_ok {
114 reasons.push(format!(
115 "Access-Control-Allow-Methods: {allow_methods:?} (expected {expected_method:?})",
116 ));
117 }
118 (false, Some(reasons.join("; ")))
119 }
120 }
121
122 #[cfg(test)]
123 mod tests {
124 use super::*;
125
126 #[test]
127 fn preflight_exact_origin_match() {
128 let (passes, error) = evaluate_preflight(200, "https://makenot.work", "PUT", "https://makenot.work", "PUT");
129 assert!(passes);
130 assert!(error.is_none());
131 }
132
133 #[test]
134 fn preflight_wildcard_origin() {
135 let (passes, error) = evaluate_preflight(200, "*", "GET", "https://makenot.work", "GET");
136 assert!(passes);
137 assert!(error.is_none());
138 }
139
140 #[test]
141 fn preflight_origin_mismatch() {
142 let (passes, error) = evaluate_preflight(200, "https://other.com", "PUT", "https://makenot.work", "PUT");
143 assert!(!passes);
144 let msg = error.unwrap();
145 assert!(msg.contains("Access-Control-Allow-Origin"));
146 assert!(msg.contains("https://other.com"));
147 }
148
149 #[test]
150 fn preflight_method_case_insensitive() {
151 let (passes, _) = evaluate_preflight(200, "*", "put", "https://x.com", "PUT");
152 assert!(passes);
153 }
154
155 #[test]
156 fn preflight_method_comma_separated() {
157 let (passes, _) = evaluate_preflight(200, "*", "GET, PUT, DELETE", "https://x.com", "PUT");
158 assert!(passes);
159 }
160
161 #[test]
162 fn preflight_method_with_whitespace() {
163 let (passes, _) = evaluate_preflight(200, "*", "GET , PUT , DELETE", "https://x.com", "PUT");
164 assert!(passes);
165 }
166
167 #[test]
168 fn preflight_method_mismatch() {
169 let (passes, error) = evaluate_preflight(200, "*", "GET, POST", "https://x.com", "PUT");
170 assert!(!passes);
171 let msg = error.unwrap();
172 assert!(msg.contains("Access-Control-Allow-Methods"));
173 }
174
175 #[test]
176 fn preflight_status_400_fails() {
177 let (passes, error) = evaluate_preflight(403, "https://makenot.work", "PUT", "https://makenot.work", "PUT");
178 assert!(!passes);
179 assert!(error.unwrap().contains("HTTP 403"));
180 }
181
182 #[test]
183 fn preflight_multiple_failures() {
184 let (passes, error) = evaluate_preflight(500, "https://wrong.com", "GET", "https://makenot.work", "PUT");
185 assert!(!passes);
186 let msg = error.unwrap();
187 assert!(msg.contains("HTTP 500"));
188 assert!(msg.contains("Access-Control-Allow-Origin"));
189 assert!(msg.contains("Access-Control-Allow-Methods"));
190 assert!(msg.contains("; "));
191 }
192
193 #[test]
194 fn preflight_missing_headers() {
195 let (passes, error) = evaluate_preflight(200, "", "", "https://makenot.work", "PUT");
196 assert!(!passes);
197 let msg = error.unwrap();
198 assert!(msg.contains("Access-Control-Allow-Origin"));
199 assert!(msg.contains("Access-Control-Allow-Methods"));
200 }
201
202 #[test]
203 fn cors_check_result_serde_roundtrip() {
204 let result = CorsCheckResult {
205 target: "mnw".to_string(),
206 url: "https://s3.example.com/bucket/test".to_string(),
207 origin: "https://makenot.work".to_string(),
208 method: "PUT".to_string(),
209 passes: true,
210 checked_at: "2026-03-28T00:00:00Z".to_string(),
211 error: None,
212 };
213 let json = serde_json::to_string(&result).unwrap();
214 let parsed: CorsCheckResult = serde_json::from_str(&json).unwrap();
215 assert_eq!(parsed.target, "mnw");
216 assert!(parsed.passes);
217 assert!(parsed.error.is_none());
218 }
219
220 #[test]
221 fn cors_check_result_with_error() {
222 let result = CorsCheckResult {
223 target: "mnw".to_string(),
224 url: "https://s3.example.com/bucket/test".to_string(),
225 origin: "https://makenot.work".to_string(),
226 method: "PUT".to_string(),
227 passes: false,
228 checked_at: "2026-03-28T00:00:00Z".to_string(),
229 error: Some("HTTP 403".to_string()),
230 };
231 let json = serde_json::to_string(&result).unwrap();
232 let parsed: CorsCheckResult = serde_json::from_str(&json).unwrap();
233 assert!(!parsed.passes);
234 assert_eq!(parsed.error.as_deref(), Some("HTTP 403"));
235 }
236 }
237