//! CSRF (Cross-Site Request Forgery) protection //! //! Uses the synchronizer token pattern: //! 1. Generate random token on session start //! 2. Include token in forms/meta tag //! 3. Validate token on state-changing requests use axum::{ extract::{FromRequestParts, Request}, handler::Handler, http::{header::HeaderMap, request::Parts, StatusCode}, middleware::{from_fn, Next}, response::{IntoResponse, Response}, routing::{delete, patch, post, put, MethodRouter}, Router, }; use rand::RngCore; use tower_sessions::Session; use crate::error::{AppError, ResultExt}; /// Session key for storing CSRF token pub const CSRF_SESSION_KEY: &str = "csrf_token"; /// CSRF token length in bytes (32 bytes = 256 bits) const CSRF_TOKEN_LENGTH: usize = 32; /// Generate a new CSRF token pub fn generate_token() -> String { let mut token = [0u8; CSRF_TOKEN_LENGTH]; rand::rng().fill_bytes(&mut token); hex::encode(token) } /// Get or create a CSRF token for the session. /// /// `tower-sessions`' `insert` is last-write-wins, so two concurrent first /// requests (e.g. the user opens two tabs while not yet having a token) /// can each generate a fresh token and clobber each other — the first /// form to post then fails with a 403 because its rendered token has /// already been overwritten. /// /// Re-check via `get` after insert: if a different token landed between /// our get and our insert, prefer THAT value over ours so all renders /// from this point forward agree. `String` is `Clone`-cheap, and the /// race window is small enough that the duplicate insert is negligible. pub async fn get_or_create_token(session: &Session) -> Result { if let Some(token) = session .get::(CSRF_SESSION_KEY) .await .context("session error")? { return Ok(token); } let candidate = generate_token(); session .insert(CSRF_SESSION_KEY, &candidate) .await .context("session insert")?; // Re-fetch so a concurrent insert wins consistently — whichever caller // wrote last is what every subsequent render will see, and we hand // back the same value here. let final_token: String = session .get(CSRF_SESSION_KEY) .await .context("session error")? .unwrap_or(candidate); Ok(final_token) } /// Validate a CSRF token against the session token pub async fn validate_token(session: &Session, provided_token: &str) -> Result { let session_token: Option = session .get(CSRF_SESSION_KEY) .await .context("session error")?; match session_token { Some(token) => Ok(crate::helpers::constant_time_compare(&token, provided_token)), None => Ok(false), } } /// Extract CSRF token from request (header or form field) pub fn extract_token_from_request(headers: &HeaderMap, body: Option<&str>) -> Option { // First, try the X-CSRF-Token header (used by HTMX) if let Some(token) = headers .get("X-CSRF-Token") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()) { return Some(token); } // Fall back to the `_csrf` field in form-encoded body (vanilla HTML // forms). We use a proper urlencoded parser instead of `split('&')` // so a textarea containing `&_csrf=attacker-token` can't sneak past // a later field with the wrong value — the parser respects field // ordering and won't conflate textarea content with form fields // because the form encoder percent-encodes `&` inside text values. if let Some(body_str) = body { for (key, value) in url::form_urlencoded::parse(body_str.as_bytes()) { if key == "_csrf" { return Some(value.into_owned()); } } } None } /// Extractor for CSRF token from session pub struct CsrfToken(pub String); impl FromRequestParts for CsrfToken where S: Send + Sync, { type Rejection = AppError; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let session = parts .extensions .get::() .ok_or(AppError::Internal(anyhow::anyhow!("Session not found")))?; let token = get_or_create_token(session).await?; Ok(CsrfToken(token)) } } /// Per-route CSRF posture, declared at the route registration site via the /// `{post,put,patch,delete}_csrf*` helpers. Carried in the helper signatures /// so the choice (and its reason) lives next to the route, not in a sibling /// allowlist file. Not stored at runtime — the reason strings exist for /// source-level documentation and grep, while the structural guarantee comes /// from `CsrfRouter` only accepting `PostureMethodRouter` values. #[derive(Clone, Copy, Debug)] pub enum CsrfPosture { /// Standard validation layer runs (header or form `_csrf`). Auto, /// Handler validates the token itself and proves it with the /// `CsrfManuallyValidated` witness. Reason documents why the /// standard layer can't apply (e.g. "multipart upload"). Manual(&'static str), /// No CSRF check applies. Reason documents why (webhook signature, /// signed link, pre-auth, etc.). Skip(&'static str), } /// Witness type proving a handler ran the standard CSRF validation path. /// The only public way to obtain one is `validate_token_consuming`, which /// performs the check. The private field with a private-module constructor /// makes the value un-fabricable from outside this module — `Default`, /// struct-literal, and `Clone` are all impossible for callers. pub use sealed::CsrfManuallyValidated; mod sealed { pub struct CsrfManuallyValidated { _private: (), } pub(super) fn make_validated() -> CsrfManuallyValidated { CsrfManuallyValidated { _private: () } } } /// Validate a token and return a sealed witness on success. Used by /// handlers registered with `post_csrf_manual` (and method variants) /// that need to validate inside the handler body — typically because the /// global middleware can't read the token for this content type (e.g. /// multipart) or because validation is conditional on request state. pub async fn validate_token_consuming( session: &Session, provided_token: &str, ) -> Result { if validate_token(session, provided_token).await? { Ok(sealed::make_validated()) } else { Err(AppError::Forbidden) } } // Manual-posture runtime assertion (dev/test only): attempted via a tokio // task-local flag set in `validate_token_consuming` and checked in a per- // route layer. Backed out 2026-05-27 — false-positive density was too high: // rendered error pages return 200, rate-limit and form-extraction // short-circuit before the handler, and the audit explicitly marked this // follow-up as "not blocking — only matters if Manual grows beyond one // route". Compile-time discipline (the `CsrfManuallyValidated` witness type // bound as `_validated`) stays the convention. /// Wrap a method-router with the Auto-posture validation layer. /// Runs `validate_auto` on every request that reaches the route. fn attach_auto_layer(method_router: MethodRouter) -> MethodRouter where S: Clone + Send + Sync + 'static, { method_router.layer(from_fn(|req: Request, next: Next| async move { let path = req.uri().path().to_string(); validate_auto(req, next, &path).await })) } /// A `MethodRouter` that has been through one of the CSRF helpers. Field /// is private and constructible only inside this module, so /// `CsrfRouter::route` will not accept a bare `axum::routing::post(handler)` /// — route files have to use the helpers, by construction. pub use posture_router::PostureMethodRouter; mod posture_router { use super::*; pub struct PostureMethodRouter(pub(super) MethodRouter); impl PostureMethodRouter where S: Clone + Send + Sync + 'static, { pub(super) fn new(inner: MethodRouter) -> Self { Self(inner) } pub(super) fn into_inner(self) -> MethodRouter { self.0 } /// Attach an additional tower layer (e.g. a rate limiter) to the /// underlying method router. Returns `Self` so callers don't lose /// the posture stamp. pub fn layer(self, layer: L) -> Self where L: tower::Layer + Clone + Send + Sync + 'static, L::Service: tower::Service + Clone + Send + Sync + 'static, >::Response: axum::response::IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, { Self(self.0.layer(layer)) } } } macro_rules! csrf_auto_helper { ($name:ident, $axum_fn:ident) => { pub fn $name(handler: H) -> PostureMethodRouter where H: Handler, T: 'static, S: Clone + Send + Sync + 'static, { posture_router::PostureMethodRouter::new(attach_auto_layer($axum_fn(handler))) } }; } macro_rules! csrf_passthrough_helper { ($name:ident, $axum_fn:ident, $variant:ident) => { pub fn $name(reason: &'static str, handler: H) -> PostureMethodRouter where H: Handler, T: 'static, S: Clone + Send + Sync + 'static, { let _ = CsrfPosture::$variant(reason); posture_router::PostureMethodRouter::new($axum_fn(handler)) } }; } // Auto posture: standard CSRF validation (header or form `_csrf`). csrf_auto_helper!(post_csrf, post); csrf_auto_helper!(put_csrf, put); csrf_auto_helper!(patch_csrf, patch); csrf_auto_helper!(delete_csrf, delete); // Manual posture: handler validates via `validate_token_consuming`. csrf_passthrough_helper!(post_csrf_manual, post, Manual); csrf_passthrough_helper!(put_csrf_manual, put, Manual); csrf_passthrough_helper!(patch_csrf_manual, patch, Manual); csrf_passthrough_helper!(delete_csrf_manual, delete, Manual); // Skip posture: no CSRF check. Reason documents why. csrf_passthrough_helper!(post_csrf_skip, post, Skip); csrf_passthrough_helper!(put_csrf_skip, put, Skip); csrf_passthrough_helper!(patch_csrf_skip, patch, Skip); csrf_passthrough_helper!(delete_csrf_skip, delete, Skip); // --- Wrappers for multi-method routes ------------------------------------ // // A handful of routes register multiple HTTP methods on one path // (e.g. `get(list).post(create)`). The handler-taking helpers above can't // compose with these because the chain is already a `MethodRouter`. These // wrappers take a pre-built `MethodRouter` and stamp it as a // `PostureMethodRouter`. Read methods (GET/HEAD) are unaffected — the // Auto validation layer only intercepts state-changing methods at the // per-route level because that's what the helper attached to. /// Wrap a multi-method chain with the Auto-posture validation layer. pub fn with_csrf(method_router: MethodRouter) -> PostureMethodRouter where S: Clone + Send + Sync + 'static, { posture_router::PostureMethodRouter::new(attach_auto_layer(method_router)) } /// Stamp a multi-method chain as Manual — handler is responsible for /// calling `validate_token_consuming`. pub fn with_csrf_manual( reason: &'static str, method_router: MethodRouter, ) -> PostureMethodRouter where S: Clone + Send + Sync + 'static, { let _ = CsrfPosture::Manual(reason); posture_router::PostureMethodRouter::new(method_router) } /// Stamp a multi-method chain as Skip — no CSRF check applies. pub fn with_csrf_skip( reason: &'static str, method_router: MethodRouter, ) -> PostureMethodRouter where S: Clone + Send + Sync + 'static, { let _ = CsrfPosture::Skip(reason); posture_router::PostureMethodRouter::new(method_router) } // --- CsrfRouter: structural enforcement ---------------------------------- // // `CsrfRouter` is the only way to register a mutation route in this // codebase. Its `route` method takes a `PostureMethodRouter`, whose // constructor is private to this module, so the only producers are the // helpers above. A bare `Router::route(path, post(handler))` cannot // reach a mounted `CsrfRouter` without going through `finalize()` first, // which is only called once in `build_app`. pub struct CsrfRouter(Router); impl Default for CsrfRouter where S: Clone + Send + Sync + 'static, { fn default() -> Self { Self::new() } } impl CsrfRouter where S: Clone + Send + Sync + 'static, { pub fn new() -> Self { Self(Router::new()) } pub fn route(self, path: &str, posture: PostureMethodRouter) -> Self { Self(self.0.route(path, posture.into_inner())) } /// Register a read-only route (GET / HEAD / OPTIONS). The structural /// guarantee only constrains state-changing methods, so read-only /// `MethodRouter`s pass through unchanged. Calling this with a /// `MethodRouter` that includes POST/PUT/PATCH/DELETE compiles, but /// readers can see the intent at the call site — and any mutation /// route registered through `route_get` is a bug visible in review. pub fn route_get(self, path: &str, method_router: MethodRouter) -> Self { Self(self.0.route(path, method_router)) } pub fn merge(self, other: Self) -> Self { Self(self.0.merge(other.0)) } pub fn nest(self, path: &str, other: Self) -> Self { Self(self.0.nest(path, other.0)) } pub fn layer(self, layer: L) -> Self where L: tower::Layer + Clone + Send + Sync + 'static, L::Service: tower::Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, { Self(self.0.layer(layer)) } pub fn route_layer(self, layer: L) -> Self where L: tower::Layer + Clone + Send + Sync + 'static, L::Service: tower::Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, { Self(self.0.route_layer(layer)) } /// Drop the structural envelope and return the underlying `Router`. /// Called once in `build_app` after all mutation routes have been /// registered; downstream code may then attach global layers, mount /// static-file services, and add GET-only routes. pub fn finalize(self) -> Router { self.0 } } /// Standard CSRF validation: header `X-CSRF-Token` first, then form-body /// `_csrf` for authenticated users. Used by `CsrfPosture::Auto` routes /// and by the path-allowlist fallback during the L2 migration. async fn validate_auto(request: Request, next: Next, path: &str) -> Response { // Safe methods (RFC 9110 §9.2.1) are read-only by definition — never // CSRF-check them. This matters for multi-method routes wrapped by // `with_csrf(get(load).post(save))`: a bare GET should not require a // token (and the harness doesn't send one for GETs). if !matches!(*request.method(), axum::http::Method::POST | axum::http::Method::PUT | axum::http::Method::PATCH | axum::http::Method::DELETE) { return next.run(request).await; } // Get session from extensions let session = match request.extensions().get::() { Some(s) => s.clone(), None => { tracing::warn!("CSRF check failed: no session"); return (StatusCode::FORBIDDEN, "CSRF validation failed").into_response(); } }; // Try header first (HTMX requests) let header_token = request .headers() .get("X-CSRF-Token") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); if let Some(ref token) = header_token { return match validate_token(&session, token).await { Ok(true) => next.run(request).await, Ok(false) => { tracing::warn!(path = %path, "CSRF token mismatch"); crate::error::AppError::Forbidden.into_response() } Err(e) => { tracing::error!(error = ?e, "CSRF validation error"); crate::error::AppError::Internal(anyhow::anyhow!("CSRF validation error")).into_response() } }; } // No header token — check if the user is authenticated let has_user: bool = session .get::("user") .await .ok() .flatten() .is_some(); if !has_user { return next.run(request).await; } // Authenticated user without header token — check form body for `_csrf`. // We only parse `application/x-www-form-urlencoded`. Other content // types are rejected here: // - `multipart/form-data` is the closest near-miss: it has its own // `_csrf` part but parsing it would mean pulling in a multipart // decoder and buffering the entire upload body, defeating the // upload-size limit. The codebase doesn't currently use multipart // forms (uploads go through HTMX + fetch, which attach // `X-CSRF-Token` on the header path above), so rejecting here is // the explicit boundary. If multipart adoption ever becomes // necessary, register the route with `post_csrf_manual` and have // the handler stream the body through a multipart parser before // calling `validate_token_consuming`. // - `application/json` and others must use the `X-CSRF-Token` // header — anything that can set a custom header can set this one. let content_type = request .headers() .get("content-type") .and_then(|v| v.to_str().ok()) .unwrap_or(""); let is_form = content_type.starts_with("application/x-www-form-urlencoded"); if !is_form { let is_multipart = content_type.starts_with("multipart/form-data"); tracing::warn!( path = %path, content_type, is_multipart, "CSRF token missing for authenticated non-form request" ); return crate::error::AppError::Forbidden.into_response(); } // Buffer the body to extract _csrf, then reconstruct the request. // Limit matches the global RequestBodyLimitLayer (1 MB) so that any // form body accepted by the server can have its CSRF token extracted. let (parts, body) = request.into_parts(); let bytes = match axum::body::to_bytes(body, 1024 * 1024).await { Ok(b) => b, Err(_) => { return (StatusCode::BAD_REQUEST, "Request body too large").into_response(); } }; let body_str = String::from_utf8_lossy(&bytes); let body_token = extract_token_from_request(&HeaderMap::new(), Some(&body_str)); let token = match body_token { Some(t) => t, None => { tracing::warn!(path = %path, "CSRF token missing from form body"); return crate::error::AppError::Forbidden.into_response(); } }; match validate_token(&session, &token).await { Ok(true) => { // Reconstruct request with the buffered body let request = Request::from_parts(parts, axum::body::Body::from(bytes)); next.run(request).await } Ok(false) => { tracing::warn!(path = %path, "CSRF token mismatch"); (StatusCode::FORBIDDEN, "Invalid CSRF token").into_response() } Err(e) => { tracing::error!(error = ?e, "CSRF validation error"); (StatusCode::INTERNAL_SERVER_ERROR, "CSRF validation error").into_response() } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_generate_token() { let token1 = generate_token(); let token2 = generate_token(); // Tokens should be 64 hex characters (32 bytes) assert_eq!(token1.len(), 64); assert_eq!(token2.len(), 64); // Tokens should be different assert_ne!(token1, token2); } #[test] fn test_constant_time_compare() { use crate::helpers::constant_time_compare; assert!(constant_time_compare("abc", "abc")); assert!(!constant_time_compare("abc", "abd")); assert!(!constant_time_compare("abc", "abcd")); assert!(!constant_time_compare("", "a")); } #[test] fn test_generate_token_is_hex() { let token = generate_token(); // Should be valid hex assert!(token.chars().all(|c| c.is_ascii_hexdigit())); } #[test] fn test_extract_token_from_header() { let mut headers = HeaderMap::new(); headers.insert("X-CSRF-Token", "abc123".parse().unwrap()); let token = extract_token_from_request(&headers, None); assert_eq!(token.as_deref(), Some("abc123")); } #[test] fn test_extract_token_from_form_body() { let headers = HeaderMap::new(); let body = "name=value&_csrf=mytoken123&other=data"; let token = extract_token_from_request(&headers, Some(body)); assert_eq!(token.as_deref(), Some("mytoken123")); } #[test] fn test_extract_token_missing() { let headers = HeaderMap::new(); let token = extract_token_from_request(&headers, None); assert!(token.is_none()); } #[test] fn test_generate_token_unique_across_many() { let tokens: Vec = (0..100).map(|_| generate_token()).collect(); let unique: std::collections::HashSet<&String> = tokens.iter().collect(); assert_eq!(unique.len(), 100, "all 100 tokens should be unique"); } #[test] fn test_generate_token_correct_byte_length() { let token = generate_token(); let bytes = hex::decode(&token).expect("token should be valid hex"); assert_eq!(bytes.len(), CSRF_TOKEN_LENGTH); } #[test] fn test_extract_token_header_takes_priority_over_body() { let mut headers = HeaderMap::new(); headers.insert("X-CSRF-Token", "header_token".parse().unwrap()); let body = "_csrf=body_token"; let token = extract_token_from_request(&headers, Some(body)); assert_eq!(token.as_deref(), Some("header_token")); } #[test] fn test_extract_token_from_body_url_encoded() { let headers = HeaderMap::new(); let body = "_csrf=token%20with%20spaces&other=val"; let token = extract_token_from_request(&headers, Some(body)); assert_eq!(token.as_deref(), Some("token with spaces")); } #[test] fn test_extract_token_csrf_at_start_of_body() { let headers = HeaderMap::new(); let body = "_csrf=firstfield&name=value"; let token = extract_token_from_request(&headers, Some(body)); assert_eq!(token.as_deref(), Some("firstfield")); } #[test] fn test_extract_token_csrf_at_end_of_body() { let headers = HeaderMap::new(); let body = "name=value&_csrf=lastfield"; let token = extract_token_from_request(&headers, Some(body)); assert_eq!(token.as_deref(), Some("lastfield")); } #[test] fn test_extract_token_empty_body() { let headers = HeaderMap::new(); let token = extract_token_from_request(&headers, Some("")); assert!(token.is_none()); } #[test] fn test_extract_token_body_without_csrf_field() { let headers = HeaderMap::new(); let body = "name=value&other=data"; let token = extract_token_from_request(&headers, Some(body)); assert!(token.is_none()); } #[test] fn test_extract_token_csrf_prefix_mismatch() { let headers = HeaderMap::new(); // Field named "_csrfx" should NOT match "_csrf=" let body = "_csrfx=notreal"; let token = extract_token_from_request(&headers, Some(body)); assert!(token.is_none()); } #[test] fn test_extract_token_empty_csrf_value() { let headers = HeaderMap::new(); let body = "_csrf=&other=val"; let token = extract_token_from_request(&headers, Some(body)); assert_eq!(token.as_deref(), Some("")); } #[test] fn test_constant_time_compare_empty_strings() { use crate::helpers::constant_time_compare; assert!(constant_time_compare("", "")); } #[test] fn test_constant_time_compare_near_miss() { use crate::helpers::constant_time_compare; let token = generate_token(); // Flip last character let mut tampered = token.clone(); let last = tampered.pop().unwrap(); tampered.push(if last == '0' { '1' } else { '0' }); assert!(!constant_time_compare(&token, &tampered)); } #[test] fn csrf_manually_validated_marker_is_zero_sized() { assert_eq!(std::mem::size_of::(), 0); } #[test] fn csrf_posture_is_copyable_and_carries_reason() { let p = CsrfPosture::Skip("webhook: stripe signature"); let copy = p; match copy { CsrfPosture::Skip(r) => assert_eq!(r, "webhook: stripe signature"), _ => panic!("variant mismatch"), } } #[test] fn test_constant_time_compare_truncated() { use crate::helpers::constant_time_compare; let token = generate_token(); let truncated = &token[..token.len() - 1]; assert!(!constant_time_compare(&token, truncated)); } }