| 1 |
|
| 2 |
|
| 3 |
|
| 4 |
use base64::Engine; |
| 5 |
use tower_governor::errors::GovernorError; |
| 6 |
use tower_governor::key_extractor::KeyExtractor; |
| 7 |
|
| 8 |
use crate::db::SyncAppId; |
| 9 |
|
| 10 |
|
| 11 |
|
| 12 |
|
| 13 |
|
| 14 |
|
| 15 |
|
| 16 |
|
| 17 |
|
| 18 |
|
| 19 |
|
| 20 |
|
| 21 |
|
| 22 |
|
| 23 |
|
| 24 |
|
| 25 |
|
| 26 |
|
| 27 |
|
| 28 |
#[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| 29 |
pub struct CloudflareIpKeyExtractor; |
| 30 |
|
| 31 |
impl KeyExtractor for CloudflareIpKeyExtractor { |
| 32 |
type Key = std::net::IpAddr; |
| 33 |
|
| 34 |
fn extract<T>(&self, req: &axum::http::Request<T>) -> Result<Self::Key, GovernorError> { |
| 35 |
if let Some(ip) = req |
| 36 |
.headers() |
| 37 |
.get("cf-connecting-ip") |
| 38 |
.and_then(|v: &axum::http::HeaderValue| v.to_str().ok()) |
| 39 |
.and_then(|s: &str| s.trim().parse::<std::net::IpAddr>().ok()) |
| 40 |
{ |
| 41 |
return Ok(ip); |
| 42 |
} |
| 43 |
|
| 44 |
tower_governor::key_extractor::PeerIpKeyExtractor.extract(req) |
| 45 |
} |
| 46 |
} |
| 47 |
|
| 48 |
|
| 49 |
|
| 50 |
|
| 51 |
|
| 52 |
|
| 53 |
|
| 54 |
|
| 55 |
|
| 56 |
|
| 57 |
|
| 58 |
|
| 59 |
|
| 60 |
|
| 61 |
|
| 62 |
|
| 63 |
|
| 64 |
#[derive(Debug, Clone)] |
| 65 |
pub struct SyncAppKeyExtractor { |
| 66 |
|
| 67 |
|
| 68 |
|
| 69 |
secret: Option<std::sync::Arc<String>>, |
| 70 |
} |
| 71 |
|
| 72 |
impl SyncAppKeyExtractor { |
| 73 |
pub fn new(secret: Option<std::sync::Arc<String>>) -> Self { |
| 74 |
Self { secret } |
| 75 |
} |
| 76 |
|
| 77 |
|
| 78 |
|
| 79 |
fn verify_app(secret: &str, token: &str) -> Option<SyncAppId> { |
| 80 |
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; |
| 81 |
|
| 82 |
#[derive(serde::Deserialize)] |
| 83 |
struct AppClaim { |
| 84 |
app: SyncAppId, |
| 85 |
} |
| 86 |
|
| 87 |
let mut validation = Validation::new(Algorithm::HS256); |
| 88 |
|
| 89 |
|
| 90 |
|
| 91 |
validation.validate_exp = false; |
| 92 |
validation.required_spec_claims.clear(); |
| 93 |
|
| 94 |
decode::<AppClaim>(token, &DecodingKey::from_secret(secret.as_bytes()), &validation) |
| 95 |
.ok() |
| 96 |
.map(|data| data.claims.app) |
| 97 |
} |
| 98 |
|
| 99 |
|
| 100 |
fn parse_app_unverified(token: &str) -> Option<SyncAppId> { |
| 101 |
let payload_b64 = token.split('.').nth(1)?; |
| 102 |
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD |
| 103 |
.decode(payload_b64) |
| 104 |
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(payload_b64)) |
| 105 |
.ok()?; |
| 106 |
|
| 107 |
#[derive(serde::Deserialize)] |
| 108 |
struct AppClaim { |
| 109 |
app: SyncAppId, |
| 110 |
} |
| 111 |
serde_json::from_slice::<AppClaim>(&payload_bytes).ok().map(|c| c.app) |
| 112 |
} |
| 113 |
} |
| 114 |
|
| 115 |
impl KeyExtractor for SyncAppKeyExtractor { |
| 116 |
type Key = SyncAppId; |
| 117 |
|
| 118 |
fn extract<T>(&self, req: &axum::http::Request<T>) -> Result<Self::Key, GovernorError> { |
| 119 |
let token = match req |
| 120 |
.headers() |
| 121 |
.get("authorization") |
| 122 |
.and_then(|v| v.to_str().ok()) |
| 123 |
.and_then(|s| s.strip_prefix("Bearer ")) |
| 124 |
{ |
| 125 |
Some(t) => t, |
| 126 |
|
| 127 |
|
| 128 |
None => return Ok(SyncAppId::nil()), |
| 129 |
}; |
| 130 |
|
| 131 |
let app = match &self.secret { |
| 132 |
Some(secret) => Self::verify_app(secret, token), |
| 133 |
None => Self::parse_app_unverified(token), |
| 134 |
}; |
| 135 |
|
| 136 |
|
| 137 |
|
| 138 |
Ok(app.unwrap_or_else(SyncAppId::nil)) |
| 139 |
} |
| 140 |
} |
| 141 |
|
| 142 |
|
| 143 |
|
| 144 |
|
| 145 |
|
| 146 |
|
| 147 |
|
| 148 |
|
| 149 |
|
| 150 |
|
| 151 |
|
| 152 |
|
| 153 |
|
| 154 |
|
| 155 |
|
| 156 |
static GOVERNOR_SWEEPERS: std::sync::Mutex<Vec<Box<dyn Fn() -> usize + Send + Sync>>> = |
| 157 |
std::sync::Mutex::new(Vec::new()); |
| 158 |
|
| 159 |
|
| 160 |
|
| 161 |
fn register_for_sweep(hook: impl Fn() -> usize + Send + Sync + 'static) { |
| 162 |
if let Ok(mut hooks) = GOVERNOR_SWEEPERS.lock() { |
| 163 |
hooks.push(Box::new(hook)); |
| 164 |
} |
| 165 |
} |
| 166 |
|
| 167 |
|
| 168 |
|
| 169 |
|
| 170 |
pub fn start_governor_sweeper() { |
| 171 |
static STARTED: std::sync::Once = std::sync::Once::new(); |
| 172 |
STARTED.call_once(|| { |
| 173 |
tokio::spawn(async { |
| 174 |
let interval = |
| 175 |
std::time::Duration::from_secs(crate::constants::GOVERNOR_SWEEP_INTERVAL_SECS); |
| 176 |
loop { |
| 177 |
tokio::time::sleep(interval).await; |
| 178 |
|
| 179 |
let (limiters, retained) = { |
| 180 |
let Ok(hooks) = GOVERNOR_SWEEPERS.lock() else { |
| 181 |
continue; |
| 182 |
}; |
| 183 |
let retained: usize = hooks.iter().map(|hook| hook()).sum(); |
| 184 |
(hooks.len(), retained) |
| 185 |
}; |
| 186 |
tracing::debug!(limiters, retained_keys = retained, "swept governor bucket maps"); |
| 187 |
} |
| 188 |
}); |
| 189 |
}); |
| 190 |
} |
| 191 |
|
| 192 |
|
| 193 |
pub fn rate_limiter_ms( |
| 194 |
ms: u64, |
| 195 |
burst: u32, |
| 196 |
) -> std::sync::Arc< |
| 197 |
tower_governor::governor::GovernorConfig< |
| 198 |
CloudflareIpKeyExtractor, |
| 199 |
::governor::middleware::StateInformationMiddleware, |
| 200 |
>, |
| 201 |
> { |
| 202 |
let config = std::sync::Arc::new( |
| 203 |
tower_governor::governor::GovernorConfigBuilder::default() |
| 204 |
.key_extractor(CloudflareIpKeyExtractor) |
| 205 |
.per_millisecond(ms) |
| 206 |
.burst_size(burst) |
| 207 |
.use_headers() |
| 208 |
.finish() |
| 209 |
.expect("rate limiter config"), |
| 210 |
); |
| 211 |
let limiter = config.limiter().clone(); |
| 212 |
register_for_sweep(move || { |
| 213 |
limiter.retain_recent(); |
| 214 |
limiter.len() |
| 215 |
}); |
| 216 |
config |
| 217 |
} |
| 218 |
|
| 219 |
|
| 220 |
pub fn rate_limiter_per_sec( |
| 221 |
per_sec: u64, |
| 222 |
burst: u32, |
| 223 |
) -> std::sync::Arc< |
| 224 |
tower_governor::governor::GovernorConfig< |
| 225 |
CloudflareIpKeyExtractor, |
| 226 |
::governor::middleware::StateInformationMiddleware, |
| 227 |
>, |
| 228 |
> { |
| 229 |
let config = std::sync::Arc::new( |
| 230 |
tower_governor::governor::GovernorConfigBuilder::default() |
| 231 |
.key_extractor(CloudflareIpKeyExtractor) |
| 232 |
.per_second(per_sec) |
| 233 |
.burst_size(burst) |
| 234 |
.use_headers() |
| 235 |
.finish() |
| 236 |
.expect("rate limiter config"), |
| 237 |
); |
| 238 |
let limiter = config.limiter().clone(); |
| 239 |
register_for_sweep(move || { |
| 240 |
limiter.retain_recent(); |
| 241 |
limiter.len() |
| 242 |
}); |
| 243 |
config |
| 244 |
} |
| 245 |
|
| 246 |
|
| 247 |
|
| 248 |
|
| 249 |
|
| 250 |
pub fn synckit_app_rate_limiter_ms( |
| 251 |
secret: Option<std::sync::Arc<String>>, |
| 252 |
ms: u64, |
| 253 |
burst: u32, |
| 254 |
) -> std::sync::Arc< |
| 255 |
tower_governor::governor::GovernorConfig< |
| 256 |
SyncAppKeyExtractor, |
| 257 |
::governor::middleware::StateInformationMiddleware, |
| 258 |
>, |
| 259 |
> { |
| 260 |
let config = std::sync::Arc::new( |
| 261 |
tower_governor::governor::GovernorConfigBuilder::default() |
| 262 |
.key_extractor(SyncAppKeyExtractor::new(secret)) |
| 263 |
.per_millisecond(ms) |
| 264 |
.burst_size(burst) |
| 265 |
.use_headers() |
| 266 |
.finish() |
| 267 |
.expect("synckit app rate limiter config"), |
| 268 |
); |
| 269 |
let limiter = config.limiter().clone(); |
| 270 |
register_for_sweep(move || { |
| 271 |
limiter.retain_recent(); |
| 272 |
limiter.len() |
| 273 |
}); |
| 274 |
config |
| 275 |
} |
| 276 |
|
| 277 |
#[cfg(test)] |
| 278 |
mod tests { |
| 279 |
use super::*; |
| 280 |
use axum::http::Request; |
| 281 |
use tower_governor::key_extractor::KeyExtractor; |
| 282 |
|
| 283 |
|
| 284 |
fn fake_jwt(app_id: &SyncAppId) -> String { |
| 285 |
let header = base64::engine::general_purpose::URL_SAFE_NO_PAD |
| 286 |
.encode(r#"{"alg":"HS256","typ":"JWT"}"#); |
| 287 |
let payload_json = serde_json::json!({ |
| 288 |
"sub": "00000000-0000-0000-0000-000000000001", |
| 289 |
"app": app_id, |
| 290 |
"iss": "makenotwork-synckit", |
| 291 |
"exp": 9999999999_i64, |
| 292 |
"iat": 1000000000_i64, |
| 293 |
}); |
| 294 |
let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD |
| 295 |
.encode(payload_json.to_string()); |
| 296 |
let sig = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode("fakesig"); |
| 297 |
format!("{header}.{payload}.{sig}") |
| 298 |
} |
| 299 |
|
| 300 |
|
| 301 |
fn signed_jwt(secret: &str, app_id: &SyncAppId) -> String { |
| 302 |
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header}; |
| 303 |
let claims = serde_json::json!({ |
| 304 |
"sub": "00000000-0000-0000-0000-000000000001", |
| 305 |
"app": app_id, |
| 306 |
"iss": "makenotwork-synckit", |
| 307 |
"exp": 9999999999_i64, |
| 308 |
"iat": 1000000000_i64, |
| 309 |
}); |
| 310 |
encode( |
| 311 |
&Header::new(Algorithm::HS256), |
| 312 |
&claims, |
| 313 |
&EncodingKey::from_secret(secret.as_bytes()), |
| 314 |
) |
| 315 |
.unwrap() |
| 316 |
} |
| 317 |
|
| 318 |
#[test] |
| 319 |
fn unverified_extracts_app_id_from_jwt_when_no_secret() { |
| 320 |
|
| 321 |
let app_id = SyncAppId::new(); |
| 322 |
let token = fake_jwt(&app_id); |
| 323 |
|
| 324 |
let req = Request::builder() |
| 325 |
.header("authorization", format!("Bearer {token}")) |
| 326 |
.body(()) |
| 327 |
.unwrap(); |
| 328 |
|
| 329 |
let extracted = SyncAppKeyExtractor::new(None).extract(&req).unwrap(); |
| 330 |
assert_eq!(extracted, app_id); |
| 331 |
} |
| 332 |
|
| 333 |
#[test] |
| 334 |
fn verified_extracts_app_id_from_validly_signed_jwt() { |
| 335 |
let secret = "test-secret-key-for-synckit-jwt".to_string(); |
| 336 |
let app_id = SyncAppId::new(); |
| 337 |
let token = signed_jwt(&secret, &app_id); |
| 338 |
|
| 339 |
let req = Request::builder() |
| 340 |
.header("authorization", format!("Bearer {token}")) |
| 341 |
.body(()) |
| 342 |
.unwrap(); |
| 343 |
|
| 344 |
let extractor = SyncAppKeyExtractor::new(Some(std::sync::Arc::new(secret))); |
| 345 |
assert_eq!(extractor.extract(&req).unwrap(), app_id); |
| 346 |
} |
| 347 |
|
| 348 |
#[test] |
| 349 |
fn verified_forged_token_collapses_to_nil_bucket() { |
| 350 |
|
| 351 |
|
| 352 |
|
| 353 |
|
| 354 |
let secret = "test-secret-key-for-synckit-jwt".to_string(); |
| 355 |
let attacker_app = SyncAppId::new(); |
| 356 |
let forged = fake_jwt(&attacker_app); |
| 357 |
|
| 358 |
let req = Request::builder() |
| 359 |
.header("authorization", format!("Bearer {forged}")) |
| 360 |
.body(()) |
| 361 |
.unwrap(); |
| 362 |
|
| 363 |
let extractor = SyncAppKeyExtractor::new(Some(std::sync::Arc::new(secret))); |
| 364 |
assert_eq!(extractor.extract(&req).unwrap(), SyncAppId::nil()); |
| 365 |
} |
| 366 |
|
| 367 |
#[test] |
| 368 |
fn verified_spray_of_forged_apps_all_share_one_bucket() { |
| 369 |
|
| 370 |
|
| 371 |
let secret = std::sync::Arc::new("test-secret-key-for-synckit-jwt".to_string()); |
| 372 |
let extractor = SyncAppKeyExtractor::new(Some(secret)); |
| 373 |
for _ in 0..5 { |
| 374 |
let forged = fake_jwt(&SyncAppId::new()); |
| 375 |
let req = Request::builder() |
| 376 |
.header("authorization", format!("Bearer {forged}")) |
| 377 |
.body(()) |
| 378 |
.unwrap(); |
| 379 |
assert_eq!(extractor.extract(&req).unwrap(), SyncAppId::nil()); |
| 380 |
} |
| 381 |
} |
| 382 |
|
| 383 |
#[test] |
| 384 |
fn missing_auth_header_returns_nil_sentinel() { |
| 385 |
let req = Request::builder().body(()).unwrap(); |
| 386 |
let key = SyncAppKeyExtractor::new(None).extract(&req).unwrap(); |
| 387 |
assert_eq!(key, SyncAppId::nil()); |
| 388 |
} |
| 389 |
|
| 390 |
#[test] |
| 391 |
fn non_bearer_auth_returns_nil_sentinel() { |
| 392 |
let req = Request::builder() |
| 393 |
.header("authorization", "Basic dXNlcjpwYXNz") |
| 394 |
.body(()) |
| 395 |
.unwrap(); |
| 396 |
let key = SyncAppKeyExtractor::new(None).extract(&req).unwrap(); |
| 397 |
assert_eq!(key, SyncAppId::nil()); |
| 398 |
} |
| 399 |
|
| 400 |
#[test] |
| 401 |
fn malformed_jwt_collapses_to_nil_sentinel() { |
| 402 |
|
| 403 |
|
| 404 |
|
| 405 |
let req = Request::builder() |
| 406 |
.header("authorization", "Bearer not-a-jwt") |
| 407 |
.body(()) |
| 408 |
.unwrap(); |
| 409 |
let key = SyncAppKeyExtractor::new(None).extract(&req).unwrap(); |
| 410 |
assert_eq!(key, SyncAppId::nil()); |
| 411 |
} |
| 412 |
|
| 413 |
#[test] |
| 414 |
fn jwt_missing_app_claim_collapses_to_nil_sentinel() { |
| 415 |
let header = base64::engine::general_purpose::URL_SAFE_NO_PAD |
| 416 |
.encode(r#"{"alg":"HS256"}"#); |
| 417 |
let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD |
| 418 |
.encode(r#"{"sub":"user","iss":"test"}"#); |
| 419 |
let sig = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode("sig"); |
| 420 |
let token = format!("{header}.{payload}.{sig}"); |
| 421 |
|
| 422 |
let req = Request::builder() |
| 423 |
.header("authorization", format!("Bearer {token}")) |
| 424 |
.body(()) |
| 425 |
.unwrap(); |
| 426 |
let key = SyncAppKeyExtractor::new(None).extract(&req).unwrap(); |
| 427 |
assert_eq!(key, SyncAppId::nil()); |
| 428 |
} |
| 429 |
|
| 430 |
#[test] |
| 431 |
fn cf_connecting_ip_is_used_when_present() { |
| 432 |
let req = Request::builder() |
| 433 |
.header("cf-connecting-ip", "203.0.113.7") |
| 434 |
.body(()) |
| 435 |
.unwrap(); |
| 436 |
let key = CloudflareIpKeyExtractor.extract(&req).unwrap(); |
| 437 |
assert_eq!(key, "203.0.113.7".parse::<std::net::IpAddr>().unwrap()); |
| 438 |
} |
| 439 |
|
| 440 |
#[test] |
| 441 |
fn forged_x_forwarded_for_is_not_trusted() { |
| 442 |
|
| 443 |
|
| 444 |
|
| 445 |
|
| 446 |
|
| 447 |
let req = Request::builder() |
| 448 |
.header("x-forwarded-for", "1.2.3.4") |
| 449 |
.header("x-real-ip", "1.2.3.4") |
| 450 |
.body(()) |
| 451 |
.unwrap(); |
| 452 |
let result = CloudflareIpKeyExtractor.extract(&req); |
| 453 |
assert!( |
| 454 |
result.is_err(), |
| 455 |
"forged XFF/X-Real-IP must not yield a per-IP bucket; got {result:?}" |
| 456 |
); |
| 457 |
} |
| 458 |
|
| 459 |
#[test] |
| 460 |
fn different_apps_get_different_keys() { |
| 461 |
let app1 = SyncAppId::new(); |
| 462 |
let app2 = SyncAppId::new(); |
| 463 |
|
| 464 |
let req1 = Request::builder() |
| 465 |
.header("authorization", format!("Bearer {}", fake_jwt(&app1))) |
| 466 |
.body(()) |
| 467 |
.unwrap(); |
| 468 |
let req2 = Request::builder() |
| 469 |
.header("authorization", format!("Bearer {}", fake_jwt(&app2))) |
| 470 |
.body(()) |
| 471 |
.unwrap(); |
| 472 |
|
| 473 |
let extractor = SyncAppKeyExtractor::new(None); |
| 474 |
let key1 = extractor.extract(&req1).unwrap(); |
| 475 |
let key2 = extractor.extract(&req2).unwrap(); |
| 476 |
assert_ne!(key1, key2); |
| 477 |
} |
| 478 |
} |
| 479 |
|