Skip to main content

max / multithreaded

3.6 KB · 129 lines History Blame Raw
1 //! CSRF (Cross-Site Request Forgery) protection.
2 //!
3 //! Synchronizer token pattern: generate random token on session start,
4 //! include in meta tag, validate on state-changing requests via X-CSRF-Token header.
5
6 use axum::{
7 extract::Request,
8 http::StatusCode,
9 middleware::Next,
10 response::{IntoResponse, Response},
11 };
12 use rand::RngCore;
13 use tower_sessions::Session;
14
15 const CSRF_SESSION_KEY: &str = "csrf_token";
16 const CSRF_TOKEN_LENGTH: usize = 32;
17
18 /// Generate a new random CSRF token (64-char hex string).
19 pub fn generate_token() -> String {
20 let mut token = [0u8; CSRF_TOKEN_LENGTH];
21 rand::thread_rng().fill_bytes(&mut token);
22 hex::encode(token)
23 }
24
25 /// Get the existing CSRF token from the session, or create and store a new one.
26 pub async fn get_or_create_token(session: &Session) -> String {
27 if let Ok(Some(token)) = session.get::<String>(CSRF_SESSION_KEY).await {
28 return token;
29 }
30 let token = generate_token();
31 let _ = session.insert(CSRF_SESSION_KEY, &token).await;
32 token
33 }
34
35 /// Constant-time comparison to prevent timing attacks.
36 pub fn constant_time_compare(a: &str, b: &str) -> bool {
37 if a.len() != b.len() {
38 return false;
39 }
40 a.bytes()
41 .zip(b.bytes())
42 .fold(0u8, |acc, (x, y)| acc | (x ^ y))
43 == 0
44 }
45
46 /// Middleware: validate X-CSRF-Token header on POST/PUT/PATCH/DELETE.
47 ///
48 /// Exempt paths: `/auth/`, `/api/health`, `/_test/`.
49 pub async fn csrf_middleware(request: Request, next: Next) -> Response {
50 let method = request.method().clone();
51
52 if !["POST", "PUT", "PATCH", "DELETE"].contains(&method.as_str()) {
53 return next.run(request).await;
54 }
55
56 let path = request.uri().path().to_string();
57
58 let exempt_prefixes = ["/auth/", "/api/health", "/_test/"];
59 if exempt_prefixes.iter().any(|p| path.starts_with(p)) {
60 return next.run(request).await;
61 }
62
63 let session = match request.extensions().get::<Session>() {
64 Some(s) => s.clone(),
65 None => {
66 tracing::warn!("CSRF check failed: no session");
67 return (StatusCode::FORBIDDEN, "CSRF validation failed").into_response();
68 }
69 };
70
71 let provided_token = request
72 .headers()
73 .get("X-CSRF-Token")
74 .and_then(|v| v.to_str().ok())
75 .map(|s| s.to_string());
76
77 let token = match provided_token {
78 Some(t) => t,
79 None => {
80 tracing::warn!(path = %path, "CSRF token missing");
81 return (StatusCode::FORBIDDEN, "CSRF token required").into_response();
82 }
83 };
84
85 let session_token: Option<String> = session
86 .get(CSRF_SESSION_KEY)
87 .await
88 .ok()
89 .flatten();
90
91 match session_token {
92 Some(ref expected) if constant_time_compare(expected, &token) => {
93 next.run(request).await
94 }
95 _ => {
96 tracing::warn!(path = %path, "CSRF token mismatch");
97 (StatusCode::FORBIDDEN, "Invalid CSRF token").into_response()
98 }
99 }
100 }
101
102 #[cfg(test)]
103 mod tests {
104 use super::*;
105
106 #[test]
107 fn token_length_and_hex() {
108 let token = generate_token();
109 assert_eq!(token.len(), 64);
110 assert!(token.chars().all(|c| c.is_ascii_hexdigit()));
111 }
112
113 #[test]
114 fn tokens_are_unique() {
115 let a = generate_token();
116 let b = generate_token();
117 assert_ne!(a, b);
118 }
119
120 #[test]
121 fn constant_time_compare_works() {
122 assert!(constant_time_compare("abc", "abc"));
123 assert!(!constant_time_compare("abc", "abd"));
124 assert!(!constant_time_compare("abc", "abcd"));
125 assert!(!constant_time_compare("", "a"));
126 assert!(constant_time_compare("", ""));
127 }
128 }
129