Skip to main content

max / multithreaded

4.3 KB · 131 lines History Blame Raw
1 //! HMAC-SHA256 authentication for internal API requests from MNW.
2 //!
3 //! MNW signs requests with `HMAC-SHA256(timestamp + "\n" + body, secret)`.
4 //! The signature and timestamp are sent in `X-Internal-Signature` and
5 //! `X-Internal-Timestamp` headers. Requests older than 60 seconds are rejected.
6
7 use axum::{
8 body::Bytes,
9 extract::{FromRequest, Request},
10 http::StatusCode,
11 response::{IntoResponse, Response},
12 };
13 use hmac::{Hmac, Mac};
14 use sha2::Sha256;
15
16 use crate::AppState;
17
18 /// Maximum age (in seconds) for an internal request timestamp before it's rejected.
19 const MAX_TIMESTAMP_AGE_SECS: i64 = 60;
20
21 /// Axum extractor that validates HMAC-SHA256 signatures on internal API requests.
22 /// Extracts the raw request body as `Bytes` after successful verification.
23 pub struct InternalAuth(pub Bytes);
24
25 impl FromRequest<AppState> for InternalAuth {
26 type Rejection = Response;
27
28 async fn from_request(req: Request, state: &AppState) -> Result<Self, Self::Rejection> {
29 let secret = state
30 .config
31 .internal_shared_secret
32 .as_deref()
33 .ok_or_else(|| {
34 tracing::warn!("internal API called but INTERNAL_SHARED_SECRET not configured");
35 StatusCode::SERVICE_UNAVAILABLE.into_response()
36 })?;
37
38 let timestamp_str = req
39 .headers()
40 .get("X-Internal-Timestamp")
41 .and_then(|v| v.to_str().ok())
42 .ok_or_else(|| {
43 (StatusCode::UNAUTHORIZED, "Missing X-Internal-Timestamp").into_response()
44 })?
45 .to_string();
46
47 let signature = req
48 .headers()
49 .get("X-Internal-Signature")
50 .and_then(|v| v.to_str().ok())
51 .ok_or_else(|| {
52 (StatusCode::UNAUTHORIZED, "Missing X-Internal-Signature").into_response()
53 })?
54 .to_string();
55
56 // Verify timestamp freshness
57 let timestamp: i64 = timestamp_str.parse().map_err(|_| {
58 (StatusCode::UNAUTHORIZED, "Invalid timestamp").into_response()
59 })?;
60
61 let now = chrono::Utc::now().timestamp();
62 if (now - timestamp).abs() > MAX_TIMESTAMP_AGE_SECS {
63 return Err(
64 (StatusCode::UNAUTHORIZED, "Timestamp too old or too far in the future")
65 .into_response(),
66 );
67 }
68
69 // Read body
70 let body = Bytes::from_request(req, state).await.map_err(|e| {
71 tracing::error!(error = %e, "failed to read request body");
72 StatusCode::BAD_REQUEST.into_response()
73 })?;
74
75 // Verify HMAC
76 let message = format!("{}\n{}", timestamp_str, std::str::from_utf8(&body).unwrap_or(""));
77 let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
78 .expect("HMAC-SHA256 accepts any key length");
79 mac.update(message.as_bytes());
80 let expected = hex::encode(mac.finalize().into_bytes());
81
82 if !constant_time_eq(expected.as_bytes(), signature.as_bytes()) {
83 return Err((StatusCode::UNAUTHORIZED, "Invalid signature").into_response());
84 }
85
86 Ok(InternalAuth(body))
87 }
88 }
89
90 /// Constant-time byte comparison to prevent timing attacks.
91 fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
92 if a.len() != b.len() {
93 return false;
94 }
95 a.iter()
96 .zip(b.iter())
97 .fold(0u8, |acc, (x, y)| acc | (x ^ y))
98 == 0
99 }
100
101 #[cfg(test)]
102 mod tests {
103 use super::*;
104
105 #[test]
106 fn constant_time_eq_works() {
107 assert!(constant_time_eq(b"hello", b"hello"));
108 assert!(!constant_time_eq(b"hello", b"world"));
109 assert!(!constant_time_eq(b"hello", b"hell"));
110 }
111
112 #[test]
113 fn hmac_signature_roundtrip() {
114 let secret = "test-secret";
115 let timestamp = "1234567890";
116 let body = r#"{"name":"test"}"#;
117 let message = format!("{}\n{}", timestamp, body);
118
119 let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
120 mac.update(message.as_bytes());
121 let sig = hex::encode(mac.finalize().into_bytes());
122
123 // Verify the same computation matches
124 let mut mac2 = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
125 mac2.update(message.as_bytes());
126 let expected = hex::encode(mac2.finalize().into_bytes());
127
128 assert!(constant_time_eq(sig.as_bytes(), expected.as_bytes()));
129 }
130 }
131