//! Prometheus metrics: HTTP request tracking, error counters, DB pool gauges. //! //! Call [`init`] once at startup to install the Prometheus recorder. //! Call [`render`] from the `/metrics` endpoint to produce the scrape output. //! The HTTP middleware in [`metrics_middleware`] records per-request metrics. use axum::{ extract::{MatchedPath, Request, State}, middleware::Next, response::{IntoResponse, Response}, }; use metrics::{counter, gauge, histogram}; use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; use std::time::Instant; /// Install the global Prometheus recorder. Returns the handle used to render /// the scrape output. Call once from `main`. pub fn init() -> PrometheusHandle { PrometheusBuilder::new() .install_recorder() .expect("failed to install Prometheus recorder") } /// Render all collected metrics in Prometheus exposition format. pub async fn render(State(handle): State) -> impl IntoResponse { handle.render() } /// Axum middleware that sets `Cache-Control` headers based on route path. /// /// Public content pages get CDN-friendly caching (Cloudflare caches for 60s, /// browsers always revalidate). Dashboard, API, and auth routes get no caching. pub async fn cache_control_middleware(request: Request, next: Next) -> Response { let path = request.uri().path().to_string(); let mut response = next.run(request).await; // Don't override if a handler already set Cache-Control if response.headers().contains_key(axum::http::header::CACHE_CONTROL) { return response; } let value = if is_public_page(&path) { // CDN caches 60s, browser always revalidates, stale content served while refreshing "public, max-age=0, s-maxage=60, stale-while-revalidate=300" } else if path.starts_with("/api/") || path.starts_with("/stripe/") || path.starts_with("/postmark/") { "no-store" } else { // Dashboard, admin, auth — private, always revalidate "private, no-cache" }; response.headers_mut().insert( axum::http::header::CACHE_CONTROL, axum::http::HeaderValue::from_static(value), ); // Add API version header to all /api/* responses if path.starts_with("/api/") { response.headers_mut().insert( axum::http::HeaderName::from_static("mnw-version"), axum::http::HeaderValue::from_static("2026-04-23"), ); } response } /// Returns true for public content pages that benefit from CDN caching. fn is_public_page(path: &str) -> bool { matches!(path, "/" | "/discover" | "/pricing" | "/source") || path.starts_with("/p/") || path.starts_with("/i/") || path.starts_with("/u/") || path.starts_with("/c/") || path.starts_with("/docs") || path.starts_with("/discover/") || path.starts_with("/source/") || path.starts_with("/feed") } /// Axum middleware that records HTTP request metrics. /// /// For every request, records: /// - `http_requests_total` counter with labels: method, path, status /// - `http_request_duration_seconds` histogram with labels: method, path, status /// /// `path` uses the matched route pattern (e.g. `/api/items/:id`) to keep /// cardinality bounded. Unmatched routes are grouped under ``. pub async fn metrics_middleware(request: Request, next: Next) -> Response { let method = request.method().to_string(); let path = request .extensions() .get::() .map(|p| p.as_str().to_string()) .unwrap_or_else(|| "".to_string()); let start = Instant::now(); let response = next.run(request).await; let duration = start.elapsed().as_secs_f64(); let status = status_class(response.status().as_u16()); let labels = [ ("method", method), ("path", path), ("status", status.to_string()), ]; counter!("http_requests_total", &labels).increment(1); histogram!("http_request_duration_seconds", &labels).record(duration); response } /// Collapse HTTP status codes into classes to keep label cardinality low. fn status_class(code: u16) -> &'static str { match code { 200..=299 => "2xx", 300..=399 => "3xx", 400..=499 => "4xx", 500..=599 => "5xx", _ => "other", } } /// Record current DB pool statistics as gauges. Call periodically from the /// health monitor or a dedicated task. pub fn record_db_pool_stats(pool: &sqlx::PgPool) { let size = pool.size() as f64; let idle = pool.num_idle() as f64; gauge!("db_pool_connections_max").set(size); gauge!("db_pool_connections_idle").set(idle); gauge!("db_pool_connections_active").set(size - idle); } /// Record server-wide Postgres saturation as Prometheus gauges. Sibling of /// the local-pool gauges — this looks at the SHARED Postgres (MNW + MT + /// ad hoc clients) so a dashboard can see the global ceiling, not just our /// pool's share. Cheap query (single row from `pg_stat_activity`). /// /// Returns `(active_backends, max_connections)` so the caller can also fire /// the existing WAM alert when utilization climbs past the threshold. #[tracing::instrument(skip_all)] pub async fn record_pg_stat_activity(pool: &sqlx::PgPool) -> Option<(i64, i64)> { let row: Result<(i64, i64), _> = sqlx::query_as( "SELECT \ (SELECT count(*) FROM pg_stat_activity \ WHERE state IS NOT NULL AND backend_type = 'client backend')::bigint, \ current_setting('max_connections')::bigint", ) .fetch_one(pool) .await; match row { Ok((active, max_conn)) if max_conn > 0 => { gauge!("pg_stat_activity_active_backends").set(active as f64); gauge!("pg_stat_activity_max_connections").set(max_conn as f64); gauge!("pg_stat_activity_utilization_ratio").set(active as f64 / max_conn as f64); Some((active, max_conn)) } Ok(_) => None, Err(e) => { tracing::debug!(error = ?e, "pg_stat_activity gauge update failed"); None } } } /// Aggregated storage fill metrics across all paying creators. /// /// Emits three gauges so the dashboard can compute fill ratio without /// joining queries client-side: /// - `creator_storage_used_bytes_total` — sum of `users.storage_used_bytes` /// for users with an active creator subscription. /// - `creator_storage_cap_bytes_total` — sum of the corresponding tier caps. /// - `creator_storage_fill_ratio` — used / cap. /// /// Pricing economics assume ~20% fill; 60%+ is 3× projection and warrants /// re-pricing. This gauge is the canonical input for that threshold. #[tracing::instrument(skip_all)] pub async fn record_storage_fill_stats(pool: &sqlx::PgPool) { // Tier cap table: kept in SQL because the values rarely change and // mirror `CreatorTier::max_storage_bytes`. Update both sites if a tier // cap moves. Tiers without a creator_subscriptions row are excluded — // a user without an active subscription has no committed storage cap. let row: Result<(i64, i64), _> = sqlx::query_as( r#" WITH tier_caps(tier, cap_bytes) AS ( VALUES ('basic', 50::bigint * 1024 * 1024 * 1024), ('small_files', 250::bigint * 1024 * 1024 * 1024), ('big_files', 1024::bigint * 1024 * 1024 * 1024), ('everything', 5120::bigint * 1024 * 1024 * 1024) ) SELECT COALESCE(SUM(u.storage_used_bytes), 0)::bigint AS used, COALESCE(SUM(tc.cap_bytes), 0)::bigint AS cap FROM users u JOIN creator_subscriptions cs ON cs.user_id = u.id AND cs.status = 'active' JOIN tier_caps tc ON tc.tier = cs.tier "#, ) .fetch_one(pool) .await; match row { Ok((used, cap)) => { gauge!("creator_storage_used_bytes_total").set(used as f64); gauge!("creator_storage_cap_bytes_total").set(cap as f64); let ratio = if cap > 0 { used as f64 / cap as f64 } else { 0.0 }; gauge!("creator_storage_fill_ratio").set(ratio); } Err(e) => { tracing::debug!(error = ?e, "storage fill stats query failed"); } } } /// Emit the current `domain_cache` size as a gauge so dashboards can track /// cache growth + correlate with `caddy_ask_total{outcome="cache_hit"}`. pub fn record_domain_cache_size(size: usize) { gauge!("domain_cache_entries").set(size as f64); } /// Axum middleware that implements idempotency keys for POST endpoints. /// /// If the request includes an `Idempotency-Key` header and the user is /// authenticated, checks for a cached response. If found, returns the cached /// response immediately. Otherwise, runs the handler and caches the result. /// /// Skips silently if no `Idempotency-Key` header is present or if the user /// is not authenticated (no session user). pub async fn idempotency_middleware( State(state): State, request: Request, next: Next, ) -> Response { use axum::http::StatusCode; // Only applies to POST/PUT methods if !matches!(*request.method(), axum::http::Method::POST | axum::http::Method::PUT) { return next.run(request).await; } // Extract idempotency key from header let idem_key = request .headers() .get("idempotency-key") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); let idem_key = match idem_key { Some(k) if !k.is_empty() && k.len() <= 256 => k, _ => return next.run(request).await, // No key — proceed normally }; // Extract user ID from session (must be authenticated) let session = request.extensions().get::().cloned(); let user_id: Option = if let Some(ref session) = session { session .get::("user") .await .ok() .flatten() .map(|u| u.id) } else { None }; let user_id = match user_id { Some(id) => id, None => return next.run(request).await, // Not authenticated — skip }; let method = request.method().to_string(); let path = request.uri().path().to_string(); // In-memory negative cache: every POST/PUT with an Idempotency-Key was // previously taking a pool conn for `get_cached_response` even when the // key had never been seen — measurable cost on a hot POST that already // makes 2-5 DB queries. Keys that recently returned None are tracked // here so the SELECT is skipped. Single-process correctness: the DB // cache table is written ONLY by this middleware's success path, which // also evicts the key from this map. type NegKey = (String, crate::db::UserId); static NEG_CACHE: std::sync::OnceLock> = std::sync::OnceLock::new(); const NEG_TTL_SECS: u64 = 60; let neg_cache = NEG_CACHE.get_or_init(dashmap::DashMap::new); let neg_key = (idem_key.clone(), user_id); let recently_negative = neg_cache .get(&neg_key) .map(|e| e.elapsed().as_secs() < NEG_TTL_SECS) .unwrap_or(false); // Periodic GC to keep the map bounded under sustained burst (every ~1k // misses we walk the map and drop expired entries). static GC_TICK: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0); if GC_TICK.fetch_add(1, std::sync::atomic::Ordering::Relaxed).is_multiple_of(1024) { neg_cache.retain(|_, t| t.elapsed().as_secs() < NEG_TTL_SECS); } // Check for cached response (scoped to key + user + method + path) if !recently_negative && let Ok(Some(cached)) = crate::db::idempotency::get_cached_response(&state.db, &idem_key, user_id, &method, &path).await { tracing::debug!(key = %idem_key, "returning cached idempotency response"); let status = StatusCode::from_u16(cached.status_code as u16).unwrap_or(StatusCode::OK); return (status, cached.response_body).into_response(); } // Cache miss (or skipped via negative cache) — record the miss timestamp // so subsequent calls in the next NEG_TTL_SECS can skip the DB SELECT. neg_cache.insert(neg_key.clone(), std::time::Instant::now()); // Run the handler let response = next.run(request).await; // Cache the response (fire-and-forget — don't block the response on DB write) let status_code = response.status().as_u16(); // Only cache successful responses (2xx/3xx) to avoid caching transient errors if status_code < 400 { // Only cache when content-length is present AND <= 1MB. We must decide // BEFORE consuming the body, otherwise a chunked / unknown-length response // that exceeds the cap would be silently truncated to empty — a correctness // landmine, since the status + headers would still claim success. let content_length = response.headers() .get(axum::http::header::CONTENT_LENGTH) .and_then(|v| v.to_str().ok()) .and_then(|v| v.parse::().ok()); let Some(len) = content_length else { tracing::debug!( key = %idem_key, method = %method, path = %path, "no content-length on response; skipping idempotency cache (body left intact)" ); return response; }; if len > 1024 * 1024 { tracing::info!( key = %idem_key, method = %method, path = %path, len, "response body exceeds 1MB; skipping idempotency cache" ); return response; } // Extract body bytes to cache. Content-length confirms <= 1MB, so this // should not exceed the limit; if it does, that's a header/body mismatch // and we surface 500 rather than silently dropping the body. let (parts, body) = response.into_parts(); let body_bytes = match axum::body::to_bytes(body, 1024 * 1024).await { Ok(b) => b, Err(e) => { tracing::error!( key = %idem_key, method = %method, path = %path, error = ?e, "response body exceeded 1MB despite content-length <= 1MB; failing closed" ); return axum::response::Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(axum::body::Body::from("internal error")) .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response()); } }; // Only cache UTF-8 responses — skip binary content to avoid corruption if let Ok(body_str) = std::str::from_utf8(&body_bytes) { let body_owned = body_str.to_owned(); let db = state.db.clone(); let key = idem_key.clone(); // Evict the negative-cache entry now that this key has a real // cached response — subsequent requests should hit the DB and get // the cached body, not keep skipping via the stale negative. neg_cache.remove(&neg_key); tokio::spawn(async move { if let Err(e) = crate::db::idempotency::store_response( &db, &key, user_id, &method, &path, status_code, &body_owned, ).await { tracing::warn!(key = %key, error = ?e, "failed to store idempotency key"); } }); } axum::response::Response::from_parts(parts, axum::body::Body::from(body_bytes)) } else { response } } /// Snapshot of current metrics for the admin dashboard. pub struct MetricsSnapshot { pub total_requests: u64, pub total_5xx: u64, pub total_errors: u64, /// (method, path, status, count) sorted by count descending pub top_routes: Vec<(String, String, String, u64)>, /// (kind, count) sorted by count descending pub error_breakdown: Vec<(String, u64)>, } /// Parse the Prometheus text output into a structured snapshot. /// This avoids adding a dependency on the prometheus data model — just /// string-parses the exposition format we produce ourselves. pub fn snapshot(handle: &PrometheusHandle) -> MetricsSnapshot { let text = handle.render(); let mut total_requests: u64 = 0; let mut total_5xx: u64 = 0; let mut routes: Vec<(String, String, String, u64)> = Vec::new(); let mut errors: Vec<(String, u64)> = Vec::new(); for line in text.lines() { if line.starts_with('#') || line.is_empty() { continue; } if let Some(rest) = line.strip_prefix("http_requests_total{") { if let Some((labels, value)) = rest.rsplit_once("} ") { let count: u64 = value.parse().unwrap_or(0); let method = extract_label(labels, "method"); let path = extract_label(labels, "path"); let status = extract_label(labels, "status"); total_requests += count; if status == "5xx" { total_5xx += count; } routes.push((method, path, status, count)); } } else if let Some(rest) = line.strip_prefix("http_errors_total{") && let Some((labels, value)) = rest.rsplit_once("} ") { let count: u64 = value.parse().unwrap_or(0); let kind = extract_label(labels, "kind"); errors.push((kind, count)); } } routes.sort_by_key(|r| std::cmp::Reverse(r.3)); routes.truncate(20); errors.sort_by_key(|e| std::cmp::Reverse(e.1)); let total_errors = errors.iter().map(|(_, c)| c).sum(); MetricsSnapshot { total_requests, total_5xx, total_errors, top_routes: routes, error_breakdown: errors, } } /// Extract a label value from a Prometheus label string like `method="GET",path="/",status="2xx"`. fn extract_label(labels: &str, key: &str) -> String { let prefix = format!("{key}=\""); labels .split(',') .find_map(|part| { let part = part.trim(); part.strip_prefix(&prefix) .and_then(|rest| rest.strip_suffix('"')) .map(|v| v.to_string()) }) .unwrap_or_default() } #[cfg(test)] mod tests { use super::*; #[test] fn status_class_mapping() { assert_eq!(status_class(200), "2xx"); assert_eq!(status_class(201), "2xx"); assert_eq!(status_class(301), "3xx"); assert_eq!(status_class(404), "4xx"); assert_eq!(status_class(500), "5xx"); assert_eq!(status_class(100), "other"); } }