//! Per-IP authentication rate limiting. //! //! russh's `auth_rejection_time` only delays within a single connection. //! Parallel connections bypass it. This module tracks failed auth attempts //! per IP and rejects early when a threshold is exceeded. use std::collections::HashMap; use std::net::IpAddr; use std::sync::Mutex; use std::time::Instant; const MAX_FAILURES: usize = 10; const WINDOW_SECS: u64 = 60; const PRUNE_THRESHOLD: usize = 1000; pub struct AuthRateLimiter { failures: Mutex>>, } impl AuthRateLimiter { pub fn new() -> Self { Self { failures: Mutex::new(HashMap::new()), } } /// Returns `true` if the IP is allowed to attempt auth. /// Returns `false` if the IP has exceeded the failure threshold. pub fn check(&self, ip: IpAddr) -> bool { let mut map = self.failures.lock().unwrap(); let cutoff = Instant::now() - std::time::Duration::from_secs(WINDOW_SECS); if let Some(times) = map.get_mut(&ip) { times.retain(|t| *t > cutoff); times.len() < MAX_FAILURES } else { true } } /// Record a failed auth attempt for the given IP. pub fn record_failure(&self, ip: IpAddr) { let mut map = self.failures.lock().unwrap(); // Prune stale entries when map grows large if map.len() > PRUNE_THRESHOLD { let cutoff = Instant::now() - std::time::Duration::from_secs(WINDOW_SECS); map.retain(|_, times| { times.retain(|t| *t > cutoff); !times.is_empty() }); } map.entry(ip).or_default().push(Instant::now()); } } #[cfg(test)] mod tests { use super::*; use std::net::Ipv4Addr; #[test] fn allows_under_threshold() { let limiter = AuthRateLimiter::new(); let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)); for _ in 0..MAX_FAILURES - 1 { assert!(limiter.check(ip)); limiter.record_failure(ip); } assert!(limiter.check(ip)); } #[test] fn blocks_at_threshold() { let limiter = AuthRateLimiter::new(); let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)); for _ in 0..MAX_FAILURES { limiter.record_failure(ip); } assert!(!limiter.check(ip)); } #[test] fn independent_ips() { let limiter = AuthRateLimiter::new(); let ip_a = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)); let ip_b = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)); for _ in 0..MAX_FAILURES { limiter.record_failure(ip_a); } assert!(!limiter.check(ip_a)); assert!(limiter.check(ip_b)); } }