//! CSRF (Cross-Site Request Forgery) protection. //! //! Synchronizer token pattern: generate random token on session start, //! include in meta tag, validate on state-changing requests via X-CSRF-Token header. use axum::{ extract::Request, http::StatusCode, middleware::Next, response::{IntoResponse, Response}, }; use rand::RngCore; use tower_sessions::Session; const CSRF_SESSION_KEY: &str = "csrf_token"; const CSRF_TOKEN_LENGTH: usize = 32; /// Generate a new random CSRF token (64-char hex string). pub fn generate_token() -> String { let mut token = [0u8; CSRF_TOKEN_LENGTH]; rand::thread_rng().fill_bytes(&mut token); hex::encode(token) } /// Get the existing CSRF token from the session, or create and store a new one. pub async fn get_or_create_token(session: &Session) -> String { if let Ok(Some(token)) = session.get::(CSRF_SESSION_KEY).await { return token; } let token = generate_token(); let _ = session.insert(CSRF_SESSION_KEY, &token).await; token } /// Constant-time comparison to prevent timing attacks. pub fn constant_time_compare(a: &str, b: &str) -> bool { if a.len() != b.len() { return false; } a.bytes() .zip(b.bytes()) .fold(0u8, |acc, (x, y)| acc | (x ^ y)) == 0 } /// Middleware: validate X-CSRF-Token header on POST/PUT/PATCH/DELETE. /// /// Exempt paths: `/auth/`, `/api/health`, `/_test/`. pub async fn csrf_middleware(request: Request, next: Next) -> Response { let method = request.method().clone(); if !["POST", "PUT", "PATCH", "DELETE"].contains(&method.as_str()) { return next.run(request).await; } let path = request.uri().path().to_string(); let exempt_prefixes = ["/auth/", "/api/health", "/_test/"]; if exempt_prefixes.iter().any(|p| path.starts_with(p)) { return next.run(request).await; } let session = match request.extensions().get::() { Some(s) => s.clone(), None => { tracing::warn!("CSRF check failed: no session"); return (StatusCode::FORBIDDEN, "CSRF validation failed").into_response(); } }; let provided_token = request .headers() .get("X-CSRF-Token") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); let token = match provided_token { Some(t) => t, None => { tracing::warn!(path = %path, "CSRF token missing"); return (StatusCode::FORBIDDEN, "CSRF token required").into_response(); } }; let session_token: Option = session .get(CSRF_SESSION_KEY) .await .ok() .flatten(); match session_token { Some(ref expected) if constant_time_compare(expected, &token) => { next.run(request).await } _ => { tracing::warn!(path = %path, "CSRF token mismatch"); (StatusCode::FORBIDDEN, "Invalid CSRF token").into_response() } } } #[cfg(test)] mod tests { use super::*; #[test] fn token_length_and_hex() { let token = generate_token(); assert_eq!(token.len(), 64); assert!(token.chars().all(|c| c.is_ascii_hexdigit())); } #[test] fn tokens_are_unique() { let a = generate_token(); let b = generate_token(); assert_ne!(a, b); } #[test] fn constant_time_compare_works() { assert!(constant_time_compare("abc", "abc")); assert!(!constant_time_compare("abc", "abd")); assert!(!constant_time_compare("abc", "abcd")); assert!(!constant_time_compare("", "a")); assert!(constant_time_compare("", "")); } }