Skip to main content

max / makenotwork

15.0 KB · 466 lines History Blame Raw
1 //! Cookie-aware in-process HTTP client.
2 //!
3 //! Wraps an Axum `Router` and uses `tower::ServiceExt::oneshot` for each
4 //! request. Manages cookies across requests and auto-injects CSRF tokens.
5
6 use axum::body::Body;
7 use axum::http::{header, Method, Request, StatusCode};
8 use axum::Router;
9 use http_body_util::BodyExt;
10 use std::collections::HashMap;
11 use std::sync::atomic::{AtomicU32, Ordering};
12 use tower::ServiceExt;
13
14 /// Monotonic counter for unique per-test IPs (10.1.x.y).
15 static IP_COUNTER: AtomicU32 = AtomicU32::new(1);
16
17 /// A test HTTP client that talks to the app in-process.
18 pub struct TestClient {
19 app: Router,
20 cookies: HashMap<String, String>,
21 csrf_token: Option<String>,
22 forwarded_ip: String,
23 bearer_token: Option<String>,
24 }
25
26 impl TestClient {
27 pub fn new(app: Router) -> Self {
28 let n = IP_COUNTER.fetch_add(1, Ordering::Relaxed);
29 let octet3 = (n / 256) % 256;
30 let octet4 = n % 256;
31 TestClient {
32 app,
33 cookies: HashMap::new(),
34 csrf_token: None,
35 forwarded_ip: format!("10.1.{}.{}", octet3, octet4),
36 bearer_token: None,
37 }
38 }
39
40 /// Set the IP address used in the X-Forwarded-For header.
41 #[allow(dead_code)]
42 pub fn set_forwarded_ip(&mut self, ip: &str) {
43 self.forwarded_ip = ip.to_string();
44 }
45
46 /// Set a bearer token for subsequent requests (used by SyncKit JWT auth).
47 #[allow(dead_code)]
48 pub fn set_bearer_token(&mut self, token: &str) {
49 self.bearer_token = Some(token.to_string());
50 }
51
52 /// Clear the bearer token.
53 #[allow(dead_code)]
54 pub fn clear_bearer_token(&mut self) {
55 self.bearer_token = None;
56 }
57
58 /// Access the current CSRF token (if any).
59 #[allow(dead_code)]
60 pub fn csrf_token(&self) -> Option<&str> {
61 self.csrf_token.as_deref()
62 }
63
64 /// GET request.
65 pub async fn get(&mut self, uri: &str) -> TestResponse {
66 self.request(Method::GET, uri, None, None).await
67 }
68
69 /// POST with form-encoded body.
70 pub async fn post_form(&mut self, uri: &str, body: &str) -> TestResponse {
71 self.request(
72 Method::POST,
73 uri,
74 Some("application/x-www-form-urlencoded"),
75 Some(body.to_string()),
76 )
77 .await
78 }
79
80 /// POST with JSON body.
81 #[allow(dead_code)]
82 pub async fn post_json(&mut self, uri: &str, body: &str) -> TestResponse {
83 self.request(Method::POST, uri, Some("application/json"), Some(body.to_string()))
84 .await
85 }
86
87 /// PUT with form-encoded body.
88 pub async fn put_form(&mut self, uri: &str, body: &str) -> TestResponse {
89 self.request(
90 Method::PUT,
91 uri,
92 Some("application/x-www-form-urlencoded"),
93 Some(body.to_string()),
94 )
95 .await
96 }
97
98 /// PUT with JSON body.
99 #[allow(dead_code)]
100 pub async fn put_json(&mut self, uri: &str, body: &str) -> TestResponse {
101 self.request(Method::PUT, uri, Some("application/json"), Some(body.to_string()))
102 .await
103 }
104
105 /// DELETE request.
106 #[allow(dead_code)]
107 pub async fn delete(&mut self, uri: &str) -> TestResponse {
108 self.request(Method::DELETE, uri, None, None).await
109 }
110
111 /// PATCH with JSON body.
112 #[allow(dead_code)]
113 pub async fn patch_json(&mut self, uri: &str, body: &str) -> TestResponse {
114 self.request(Method::PATCH, uri, Some("application/json"), Some(body.to_string()))
115 .await
116 }
117
118 /// DELETE with form-encoded body.
119 #[allow(dead_code)]
120 pub async fn delete_form(&mut self, uri: &str, body: &str) -> TestResponse {
121 self.request(
122 Method::DELETE,
123 uri,
124 Some("application/x-www-form-urlencoded"),
125 Some(body.to_string()),
126 )
127 .await
128 }
129
130 /// POST with multipart/form-data body. Fields are (name, value) pairs.
131 /// Supports repeated field names for Vec<T> deserialization.
132 #[allow(dead_code)]
133 pub async fn post_multipart(&mut self, uri: &str, fields: &[(&str, &str)]) -> TestResponse {
134 let boundary = "----TestBoundary7MA4YWxkTrZu0gW";
135 let mut body = String::new();
136 for (name, value) in fields {
137 body.push_str(&format!("--{}\r\n", boundary));
138 body.push_str(&format!(
139 "Content-Disposition: form-data; name=\"{}\"\r\n\r\n{}\r\n",
140 name, value
141 ));
142 }
143 body.push_str(&format!("--{}--\r\n", boundary));
144
145 let content_type = format!("multipart/form-data; boundary={}", boundary);
146 self.request(Method::POST, uri, Some(&content_type), Some(body))
147 .await
148 }
149
150 /// HTMX GET request (includes `HX-Request: true` header).
151 #[allow(dead_code)]
152 pub async fn htmx_get(&mut self, uri: &str) -> TestResponse {
153 self.request_htmx(Method::GET, uri, None, None).await
154 }
155
156 /// HTMX POST with form-encoded body.
157 #[allow(dead_code)]
158 pub async fn htmx_post_form(&mut self, uri: &str, body: &str) -> TestResponse {
159 self.request_htmx(
160 Method::POST,
161 uri,
162 Some("application/x-www-form-urlencoded"),
163 Some(body.to_string()),
164 )
165 .await
166 }
167
168 /// HTMX PUT with form-encoded body.
169 #[allow(dead_code)]
170 pub async fn htmx_put_form(&mut self, uri: &str, body: &str) -> TestResponse {
171 self.request_htmx(
172 Method::PUT,
173 uri,
174 Some("application/x-www-form-urlencoded"),
175 Some(body.to_string()),
176 )
177 .await
178 }
179
180 /// HTMX DELETE request (includes `HX-Request: true` header).
181 #[allow(dead_code)]
182 pub async fn htmx_delete(&mut self, uri: &str) -> TestResponse {
183 self.request_htmx(Method::DELETE, uri, None, None).await
184 }
185
186 /// Fetch the CSRF token by loading the /login page and extracting it
187 /// from `<meta name="csrf-token" content="...">`.
188 pub async fn fetch_csrf_token(&mut self) {
189 let resp = self.get("/login").await;
190 if let Some(token) = extract_csrf_from_html(&resp.text) {
191 self.csrf_token = Some(token);
192 }
193 }
194
195 /// Build and send a regular request (no HTMX header).
196 async fn request(
197 &mut self,
198 method: Method,
199 uri: &str,
200 content_type: Option<&str>,
201 body: Option<String>,
202 ) -> TestResponse {
203 self.send(method, uri, content_type, body, false).await
204 }
205
206 /// Build and send a request with the `HX-Request: true` header.
207 #[allow(dead_code)]
208 async fn request_htmx(
209 &mut self,
210 method: Method,
211 uri: &str,
212 content_type: Option<&str>,
213 body: Option<String>,
214 ) -> TestResponse {
215 self.send(method, uri, content_type, body, true).await
216 }
217
218 /// Build and send a request through `oneshot`, optionally with the HTMX header.
219 async fn send(
220 &mut self,
221 method: Method,
222 uri: &str,
223 content_type: Option<&str>,
224 body: Option<String>,
225 htmx: bool,
226 ) -> TestResponse {
227 let body_data = body.unwrap_or_default();
228 let mut builder = Request::builder()
229 .method(&method)
230 .uri(uri)
231 // Required: SmartIpKeyExtractor needs an IP; oneshot has no ConnectInfo
232 .header("X-Forwarded-For", &self.forwarded_ip)
233 // Production reads `CF-Connecting-IP` (the only header origin clients
234 // can't spoof through Caddy). Send both so the harness matches the
235 // production proxy chain and `extract_client_ip` finds a value.
236 .header("CF-Connecting-IP", &self.forwarded_ip);
237
238 if htmx {
239 builder = builder.header("HX-Request", "true");
240 }
241
242 // Inject bearer token if set
243 if let Some(ref token) = self.bearer_token {
244 builder = builder.header(header::AUTHORIZATION, format!("Bearer {}", token));
245 }
246
247 // Set content type
248 if let Some(ct) = content_type {
249 builder = builder.header(header::CONTENT_TYPE, ct);
250 }
251
252 // Inject CSRF token for mutating methods
253 if matches!(method, Method::POST | Method::PUT | Method::PATCH | Method::DELETE)
254 && let Some(ref token) = self.csrf_token
255 {
256 builder = builder.header("X-CSRF-Token", token.as_str());
257 }
258
259 // Attach cookies
260 if !self.cookies.is_empty() {
261 let cookie_header: String = self
262 .cookies
263 .iter()
264 .map(|(k, v)| format!("{}={}", k, v))
265 .collect::<Vec<_>>()
266 .join("; ");
267 builder = builder.header(header::COOKIE, cookie_header);
268 }
269
270 let request = builder.body(Body::from(body_data)).expect("Failed to build request");
271
272 let response = self
273 .app
274 .clone()
275 .oneshot(request)
276 .await
277 .expect("Failed to send request");
278
279 // Collect set-cookie headers before consuming response
280 let status = response.status();
281 let headers = response.headers().clone();
282
283 // Store cookies from response
284 for value in headers.get_all(header::SET_COOKIE) {
285 if let Ok(cookie_str) = value.to_str() {
286 // Parse "name=value; ..." — take only the name=value part
287 if let Some(nv) = cookie_str.split(';').next()
288 && let Some((name, val)) = nv.split_once('=')
289 {
290 self.cookies
291 .insert(name.trim().to_string(), val.trim().to_string());
292 }
293 }
294 }
295
296 // Read body
297 let body_bytes = response
298 .into_body()
299 .collect()
300 .await
301 .expect("Failed to read response body")
302 .to_bytes();
303 let text = String::from_utf8_lossy(&body_bytes).to_string();
304
305 // Auto-extract CSRF token from HTML responses for convenience
306 if let Some(token) = extract_csrf_from_html(&text) {
307 self.csrf_token = Some(token);
308 }
309
310 TestResponse {
311 status,
312 text,
313 headers,
314 }
315 }
316
317 /// Raw request with custom headers. Used for webhook tests where we need
318 /// to set the stripe-signature header and bypass CSRF.
319 #[allow(dead_code)]
320 pub async fn request_with_headers(
321 &mut self,
322 method: &str,
323 uri: &str,
324 body: Option<&str>,
325 extra_headers: &[(&str, &str)],
326 ) -> TestResponse {
327 let body_data = body.unwrap_or_default().to_string();
328 let mut builder = Request::builder()
329 .method(method)
330 .uri(uri)
331 .header("X-Forwarded-For", &self.forwarded_ip)
332 // Production reads `CF-Connecting-IP` (the only header origin clients
333 // can't spoof through Caddy). Send both so the harness matches the
334 // production proxy chain and `extract_client_ip` finds a value.
335 .header("CF-Connecting-IP", &self.forwarded_ip);
336
337 for (name, value) in extra_headers {
338 builder = builder.header(*name, *value);
339 }
340
341 // Attach cookies
342 if !self.cookies.is_empty() {
343 let cookie_header: String = self
344 .cookies
345 .iter()
346 .map(|(k, v)| format!("{}={}", k, v))
347 .collect::<Vec<_>>()
348 .join("; ");
349 builder = builder.header(header::COOKIE, cookie_header);
350 }
351
352 let request = builder.body(Body::from(body_data)).expect("Failed to build request");
353
354 let response = self
355 .app
356 .clone()
357 .oneshot(request)
358 .await
359 .expect("Failed to send request");
360
361 let status = response.status();
362 let resp_headers = response.headers().clone();
363
364 for value in resp_headers.get_all(header::SET_COOKIE) {
365 if let Ok(cookie_str) = value.to_str()
366 && let Some(nv) = cookie_str.split(';').next()
367 && let Some((name, val)) = nv.split_once('=')
368 {
369 self.cookies
370 .insert(name.trim().to_string(), val.trim().to_string());
371 }
372 }
373
374 let body_bytes = response
375 .into_body()
376 .collect()
377 .await
378 .expect("Failed to read response body")
379 .to_bytes();
380 let text = String::from_utf8_lossy(&body_bytes).to_string();
381
382 TestResponse {
383 status,
384 text,
385 headers: resp_headers,
386 }
387 }
388
389 /// Send a GET request and return status + headers WITHOUT reading the body.
390 /// Use for streaming endpoints (SSE) where the body never ends.
391 #[allow(dead_code)]
392 pub async fn get_streaming(&mut self, uri: &str) -> TestResponse {
393 let mut builder = Request::builder()
394 .method(Method::GET)
395 .uri(uri)
396 .header("X-Forwarded-For", &self.forwarded_ip)
397 // Production reads `CF-Connecting-IP` (the only header origin clients
398 // can't spoof through Caddy). Send both so the harness matches the
399 // production proxy chain and `extract_client_ip` finds a value.
400 .header("CF-Connecting-IP", &self.forwarded_ip);
401
402 if let Some(ref token) = self.bearer_token {
403 builder = builder.header(header::AUTHORIZATION, format!("Bearer {}", token));
404 }
405
406 if !self.cookies.is_empty() {
407 let cookie_header: String = self
408 .cookies
409 .iter()
410 .map(|(k, v)| format!("{}={}", k, v))
411 .collect::<Vec<_>>()
412 .join("; ");
413 builder = builder.header(header::COOKIE, cookie_header);
414 }
415
416 let request = builder.body(Body::empty()).expect("Failed to build request");
417
418 let response = self
419 .app
420 .clone()
421 .oneshot(request)
422 .await
423 .expect("Failed to send request");
424
425 let status = response.status();
426 let headers = response.headers().clone();
427
428 // Do NOT read the body — return immediately
429 TestResponse {
430 status,
431 text: String::new(),
432 headers,
433 }
434 }
435 }
436
437 /// Response wrapper with convenience methods.
438 #[allow(dead_code)]
439 pub struct TestResponse {
440 pub status: StatusCode,
441 pub text: String,
442 pub headers: axum::http::HeaderMap,
443 }
444
445 #[allow(dead_code)]
446 impl TestResponse {
447 /// Parse the body as JSON.
448 pub fn json<T: serde::de::DeserializeOwned>(&self) -> T {
449 serde_json::from_str(&self.text)
450 .unwrap_or_else(|e| panic!("Failed to parse JSON: {}\nBody: {}", e, &self.text))
451 }
452
453 /// Get a header value as a string.
454 pub fn header(&self, name: &str) -> Option<&str> {
455 self.headers.get(name).and_then(|v| v.to_str().ok())
456 }
457 }
458
459 /// Extract CSRF token from `<meta name="csrf-token" content="...">`.
460 fn extract_csrf_from_html(html: &str) -> Option<String> {
461 let marker = "csrf-token\" content=\"";
462 let start = html.find(marker)? + marker.len();
463 let end = html[start..].find('"')? + start;
464 Some(html[start..end].to_string())
465 }
466