Skip to main content

max / makenotwork

25.4 KB · 752 lines History Blame Raw
1 //! HTTP API for serve mode — exposes health check data to consumers like MNW.
2
3 use std::collections::{HashMap, HashSet};
4 use std::sync::Arc;
5 use std::sync::atomic::{AtomicU64, Ordering};
6
7 use axum::extract::{Path, Request, State as AxumState};
8 use axum::http::StatusCode;
9 use axum::middleware::{self, Next};
10 use axum::response::IntoResponse;
11 use axum::routing::get;
12 use axum::{Json, Router};
13 use serde::Serialize;
14
15 use crate::checks::http::{compute_test_staleness, detect_test_duration_drift};
16 use crate::config::Config;
17 use crate::db;
18 use crate::peer::SharedMeshState;
19 use crate::types::{HealthSnapshot, LatencyBucket, LatencyStats, TestStaleness};
20
21 /// Fixed-window rate limiter: allows `max_per_window` requests per `window_duration`.
22 #[derive(Clone)]
23 pub struct RateLimiter {
24 count: Arc<AtomicU64>,
25 window_start: Arc<std::sync::Mutex<std::time::Instant>>,
26 max_per_window: u64,
27 window_duration: std::time::Duration,
28 }
29
30 impl RateLimiter {
31 pub fn new(max_per_window: u64, window_duration: std::time::Duration) -> Self {
32 Self {
33 count: Arc::new(AtomicU64::new(0)),
34 window_start: Arc::new(std::sync::Mutex::new(std::time::Instant::now())),
35 max_per_window,
36 window_duration,
37 }
38 }
39
40 pub fn try_acquire(&self) -> bool {
41 let mut start = self.window_start.lock().unwrap();
42 let now = std::time::Instant::now();
43 if now.duration_since(*start) > self.window_duration {
44 *start = now;
45 self.count.store(1, Ordering::Release);
46 true
47 } else {
48 let prev = self.count.fetch_add(1, Ordering::Acquire);
49 prev < self.max_per_window
50 }
51 }
52 }
53
54 /// Shared state for the API server.
55 #[derive(Clone)]
56 pub struct ApiState {
57 pub pool: sqlx::SqlitePool,
58 pub config: Arc<Config>,
59 pub mesh: Option<SharedMeshState>,
60 pub rate_limiter: RateLimiter,
61 }
62
63 /// Rate limiting middleware. Returns 429 if the request rate exceeds the limit.
64 async fn rate_limit(
65 AxumState(state): AxumState<ApiState>,
66 req: Request,
67 next: Next,
68 ) -> impl IntoResponse {
69 if state.rate_limiter.try_acquire() {
70 Ok(next.run(req).await)
71 } else {
72 Err((StatusCode::TOO_MANY_REQUESTS, Json(serde_json::json!({
73 "error": "rate limit exceeded"
74 }))))
75 }
76 }
77
78 /// Bearer token authentication middleware.
79 /// If `api_token` is configured, requires `Authorization: Bearer <token>` on every request.
80 /// If no token is configured, all requests pass through.
81 async fn require_bearer_token(
82 AxumState(state): AxumState<ApiState>,
83 req: Request,
84 next: Next,
85 ) -> impl IntoResponse {
86 let expected = state.config.serve.api_token.as_deref();
87 let Some(expected) = expected else {
88 return Ok(next.run(req).await);
89 };
90
91 let auth_header = req.headers().get("authorization").and_then(|v| v.to_str().ok());
92 match auth_header {
93 Some(header) if header.starts_with("Bearer ") => {
94 let token = &header[7..];
95 // Constant-time comparison to prevent timing side-channels
96 use subtle::ConstantTimeEq;
97 if token.as_bytes().ct_eq(expected.as_bytes()).into() {
98 Ok(next.run(req).await)
99 } else {
100 Err((StatusCode::UNAUTHORIZED, Json(serde_json::json!({
101 "error": "invalid bearer token"
102 }))))
103 }
104 }
105 _ => Err((StatusCode::UNAUTHORIZED, Json(serde_json::json!({
106 "error": "missing or malformed Authorization header"
107 })))),
108 }
109 }
110
111 /// `GET /api/health` — simple health endpoint for PoM itself.
112 /// Not behind auth — allows external monitoring without credentials.
113 async fn self_health() -> impl IntoResponse {
114 Json(serde_json::json!({
115 "status": "operational",
116 "version": env!("CARGO_PKG_VERSION"),
117 }))
118 }
119
120 /// Build the axum router for the PoM API.
121 pub fn router(pool: sqlx::SqlitePool, config: Config, mesh: Option<SharedMeshState>) -> Router {
122 let state = ApiState {
123 pool,
124 config: Arc::new(config),
125 mesh,
126 rate_limiter: RateLimiter::new(60, std::time::Duration::from_secs(60)),
127 };
128
129 // Authenticated routes (behind bearer token + rate limit)
130 let authenticated = Router::new()
131 .route("/api/status", get(status_all))
132 .route("/api/status/{target}", get(status_target))
133 .route("/api/trends/{target}", get(trends))
134 .route("/api/peer/info", get(peer_info))
135 .route("/api/peer/status", get(peer_status))
136 .route("/api/mesh", get(mesh_view))
137 .layer(middleware::from_fn_with_state(state.clone(), require_bearer_token))
138 .layer(middleware::from_fn_with_state(state.clone(), rate_limit));
139
140 // Public routes (no auth, no rate limit)
141 let public = Router::new()
142 .route("/api/health", get(self_health));
143
144 let mut app = public.merge(authenticated);
145
146 if state.config.serve.dashboard {
147 app = app.route("/", get(crate::dashboard::dashboard_handler));
148 }
149
150 app.with_state(state)
151 }
152
153 // --- Response types ---
154
155 #[derive(Serialize)]
156 struct StatusResponse {
157 /// Per-target status summaries, keyed by target config name.
158 targets: HashMap<String, TargetStatus>,
159 }
160
161 #[derive(Serialize)]
162 struct TargetStatus {
163 /// Human-readable display label for this target.
164 label: String,
165 /// Most recent health check snapshot. `None` if no checks have been recorded yet.
166 latest: Option<SnapshotJson>,
167 /// Last 10 health check snapshots, most recent first.
168 recent: Vec<SnapshotJson>,
169 /// Uptime percentage over the last 24 hours. `None` if no checks in that window.
170 uptime_24h: Option<f64>,
171 /// Uptime percentage over the last 7 days. `None` if no checks in that window.
172 uptime_7d: Option<f64>,
173 /// Latency statistics over the last 24 hours. Omitted if no operational checks exist.
174 #[serde(skip_serializing_if = "Option::is_none")]
175 latency_24h: Option<LatencyStats>,
176 /// Latest TLS certificate check result. Omitted if TLS monitoring is not configured.
177 #[serde(skip_serializing_if = "Option::is_none")]
178 tls: Option<db::TlsCheckRow>,
179 /// Test staleness assessment. Omitted if test running is not configured for this target.
180 #[serde(skip_serializing_if = "Option::is_none")]
181 test_staleness: Option<TestStaleness>,
182 /// Currently open incident. Omitted if the target is not in an incident state.
183 #[serde(skip_serializing_if = "Option::is_none")]
184 current_incident: Option<db::IncidentRow>,
185 /// Recent resolved and open incidents (up to 10). Omitted if empty.
186 #[serde(skip_serializing_if = "Vec::is_empty")]
187 incidents: Vec<db::IncidentRow>,
188 /// Latest route check results per path. Omitted if empty.
189 #[serde(skip_serializing_if = "Vec::is_empty")]
190 route_status: Vec<RouteStatusJson>,
191 /// Latest DNS check results. Omitted if empty.
192 #[serde(skip_serializing_if = "Vec::is_empty")]
193 dns_status: Vec<DnsStatusJson>,
194 /// Latest WHOIS check result. Omitted if no WHOIS monitoring is configured.
195 #[serde(skip_serializing_if = "Option::is_none")]
196 whois: Option<db::WhoisCheckRow>,
197 /// Test duration drift warning. Omitted if no drift detected or no test config.
198 #[serde(skip_serializing_if = "Option::is_none")]
199 test_duration_drift: Option<String>,
200 }
201
202 #[derive(Serialize)]
203 struct DnsStatusJson {
204 name: String,
205 record_type: String,
206 expected: Vec<String>,
207 actual: Vec<String>,
208 matches: bool,
209 checked_at: String,
210 }
211
212 #[derive(Serialize)]
213 struct RouteStatusJson {
214 path: String,
215 status_code: i64,
216 ok: bool,
217 checked_at: String,
218 response_time_ms: i64,
219 }
220
221 #[derive(Serialize)]
222 struct SnapshotJson {
223 /// Health status as a lowercase string (e.g. "operational", "degraded").
224 status: String,
225 /// Timestamp of the check in RFC 3339 format.
226 checked_at: String,
227 /// Round-trip response time in milliseconds.
228 response_time_ms: i64,
229 /// Structured health details from the endpoint. Omitted when unavailable.
230 #[serde(skip_serializing_if = "Option::is_none")]
231 details: Option<serde_json::Value>,
232 /// Error message if the check failed. Omitted on success.
233 #[serde(skip_serializing_if = "Option::is_none")]
234 error: Option<String>,
235 }
236
237 impl From<HealthSnapshot> for SnapshotJson {
238 fn from(s: HealthSnapshot) -> Self {
239 Self {
240 status: s.status.to_string(),
241 checked_at: s.checked_at,
242 response_time_ms: s.response_time_ms,
243 details: s.details.map(|d| serde_json::to_value(d).unwrap_or_default()),
244 error: s.error,
245 }
246 }
247 }
248
249 /// Build a `TargetStatus` for a single target.
250 async fn build_target_status(
251 pool: &sqlx::SqlitePool,
252 name: &str,
253 label: &str,
254 config: &Config,
255 ) -> TargetStatus {
256 let recent = db::get_health_history(pool, Some(name), 10)
257 .await
258 .unwrap_or_default();
259
260 // Extract the version info we need before consuming the snapshots.
261 let latest_version = recent.first()
262 .and_then(|s| s.details.as_ref())
263 .and_then(|d| d.version.clone());
264 let latest = recent.first().cloned().map(SnapshotJson::from);
265 let recent_json: Vec<SnapshotJson> = recent.into_iter().map(SnapshotJson::from).collect();
266
267 let uptime_24h = db::get_uptime_percent(pool, name, 24)
268 .await
269 .unwrap_or(None);
270 let uptime_7d = db::get_uptime_percent(pool, name, 168)
271 .await
272 .unwrap_or(None);
273
274 // Compute 24h latency stats from operational checks
275 let latency_24h = {
276 let cutoff = (chrono::Utc::now() - chrono::Duration::hours(24)).to_rfc3339();
277 let times = db::get_response_times(pool, name, &cutoff)
278 .await
279 .unwrap_or_default();
280 let operational_times: Vec<i64> = times.iter()
281 .filter(|(_, ms)| *ms > 0)
282 .map(|(_, ms)| *ms)
283 .collect();
284 LatencyStats::from_times(&operational_times)
285 };
286
287 let tls = db::get_latest_tls_check(pool, name)
288 .await
289 .unwrap_or(None);
290
291 // Compute test staleness for targets with test config
292 let test_staleness = if let Some(target_config) = config.get_target(name)
293 && let Some(tests_config) = &target_config.tests
294 {
295 let current_version = latest_version.clone();
296
297 let latest_test = db::get_latest_test_run(pool, name).await.unwrap_or(None);
298
299 let tested_version = if let Some(ref test) = latest_test {
300 db::get_version_at_time(pool, name, &test.started_at)
301 .await
302 .unwrap_or(None)
303 } else {
304 None
305 };
306
307 let staleness = compute_test_staleness(
308 current_version.as_deref(),
309 tested_version.as_deref(),
310 latest_test.as_ref().map(|t| t.started_at.as_str()),
311 tests_config.staleness_days,
312 );
313 Some(staleness)
314 } else {
315 None
316 };
317
318 // Compute test duration drift for targets with test config
319 let test_duration_drift = if config.get_target(name).and_then(|t| t.tests.as_ref()).is_some() {
320 let durations = db::get_test_durations(pool, name, 13)
321 .await
322 .unwrap_or_default();
323 detect_test_duration_drift(&durations, 10, 3, 1.5)
324 } else {
325 None
326 };
327
328 let current_incident = db::get_open_incident(pool, name)
329 .await
330 .unwrap_or(None);
331
332 let incidents = db::get_recent_incidents(pool, name, 10)
333 .await
334 .unwrap_or_default();
335
336 let route_checks = db::get_latest_route_checks(pool, name)
337 .await
338 .unwrap_or_default();
339 let expected_routes: HashSet<&str> = config.get_target(name)
340 .map(|t| t.expected_routes.iter().map(|s| s.as_str()).collect())
341 .unwrap_or_default();
342 let route_status: Vec<RouteStatusJson> = route_checks
343 .into_iter()
344 .filter(|r| expected_routes.contains(r.path.as_str()))
345 .map(|r| RouteStatusJson {
346 path: r.path,
347 status_code: r.status_code,
348 ok: r.ok,
349 checked_at: r.checked_at,
350 response_time_ms: r.response_time_ms,
351 })
352 .collect();
353
354 let dns_checks = db::get_latest_dns_checks(pool, name)
355 .await
356 .unwrap_or_default();
357 let expected_dns: HashSet<(String, String)> = config.get_target(name)
358 .map(|t| t.dns.iter().map(|d| (d.name.clone(), d.record_type.to_string())).collect())
359 .unwrap_or_default();
360 let dns_status: Vec<DnsStatusJson> = dns_checks
361 .into_iter()
362 .filter(|r| expected_dns.contains(&(r.name.clone(), r.record_type.clone())))
363 .map(|r| DnsStatusJson {
364 name: r.name,
365 record_type: r.record_type,
366 expected: serde_json::from_str(&r.expected).unwrap_or_default(),
367 actual: serde_json::from_str(&r.actual).unwrap_or_default(),
368 matches: r.matches,
369 checked_at: r.checked_at,
370 })
371 .collect();
372
373 let whois = db::get_latest_whois_check(pool, name)
374 .await
375 .unwrap_or(None);
376
377 TargetStatus {
378 label: label.to_string(),
379 latest,
380 recent: recent_json,
381 uptime_24h,
382 uptime_7d,
383 latency_24h,
384 tls,
385 test_staleness,
386 test_duration_drift,
387 current_incident,
388 incidents,
389 route_status,
390 dns_status,
391 whois,
392 }
393 }
394
395 /// `GET /api/status` — JSON summary for all targets.
396 async fn status_all(
397 AxumState(state): AxumState<ApiState>,
398 ) -> impl IntoResponse {
399 let mut targets = HashMap::new();
400
401 for name in state.config.target_names() {
402 if let Some(target_config) = state.config.get_target(&name) {
403 let status = build_target_status(&state.pool, &name, &target_config.label, &state.config).await;
404 targets.insert(name, status);
405 }
406 }
407
408 Json(StatusResponse { targets })
409 }
410
411 /// `GET /api/status/{target}` — JSON summary for a single target.
412 async fn status_target(
413 AxumState(state): AxumState<ApiState>,
414 Path(target): Path<String>,
415 ) -> impl IntoResponse {
416 let Some(target_config) = state.config.get_target(&target) else {
417 return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({
418 "error": format!("unknown target: {target}")
419 }))));
420 };
421
422 let status = build_target_status(&state.pool, &target, &target_config.label, &state.config).await;
423 Ok(Json(status))
424 }
425
426 // --- Peer endpoints ---
427
428 /// `GET /api/peer/info` — Returns this instance's identity info.
429 async fn peer_info(
430 AxumState(state): AxumState<ApiState>,
431 ) -> impl IntoResponse {
432 let Some(ref mesh) = state.mesh else {
433 return Err((StatusCode::SERVICE_UNAVAILABLE, Json(serde_json::json!({
434 "error": "peer mesh not enabled"
435 }))));
436 };
437
438 let mesh_state = mesh.read().await;
439 Ok(Json(serde_json::to_value(&mesh_state.instance).unwrap_or_default()))
440 }
441
442 /// `GET /api/peer/status` — This instance's full view: own info + target statuses + peer summaries.
443 async fn peer_status(
444 AxumState(state): AxumState<ApiState>,
445 ) -> impl IntoResponse {
446 let Some(ref mesh) = state.mesh else {
447 return Err((StatusCode::SERVICE_UNAVAILABLE, Json(serde_json::json!({
448 "error": "peer mesh not enabled"
449 }))));
450 };
451
452 // Collect mesh data under lock, then drop lock before DB queries
453 let (instance, peers) = {
454 let mesh_state = mesh.read().await;
455 let instance = mesh_state.instance.clone();
456 let peers: HashMap<String, serde_json::Value> = mesh_state.peers.iter().map(|(name, peer)| {
457 (name.clone(), serde_json::json!({
458 "status": peer.status,
459 "last_seen": peer.last_seen,
460 "latency_ms": peer.latency_ms,
461 }))
462 }).collect();
463 (instance, peers)
464 };
465
466 // Build target statuses (DB queries with no lock held)
467 let mut targets = HashMap::new();
468 for name in state.config.target_names() {
469 if let Some(target_config) = state.config.get_target(&name)
470 && let Ok(Some(latest)) = db::get_latest_health(&state.pool, &name).await
471 {
472 targets.insert(name, serde_json::json!({
473 "label": target_config.label,
474 "status": latest.status.to_string(),
475 "response_time_ms": latest.response_time_ms,
476 "checked_at": latest.checked_at,
477 }));
478 }
479 }
480
481 Ok(Json(serde_json::json!({
482 "instance": instance,
483 "targets": targets,
484 "peers": peers,
485 })))
486 }
487
488 /// `GET /api/mesh` — Aggregated view: self + each peer's cached status.
489 async fn mesh_view(
490 AxumState(state): AxumState<ApiState>,
491 ) -> impl IntoResponse {
492 let Some(ref mesh) = state.mesh else {
493 return Err((StatusCode::SERVICE_UNAVAILABLE, Json(serde_json::json!({
494 "error": "peer mesh not enabled"
495 }))));
496 };
497
498 // Collect all mesh data under lock, then drop lock before DB queries
499 let (instance, own_peers_json, peer_entries) = {
500 let mesh_state = mesh.read().await;
501 let instance = mesh_state.instance.clone();
502 let own_peers: HashMap<String, serde_json::Value> = mesh_state.peers.iter().map(|(name, peer)| {
503 (name.clone(), serde_json::json!({
504 "status": peer.status,
505 "last_seen": peer.last_seen,
506 "latency_ms": peer.latency_ms,
507 }))
508 }).collect();
509 let peer_entries: Vec<(String, Option<serde_json::Value>, serde_json::Value)> = mesh_state.peers.iter().map(|(name, peer)| {
510 let fallback = serde_json::json!({
511 "status": peer.status,
512 "last_seen": peer.last_seen,
513 "error": "no status data cached",
514 });
515 (name.clone(), peer.status_data.clone(), fallback)
516 }).collect();
517 (instance, own_peers, peer_entries)
518 };
519
520 // Build target statuses (DB queries with no lock held)
521 let mut targets = HashMap::new();
522 for name in state.config.target_names() {
523 if let Some(target_config) = state.config.get_target(&name)
524 && let Ok(Some(latest)) = db::get_latest_health(&state.pool, &name).await
525 {
526 targets.insert(name, serde_json::json!({
527 "label": target_config.label,
528 "status": latest.status.to_string(),
529 "response_time_ms": latest.response_time_ms,
530 "checked_at": latest.checked_at,
531 }));
532 }
533 }
534
535 let self_entry = serde_json::json!({
536 "instance": instance,
537 "targets": targets,
538 "peers": own_peers_json,
539 });
540
541 let mut instances = serde_json::Map::new();
542 instances.insert(instance.name.clone(), self_entry);
543
544 for (name, status_data, fallback) in peer_entries {
545 instances.insert(name, status_data.unwrap_or(fallback));
546 }
547
548 Ok(Json(serde_json::json!({
549 "instances": instances,
550 })))
551 }
552
553 // --- Trends endpoint ---
554
555 #[derive(Serialize)]
556 struct TrendResponse {
557 /// Target config name this trend data belongs to.
558 target: String,
559 /// Requested time window in hours (from query param, default 24).
560 window_hours: u64,
561 /// Requested bucket width in minutes (from query param, default 60).
562 bucket_minutes: u64,
563 /// Per-bucket latency statistics within the requested window.
564 buckets: Vec<LatencyBucket>,
565 /// Aggregate latency statistics across the entire requested window.
566 overall: Option<LatencyStats>,
567 /// 7-day baseline latency statistics for drift comparison.
568 baseline: Option<LatencyStats>,
569 }
570
571 /// `GET /api/trends/{target}?hours=24&bucket_minutes=60` — latency trend data.
572 async fn trends(
573 AxumState(state): AxumState<ApiState>,
574 Path(target): Path<String>,
575 axum::extract::Query(params): axum::extract::Query<TrendQueryParams>,
576 ) -> impl IntoResponse {
577 let Some(_target_config) = state.config.get_target(&target) else {
578 return Err((StatusCode::NOT_FOUND, Json(serde_json::json!({
579 "error": format!("unknown target: {target}")
580 }))));
581 };
582
583 let hours = params.hours.unwrap_or(24);
584 let bucket_minutes = params.bucket_minutes.unwrap_or(60);
585
586 let cutoff = (chrono::Utc::now() - chrono::Duration::hours(hours as i64)).to_rfc3339();
587 let times = db::get_response_times(&state.pool, &target, &cutoff)
588 .await
589 .unwrap_or_default();
590
591 let operational_times: Vec<i64> = times.iter()
592 .filter(|(_, ms)| *ms > 0)
593 .map(|(_, ms)| *ms)
594 .collect();
595 let overall = LatencyStats::from_times(&operational_times);
596
597 let operational_data: Vec<(String, i64)> = times.into_iter()
598 .filter(|(_, ms)| *ms > 0)
599 .collect();
600 let buckets = LatencyStats::bucket_by_time(&operational_data, bucket_minutes);
601
602 // 7d baseline for reference
603 let baseline_cutoff = (chrono::Utc::now() - chrono::Duration::hours(168)).to_rfc3339();
604 let baseline_times = db::get_response_times(&state.pool, &target, &baseline_cutoff)
605 .await
606 .unwrap_or_default();
607 let baseline_operational: Vec<i64> = baseline_times.iter()
608 .filter(|(_, ms)| *ms > 0)
609 .map(|(_, ms)| *ms)
610 .collect();
611 let baseline = LatencyStats::from_times(&baseline_operational);
612
613 Ok(Json(TrendResponse {
614 target,
615 window_hours: hours,
616 bucket_minutes,
617 buckets,
618 overall,
619 baseline,
620 }))
621 }
622
623 #[derive(serde::Deserialize)]
624 struct TrendQueryParams {
625 /// Time window to query, in hours. Defaults to 24 if omitted.
626 hours: Option<u64>,
627 /// Width of each latency bucket, in minutes. Defaults to 60 if omitted.
628 bucket_minutes: Option<u64>,
629 }
630
631 #[cfg(test)]
632 mod tests {
633 use super::*;
634 use axum::body::Body;
635 use axum::http::Request as HttpRequest;
636 use tower::ServiceExt;
637
638 fn test_config(api_token: Option<&str>) -> Config {
639 let mut config = Config {
640 serve: crate::config::ServeConfig::default(),
641 instance: Default::default(),
642 targets: HashMap::new(),
643 peers: HashMap::new(),
644 alerts: None,
645 };
646 config.serve.api_token = api_token.map(|s| s.to_string());
647 config
648 }
649
650 #[tokio::test]
651 async fn no_token_configured_allows_all_requests() {
652 let pool = crate::db::connect_in_memory().await.unwrap();
653 let app = router(pool, test_config(None), None);
654
655 let req = HttpRequest::builder()
656 .uri("/api/status")
657 .body(Body::empty())
658 .unwrap();
659 let resp = app.oneshot(req).await.unwrap();
660 assert_eq!(resp.status(), StatusCode::OK);
661 }
662
663 #[tokio::test]
664 async fn valid_token_allows_request() {
665 let pool = crate::db::connect_in_memory().await.unwrap();
666 let app = router(pool, test_config(Some("secret123")), None);
667
668 let req = HttpRequest::builder()
669 .uri("/api/status")
670 .header("authorization", "Bearer secret123")
671 .body(Body::empty())
672 .unwrap();
673 let resp = app.oneshot(req).await.unwrap();
674 assert_eq!(resp.status(), StatusCode::OK);
675 }
676
677 #[tokio::test]
678 async fn wrong_token_returns_401() {
679 let pool = crate::db::connect_in_memory().await.unwrap();
680 let app = router(pool, test_config(Some("secret123")), None);
681
682 let req = HttpRequest::builder()
683 .uri("/api/status")
684 .header("authorization", "Bearer wrong-token")
685 .body(Body::empty())
686 .unwrap();
687 let resp = app.oneshot(req).await.unwrap();
688 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
689 }
690
691 #[tokio::test]
692 async fn missing_header_returns_401() {
693 let pool = crate::db::connect_in_memory().await.unwrap();
694 let app = router(pool, test_config(Some("secret123")), None);
695
696 let req = HttpRequest::builder()
697 .uri("/api/status")
698 .body(Body::empty())
699 .unwrap();
700 let resp = app.oneshot(req).await.unwrap();
701 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
702 }
703
704 #[tokio::test]
705 async fn malformed_header_returns_401() {
706 let pool = crate::db::connect_in_memory().await.unwrap();
707 let app = router(pool, test_config(Some("secret123")), None);
708
709 let req = HttpRequest::builder()
710 .uri("/api/status")
711 .header("authorization", "Basic dXNlcjpwYXNz")
712 .body(Body::empty())
713 .unwrap();
714 let resp = app.oneshot(req).await.unwrap();
715 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
716 }
717
718 #[test]
719 fn rate_limiter_allows_within_limit() {
720 let limiter = RateLimiter::new(3, std::time::Duration::from_secs(60));
721 assert!(limiter.try_acquire());
722 assert!(limiter.try_acquire());
723 assert!(limiter.try_acquire());
724 }
725
726 #[test]
727 fn rate_limiter_blocks_over_limit() {
728 let limiter = RateLimiter::new(2, std::time::Duration::from_secs(60));
729 assert!(limiter.try_acquire());
730 assert!(limiter.try_acquire());
731 assert!(!limiter.try_acquire());
732 }
733
734 #[tokio::test]
735 async fn rate_limiter_resets_after_window() {
736 let limiter = RateLimiter::new(1, std::time::Duration::from_millis(10));
737 assert!(limiter.try_acquire());
738 assert!(!limiter.try_acquire());
739 tokio::time::sleep(std::time::Duration::from_millis(15)).await;
740 assert!(limiter.try_acquire());
741 }
742
743 #[test]
744 fn rate_limiter_counter_starts_at_one() {
745 let limiter = RateLimiter::new(1, std::time::Duration::from_millis(0));
746 // First call resets window (counter stored as 1), returns true
747 assert!(limiter.try_acquire());
748 // Window has already expired (0ms), so next call also resets
749 assert!(limiter.try_acquire());
750 }
751 }
752