Skip to main content

max / makenotwork

20.0 KB · 624 lines History Blame Raw
1 //! Subscription queries: tier CRUD, subscription lifecycle, and access control.
2
3 use sqlx::PgPool;
4
5 use super::models::*;
6 use super::{PriceCents, ProjectId, SubscriptionId, SubscriptionTierId, UserId};
7 use crate::error::Result;
8
9 // ── Tier CRUD ──
10
11 /// Create a new subscription tier for a project.
12 #[tracing::instrument(skip_all)]
13 pub async fn create_subscription_tier(
14 pool: &PgPool,
15 project_id: ProjectId,
16 name: &str,
17 description: Option<&str>,
18 price_cents: PriceCents,
19 ) -> Result<DbSubscriptionTier> {
20 let tier = sqlx::query_as::<_, DbSubscriptionTier>(
21 r#"
22 INSERT INTO subscription_tiers (project_id, name, description, price_cents)
23 VALUES ($1, $2, $3, $4)
24 RETURNING *
25 "#,
26 )
27 .bind(project_id)
28 .bind(name)
29 .bind(description)
30 .bind(price_cents.as_i32())
31 .fetch_one(pool)
32 .await?;
33
34 Ok(tier)
35 }
36
37 /// Get a subscription tier by ID.
38 #[tracing::instrument(skip_all)]
39 pub async fn get_subscription_tier_by_id(
40 pool: &PgPool,
41 id: SubscriptionTierId,
42 ) -> Result<Option<DbSubscriptionTier>> {
43 let tier = sqlx::query_as::<_, DbSubscriptionTier>(
44 "SELECT * FROM subscription_tiers WHERE id = $1",
45 )
46 .bind(id)
47 .fetch_optional(pool)
48 .await?;
49
50 Ok(tier)
51 }
52
53 /// Get all active tiers for a project, ordered by sort_order.
54 #[tracing::instrument(skip_all)]
55 pub async fn get_active_tiers_by_project(
56 pool: &PgPool,
57 project_id: ProjectId,
58 ) -> Result<Vec<DbSubscriptionTier>> {
59 let tiers = sqlx::query_as::<_, DbSubscriptionTier>(
60 "SELECT * FROM subscription_tiers WHERE project_id = $1 AND is_active = true ORDER BY sort_order, created_at",
61 )
62 .bind(project_id)
63 .fetch_all(pool)
64 .await?;
65
66 Ok(tiers)
67 }
68
69 /// Get all tiers for a project (active and inactive), for dashboard management.
70 #[tracing::instrument(skip_all)]
71 pub async fn get_all_tiers_by_project(
72 pool: &PgPool,
73 project_id: ProjectId,
74 ) -> Result<Vec<DbSubscriptionTier>> {
75 let tiers = sqlx::query_as::<_, DbSubscriptionTier>(
76 "SELECT * FROM subscription_tiers WHERE project_id = $1 ORDER BY sort_order, created_at",
77 )
78 .bind(project_id)
79 .fetch_all(pool)
80 .await?;
81
82 Ok(tiers)
83 }
84
85 /// Update a subscription tier's name, description, and active status.
86 #[tracing::instrument(skip_all)]
87 pub async fn update_subscription_tier(
88 pool: &PgPool,
89 id: SubscriptionTierId,
90 name: &str,
91 description: Option<&str>,
92 is_active: bool,
93 ) -> Result<DbSubscriptionTier> {
94 let tier = sqlx::query_as::<_, DbSubscriptionTier>(
95 r#"
96 UPDATE subscription_tiers
97 SET name = $2, description = $3, is_active = $4
98 WHERE id = $1
99 RETURNING *
100 "#,
101 )
102 .bind(id)
103 .bind(name)
104 .bind(description)
105 .bind(is_active)
106 .fetch_one(pool)
107 .await?;
108
109 Ok(tier)
110 }
111
112 /// Store Stripe product and price IDs on a tier after creating them on connected account.
113 #[tracing::instrument(skip_all)]
114 pub async fn update_tier_stripe_ids(
115 pool: &PgPool,
116 tier_id: SubscriptionTierId,
117 product_id: &str,
118 price_id: &str,
119 ) -> Result<()> {
120 sqlx::query(
121 r#"
122 UPDATE subscription_tiers
123 SET stripe_product_id = $2, stripe_price_id = $3
124 WHERE id = $1
125 "#,
126 )
127 .bind(tier_id)
128 .bind(product_id)
129 .bind(price_id)
130 .execute(pool)
131 .await?;
132
133 Ok(())
134 }
135
136 /// Delete a subscription tier. Soft-deletes (sets is_active=false) if any
137 /// subscriptions reference it; hard-deletes otherwise.
138 ///
139 /// Uses a transaction with FOR UPDATE to prevent a TOCTOU race where a
140 /// subscription could be created between the existence check and the delete.
141 #[tracing::instrument(skip_all)]
142 pub async fn delete_subscription_tier(pool: &PgPool, id: SubscriptionTierId) -> Result<()> {
143 let mut tx = pool.begin().await?;
144
145 // Lock the tier row to serialize against concurrent subscription creation
146 sqlx::query("SELECT id FROM subscription_tiers WHERE id = $1 FOR UPDATE")
147 .bind(id)
148 .fetch_optional(&mut *tx)
149 .await?
150 .ok_or(sqlx::Error::RowNotFound)?;
151
152 let has_subscriptions: bool = sqlx::query_scalar(
153 "SELECT EXISTS(SELECT 1 FROM subscriptions WHERE tier_id = $1)",
154 )
155 .bind(id)
156 .fetch_one(&mut *tx)
157 .await?;
158
159 if has_subscriptions {
160 sqlx::query("UPDATE subscription_tiers SET is_active = false WHERE id = $1")
161 .bind(id)
162 .execute(&mut *tx)
163 .await?;
164 } else {
165 sqlx::query("DELETE FROM subscription_tiers WHERE id = $1")
166 .bind(id)
167 .execute(&mut *tx)
168 .await?;
169 }
170
171 tx.commit().await?;
172 Ok(())
173 }
174
175 // ── Subscription lifecycle ──
176
177 /// Create a new subscription record after successful checkout.
178 ///
179 /// Returns `None` if the subscription already exists (duplicate webhook
180 /// or concurrent active subscription for the same user+project).
181 /// The partial UNIQUE index on `(subscriber_id, project_id) WHERE status = 'active'`
182 /// prevents multiple active subscriptions at the DB level.
183 #[tracing::instrument(skip_all)]
184 pub async fn create_subscription<'e>(
185 executor: impl sqlx::PgExecutor<'e>,
186 subscriber_id: UserId,
187 tier_id: SubscriptionTierId,
188 project_id: ProjectId,
189 stripe_subscription_id: &str,
190 stripe_customer_id: &str,
191 ) -> Result<Option<DbSubscription>> {
192 let sub = sqlx::query_as::<_, DbSubscription>(
193 r#"
194 INSERT INTO subscriptions (subscriber_id, tier_id, project_id, stripe_subscription_id, stripe_customer_id)
195 VALUES ($1, $2, $3, $4, $5)
196 ON CONFLICT DO NOTHING
197 RETURNING *
198 "#,
199 )
200 .bind(subscriber_id)
201 .bind(tier_id)
202 .bind(project_id)
203 .bind(stripe_subscription_id)
204 .bind(stripe_customer_id)
205 .fetch_optional(executor)
206 .await?;
207
208 Ok(sub)
209 }
210
211 /// Look up a subscription by its Stripe subscription ID.
212 #[tracing::instrument(skip_all)]
213 pub async fn get_subscription_by_stripe_id(
214 pool: &PgPool,
215 stripe_sub_id: &str,
216 ) -> Result<Option<DbSubscription>> {
217 let sub = sqlx::query_as::<_, DbSubscription>(
218 "SELECT * FROM subscriptions WHERE stripe_subscription_id = $1",
219 )
220 .bind(stripe_sub_id)
221 .fetch_optional(pool)
222 .await?;
223
224 Ok(sub)
225 }
226
227 // Apply a Stripe-driven status and/or period update in one guarded statement.
228 // `canceled` is terminal — an out-of-order `updated`(active) or `invoice.paid`
229 // landing after a `deleted` cannot revive the row (status) nor refresh its
230 // period, because both columns are written here under the single guard. The
231 // old split `update_subscription_status` + `update_subscription_period` (whose
232 // period half lacked the guard) are replaced by this; reactivation only ever
233 // happens at checkout via `create_subscription`, never through this path.
234 crate::db::subscription_writer::define_stripe_subscription_writer!(
235 apply_stripe_update,
236 "subscriptions",
237 DbSubscription
238 );
239
240 /// Mark a subscription as canceled.
241 #[tracing::instrument(skip_all)]
242 pub async fn cancel_subscription(
243 pool: &PgPool,
244 stripe_sub_id: &str,
245 ) -> Result<Option<DbSubscription>> {
246 let sub = sqlx::query_as::<_, DbSubscription>(
247 r#"
248 UPDATE subscriptions
249 SET status = 'canceled', canceled_at = COALESCE(canceled_at, NOW())
250 WHERE stripe_subscription_id = $1
251 RETURNING *
252 "#,
253 )
254 .bind(stripe_sub_id)
255 .fetch_optional(pool)
256 .await?;
257
258 Ok(sub)
259 }
260
261 // ── Suspension pause/resume ──
262
263 /// Get all active subscriptions to a creator's projects (for pausing on suspension).
264 #[tracing::instrument(skip_all)]
265 pub async fn get_active_subscriptions_by_creator(
266 pool: &PgPool,
267 creator_id: UserId,
268 ) -> Result<Vec<DbSubscription>> {
269 let subs = sqlx::query_as::<_, DbSubscription>(
270 r#"
271 SELECT s.* FROM subscriptions s
272 WHERE s.project_id IN (SELECT id FROM projects WHERE user_id = $1)
273 AND s.status = 'active'
274 AND s.paused_at IS NULL
275 "#,
276 )
277 .bind(creator_id)
278 .fetch_all(pool)
279 .await?;
280
281 Ok(subs)
282 }
283
284 /// Mark all active subscriptions to a creator's projects as paused.
285 #[tracing::instrument(skip_all)]
286 pub async fn pause_subscriptions_for_creator(
287 pool: &PgPool,
288 creator_id: UserId,
289 ) -> Result<u64> {
290 let result = sqlx::query(
291 r#"
292 UPDATE subscriptions SET paused_at = NOW()
293 WHERE project_id IN (SELECT id FROM projects WHERE user_id = $1)
294 AND status = 'active'
295 AND paused_at IS NULL
296 "#,
297 )
298 .bind(creator_id)
299 .execute(pool)
300 .await?;
301
302 Ok(result.rows_affected())
303 }
304
305 /// Get all paused subscriptions to a creator's projects (for cancelling on termination).
306 #[tracing::instrument(skip_all)]
307 pub async fn get_paused_subscriptions_by_creator(
308 pool: &PgPool,
309 creator_id: UserId,
310 ) -> Result<Vec<DbSubscription>> {
311 let subs = sqlx::query_as::<_, DbSubscription>(
312 r#"
313 SELECT s.* FROM subscriptions s
314 WHERE s.project_id IN (SELECT id FROM projects WHERE user_id = $1)
315 AND s.status = 'active'
316 AND s.paused_at IS NOT NULL
317 "#,
318 )
319 .bind(creator_id)
320 .fetch_all(pool)
321 .await?;
322
323 Ok(subs)
324 }
325
326 /// Resume all paused subscriptions for a creator's projects.
327 #[tracing::instrument(skip_all)]
328 pub async fn resume_subscriptions_for_creator(
329 pool: &PgPool,
330 creator_id: UserId,
331 ) -> Result<Vec<DbSubscription>> {
332 let subs = sqlx::query_as::<_, DbSubscription>(
333 r#"
334 UPDATE subscriptions SET paused_at = NULL
335 WHERE project_id IN (SELECT id FROM projects WHERE user_id = $1)
336 AND status = 'active'
337 AND paused_at IS NOT NULL
338 RETURNING *
339 "#,
340 )
341 .bind(creator_id)
342 .fetch_all(pool)
343 .await?;
344
345 Ok(subs)
346 }
347
348 // ── Access control ──
349
350 /// SQL predicate identifying a `subscriptions` row that currently grants access
351 /// to its scope. The `current_period_end` clause is defense-in-depth against a
352 /// missed/delayed `customer.subscription.deleted` webhook — `status = 'active'`
353 /// alone trusts Stripe to push the cancellation promptly.
354 ///
355 /// What a subscription access check is scoped to. Both arms run the SAME sealed
356 /// predicate inside [`gate`], so a project gate and an item gate cannot diverge.
357 #[derive(Debug, Clone, Copy)]
358 pub enum SubscriptionScope {
359 Project(ProjectId),
360 Item(super::ItemId),
361 }
362
363 pub use gate::SubscriptionGate;
364
365 /// Sealed home of the "does a subscription grant access right now" predicate.
366 ///
367 /// The predicate text lives in exactly ONE place — [`SubscriptionGate`]'s
368 /// private `PREDICATE` associated const — and is unreachable from the rest of
369 /// this module, let alone other modules. The only way to learn "this user has
370 /// access" is [`SubscriptionGate::check`] (or [`SubscriptionGate::accessible_item_ids`]
371 /// for the batch shape), each of which runs that predicate. A `SubscriptionGate`
372 /// value is a witness: its field is private and there is no public constructor,
373 /// so access-granting code can neither fabricate one nor hand-write a divergent
374 /// gate.
375 ///
376 /// Payments S1 / CHRONIC 2: the predicate used to be a shareable `&str` const,
377 /// and item gates drifted by dropping the `current_period_end` clause. A const
378 /// is copy-pasteable; a private associated const inside a sealed submodule is
379 /// not. Sealing it here makes the divergence unwritable, not merely discouraged.
380 mod gate {
381 use super::SubscriptionScope;
382 use crate::db::{ItemId, UserId};
383 use crate::error::Result;
384 use sqlx::PgPool;
385 use std::collections::HashMap;
386
387 /// Proof that a subscription currently grants access. Constructible ONLY via
388 /// [`SubscriptionGate::check`] — the private `()` field seals the type so no
389 /// other code can mint one.
390 #[derive(Debug, Clone, Copy)]
391 pub struct SubscriptionGate(());
392
393 impl SubscriptionGate {
394 /// The single source of truth for "grants access right now". Private to
395 /// this submodule: nothing outside can read it as a string, so it cannot
396 /// be copy-pasted into a divergent query. (Compile-time constant, never
397 /// user input, so the `format!` interpolation is injection-safe; the
398 /// `$N` placeholders stay bound.)
399 const PREDICATE: &'static str = "status = 'active' AND paused_at IS NULL \
400 AND (current_period_end IS NULL OR current_period_end > NOW())";
401
402 /// Does `user_id` hold a subscription that currently grants access to
403 /// `scope`? Returns `Some(gate)` iff so — the sole gate constructor and
404 /// the single entry point for project- and item-level access checks.
405 #[tracing::instrument(skip_all)]
406 pub async fn check(
407 pool: &PgPool,
408 user_id: UserId,
409 scope: SubscriptionScope,
410 ) -> Result<Option<SubscriptionGate>> {
411 let exists: bool = match scope {
412 SubscriptionScope::Project(project_id) => {
413 sqlx::query_scalar(&format!(
414 "SELECT EXISTS(SELECT 1 FROM subscriptions \
415 WHERE subscriber_id = $1 AND project_id = $2 AND {})",
416 Self::PREDICATE
417 ))
418 .bind(user_id)
419 .bind(project_id)
420 .fetch_one(pool)
421 .await?
422 }
423 SubscriptionScope::Item(item_id) => {
424 sqlx::query_scalar(&format!(
425 "SELECT EXISTS(SELECT 1 FROM subscriptions \
426 WHERE subscriber_id = $1 AND item_id = $2 AND {})",
427 Self::PREDICATE
428 ))
429 .bind(user_id)
430 .bind(item_id)
431 .fetch_one(pool)
432 .await?
433 }
434 };
435
436 Ok(exists.then_some(SubscriptionGate(())))
437 }
438
439 /// Every item ID `user_id` currently has access to via subscription
440 /// (batch gate). Runs the same sealed predicate as [`check`], so the
441 /// batch path cannot drift from the single-item gate.
442 #[tracing::instrument(skip_all)]
443 pub async fn accessible_item_ids(pool: &PgPool, user_id: UserId) -> Result<Vec<ItemId>> {
444 let item_ids: Vec<ItemId> = sqlx::query_scalar(&format!(
445 "SELECT DISTINCT item_id FROM subscriptions \
446 WHERE subscriber_id = $1 AND item_id IS NOT NULL AND {}",
447 Self::PREDICATE
448 ))
449 .bind(user_id)
450 .fetch_all(pool)
451 .await?;
452
453 Ok(item_ids)
454 }
455
456 /// Map of every item `user_id` currently has subscription access to →
457 /// its access proof. The witness-bearing batch shape used by the project
458 /// page, where each item's [`AccessContext`](crate::pricing::AccessContext)
459 /// needs its own gate. Runs the sealed predicate once.
460 #[tracing::instrument(skip_all)]
461 pub async fn subscribed_item_gates(
462 pool: &PgPool,
463 user_id: UserId,
464 ) -> Result<HashMap<ItemId, SubscriptionGate>> {
465 let ids = Self::accessible_item_ids(pool, user_id).await?;
466 Ok(ids.into_iter().map(|id| (id, SubscriptionGate(()))).collect())
467 }
468
469 /// Test-only constructor. Real gates can only be minted by running the
470 /// predicate against the DB; unit tests (e.g. `pricing`) need to
471 /// fabricate the "access granted" state without a database. Gated to
472 /// test builds so production code still cannot forge a witness.
473 #[cfg(test)]
474 pub(crate) fn test_witness() -> Self {
475 SubscriptionGate(())
476 }
477 }
478 }
479
480 /// Does `user_id` hold a subscription that currently grants access to `scope`?
481 ///
482 /// Thin boolean wrapper over the sealed [`SubscriptionGate::check`]; prefer
483 /// taking the [`SubscriptionGate`] witness directly where a proof of access is
484 /// useful downstream.
485 #[tracing::instrument(skip_all)]
486 pub async fn has_access(
487 pool: &PgPool,
488 user_id: UserId,
489 scope: SubscriptionScope,
490 ) -> Result<bool> {
491 Ok(SubscriptionGate::check(pool, user_id, scope).await?.is_some())
492 }
493
494 /// Get user subscriptions joined with project and tier data (for library display).
495 #[tracing::instrument(skip_all)]
496 pub async fn get_user_subscriptions_with_details(
497 pool: &PgPool,
498 user_id: UserId,
499 ) -> Result<Vec<DbUserSubscriptionRow>> {
500 let rows = sqlx::query_as::<_, DbUserSubscriptionRow>(
501 "SELECT s.id, s.project_id, p.title AS project_title, p.slug AS project_slug,
502 t.name AS tier_name, t.price_cents, s.status,
503 s.current_period_end, s.stripe_subscription_id
504 FROM subscriptions s
505 JOIN projects p ON p.id = s.project_id
506 JOIN subscription_tiers t ON t.id = s.tier_id
507 WHERE s.subscriber_id = $1
508 ORDER BY s.created_at DESC
509 LIMIT 1000",
510 )
511 .bind(user_id)
512 .fetch_all(pool)
513 .await?;
514
515 Ok(rows)
516 }
517
518 /// Get the number of active subscribers to a project (for dashboard display).
519 ///
520 /// NOT an access gate — this is a creator-facing headcount, so it deliberately
521 /// counts `status = 'active'` rows regardless of `current_period_end` (a sub in
522 /// its grace window is still a subscriber). Do not "align" it with
523 /// [`GRANTS_ACCESS_PREDICATE`]; the divergence here is intentional.
524 #[tracing::instrument(skip_all)]
525 pub async fn get_project_subscriber_count(
526 pool: &PgPool,
527 project_id: ProjectId,
528 ) -> Result<i64> {
529 let count: i64 = sqlx::query_scalar(
530 "SELECT COUNT(*) FROM subscriptions WHERE project_id = $1 AND status = 'active' AND paused_at IS NULL",
531 )
532 .bind(project_id)
533 .fetch_one(pool)
534 .await?;
535
536 Ok(count)
537 }
538
539 // ── Export ──
540
541 /// Export all subscribers across a creator's projects.
542 ///
543 /// Returns username, display_name, tier name, subscription status, and when.
544 #[tracing::instrument(skip_all)]
545 pub async fn get_project_subscribers_for_export(
546 pool: &PgPool,
547 user_id: UserId,
548 ) -> Result<Vec<SubscriberExportRow>> {
549 let rows = sqlx::query_as::<_, SubscriberExportRow>(
550 r#"
551 SELECT u.username, u.display_name, t.name AS tier_name, s.status, s.created_at
552 FROM subscriptions s
553 JOIN users u ON u.id = s.subscriber_id
554 JOIN subscription_tiers t ON t.id = s.tier_id
555 WHERE s.project_id IN (SELECT id FROM projects WHERE user_id = $1)
556 ORDER BY s.created_at DESC
557 "#,
558 )
559 .bind(user_id)
560 .fetch_all(pool)
561 .await?;
562
563 Ok(rows)
564 }
565
566 /// Export all subscriptions across a creator's projects with full detail.
567 ///
568 /// Returns project name, tier name, price, subscriber username, status,
569 /// billing period dates, and cancellation date.
570 #[tracing::instrument(skip_all)]
571 pub async fn get_subscriptions_for_export(
572 pool: &PgPool,
573 user_id: UserId,
574 ) -> Result<Vec<SubscriptionExportRow>> {
575 let rows = sqlx::query_as::<_, SubscriptionExportRow>(
576 r#"
577 SELECT p.title AS project_title, t.name AS tier_name, t.price_cents,
578 u.username, s.status,
579 s.current_period_start, s.current_period_end,
580 s.canceled_at, s.created_at
581 FROM subscriptions s
582 JOIN users u ON u.id = s.subscriber_id
583 JOIN subscription_tiers t ON t.id = s.tier_id
584 JOIN projects p ON p.id = s.project_id
585 WHERE s.project_id IN (SELECT id FROM projects WHERE user_id = $1)
586 ORDER BY s.created_at DESC
587 "#,
588 )
589 .bind(user_id)
590 .fetch_all(pool)
591 .await?;
592
593 Ok(rows)
594 }
595
596 // ── Event log ──
597
598 /// Log a subscription webhook event for debugging and idempotency.
599 /// The UNIQUE index on stripe_event_id makes duplicate events a no-op.
600 #[tracing::instrument(skip_all)]
601 pub async fn log_subscription_event(
602 pool: &PgPool,
603 subscription_id: Option<SubscriptionId>,
604 stripe_event_id: &str,
605 event_type: &str,
606 payload: &serde_json::Value,
607 ) -> Result<()> {
608 sqlx::query(
609 r#"
610 INSERT INTO subscription_events (subscription_id, stripe_event_id, event_type, payload)
611 VALUES ($1, $2, $3, $4)
612 ON CONFLICT (stripe_event_id) DO NOTHING
613 "#,
614 )
615 .bind(subscription_id)
616 .bind(stripe_event_id)
617 .bind(event_type)
618 .bind(payload)
619 .execute(pool)
620 .await?;
621
622 Ok(())
623 }
624