//! Custom domain management API endpoints. use axum::{ extract::{Path, Query, State}, response::IntoResponse, Form, Json, }; use serde::Deserialize; use serde_json::json; use crate::{ auth::AuthUser, db::{self, CustomDomainId}, error::{AppError, Result}, AppState, }; #[derive(Deserialize)] pub(super) struct AddDomainRequest { domain: String, } /// POST /api/domains: add a custom domain. #[tracing::instrument(skip_all, name = "api::domains::add")] pub(super) async fn add_domain( State(state): State, AuthUser(session_user): AuthUser, Form(req): Form, ) -> Result { session_user.check_not_sandbox()?; let domain = normalize_domain(&req.domain)?; validate_domain(&domain)?; let verification_token = generate_verification_token(); let _row = db::custom_domains::create_custom_domain(&state.db, session_user.id, &domain, &verification_token) .await?; let instructions = format!( "Add two DNS records, then click Verify: a CNAME {0}connect.makenot.work (set DNS-only / unproxied), and a TXT _mnw-verify.{0} with value {1}.", domain, verification_token ); Ok(axum::response::Html(format!( "

{}

", instructions ))) } #[derive(Deserialize)] pub(super) struct VerifyDomainRequest { domain_id: CustomDomainId, } /// POST /api/domains/verify: trigger DNS verification. #[tracing::instrument(skip_all, name = "api::domains::verify")] pub(super) async fn verify_domain( State(state): State, AuthUser(session_user): AuthUser, Form(req): Form, ) -> Result { session_user.check_not_sandbox()?; let cd = db::custom_domains::get_custom_domain_by_user(&state.db, session_user.id) .await? .ok_or(AppError::NotFound)?; if cd.id != req.domain_id { return Err(AppError::NotFound); } if cd.verified { return Ok(axum::response::Html("

Domain already verified.

".to_string())); } // Query DNS via Cloudflare DNS-over-HTTPS let lookup_name = format!("_mnw-verify.{}", cd.domain); let txt_records = dns_lookup_txt(&lookup_name).await?; let matched = txt_records .iter() .any(|txt| txt.trim() == cd.verification_token); if !matched { return Ok(axum::response::Html(format!( "

TXT record not found. Add _mnw-verify.{} TXT {} and try again.

", cd.domain, cd.verification_token ))); } // Mark verified in DB and update cache db::custom_domains::mark_domain_verified(&state.db, cd.id).await?; state .domain_cache .insert(cd.domain.clone(), session_user.id); Ok(axum::response::Html("

Domain verified successfully. Reload to see changes.

".to_string())) } /// DELETE /api/domains/{id}: remove a custom domain. #[tracing::instrument(skip_all, name = "api::domains::remove")] pub(super) async fn remove_domain( State(state): State, AuthUser(session_user): AuthUser, Path(id): Path, ) -> Result { session_user.check_not_sandbox()?; // Fetch first so we can remove from cache let cd = db::custom_domains::get_custom_domain_by_user(&state.db, session_user.id) .await? .ok_or(AppError::NotFound)?; if cd.id != id { return Err(AppError::NotFound); } db::custom_domains::delete_custom_domain(&state.db, id, session_user.id).await?; // Remove from cache state.domain_cache.remove(&cd.domain); Ok(axum::http::StatusCode::NO_CONTENT) } /// GET /api/domains: get current user's custom domain. #[tracing::instrument(skip_all, name = "api::domains::get")] pub(super) async fn get_domain( State(state): State, AuthUser(session_user): AuthUser, ) -> Result { let cd = db::custom_domains::get_custom_domain_by_user(&state.db, session_user.id).await?; match cd { Some(d) => { let instructions = if d.verified { String::new() } else { format!( "Point {0} at connect.makenot.work (CNAME, DNS-only) and add a TXT _mnw-verify.{0} with value {1}, then verify.", d.domain, d.verification_token ) }; Ok(Json(json!({ "id": d.id, "domain": d.domain, "verified": d.verified, "verification_token": d.verification_token, "instructions": instructions, "verified_at": d.verified_at, }))) } None => Ok(Json(json!(null))), } } #[derive(Deserialize)] pub(super) struct CaddyAskQuery { domain: String, } /// GET /api/domains/caddy-ask: Caddy on-demand TLS check. /// /// Unauthenticated (called by Caddy, not users). Returns 200 if the domain /// is verified, 404 if not. Caddy treats anything outside the 2xx range as /// "do not issue." /// /// Abuse model: every request reaches MNW from Caddy's IP, so per-IP rate /// limiting can't distinguish a legitimate visitor from an attacker probing /// thousands of bogus hostnames. Defenses here are content-based: /// 1. Reject syntactically-invalid hostnames before touching the DB. /// 2. Bound concurrent cache-miss DB lookups via `caddy_ask_semaphore`; /// at capacity, return 503 so Caddy backs off and the pool stays free. #[tracing::instrument(skip_all, name = "api::domains::caddy_ask")] pub(super) async fn caddy_ask( State(state): State, Query(q): Query, ) -> impl IntoResponse { use metrics::{counter, gauge}; let domain = q.domain.to_lowercase(); // Cheap structural reject: hostnames with no dot, spaces, control chars, // or absurd length can't be verified domains. Skipping the DB on these // shuts down the easiest flood vector. if domain.is_empty() || domain.len() > 253 || !domain.contains('.') || domain.contains(|c: char| c.is_whitespace() || c.is_control()) { counter!("caddy_ask_total", "outcome" => "rejected_invalid").increment(1); return axum::http::StatusCode::NOT_FOUND; } // Fast path: cache hit, no DB, no semaphore. if state.domain_cache.contains_key(&domain) { counter!("caddy_ask_total", "outcome" => "cache_hit").increment(1); return axum::http::StatusCode::OK; } // Slow path: cap concurrent DB lookups. `try_acquire` is non-blocking; // hitting the cap means we're already under pressure and want Caddy to // retry rather than queue. let Ok(_permit) = state.caddy_ask_semaphore.try_acquire() else { tracing::warn!(domain = %domain, "caddy-ask: cache-miss concurrency cap reached"); counter!("caddy_ask_total", "outcome" => "rejected_at_cap").increment(1); return axum::http::StatusCode::SERVICE_UNAVAILABLE; }; match db::custom_domains::get_verified_domain(&state.db, &domain).await { Ok(Some(d)) => { state.domain_cache.insert(d.domain, d.user_id); // Update cache-size gauge after the insert so it reflects current state. gauge!("domain_cache_entries").set(state.domain_cache.len() as f64); counter!("caddy_ask_total", "outcome" => "miss_found").increment(1); axum::http::StatusCode::OK } _ => { counter!("caddy_ask_total", "outcome" => "miss_notfound").increment(1); axum::http::StatusCode::NOT_FOUND } } } // ── Helpers ── /// Normalize a domain: lowercase, strip protocol, strip trailing slash/path. fn normalize_domain(input: &str) -> Result { let mut domain = input.trim().to_lowercase(); // Strip protocol if present if let Some(rest) = domain.strip_prefix("https://") { domain = rest.to_string(); } else if let Some(rest) = domain.strip_prefix("http://") { domain = rest.to_string(); } // Strip path if let Some(pos) = domain.find('/') { domain.truncate(pos); } // Strip port if let Some(pos) = domain.find(':') { domain.truncate(pos); } if domain.is_empty() { return Err(AppError::validation("Domain cannot be empty.".to_string())); } Ok(domain) } /// Basic domain validation: must have at least one dot, no spaces, reasonable length. fn validate_domain(domain: &str) -> Result<()> { if domain.len() > 253 { return Err(AppError::validation( "Domain name is too long.".to_string(), )); } if !domain.contains('.') { return Err(AppError::validation( "Domain must include a TLD (e.g. example.com).".to_string(), )); } if domain.contains(' ') || domain.contains('\t') { return Err(AppError::validation( "Domain cannot contain spaces.".to_string(), )); } // Block MNW domains if domain == "makenot.work" || domain.ends_with(".makenot.work") || domain == "makenotwork.com" || domain.ends_with(".makenotwork.com") { return Err(AppError::validation( "Cannot use a makenot.work domain.".to_string(), )); } // Check labels for label in domain.split('.') { if label.is_empty() || label.len() > 63 { return Err(AppError::validation( "Each domain label must be 1-63 characters.".to_string(), )); } if !label .chars() .all(|c| c.is_ascii_alphanumeric() || c == '-') { return Err(AppError::validation( "Domain labels can only contain letters, numbers, and hyphens.".to_string(), )); } if label.starts_with('-') || label.ends_with('-') { return Err(AppError::validation( "Domain labels cannot start or end with a hyphen.".to_string(), )); } } Ok(()) } /// Generate a random verification token. fn generate_verification_token() -> String { let mut bytes = [0u8; 16]; rand::RngCore::fill_bytes(&mut rand::rng(), &mut bytes); format!("mnw-verify-{}", hex::encode(bytes)) } /// Query TXT records via Cloudflare DNS-over-HTTPS. async fn dns_lookup_txt(name: &str) -> Result> { let client = reqwest::Client::new(); let resp = client .get("https://cloudflare-dns.com/dns-query") .query(&[("name", name), ("type", "TXT")]) .header("Accept", "application/dns-json") .timeout(std::time::Duration::from_secs(10)) .send() .await .map_err(|e| { tracing::warn!(error = ?e, name = %name, "DNS lookup failed"); AppError::BadRequest("DNS lookup failed. Please try again.".to_string()) })?; if !resp.status().is_success() { return Err(AppError::BadRequest( "DNS lookup failed. Please try again.".to_string(), )); } let body: serde_json::Value = resp.json().await.map_err(|e| { tracing::warn!(error = ?e, "Failed to parse DNS response"); AppError::BadRequest("DNS lookup failed. Please try again.".to_string()) })?; // Parse Cloudflare DNS-over-HTTPS JSON response let mut records = Vec::new(); if let Some(answers) = body["Answer"].as_array() { for answer in answers { if answer["type"].as_u64() == Some(16) { // TXT record type = 16 if let Some(data) = answer["data"].as_str() { // Remove surrounding quotes from TXT record data let cleaned = data.trim_matches('"'); records.push(cleaned.to_string()); } } } } Ok(records) } #[cfg(test)] mod tests { use super::*; #[test] fn normalize_strips_protocol() { assert_eq!(normalize_domain("https://example.com").unwrap(), "example.com"); assert_eq!(normalize_domain("http://example.com").unwrap(), "example.com"); } #[test] fn normalize_strips_path_and_port() { assert_eq!(normalize_domain("example.com/path").unwrap(), "example.com"); assert_eq!(normalize_domain("example.com:443").unwrap(), "example.com"); } #[test] fn normalize_lowercases() { assert_eq!(normalize_domain("EXAMPLE.COM").unwrap(), "example.com"); } #[test] fn normalize_empty_errors() { assert!(normalize_domain("").is_err()); } #[test] fn validate_valid_domains() { assert!(validate_domain("example.com").is_ok()); assert!(validate_domain("sub.example.com").is_ok()); assert!(validate_domain("my-site.co.uk").is_ok()); } #[test] fn validate_no_tld() { assert!(validate_domain("localhost").is_err()); } #[test] fn validate_mnw_blocked() { assert!(validate_domain("makenot.work").is_err()); assert!(validate_domain("sub.makenot.work").is_err()); assert!(validate_domain("makenotwork.com").is_err()); } #[test] fn validate_spaces() { assert!(validate_domain("exam ple.com").is_err()); } #[test] fn validate_hyphen_edges() { assert!(validate_domain("-example.com").is_err()); assert!(validate_domain("example-.com").is_err()); } #[test] fn verification_token_format() { let token = generate_verification_token(); assert!(token.starts_with("mnw-verify-")); assert_eq!(token.len(), "mnw-verify-".len() + 32); } }