| 1 |
|
| 2 |
|
| 3 |
|
| 4 |
|
| 5 |
|
| 6 |
use axum::{ |
| 7 |
extract::Request, |
| 8 |
http::StatusCode, |
| 9 |
middleware::Next, |
| 10 |
response::{IntoResponse, Response}, |
| 11 |
}; |
| 12 |
use rand::RngCore; |
| 13 |
use tower_sessions::Session; |
| 14 |
|
| 15 |
const CSRF_SESSION_KEY: &str = "csrf_token"; |
| 16 |
const CSRF_TOKEN_LENGTH: usize = 32; |
| 17 |
|
| 18 |
|
| 19 |
pub fn generate_token() -> String { |
| 20 |
let mut token = [0u8; CSRF_TOKEN_LENGTH]; |
| 21 |
rand::thread_rng().fill_bytes(&mut token); |
| 22 |
hex::encode(token) |
| 23 |
} |
| 24 |
|
| 25 |
|
| 26 |
pub async fn get_or_create_token(session: &Session) -> String { |
| 27 |
if let Ok(Some(token)) = session.get::<String>(CSRF_SESSION_KEY).await { |
| 28 |
return token; |
| 29 |
} |
| 30 |
let token = generate_token(); |
| 31 |
let _ = session.insert(CSRF_SESSION_KEY, &token).await; |
| 32 |
token |
| 33 |
} |
| 34 |
|
| 35 |
|
| 36 |
pub fn constant_time_compare(a: &str, b: &str) -> bool { |
| 37 |
if a.len() != b.len() { |
| 38 |
return false; |
| 39 |
} |
| 40 |
a.bytes() |
| 41 |
.zip(b.bytes()) |
| 42 |
.fold(0u8, |acc, (x, y)| acc | (x ^ y)) |
| 43 |
== 0 |
| 44 |
} |
| 45 |
|
| 46 |
|
| 47 |
|
| 48 |
|
| 49 |
pub async fn csrf_middleware(request: Request, next: Next) -> Response { |
| 50 |
let method = request.method().clone(); |
| 51 |
|
| 52 |
if !["POST", "PUT", "PATCH", "DELETE"].contains(&method.as_str()) { |
| 53 |
return next.run(request).await; |
| 54 |
} |
| 55 |
|
| 56 |
let path = request.uri().path().to_string(); |
| 57 |
|
| 58 |
let exempt_prefixes = ["/auth/", "/api/health", "/_test/"]; |
| 59 |
if exempt_prefixes.iter().any(|p| path.starts_with(p)) { |
| 60 |
return next.run(request).await; |
| 61 |
} |
| 62 |
|
| 63 |
let session = match request.extensions().get::<Session>() { |
| 64 |
Some(s) => s.clone(), |
| 65 |
None => { |
| 66 |
tracing::warn!("CSRF check failed: no session"); |
| 67 |
return (StatusCode::FORBIDDEN, "CSRF validation failed").into_response(); |
| 68 |
} |
| 69 |
}; |
| 70 |
|
| 71 |
let provided_token = request |
| 72 |
.headers() |
| 73 |
.get("X-CSRF-Token") |
| 74 |
.and_then(|v| v.to_str().ok()) |
| 75 |
.map(|s| s.to_string()); |
| 76 |
|
| 77 |
let token = match provided_token { |
| 78 |
Some(t) => t, |
| 79 |
None => { |
| 80 |
tracing::warn!(path = %path, "CSRF token missing"); |
| 81 |
return (StatusCode::FORBIDDEN, "CSRF token required").into_response(); |
| 82 |
} |
| 83 |
}; |
| 84 |
|
| 85 |
let session_token: Option<String> = session |
| 86 |
.get(CSRF_SESSION_KEY) |
| 87 |
.await |
| 88 |
.ok() |
| 89 |
.flatten(); |
| 90 |
|
| 91 |
match session_token { |
| 92 |
Some(ref expected) if constant_time_compare(expected, &token) => { |
| 93 |
next.run(request).await |
| 94 |
} |
| 95 |
_ => { |
| 96 |
tracing::warn!(path = %path, "CSRF token mismatch"); |
| 97 |
(StatusCode::FORBIDDEN, "Invalid CSRF token").into_response() |
| 98 |
} |
| 99 |
} |
| 100 |
} |
| 101 |
|
| 102 |
#[cfg(test)] |
| 103 |
mod tests { |
| 104 |
use super::*; |
| 105 |
|
| 106 |
#[test] |
| 107 |
fn token_length_and_hex() { |
| 108 |
let token = generate_token(); |
| 109 |
assert_eq!(token.len(), 64); |
| 110 |
assert!(token.chars().all(|c| c.is_ascii_hexdigit())); |
| 111 |
} |
| 112 |
|
| 113 |
#[test] |
| 114 |
fn tokens_are_unique() { |
| 115 |
let a = generate_token(); |
| 116 |
let b = generate_token(); |
| 117 |
assert_ne!(a, b); |
| 118 |
} |
| 119 |
|
| 120 |
#[test] |
| 121 |
fn constant_time_compare_works() { |
| 122 |
assert!(constant_time_compare("abc", "abc")); |
| 123 |
assert!(!constant_time_compare("abc", "abd")); |
| 124 |
assert!(!constant_time_compare("abc", "abcd")); |
| 125 |
assert!(!constant_time_compare("", "a")); |
| 126 |
assert!(constant_time_compare("", "")); |
| 127 |
} |
| 128 |
} |
| 129 |
|