| 1 |
|
| 2 |
|
| 3 |
|
| 4 |
|
| 5 |
|
| 6 |
|
| 7 |
use std::collections::HashMap; |
| 8 |
use std::net::IpAddr; |
| 9 |
use std::sync::Mutex; |
| 10 |
use std::time::Instant; |
| 11 |
|
| 12 |
const MAX_FAILURES: usize = 10; |
| 13 |
const WINDOW_SECS: u64 = 60; |
| 14 |
const PRUNE_THRESHOLD: usize = 1000; |
| 15 |
|
| 16 |
pub struct AuthRateLimiter { |
| 17 |
failures: Mutex<HashMap<IpAddr, Vec<Instant>>>, |
| 18 |
} |
| 19 |
|
| 20 |
impl AuthRateLimiter { |
| 21 |
pub fn new() -> Self { |
| 22 |
Self { |
| 23 |
failures: Mutex::new(HashMap::new()), |
| 24 |
} |
| 25 |
} |
| 26 |
|
| 27 |
|
| 28 |
|
| 29 |
pub fn check(&self, ip: IpAddr) -> bool { |
| 30 |
let mut map = self.failures.lock().unwrap(); |
| 31 |
let cutoff = Instant::now() - std::time::Duration::from_secs(WINDOW_SECS); |
| 32 |
|
| 33 |
if let Some(times) = map.get_mut(&ip) { |
| 34 |
times.retain(|t| *t > cutoff); |
| 35 |
times.len() < MAX_FAILURES |
| 36 |
} else { |
| 37 |
true |
| 38 |
} |
| 39 |
} |
| 40 |
|
| 41 |
|
| 42 |
pub fn record_failure(&self, ip: IpAddr) { |
| 43 |
let mut map = self.failures.lock().unwrap(); |
| 44 |
|
| 45 |
|
| 46 |
if map.len() > PRUNE_THRESHOLD { |
| 47 |
let cutoff = Instant::now() - std::time::Duration::from_secs(WINDOW_SECS); |
| 48 |
map.retain(|_, times| { |
| 49 |
times.retain(|t| *t > cutoff); |
| 50 |
!times.is_empty() |
| 51 |
}); |
| 52 |
} |
| 53 |
|
| 54 |
map.entry(ip).or_default().push(Instant::now()); |
| 55 |
} |
| 56 |
} |
| 57 |
|
| 58 |
#[cfg(test)] |
| 59 |
mod tests { |
| 60 |
use super::*; |
| 61 |
use std::net::Ipv4Addr; |
| 62 |
|
| 63 |
#[test] |
| 64 |
fn allows_under_threshold() { |
| 65 |
let limiter = AuthRateLimiter::new(); |
| 66 |
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)); |
| 67 |
|
| 68 |
for _ in 0..MAX_FAILURES - 1 { |
| 69 |
assert!(limiter.check(ip)); |
| 70 |
limiter.record_failure(ip); |
| 71 |
} |
| 72 |
assert!(limiter.check(ip)); |
| 73 |
} |
| 74 |
|
| 75 |
#[test] |
| 76 |
fn blocks_at_threshold() { |
| 77 |
let limiter = AuthRateLimiter::new(); |
| 78 |
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)); |
| 79 |
|
| 80 |
for _ in 0..MAX_FAILURES { |
| 81 |
limiter.record_failure(ip); |
| 82 |
} |
| 83 |
assert!(!limiter.check(ip)); |
| 84 |
} |
| 85 |
|
| 86 |
#[test] |
| 87 |
fn independent_ips() { |
| 88 |
let limiter = AuthRateLimiter::new(); |
| 89 |
let ip_a = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)); |
| 90 |
let ip_b = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)); |
| 91 |
|
| 92 |
for _ in 0..MAX_FAILURES { |
| 93 |
limiter.record_failure(ip_a); |
| 94 |
} |
| 95 |
assert!(!limiter.check(ip_a)); |
| 96 |
assert!(limiter.check(ip_b)); |
| 97 |
} |
| 98 |
} |
| 99 |
|