Skip to main content

max / makenotwork

11.3 KB · 384 lines History Blame Raw
1 //! License key management: CRUD, activation tracking, and revocation.
2
3 use sqlx::PgPool;
4
5 use super::models::*;
6 use super::validated_types::KeyCode;
7 use super::{ItemId, LicenseActivationId, LicenseKeyId, TransactionId, UserId};
8 use crate::error::Result;
9
10 /// Create a new license key for an item.
11 ///
12 /// Retries once on a 23505 unique-violation with a freshly-generated code.
13 /// A real collision out of the wordlist generator is vanishingly rare (the
14 /// six-word space gives ~6B coin-flip headroom), but the alternative is
15 /// surfacing a 500 to whatever flow is creating the key.
16 #[tracing::instrument(skip_all)]
17 pub async fn create_license_key(
18 pool: &PgPool,
19 item_id: ItemId,
20 owner_id: UserId,
21 transaction_id: Option<TransactionId>,
22 key_code: &KeyCode,
23 max_activations: Option<i32>,
24 ) -> Result<DbLicenseKey> {
25 const SQL: &str = r#"
26 INSERT INTO license_keys (item_id, owner_id, transaction_id, key_code, max_activations)
27 VALUES ($1, $2, $3, $4, $5)
28 RETURNING *
29 "#;
30
31 let first = sqlx::query_as::<_, DbLicenseKey>(SQL)
32 .bind(item_id)
33 .bind(owner_id)
34 .bind(transaction_id)
35 .bind(key_code)
36 .bind(max_activations)
37 .fetch_one(pool)
38 .await;
39
40 match first {
41 Ok(key) => Ok(key),
42 Err(sqlx::Error::Database(e)) if e.code().as_deref() == Some("23505") => {
43 let retry_code = crate::helpers::generate_key_code();
44 tracing::warn!(item_id = %item_id, "license key 23505 collision; retrying once");
45 let key = sqlx::query_as::<_, DbLicenseKey>(SQL)
46 .bind(item_id)
47 .bind(owner_id)
48 .bind(transaction_id)
49 .bind(&retry_code)
50 .bind(max_activations)
51 .fetch_one(pool)
52 .await?;
53 Ok(key)
54 }
55 Err(e) => Err(e.into()),
56 }
57 }
58
59 /// Look up a license key by its code.
60 #[tracing::instrument(skip_all)]
61 pub async fn get_license_key_by_code(pool: &PgPool, key_code: &KeyCode) -> Result<Option<DbLicenseKey>> {
62 let key = sqlx::query_as::<_, DbLicenseKey>(
63 "SELECT * FROM license_keys WHERE key_code = $1",
64 )
65 .bind(key_code)
66 .fetch_optional(pool)
67 .await?;
68
69 Ok(key)
70 }
71
72 /// Get a license key by ID.
73 #[tracing::instrument(skip_all)]
74 pub async fn get_license_key_by_id(pool: &PgPool, id: LicenseKeyId) -> Result<Option<DbLicenseKey>> {
75 let key = sqlx::query_as::<_, DbLicenseKey>(
76 "SELECT * FROM license_keys WHERE id = $1",
77 )
78 .bind(id)
79 .fetch_optional(pool)
80 .await?;
81
82 Ok(key)
83 }
84
85 /// Count license keys for an item.
86 #[tracing::instrument(skip_all)]
87 pub async fn count_keys_by_item(pool: &PgPool, item_id: ItemId) -> Result<i64> {
88 let count: i64 = sqlx::query_scalar(
89 "SELECT COUNT(*) FROM license_keys WHERE item_id = $1",
90 )
91 .bind(item_id)
92 .fetch_one(pool)
93 .await?;
94
95 Ok(count)
96 }
97
98 /// List all license keys for an item, newest first.
99 ///
100 /// Hard-caps at 500 rows to bound memory and response size for the creator
101 /// dashboard list view. Items with more than 500 keys are uncommon;
102 /// future work could add cursor-based pagination if needed.
103 #[tracing::instrument(skip_all)]
104 pub async fn get_license_keys_by_item(pool: &PgPool, item_id: ItemId) -> Result<Vec<DbLicenseKey>> {
105 let keys = sqlx::query_as::<_, DbLicenseKey>(
106 "SELECT * FROM license_keys WHERE item_id = $1 ORDER BY created_at DESC LIMIT 500",
107 )
108 .bind(item_id)
109 .fetch_all(pool)
110 .await?;
111
112 Ok(keys)
113 }
114
115 /// Batch-load license keys for multiple items, grouped by item_id.
116 #[tracing::instrument(skip_all)]
117 pub async fn get_license_keys_by_items(
118 pool: &PgPool,
119 item_ids: &[ItemId],
120 ) -> Result<std::collections::HashMap<ItemId, Vec<DbLicenseKey>>> {
121 let keys = sqlx::query_as::<_, DbLicenseKey>(
122 "SELECT * FROM license_keys WHERE item_id = ANY($1) ORDER BY item_id, created_at DESC",
123 )
124 .bind(item_ids)
125 .fetch_all(pool)
126 .await?;
127
128 let mut map: std::collections::HashMap<ItemId, Vec<DbLicenseKey>> = std::collections::HashMap::new();
129 for k in keys {
130 map.entry(k.item_id).or_default().push(k);
131 }
132 Ok(map)
133 }
134
135 /// Find an existing activation for a key + machine combo.
136 #[tracing::instrument(skip_all)]
137 pub async fn get_activation(
138 pool: &PgPool,
139 license_key_id: LicenseKeyId,
140 machine_id: &str,
141 ) -> Result<Option<DbLicenseActivation>> {
142 let activation = sqlx::query_as::<_, DbLicenseActivation>(
143 "SELECT * FROM license_activations WHERE license_key_id = $1 AND machine_id = $2",
144 )
145 .bind(license_key_id)
146 .bind(machine_id)
147 .fetch_optional(pool)
148 .await?;
149
150 Ok(activation)
151 }
152
153 /// Update the last_validated_at timestamp for an existing activation.
154 #[tracing::instrument(skip_all)]
155 pub async fn touch_activation(pool: &PgPool, activation_id: LicenseActivationId) -> Result<()> {
156 sqlx::query(
157 "UPDATE license_activations SET last_validated_at = NOW() WHERE id = $1",
158 )
159 .bind(activation_id)
160 .execute(pool)
161 .await?;
162
163 Ok(())
164 }
165
166 /// Activate a license key on a machine, atomically enforcing max_activations.
167 ///
168 /// Uses a transaction with `FOR UPDATE` to serialize concurrent activations
169 /// for the same key. Re-activations (same machine_id) always succeed via
170 /// upsert. New activations are rejected if the active count would exceed
171 /// `max_activations`.
172 ///
173 /// Returns `None` if the activation limit has been reached.
174 ///
175 /// After the upsert, the denormalized `activation_count` on `license_keys`
176 /// is refreshed with a full COUNT rather than an increment; this avoids
177 /// drift if a crash leaves the count out of sync.
178 #[tracing::instrument(skip_all)]
179 pub async fn try_create_activation(
180 pool: &PgPool,
181 license_key_id: LicenseKeyId,
182 machine_id: &str,
183 label: Option<&str>,
184 max_activations: Option<i32>,
185 ) -> Result<Option<DbLicenseActivation>> {
186 let mut tx = pool.begin().await?;
187
188 // Lock the license key row to serialize concurrent activations
189 sqlx::query("SELECT 1 FROM license_keys WHERE id = $1 FOR UPDATE")
190 .bind(license_key_id)
191 .fetch_one(&mut *tx)
192 .await?;
193
194 // Check if this machine already has an activation (re-activation is always OK)
195 let existing: Option<DbLicenseActivation> = sqlx::query_as(
196 "SELECT * FROM license_activations WHERE license_key_id = $1 AND machine_id = $2",
197 )
198 .bind(license_key_id)
199 .bind(machine_id)
200 .fetch_optional(&mut *tx)
201 .await?;
202
203 // For truly new activations, enforce the limit
204 if existing.is_none()
205 && let Some(max) = max_activations
206 {
207 let count: i64 = sqlx::query_scalar(
208 "SELECT COUNT(*) FROM license_activations WHERE license_key_id = $1 AND is_active = true",
209 )
210 .bind(license_key_id)
211 .fetch_one(&mut *tx)
212 .await?;
213
214 if count >= max as i64 {
215 tx.rollback().await?;
216 return Ok(None);
217 }
218 }
219
220 // Upsert: if same machine_id re-activates, reactivate it
221 let activation = sqlx::query_as::<_, DbLicenseActivation>(
222 r#"
223 INSERT INTO license_activations (license_key_id, machine_id, label)
224 VALUES ($1, $2, $3)
225 ON CONFLICT (license_key_id, machine_id)
226 DO UPDATE SET is_active = true, last_validated_at = NOW(),
227 label = COALESCE(EXCLUDED.label, license_activations.label)
228 RETURNING *
229 "#,
230 )
231 .bind(license_key_id)
232 .bind(machine_id)
233 .bind(label)
234 .fetch_one(&mut *tx)
235 .await?;
236
237 // Recount active activations to keep denormalized count accurate
238 sqlx::query(
239 r#"
240 UPDATE license_keys
241 SET activation_count = (
242 SELECT COUNT(*) FROM license_activations
243 WHERE license_key_id = $1 AND is_active = true
244 )
245 WHERE id = $1
246 "#,
247 )
248 .bind(license_key_id)
249 .execute(&mut *tx)
250 .await?;
251
252 tx.commit().await?;
253 Ok(Some(activation))
254 }
255
256 /// Deactivate a machine and update the key's activation_count.
257 ///
258 /// Only recounts if a row was actually deactivated (`rows_affected > 0`),
259 /// avoiding a wasted query when the machine wasn't active. Uses the same
260 /// full-recount strategy as [`try_create_activation`] for consistency.
261 #[tracing::instrument(skip_all)]
262 pub async fn deactivate_machine(
263 pool: &PgPool,
264 license_key_id: LicenseKeyId,
265 machine_id: &str,
266 ) -> Result<bool> {
267 let mut tx = pool.begin().await?;
268
269 let result = sqlx::query(
270 r#"
271 UPDATE license_activations
272 SET is_active = false
273 WHERE license_key_id = $1 AND machine_id = $2 AND is_active = true
274 "#,
275 )
276 .bind(license_key_id)
277 .bind(machine_id)
278 .execute(&mut *tx)
279 .await?;
280
281 if result.rows_affected() > 0 {
282 // Recount active activations
283 sqlx::query(
284 r#"
285 UPDATE license_keys
286 SET activation_count = (
287 SELECT COUNT(*) FROM license_activations
288 WHERE license_key_id = $1 AND is_active = true
289 )
290 WHERE id = $1
291 "#,
292 )
293 .bind(license_key_id)
294 .execute(&mut *tx)
295 .await?;
296
297 tx.commit().await?;
298 Ok(true)
299 } else {
300 tx.commit().await?;
301 Ok(false)
302 }
303 }
304
305 /// Revoke a license key and deactivate all its activations.
306 ///
307 /// Wrapped in a transaction so the key revocation and activation
308 /// deactivation are atomic; a crash between the two statements
309 /// cannot leave the key revoked with activations still active.
310 #[tracing::instrument(skip_all)]
311 pub async fn revoke_license_key(pool: &PgPool, key_id: LicenseKeyId) -> Result<()> {
312 let mut tx = pool.begin().await?;
313
314 sqlx::query(
315 r#"
316 UPDATE license_keys
317 SET revoked_at = NOW()
318 WHERE id = $1
319 "#,
320 )
321 .bind(key_id)
322 .execute(&mut *tx)
323 .await?;
324
325 sqlx::query(
326 "UPDATE license_activations SET is_active = false WHERE license_key_id = $1",
327 )
328 .bind(key_id)
329 .execute(&mut *tx)
330 .await?;
331
332 tx.commit().await?;
333 Ok(())
334 }
335
336 /// Revoke all license keys for a given transaction and deactivate all activations.
337 /// Called from the Stripe `charge.refunded` webhook handler.
338 ///
339 /// Two-step approach: bulk-revoke keys, then bulk-deactivate activations.
340 /// Separate queries because `license_activations` is keyed by `license_key_id`,
341 /// not `transaction_id`.
342 #[tracing::instrument(skip_all)]
343 pub async fn revoke_keys_by_transaction(
344 conn: &mut sqlx::PgConnection,
345 transaction_id: TransactionId,
346 ) -> Result<u64> {
347 // Get all key IDs for this transaction
348 let key_ids: Vec<LicenseKeyId> = sqlx::query_scalar(
349 "SELECT id FROM license_keys WHERE transaction_id = $1 AND revoked_at IS NULL",
350 )
351 .bind(transaction_id)
352 .fetch_all(&mut *conn)
353 .await?;
354
355 if key_ids.is_empty() {
356 return Ok(0);
357 }
358
359 // Revoke the keys
360 let result = sqlx::query(
361 r#"
362 UPDATE license_keys
363 SET revoked_at = NOW()
364 WHERE transaction_id = $1 AND revoked_at IS NULL
365 "#,
366 )
367 .bind(transaction_id)
368 .execute(&mut *conn)
369 .await?;
370
371 // Deactivate all activations for those keys in a single query
372 if !key_ids.is_empty() {
373 sqlx::query(
374 "UPDATE license_activations SET is_active = false WHERE license_key_id = ANY($1)",
375 )
376 .bind(&key_ids)
377 .execute(&mut *conn)
378 .await?;
379 }
380
381 Ok(result.rows_affected())
382 }
383
384