Skip to main content

max / multithreaded

10.2 KB · 343 lines History Blame Raw
1 //! Cookie-aware in-process HTTP client for integration tests.
2
3 use axum::body::Body;
4 use axum::extract::ConnectInfo;
5 use axum::http::{header, Method, Request, StatusCode};
6 use axum::Router;
7 use http_body_util::BodyExt;
8 use std::collections::HashMap;
9 use std::net::SocketAddr;
10 use tower::ServiceExt;
11
12 pub struct TestClient {
13 app: Router,
14 cookies: HashMap<String, String>,
15 csrf_token: Option<String>,
16 }
17
18 impl TestClient {
19 pub fn new(app: Router) -> Self {
20 TestClient {
21 app,
22 cookies: HashMap::new(),
23 csrf_token: None,
24 }
25 }
26
27 pub async fn get(&mut self, uri: &str) -> TestResponse {
28 self.send(Method::GET, uri, None, None).await
29 }
30
31 pub async fn post_form(&mut self, uri: &str, body: &str) -> TestResponse {
32 self.send(
33 Method::POST,
34 uri,
35 Some("application/x-www-form-urlencoded"),
36 Some(body.to_string()),
37 )
38 .await
39 }
40
41 /// POST without injecting the CSRF token. Used to test CSRF rejection.
42 pub async fn post_form_no_csrf(&mut self, uri: &str, body: &str) -> TestResponse {
43 self.send_raw(
44 Method::POST,
45 uri,
46 Some("application/x-www-form-urlencoded"),
47 Some(body.to_string()),
48 false,
49 )
50 .await
51 }
52
53 /// POST with a specific (wrong) CSRF token.
54 pub async fn post_form_with_token(
55 &mut self,
56 uri: &str,
57 body: &str,
58 token: &str,
59 ) -> TestResponse {
60 self.send_raw_with_token(
61 Method::POST,
62 uri,
63 Some("application/x-www-form-urlencoded"),
64 Some(body.to_string()),
65 token,
66 )
67 .await
68 }
69
70 pub async fn post_json(&mut self, uri: &str, body: &str) -> TestResponse {
71 self.send(
72 Method::POST,
73 uri,
74 Some("application/json"),
75 Some(body.to_string()),
76 )
77 .await
78 }
79
80 pub async fn post_multipart(
81 &mut self,
82 uri: &str,
83 file_data: &[u8],
84 content_type: &str,
85 filename: &str,
86 ) -> TestResponse {
87 let boundary = "----TestBoundary1234567890";
88 let mut body = Vec::new();
89 body.extend_from_slice(format!("--{boundary}\r\n").as_bytes());
90 body.extend_from_slice(
91 format!(
92 "Content-Disposition: form-data; name=\"file\"; filename=\"{filename}\"\r\n\
93 Content-Type: {content_type}\r\n\r\n"
94 )
95 .as_bytes(),
96 );
97 body.extend_from_slice(file_data);
98 body.extend_from_slice(format!("\r\n--{boundary}--\r\n").as_bytes());
99
100 let ct = format!("multipart/form-data; boundary={boundary}");
101 self.send_bytes(Method::POST, uri, &ct, body).await
102 }
103
104 pub fn csrf_token(&self) -> Option<&str> {
105 self.csrf_token.as_deref()
106 }
107
108 async fn send(
109 &mut self,
110 method: Method,
111 uri: &str,
112 content_type: Option<&str>,
113 body: Option<String>,
114 ) -> TestResponse {
115 self.send_raw(method, uri, content_type, body, true).await
116 }
117
118 async fn send_bytes(
119 &mut self,
120 method: Method,
121 uri: &str,
122 content_type: &str,
123 body: Vec<u8>,
124 ) -> TestResponse {
125 let mut builder = Request::builder()
126 .method(&method)
127 .uri(uri)
128 .header(header::CONTENT_TYPE, content_type);
129
130 if matches!(method, Method::POST | Method::PUT | Method::PATCH | Method::DELETE)
131 && let Some(ref token) = self.csrf_token
132 {
133 builder = builder.header("X-CSRF-Token", token.as_str());
134 }
135
136 if !self.cookies.is_empty() {
137 let cookie_header: String = self
138 .cookies
139 .iter()
140 .map(|(k, v)| format!("{}={}", k, v))
141 .collect::<Vec<_>>()
142 .join("; ");
143 builder = builder.header(header::COOKIE, cookie_header);
144 }
145
146 let mut request = builder.body(Body::from(body)).expect("Failed to build request");
147 request.extensions_mut().insert(ConnectInfo(SocketAddr::from(([127, 0, 0, 1], 0))));
148
149 let response = self.app.clone().oneshot(request).await.expect("Failed to send request");
150 let status = response.status();
151 let headers = response.headers().clone();
152
153 for value in headers.get_all(header::SET_COOKIE) {
154 if let Ok(cookie_str) = value.to_str()
155 && let Some(nv) = cookie_str.split(';').next()
156 && let Some((name, val)) = nv.split_once('=')
157 {
158 self.cookies.insert(name.trim().to_string(), val.trim().to_string());
159 }
160 }
161
162 let body_bytes = response.into_body().collect().await
163 .expect("Failed to read response body").to_bytes();
164 let text = String::from_utf8_lossy(&body_bytes).to_string();
165
166 if let Some(token) = extract_csrf_from_html(&text) {
167 self.csrf_token = Some(token);
168 }
169
170 TestResponse { status, text, headers }
171 }
172
173 async fn send_raw(
174 &mut self,
175 method: Method,
176 uri: &str,
177 content_type: Option<&str>,
178 body: Option<String>,
179 inject_csrf: bool,
180 ) -> TestResponse {
181 let body_data = body.unwrap_or_default();
182 let mut builder = Request::builder().method(&method).uri(uri);
183
184 if let Some(ct) = content_type {
185 builder = builder.header(header::CONTENT_TYPE, ct);
186 }
187
188 if inject_csrf
189 && matches!(method, Method::POST | Method::PUT | Method::PATCH | Method::DELETE)
190 && let Some(ref token) = self.csrf_token
191 {
192 builder = builder.header("X-CSRF-Token", token.as_str());
193 }
194
195 if !self.cookies.is_empty() {
196 let cookie_header: String = self
197 .cookies
198 .iter()
199 .map(|(k, v)| format!("{}={}", k, v))
200 .collect::<Vec<_>>()
201 .join("; ");
202 builder = builder.header(header::COOKIE, cookie_header);
203 }
204
205 let mut request = builder.body(Body::from(body_data)).expect("Failed to build request");
206 // Provide ConnectInfo so SmartIpKeyExtractor works in tests
207 request.extensions_mut().insert(ConnectInfo(SocketAddr::from(([127, 0, 0, 1], 0))));
208
209 let response = self
210 .app
211 .clone()
212 .oneshot(request)
213 .await
214 .expect("Failed to send request");
215
216 let status = response.status();
217 let headers = response.headers().clone();
218
219 for value in headers.get_all(header::SET_COOKIE) {
220 if let Ok(cookie_str) = value.to_str()
221 && let Some(nv) = cookie_str.split(';').next()
222 && let Some((name, val)) = nv.split_once('=')
223 {
224 self.cookies
225 .insert(name.trim().to_string(), val.trim().to_string());
226 }
227 }
228
229 let body_bytes = response
230 .into_body()
231 .collect()
232 .await
233 .expect("Failed to read response body")
234 .to_bytes();
235 let text = String::from_utf8_lossy(&body_bytes).to_string();
236
237 if let Some(token) = extract_csrf_from_html(&text) {
238 self.csrf_token = Some(token);
239 }
240
241 TestResponse {
242 status,
243 text,
244 headers,
245 }
246 }
247
248 async fn send_raw_with_token(
249 &mut self,
250 method: Method,
251 uri: &str,
252 content_type: Option<&str>,
253 body: Option<String>,
254 token: &str,
255 ) -> TestResponse {
256 let body_data = body.unwrap_or_default();
257 let mut builder = Request::builder().method(&method).uri(uri);
258
259 if let Some(ct) = content_type {
260 builder = builder.header(header::CONTENT_TYPE, ct);
261 }
262
263 builder = builder.header("X-CSRF-Token", token);
264
265 if !self.cookies.is_empty() {
266 let cookie_header: String = self
267 .cookies
268 .iter()
269 .map(|(k, v)| format!("{}={}", k, v))
270 .collect::<Vec<_>>()
271 .join("; ");
272 builder = builder.header(header::COOKIE, cookie_header);
273 }
274
275 let mut request = builder.body(Body::from(body_data)).expect("Failed to build request");
276 request.extensions_mut().insert(ConnectInfo(SocketAddr::from(([127, 0, 0, 1], 0))));
277
278 let response = self
279 .app
280 .clone()
281 .oneshot(request)
282 .await
283 .expect("Failed to send request");
284
285 let status = response.status();
286 let headers = response.headers().clone();
287
288 for value in headers.get_all(header::SET_COOKIE) {
289 if let Ok(cookie_str) = value.to_str()
290 && let Some(nv) = cookie_str.split(';').next()
291 && let Some((name, val)) = nv.split_once('=')
292 {
293 self.cookies
294 .insert(name.trim().to_string(), val.trim().to_string());
295 }
296 }
297
298 let body_bytes = response
299 .into_body()
300 .collect()
301 .await
302 .expect("Failed to read response body")
303 .to_bytes();
304 let text = String::from_utf8_lossy(&body_bytes).to_string();
305
306 if let Some(token) = extract_csrf_from_html(&text) {
307 self.csrf_token = Some(token);
308 }
309
310 TestResponse {
311 status,
312 text,
313 headers,
314 }
315 }
316 }
317
318 #[allow(dead_code)]
319 pub struct TestResponse {
320 pub status: StatusCode,
321 pub text: String,
322 pub headers: axum::http::HeaderMap,
323 }
324
325 #[allow(dead_code)]
326 impl TestResponse {
327 pub fn json<T: serde::de::DeserializeOwned>(&self) -> T {
328 serde_json::from_str(&self.text)
329 .unwrap_or_else(|e| panic!("Failed to parse JSON: {}\nBody: {}", e, &self.text))
330 }
331
332 pub fn header(&self, name: &str) -> Option<&str> {
333 self.headers.get(name).and_then(|v| v.to_str().ok())
334 }
335 }
336
337 fn extract_csrf_from_html(html: &str) -> Option<String> {
338 let marker = "csrf-token\" content=\"";
339 let start = html.find(marker)? + marker.len();
340 let end = html[start..].find('"')? + start;
341 Some(html[start..end].to_string())
342 }
343