Skip to main content

max / makenotwork

2.7 KB · 99 lines History Blame Raw
1 //! Per-IP authentication rate limiting.
2 //!
3 //! russh's `auth_rejection_time` only delays within a single connection.
4 //! Parallel connections bypass it. This module tracks failed auth attempts
5 //! per IP and rejects early when a threshold is exceeded.
6
7 use std::collections::HashMap;
8 use std::net::IpAddr;
9 use std::sync::Mutex;
10 use std::time::Instant;
11
12 const MAX_FAILURES: usize = 10;
13 const WINDOW_SECS: u64 = 60;
14 const PRUNE_THRESHOLD: usize = 1000;
15
16 pub struct AuthRateLimiter {
17 failures: Mutex<HashMap<IpAddr, Vec<Instant>>>,
18 }
19
20 impl AuthRateLimiter {
21 pub fn new() -> Self {
22 Self {
23 failures: Mutex::new(HashMap::new()),
24 }
25 }
26
27 /// Returns `true` if the IP is allowed to attempt auth.
28 /// Returns `false` if the IP has exceeded the failure threshold.
29 pub fn check(&self, ip: IpAddr) -> bool {
30 let mut map = self.failures.lock().unwrap();
31 let cutoff = Instant::now() - std::time::Duration::from_secs(WINDOW_SECS);
32
33 if let Some(times) = map.get_mut(&ip) {
34 times.retain(|t| *t > cutoff);
35 times.len() < MAX_FAILURES
36 } else {
37 true
38 }
39 }
40
41 /// Record a failed auth attempt for the given IP.
42 pub fn record_failure(&self, ip: IpAddr) {
43 let mut map = self.failures.lock().unwrap();
44
45 // Prune stale entries when map grows large
46 if map.len() > PRUNE_THRESHOLD {
47 let cutoff = Instant::now() - std::time::Duration::from_secs(WINDOW_SECS);
48 map.retain(|_, times| {
49 times.retain(|t| *t > cutoff);
50 !times.is_empty()
51 });
52 }
53
54 map.entry(ip).or_default().push(Instant::now());
55 }
56 }
57
58 #[cfg(test)]
59 mod tests {
60 use super::*;
61 use std::net::Ipv4Addr;
62
63 #[test]
64 fn allows_under_threshold() {
65 let limiter = AuthRateLimiter::new();
66 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
67
68 for _ in 0..MAX_FAILURES - 1 {
69 assert!(limiter.check(ip));
70 limiter.record_failure(ip);
71 }
72 assert!(limiter.check(ip));
73 }
74
75 #[test]
76 fn blocks_at_threshold() {
77 let limiter = AuthRateLimiter::new();
78 let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
79
80 for _ in 0..MAX_FAILURES {
81 limiter.record_failure(ip);
82 }
83 assert!(!limiter.check(ip));
84 }
85
86 #[test]
87 fn independent_ips() {
88 let limiter = AuthRateLimiter::new();
89 let ip_a = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
90 let ip_b = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2));
91
92 for _ in 0..MAX_FAILURES {
93 limiter.record_failure(ip_a);
94 }
95 assert!(!limiter.check(ip_a));
96 assert!(limiter.check(ip_b));
97 }
98 }
99