Skip to main content

max / makenotwork

5.1 KB · 196 lines History Blame Raw
1 //! Custom domain CRUD queries.
2
3 use sqlx::PgPool;
4
5 use super::models::DbCustomDomain;
6 use super::{CustomDomainId, UserId};
7 use crate::error::{AppError, Result};
8
9 /// Create a custom domain entry with a verification token.
10 /// Enforces a 1-domain-per-user limit using a transaction to prevent TOCTOU races.
11 #[tracing::instrument(skip_all)]
12 pub async fn create_custom_domain(
13 pool: &PgPool,
14 user_id: UserId,
15 domain: &str,
16 verification_token: &str,
17 ) -> Result<DbCustomDomain> {
18 let mut tx = pool.begin().await?;
19
20 // Lock existing rows to serialize concurrent domain creation attempts
21 let existing: Vec<(CustomDomainId,)> = sqlx::query_as(
22 "SELECT id FROM custom_domains WHERE user_id = $1 FOR UPDATE",
23 )
24 .bind(user_id)
25 .fetch_all(&mut *tx)
26 .await?;
27
28 if !existing.is_empty() {
29 return Err(AppError::BadRequest(
30 "You already have a custom domain configured. Remove it first to add a new one.".to_string(),
31 ));
32 }
33
34 let row = sqlx::query_as::<_, DbCustomDomain>(
35 r#"
36 INSERT INTO custom_domains (user_id, domain, verification_token)
37 VALUES ($1, $2, $3)
38 RETURNING *
39 "#,
40 )
41 .bind(user_id)
42 .bind(domain)
43 .bind(verification_token)
44 .fetch_one(&mut *tx)
45 .await?;
46
47 tx.commit().await?;
48 Ok(row)
49 }
50
51 /// Get the custom domain for a user (at most one).
52 #[tracing::instrument(skip_all)]
53 pub async fn get_custom_domain_by_user(
54 pool: &PgPool,
55 user_id: UserId,
56 ) -> Result<Option<DbCustomDomain>> {
57 let row = sqlx::query_as::<_, DbCustomDomain>(
58 "SELECT * FROM custom_domains WHERE user_id = $1",
59 )
60 .bind(user_id)
61 .fetch_optional(pool)
62 .await?;
63
64 Ok(row)
65 }
66
67 /// Look up a verified domain by hostname (for routing).
68 #[tracing::instrument(skip_all)]
69 pub async fn get_verified_domain(
70 pool: &PgPool,
71 domain: &str,
72 ) -> Result<Option<DbCustomDomain>> {
73 let row = sqlx::query_as::<_, DbCustomDomain>(
74 "SELECT * FROM custom_domains WHERE domain = $1 AND verified = true",
75 )
76 .bind(domain)
77 .fetch_optional(pool)
78 .await?;
79
80 Ok(row)
81 }
82
83 /// Mark a domain as verified.
84 #[tracing::instrument(skip_all)]
85 pub async fn mark_domain_verified(pool: &PgPool, domain_id: CustomDomainId) -> Result<()> {
86 sqlx::query("UPDATE custom_domains SET verified = true, verified_at = NOW() WHERE id = $1")
87 .bind(domain_id)
88 .execute(pool)
89 .await?;
90
91 Ok(())
92 }
93
94 /// Delete a custom domain (only if owned by the given user).
95 #[tracing::instrument(skip_all)]
96 pub async fn delete_custom_domain(
97 pool: &PgPool,
98 domain_id: CustomDomainId,
99 user_id: UserId,
100 ) -> Result<()> {
101 let result = sqlx::query(
102 "DELETE FROM custom_domains WHERE id = $1 AND user_id = $2",
103 )
104 .bind(domain_id)
105 .bind(user_id)
106 .execute(pool)
107 .await?;
108
109 if result.rows_affected() == 0 {
110 return Err(AppError::NotFound);
111 }
112
113 Ok(())
114 }
115
116 /// Get all verified domains (for cache warm-up on startup).
117 #[tracing::instrument(skip_all)]
118 pub async fn get_all_verified_domains(pool: &PgPool) -> Result<Vec<DbCustomDomain>> {
119 let rows = sqlx::query_as::<_, DbCustomDomain>(
120 "SELECT * FROM custom_domains WHERE verified = true",
121 )
122 .fetch_all(pool)
123 .await?;
124
125 Ok(rows)
126 }
127
128 #[cfg(test)]
129 mod tests {
130 use super::*;
131
132 #[test]
133 fn custom_domain_id_new_is_unique() {
134 let a = CustomDomainId::new();
135 let b = CustomDomainId::new();
136 assert_ne!(a, b);
137 }
138
139 #[test]
140 fn custom_domain_id_nil() {
141 let nil = CustomDomainId::nil();
142 assert_eq!(*nil.as_uuid(), uuid::Uuid::nil());
143 }
144
145 #[test]
146 fn custom_domain_id_display() {
147 let id = CustomDomainId::nil();
148 assert_eq!(id.to_string(), "00000000-0000-0000-0000-000000000000");
149 }
150
151 #[test]
152 fn custom_domain_id_serde_roundtrip() {
153 let id = CustomDomainId::new();
154 let json = serde_json::to_string(&id).unwrap();
155 let parsed: CustomDomainId = serde_json::from_str(&json).unwrap();
156 assert_eq!(id, parsed);
157 }
158
159 #[test]
160 fn bad_request_error_contains_message() {
161 let err = AppError::BadRequest(
162 "You already have a custom domain configured. Remove it first to add a new one."
163 .to_string(),
164 );
165 let msg = err.user_message();
166 assert!(msg.contains("already have a custom domain"));
167 }
168
169 #[test]
170 fn not_found_error_status() {
171 let err = AppError::NotFound;
172 assert_eq!(err.status_code(), axum::http::StatusCode::NOT_FOUND);
173 }
174
175 #[test]
176 fn user_id_and_custom_domain_id_are_distinct_types() {
177 // Compile-time type safety: these are different types wrapping UUIDs.
178 let uid = UserId::new();
179 let did = CustomDomainId::new();
180 assert_ne!(uid.as_uuid(), did.as_uuid());
181 }
182
183 #[test]
184 fn db_custom_domain_struct_is_clone() {
185 // DbCustomDomain derives Clone — verify it compiles.
186 fn assert_clone<T: Clone>() {}
187 assert_clone::<DbCustomDomain>();
188 }
189
190 #[test]
191 fn db_custom_domain_struct_is_debug() {
192 fn assert_debug<T: std::fmt::Debug>() {}
193 assert_debug::<DbCustomDomain>();
194 }
195 }
196