Skip to main content

max / makenotwork

18.5 KB · 479 lines History Blame Raw
1 //! Prometheus metrics: HTTP request tracking, error counters, DB pool gauges.
2 //!
3 //! Call [`init`] once at startup to install the Prometheus recorder.
4 //! Call [`render`] from the `/metrics` endpoint to produce the scrape output.
5 //! The HTTP middleware in [`metrics_middleware`] records per-request metrics.
6
7 use axum::{
8 extract::{MatchedPath, Request, State},
9 middleware::Next,
10 response::{IntoResponse, Response},
11 };
12 use metrics::{counter, gauge, histogram};
13 use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
14 use std::time::Instant;
15
16 /// Install the global Prometheus recorder. Returns the handle used to render
17 /// the scrape output. Call once from `main`.
18 pub fn init() -> PrometheusHandle {
19 PrometheusBuilder::new()
20 .install_recorder()
21 .expect("failed to install Prometheus recorder")
22 }
23
24 /// Render all collected metrics in Prometheus exposition format.
25 pub async fn render(State(handle): State<PrometheusHandle>) -> impl IntoResponse {
26 handle.render()
27 }
28
29 /// Axum middleware that sets `Cache-Control` headers based on route path.
30 ///
31 /// Public content pages get CDN-friendly caching (Cloudflare caches for 60s,
32 /// browsers always revalidate). Dashboard, API, and auth routes get no caching.
33 pub async fn cache_control_middleware(request: Request, next: Next) -> Response {
34 let path = request.uri().path().to_string();
35 let mut response = next.run(request).await;
36
37 // Don't override if a handler already set Cache-Control
38 if response.headers().contains_key(axum::http::header::CACHE_CONTROL) {
39 return response;
40 }
41
42 let value = if is_public_page(&path) {
43 // CDN caches 60s, browser always revalidates, stale content served while refreshing
44 "public, max-age=0, s-maxage=60, stale-while-revalidate=300"
45 } else if path.starts_with("/api/") || path.starts_with("/stripe/") || path.starts_with("/postmark/") {
46 "no-store"
47 } else {
48 // Dashboard, admin, auth — private, always revalidate
49 "private, no-cache"
50 };
51
52 response.headers_mut().insert(
53 axum::http::header::CACHE_CONTROL,
54 axum::http::HeaderValue::from_static(value),
55 );
56
57 // Add API version header to all /api/* responses
58 if path.starts_with("/api/") {
59 response.headers_mut().insert(
60 axum::http::HeaderName::from_static("mnw-version"),
61 axum::http::HeaderValue::from_static("2026-04-23"),
62 );
63 }
64
65 response
66 }
67
68 /// Returns true for public content pages that benefit from CDN caching.
69 fn is_public_page(path: &str) -> bool {
70 matches!(path, "/" | "/discover" | "/pricing" | "/source")
71 || path.starts_with("/p/")
72 || path.starts_with("/i/")
73 || path.starts_with("/u/")
74 || path.starts_with("/c/")
75 || path.starts_with("/docs")
76 || path.starts_with("/discover/")
77 || path.starts_with("/source/")
78 || path.starts_with("/feed")
79 }
80
81 /// Axum middleware that records HTTP request metrics.
82 ///
83 /// For every request, records:
84 /// - `http_requests_total` counter with labels: method, path, status
85 /// - `http_request_duration_seconds` histogram with labels: method, path, status
86 ///
87 /// `path` uses the matched route pattern (e.g. `/api/items/:id`) to keep
88 /// cardinality bounded. Unmatched routes are grouped under `<unmatched>`.
89 pub async fn metrics_middleware(request: Request, next: Next) -> Response {
90 let method = request.method().to_string();
91 let path = request
92 .extensions()
93 .get::<MatchedPath>()
94 .map(|p| p.as_str().to_string())
95 .unwrap_or_else(|| "<unmatched>".to_string());
96
97 let start = Instant::now();
98 let response = next.run(request).await;
99 let duration = start.elapsed().as_secs_f64();
100
101 let status = status_class(response.status().as_u16());
102
103 let labels = [
104 ("method", method),
105 ("path", path),
106 ("status", status.to_string()),
107 ];
108
109 counter!("http_requests_total", &labels).increment(1);
110 histogram!("http_request_duration_seconds", &labels).record(duration);
111
112 response
113 }
114
115 /// Collapse HTTP status codes into classes to keep label cardinality low.
116 fn status_class(code: u16) -> &'static str {
117 match code {
118 200..=299 => "2xx",
119 300..=399 => "3xx",
120 400..=499 => "4xx",
121 500..=599 => "5xx",
122 _ => "other",
123 }
124 }
125
126 /// Record current DB pool statistics as gauges. Call periodically from the
127 /// health monitor or a dedicated task.
128 pub fn record_db_pool_stats(pool: &sqlx::PgPool) {
129 let size = pool.size() as f64;
130 let idle = pool.num_idle() as f64;
131 gauge!("db_pool_connections_max").set(size);
132 gauge!("db_pool_connections_idle").set(idle);
133 gauge!("db_pool_connections_active").set(size - idle);
134 }
135
136 /// Record server-wide Postgres saturation as Prometheus gauges. Sibling of
137 /// the local-pool gauges — this looks at the SHARED Postgres (MNW + MT +
138 /// ad hoc clients) so a dashboard can see the global ceiling, not just our
139 /// pool's share. Cheap query (single row from `pg_stat_activity`).
140 ///
141 /// Returns `(active_backends, max_connections)` so the caller can also fire
142 /// the existing WAM alert when utilization climbs past the threshold.
143 #[tracing::instrument(skip_all)]
144 pub async fn record_pg_stat_activity(pool: &sqlx::PgPool) -> Option<(i64, i64)> {
145 let row: Result<(i64, i64), _> = sqlx::query_as(
146 "SELECT \
147 (SELECT count(*) FROM pg_stat_activity \
148 WHERE state IS NOT NULL AND backend_type = 'client backend')::bigint, \
149 current_setting('max_connections')::bigint",
150 )
151 .fetch_one(pool)
152 .await;
153
154 match row {
155 Ok((active, max_conn)) if max_conn > 0 => {
156 gauge!("pg_stat_activity_active_backends").set(active as f64);
157 gauge!("pg_stat_activity_max_connections").set(max_conn as f64);
158 gauge!("pg_stat_activity_utilization_ratio").set(active as f64 / max_conn as f64);
159 Some((active, max_conn))
160 }
161 Ok(_) => None,
162 Err(e) => {
163 tracing::debug!(error = ?e, "pg_stat_activity gauge update failed");
164 None
165 }
166 }
167 }
168
169 /// Aggregated storage fill metrics across all paying creators.
170 ///
171 /// Emits three gauges so the dashboard can compute fill ratio without
172 /// joining queries client-side:
173 /// - `creator_storage_used_bytes_total` — sum of `users.storage_used_bytes`
174 /// for users with an active creator subscription.
175 /// - `creator_storage_cap_bytes_total` — sum of the corresponding tier caps.
176 /// - `creator_storage_fill_ratio` — used / cap.
177 ///
178 /// Pricing economics assume ~20% fill; 60%+ is 3× projection and warrants
179 /// re-pricing. This gauge is the canonical input for that threshold.
180 #[tracing::instrument(skip_all)]
181 pub async fn record_storage_fill_stats(pool: &sqlx::PgPool) {
182 // Tier cap table: kept in SQL because the values rarely change and
183 // mirror `CreatorTier::max_storage_bytes`. Update both sites if a tier
184 // cap moves. Tiers without a creator_subscriptions row are excluded —
185 // a user without an active subscription has no committed storage cap.
186 let row: Result<(i64, i64), _> = sqlx::query_as(
187 r#"
188 WITH tier_caps(tier, cap_bytes) AS (
189 VALUES
190 ('basic', 50::bigint * 1024 * 1024 * 1024),
191 ('small_files', 250::bigint * 1024 * 1024 * 1024),
192 ('big_files', 1024::bigint * 1024 * 1024 * 1024),
193 ('everything', 5120::bigint * 1024 * 1024 * 1024)
194 )
195 SELECT
196 COALESCE(SUM(u.storage_used_bytes), 0)::bigint AS used,
197 COALESCE(SUM(tc.cap_bytes), 0)::bigint AS cap
198 FROM users u
199 JOIN creator_subscriptions cs
200 ON cs.user_id = u.id AND cs.status = 'active'
201 JOIN tier_caps tc ON tc.tier = cs.tier
202 "#,
203 )
204 .fetch_one(pool)
205 .await;
206
207 match row {
208 Ok((used, cap)) => {
209 gauge!("creator_storage_used_bytes_total").set(used as f64);
210 gauge!("creator_storage_cap_bytes_total").set(cap as f64);
211 let ratio = if cap > 0 { used as f64 / cap as f64 } else { 0.0 };
212 gauge!("creator_storage_fill_ratio").set(ratio);
213 }
214 Err(e) => {
215 tracing::debug!(error = ?e, "storage fill stats query failed");
216 }
217 }
218 }
219
220 /// Emit the current `domain_cache` size as a gauge so dashboards can track
221 /// cache growth + correlate with `caddy_ask_total{outcome="cache_hit"}`.
222 pub fn record_domain_cache_size(size: usize) {
223 gauge!("domain_cache_entries").set(size as f64);
224 }
225
226 /// Axum middleware that implements idempotency keys for POST endpoints.
227 ///
228 /// If the request includes an `Idempotency-Key` header and the user is
229 /// authenticated, checks for a cached response. If found, returns the cached
230 /// response immediately. Otherwise, runs the handler and caches the result.
231 ///
232 /// Skips silently if no `Idempotency-Key` header is present or if the user
233 /// is not authenticated (no session user).
234 pub async fn idempotency_middleware(
235 State(state): State<crate::AppState>,
236 request: Request,
237 next: Next,
238 ) -> Response {
239 use axum::http::StatusCode;
240
241 // Only applies to POST/PUT methods
242 if !matches!(*request.method(), axum::http::Method::POST | axum::http::Method::PUT) {
243 return next.run(request).await;
244 }
245
246 // Extract idempotency key from header
247 let idem_key = request
248 .headers()
249 .get("idempotency-key")
250 .and_then(|v| v.to_str().ok())
251 .map(|s| s.to_string());
252
253 let idem_key = match idem_key {
254 Some(k) if !k.is_empty() && k.len() <= 256 => k,
255 _ => return next.run(request).await, // No key — proceed normally
256 };
257
258 // Extract user ID from session (must be authenticated)
259 let session = request.extensions().get::<tower_sessions::Session>().cloned();
260 let user_id: Option<crate::db::UserId> = if let Some(ref session) = session {
261 session
262 .get::<crate::auth::SessionUser>("user")
263 .await
264 .ok()
265 .flatten()
266 .map(|u| u.id)
267 } else {
268 None
269 };
270
271 let user_id = match user_id {
272 Some(id) => id,
273 None => return next.run(request).await, // Not authenticated — skip
274 };
275
276 let method = request.method().to_string();
277 let path = request.uri().path().to_string();
278
279 // In-memory negative cache: every POST/PUT with an Idempotency-Key was
280 // previously taking a pool conn for `get_cached_response` even when the
281 // key had never been seen — measurable cost on a hot POST that already
282 // makes 2-5 DB queries. Keys that recently returned None are tracked
283 // here so the SELECT is skipped. Single-process correctness: the DB
284 // cache table is written ONLY by this middleware's success path, which
285 // also evicts the key from this map.
286 type NegKey = (String, crate::db::UserId);
287 static NEG_CACHE: std::sync::OnceLock<dashmap::DashMap<NegKey, std::time::Instant>> =
288 std::sync::OnceLock::new();
289 const NEG_TTL_SECS: u64 = 60;
290 let neg_cache = NEG_CACHE.get_or_init(dashmap::DashMap::new);
291 let neg_key = (idem_key.clone(), user_id);
292 let recently_negative = neg_cache
293 .get(&neg_key)
294 .map(|e| e.elapsed().as_secs() < NEG_TTL_SECS)
295 .unwrap_or(false);
296
297 // Periodic GC to keep the map bounded under sustained burst (every ~1k
298 // misses we walk the map and drop expired entries).
299 static GC_TICK: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
300 if GC_TICK.fetch_add(1, std::sync::atomic::Ordering::Relaxed).is_multiple_of(1024) {
301 neg_cache.retain(|_, t| t.elapsed().as_secs() < NEG_TTL_SECS);
302 }
303
304 // Check for cached response (scoped to key + user + method + path)
305 if !recently_negative
306 && let Ok(Some(cached)) = crate::db::idempotency::get_cached_response(&state.db, &idem_key, user_id, &method, &path).await
307 {
308 tracing::debug!(key = %idem_key, "returning cached idempotency response");
309 let status = StatusCode::from_u16(cached.status_code as u16).unwrap_or(StatusCode::OK);
310 return (status, cached.response_body).into_response();
311 }
312
313 // Cache miss (or skipped via negative cache) — record the miss timestamp
314 // so subsequent calls in the next NEG_TTL_SECS can skip the DB SELECT.
315 neg_cache.insert(neg_key.clone(), std::time::Instant::now());
316
317 // Run the handler
318 let response = next.run(request).await;
319
320 // Cache the response (fire-and-forget — don't block the response on DB write)
321 let status_code = response.status().as_u16();
322
323 // Only cache successful responses (2xx/3xx) to avoid caching transient errors
324 if status_code < 400 {
325 // Only cache when content-length is present AND <= 1MB. We must decide
326 // BEFORE consuming the body, otherwise a chunked / unknown-length response
327 // that exceeds the cap would be silently truncated to empty — a correctness
328 // landmine, since the status + headers would still claim success.
329 let content_length = response.headers()
330 .get(axum::http::header::CONTENT_LENGTH)
331 .and_then(|v| v.to_str().ok())
332 .and_then(|v| v.parse::<usize>().ok());
333 let Some(len) = content_length else {
334 tracing::debug!(
335 key = %idem_key, method = %method, path = %path,
336 "no content-length on response; skipping idempotency cache (body left intact)"
337 );
338 return response;
339 };
340 if len > 1024 * 1024 {
341 tracing::info!(
342 key = %idem_key, method = %method, path = %path, len,
343 "response body exceeds 1MB; skipping idempotency cache"
344 );
345 return response;
346 }
347
348 // Extract body bytes to cache. Content-length confirms <= 1MB, so this
349 // should not exceed the limit; if it does, that's a header/body mismatch
350 // and we surface 500 rather than silently dropping the body.
351 let (parts, body) = response.into_parts();
352 let body_bytes = match axum::body::to_bytes(body, 1024 * 1024).await {
353 Ok(b) => b,
354 Err(e) => {
355 tracing::error!(
356 key = %idem_key, method = %method, path = %path, error = ?e,
357 "response body exceeded 1MB despite content-length <= 1MB; failing closed"
358 );
359 return axum::response::Response::builder()
360 .status(StatusCode::INTERNAL_SERVER_ERROR)
361 .body(axum::body::Body::from("internal error"))
362 .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response());
363 }
364 };
365 // Only cache UTF-8 responses — skip binary content to avoid corruption
366 if let Ok(body_str) = std::str::from_utf8(&body_bytes) {
367 let body_owned = body_str.to_owned();
368 let db = state.db.clone();
369 let key = idem_key.clone();
370 // Evict the negative-cache entry now that this key has a real
371 // cached response — subsequent requests should hit the DB and get
372 // the cached body, not keep skipping via the stale negative.
373 neg_cache.remove(&neg_key);
374 tokio::spawn(async move {
375 if let Err(e) = crate::db::idempotency::store_response(
376 &db, &key, user_id, &method, &path, status_code, &body_owned,
377 ).await {
378 tracing::warn!(key = %key, error = ?e, "failed to store idempotency key");
379 }
380 });
381 }
382
383 axum::response::Response::from_parts(parts, axum::body::Body::from(body_bytes))
384 } else {
385 response
386 }
387 }
388
389 /// Snapshot of current metrics for the admin dashboard.
390 pub struct MetricsSnapshot {
391 pub total_requests: u64,
392 pub total_5xx: u64,
393 pub total_errors: u64,
394 /// (method, path, status, count) sorted by count descending
395 pub top_routes: Vec<(String, String, String, u64)>,
396 /// (kind, count) sorted by count descending
397 pub error_breakdown: Vec<(String, u64)>,
398 }
399
400 /// Parse the Prometheus text output into a structured snapshot.
401 /// This avoids adding a dependency on the prometheus data model — just
402 /// string-parses the exposition format we produce ourselves.
403 pub fn snapshot(handle: &PrometheusHandle) -> MetricsSnapshot {
404 let text = handle.render();
405 let mut total_requests: u64 = 0;
406 let mut total_5xx: u64 = 0;
407 let mut routes: Vec<(String, String, String, u64)> = Vec::new();
408 let mut errors: Vec<(String, u64)> = Vec::new();
409
410 for line in text.lines() {
411 if line.starts_with('#') || line.is_empty() {
412 continue;
413 }
414
415 if let Some(rest) = line.strip_prefix("http_requests_total{") {
416 if let Some((labels, value)) = rest.rsplit_once("} ") {
417 let count: u64 = value.parse().unwrap_or(0);
418 let method = extract_label(labels, "method");
419 let path = extract_label(labels, "path");
420 let status = extract_label(labels, "status");
421 total_requests += count;
422 if status == "5xx" {
423 total_5xx += count;
424 }
425 routes.push((method, path, status, count));
426 }
427 } else if let Some(rest) = line.strip_prefix("http_errors_total{")
428 && let Some((labels, value)) = rest.rsplit_once("} ")
429 {
430 let count: u64 = value.parse().unwrap_or(0);
431 let kind = extract_label(labels, "kind");
432 errors.push((kind, count));
433 }
434 }
435
436 routes.sort_by_key(|r| std::cmp::Reverse(r.3));
437 routes.truncate(20);
438 errors.sort_by_key(|e| std::cmp::Reverse(e.1));
439
440 let total_errors = errors.iter().map(|(_, c)| c).sum();
441
442 MetricsSnapshot {
443 total_requests,
444 total_5xx,
445 total_errors,
446 top_routes: routes,
447 error_breakdown: errors,
448 }
449 }
450
451 /// Extract a label value from a Prometheus label string like `method="GET",path="/",status="2xx"`.
452 fn extract_label(labels: &str, key: &str) -> String {
453 let prefix = format!("{key}=\"");
454 labels
455 .split(',')
456 .find_map(|part| {
457 let part = part.trim();
458 part.strip_prefix(&prefix)
459 .and_then(|rest| rest.strip_suffix('"'))
460 .map(|v| v.to_string())
461 })
462 .unwrap_or_default()
463 }
464
465 #[cfg(test)]
466 mod tests {
467 use super::*;
468
469 #[test]
470 fn status_class_mapping() {
471 assert_eq!(status_class(200), "2xx");
472 assert_eq!(status_class(201), "2xx");
473 assert_eq!(status_class(301), "3xx");
474 assert_eq!(status_class(404), "4xx");
475 assert_eq!(status_class(500), "5xx");
476 assert_eq!(status_class(100), "other");
477 }
478 }
479