Skip to main content

max / makenotwork

29.1 KB · 847 lines History Blame Raw
1 //! Unified promo code management: creation, validation, usage tracking, and deletion.
2 //!
3 //! Replaces the old `discount_codes` and `download_codes` modules. Supports three
4 //! code purposes: discount, free_access, and free_trial.
5
6 use sqlx::PgPool;
7
8 use super::enums::DiscountType;
9 use super::models::*;
10 use super::{CodePurpose, ItemId, ProjectId, PromoCodeId, SubscriptionTierId, UserId};
11 use crate::error::{AppError, Result};
12
13 /// Create a new promo code for a creator.
14 #[allow(clippy::too_many_arguments)]
15 #[tracing::instrument(skip_all)]
16 pub async fn create_promo_code(
17 pool: &PgPool,
18 creator_id: UserId,
19 code: &str,
20 code_purpose: super::CodePurpose,
21 discount_type: Option<DiscountType>,
22 discount_value: Option<i32>,
23 min_price_cents: i32,
24 trial_days: Option<i32>,
25 max_uses: Option<i32>,
26 expires_at: Option<chrono::DateTime<chrono::Utc>>,
27 starts_at: Option<chrono::DateTime<chrono::Utc>>,
28 item_id: Option<ItemId>,
29 project_id: Option<ProjectId>,
30 tier_id: Option<SubscriptionTierId>,
31 ) -> Result<DbPromoCode> {
32 let promo_code = sqlx::query_as::<_, DbPromoCode>(
33 r#"
34 INSERT INTO promo_codes (creator_id, code, code_purpose, discount_type, discount_value,
35 min_price_cents, trial_days, max_uses, expires_at, starts_at, item_id, project_id, tier_id)
36 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
37 RETURNING *
38 "#,
39 )
40 .bind(creator_id)
41 .bind(code)
42 .bind(code_purpose)
43 .bind(discount_type)
44 .bind(discount_value)
45 .bind(min_price_cents)
46 .bind(trial_days)
47 .bind(max_uses)
48 .bind(expires_at)
49 .bind(starts_at)
50 .bind(item_id)
51 .bind(project_id)
52 .bind(tier_id)
53 .fetch_one(pool)
54 .await?;
55
56 Ok(promo_code)
57 }
58
59 /// Fetch a promo code by primary key.
60 #[tracing::instrument(skip_all)]
61 pub async fn get_promo_code_by_id(pool: &PgPool, id: PromoCodeId) -> Result<Option<DbPromoCode>> {
62 let code = sqlx::query_as::<_, DbPromoCode>(
63 "SELECT * FROM promo_codes WHERE id = $1",
64 )
65 .bind(id)
66 .fetch_optional(pool)
67 .await?;
68
69 Ok(code)
70 }
71
72 /// Look up a promo code by creator ID and code string (case-insensitive).
73 /// Used at checkout to validate discount codes.
74 #[tracing::instrument(skip_all)]
75 pub async fn get_promo_code_by_creator_and_code(
76 pool: &PgPool,
77 creator_id: UserId,
78 code: &str,
79 ) -> Result<Option<DbPromoCode>> {
80 let promo_code = sqlx::query_as::<_, DbPromoCode>(
81 "SELECT * FROM promo_codes WHERE creator_id = $1 AND upper(code) = upper($2)",
82 )
83 .bind(creator_id)
84 .bind(code)
85 .fetch_optional(pool)
86 .await?;
87
88 Ok(promo_code)
89 }
90
91 /// Look up a free_access promo code by code string (case-insensitive, cross-creator).
92 /// Used for free_access code claims where the buyer doesn't know the creator.
93 /// Scoped to free_access purpose to prevent cross-creator collision with discount codes.
94 #[tracing::instrument(skip_all)]
95 pub async fn get_promo_code_by_code(
96 pool: &PgPool,
97 code: &str,
98 ) -> Result<Option<DbPromoCode>> {
99 let promo_code = sqlx::query_as::<_, DbPromoCode>(
100 "SELECT * FROM promo_codes WHERE upper(code) = upper($1) AND code_purpose = 'free_access'",
101 )
102 .bind(code)
103 .fetch_optional(pool)
104 .await?;
105
106 Ok(promo_code)
107 }
108
109 /// SQL fragment for promo code listing queries: selects all promo_codes columns
110 /// plus LEFT JOINed item and project titles.
111 const PROMO_CODE_WITH_NAMES_SELECT: &str = r#"
112 SELECT pc.*, i.title AS item_title, p.title AS project_title
113 FROM promo_codes pc
114 LEFT JOIN items i ON pc.item_id = i.id
115 LEFT JOIN projects p ON pc.project_id = p.id
116 "#;
117
118 /// List all promo codes for a creator, newest first. Capped at 500.
119 #[tracing::instrument(skip_all)]
120 pub async fn get_promo_codes_by_creator(pool: &PgPool, creator_id: UserId) -> Result<Vec<DbPromoCodeWithNames>> {
121 let query = format!("{PROMO_CODE_WITH_NAMES_SELECT} WHERE pc.creator_id = $1 ORDER BY pc.created_at DESC LIMIT 500");
122 let codes = sqlx::query_as::<_, DbPromoCodeWithNames>(&query)
123 .bind(creator_id)
124 .fetch_all(pool)
125 .await?;
126
127 Ok(codes)
128 }
129
130 /// List all promo codes scoped to a project, newest first. Capped at 500.
131 #[tracing::instrument(skip_all)]
132 pub async fn get_promo_codes_by_project(pool: &PgPool, project_id: ProjectId) -> Result<Vec<DbPromoCodeWithNames>> {
133 let query = format!("{PROMO_CODE_WITH_NAMES_SELECT} WHERE pc.project_id = $1 ORDER BY pc.created_at DESC LIMIT 500");
134 let codes = sqlx::query_as::<_, DbPromoCodeWithNames>(&query)
135 .bind(project_id)
136 .fetch_all(pool)
137 .await?;
138
139 Ok(codes)
140 }
141
142 /// List all promo codes scoped to an item, newest first. Capped at 500.
143 #[tracing::instrument(skip_all)]
144 pub async fn get_promo_codes_by_item(pool: &PgPool, item_id: ItemId) -> Result<Vec<DbPromoCodeWithNames>> {
145 let query = format!("{PROMO_CODE_WITH_NAMES_SELECT} WHERE pc.item_id = $1 ORDER BY pc.created_at DESC LIMIT 500");
146 let codes = sqlx::query_as::<_, DbPromoCodeWithNames>(&query)
147 .bind(item_id)
148 .fetch_all(pool)
149 .await?;
150
151 Ok(codes)
152 }
153
154 /// Batch-load item-scoped promo codes for multiple items, grouped by item_id.
155 #[tracing::instrument(skip_all)]
156 pub async fn get_promo_codes_by_items(
157 pool: &PgPool,
158 item_ids: &[ItemId],
159 ) -> Result<std::collections::HashMap<ItemId, Vec<DbPromoCodeWithNames>>> {
160 let query = format!("{PROMO_CODE_WITH_NAMES_SELECT} WHERE pc.item_id = ANY($1) ORDER BY pc.item_id, pc.created_at DESC");
161 let codes = sqlx::query_as::<_, DbPromoCodeWithNames>(&query)
162 .bind(item_ids)
163 .fetch_all(pool)
164 .await?;
165
166 let mut map: std::collections::HashMap<ItemId, Vec<DbPromoCodeWithNames>> = std::collections::HashMap::new();
167 for pc in codes {
168 if let Some(item_id) = pc.item_id {
169 map.entry(item_id).or_default().push(pc);
170 }
171 }
172 Ok(map)
173 }
174
175 /// Atomically increment use_count, respecting the max_uses limit.
176 ///
177 /// Returns `true` if the increment succeeded, `false` if the code has already
178 /// reached its usage limit. The `WHERE` clause enforces the limit at the
179 /// database level, preventing TOCTOU races.
180 ///
181 /// Accepts any sqlx executor (`&PgPool`, `&mut Transaction`, etc.) so callers
182 /// can include this in a larger transaction when needed.
183 #[tracing::instrument(skip_all)]
184 pub async fn try_increment_use_count<'e>(
185 executor: impl sqlx::PgExecutor<'e>,
186 id: PromoCodeId,
187 ) -> Result<bool> {
188 let result = sqlx::query(
189 "UPDATE promo_codes SET use_count = use_count + 1 \
190 WHERE id = $1 \
191 AND (max_uses IS NULL OR use_count < max_uses) \
192 AND (expires_at IS NULL OR expires_at > NOW()) \
193 AND (starts_at IS NULL OR starts_at <= NOW())",
194 )
195 .bind(id)
196 .execute(executor)
197 .await?;
198
199 Ok(result.rows_affected() > 0)
200 }
201
202 /// Release a reserved use_count slot (decrement, clamped to 0).
203 ///
204 /// Used in two places that must coordinate so the count doesn't drop twice
205 /// for the same reservation:
206 /// 1. Route handlers, when a Stripe checkout creation or pending-tx
207 /// insert fails AFTER the use_count was reserved. They call
208 /// `release_use_count_and_detach` (below) which also nulls the
209 /// `promo_code_id` on any pending transaction rows for this
210 /// reservation, so `cleanup_stale_pending` can't fire a second
211 /// release for the same buyer's promo hold.
212 /// 2. `cleanup_stale_pending` itself, when it deletes stale pending
213 /// rows past the 24h checkout-session expiry. Those rows still
214 /// carry their `promo_code_id`, so this plain function is the
215 /// right call from there.
216 ///
217 /// `GREATEST(0, ...)` makes a double-release harmless (count clamps at
218 /// zero) but the structural fix above prevents it from happening at all.
219 #[tracing::instrument(skip_all)]
220 pub async fn release_use_count(pool: &PgPool, id: PromoCodeId) -> Result<()> {
221 sqlx::query(
222 "UPDATE promo_codes SET use_count = GREATEST(0, use_count - 1) WHERE id = $1",
223 )
224 .bind(id)
225 .execute(pool)
226 .await?;
227
228 Ok(())
229 }
230
231 /// Release a use_count slot AND detach the same promo_code_id from any
232 /// pending transactions for `buyer_id` so the scheduler's
233 /// `cleanup_stale_pending` doesn't release it a second time when those
234 /// stale rows eventually time out.
235 ///
236 /// Use this from route-level failure paths (Stripe session creation
237 /// failed, pending-tx insert failed mid-cart, etc). The detach is a
238 /// no-op when the failure happened BEFORE any pending row was inserted;
239 /// it's the safety net for when a partial pending row may have landed.
240 #[tracing::instrument(skip_all)]
241 pub async fn release_use_count_and_detach(
242 pool: &PgPool,
243 id: PromoCodeId,
244 buyer_id: UserId,
245 ) -> Result<()> {
246 let mut tx = pool.begin().await?;
247
248 sqlx::query(
249 "UPDATE transactions SET promo_code_id = NULL \
250 WHERE buyer_id = $1 AND promo_code_id = $2 AND status = 'pending'",
251 )
252 .bind(buyer_id)
253 .bind(id)
254 .execute(&mut *tx)
255 .await?;
256
257 sqlx::query(
258 "UPDATE promo_codes SET use_count = GREATEST(0, use_count - 1) WHERE id = $1",
259 )
260 .bind(id)
261 .execute(&mut *tx)
262 .await?;
263
264 tx.commit().await?;
265 Ok(())
266 }
267
268 /// Update editable fields on a promo code (expires_at, starts_at, max_uses).
269 #[tracing::instrument(skip_all)]
270 pub async fn update_promo_code(
271 pool: &PgPool,
272 id: PromoCodeId,
273 expires_at: Option<Option<chrono::DateTime<chrono::Utc>>>,
274 starts_at: Option<Option<chrono::DateTime<chrono::Utc>>>,
275 max_uses: Option<Option<i32>>,
276 ) -> Result<DbPromoCode> {
277 // Build SET clauses for provided fields only
278 let mut sets = Vec::new();
279 let mut param_idx = 2u32; // $1 = id
280
281 if expires_at.is_some() {
282 sets.push(format!("expires_at = ${param_idx}"));
283 param_idx += 1;
284 }
285 if starts_at.is_some() {
286 sets.push(format!("starts_at = ${param_idx}"));
287 param_idx += 1;
288 }
289 if max_uses.is_some() {
290 sets.push(format!("max_uses = ${param_idx}"));
291 // Final SET clause; param_idx is never read after this point, so the
292 // increment is elided to avoid an unused_assignments warning. Restore
293 // it if a new optional field is added below.
294 }
295
296 if sets.is_empty() {
297 // Nothing to update — just return current state
298 return get_promo_code_by_id(pool, id)
299 .await?
300 .ok_or_else(|| crate::error::AppError::NotFound);
301 }
302
303 let sql = format!("UPDATE promo_codes SET {} WHERE id = $1 RETURNING *", sets.join(", "));
304 let mut query = sqlx::query_as::<_, DbPromoCode>(&sql).bind(id);
305
306 if let Some(val) = expires_at {
307 query = query.bind(val);
308 }
309 if let Some(val) = starts_at {
310 query = query.bind(val);
311 }
312 if let Some(val) = max_uses {
313 query = query.bind(val);
314 }
315
316 let code = query.fetch_one(pool).await?;
317 Ok(code)
318 }
319
320 /// Delete all expired promo codes for a creator. Returns number of rows deleted.
321 #[tracing::instrument(skip_all)]
322 pub async fn delete_expired_by_creator(pool: &PgPool, creator_id: UserId) -> Result<u64> {
323 let result = sqlx::query(
324 "DELETE FROM promo_codes WHERE creator_id = $1 AND expires_at IS NOT NULL AND expires_at < NOW()",
325 )
326 .bind(creator_id)
327 .execute(pool)
328 .await?;
329
330 Ok(result.rows_affected())
331 }
332
333 /// Delete a promo code permanently.
334 #[tracing::instrument(skip_all)]
335 pub async fn delete_promo_code(pool: &PgPool, id: PromoCodeId) -> Result<()> {
336 sqlx::query("DELETE FROM promo_codes WHERE id = $1")
337 .bind(id)
338 .execute(pool)
339 .await?;
340
341 Ok(())
342 }
343
344 /// A single row in the "who redeemed this code" view.
345 ///
346 /// `display_name` / `username` are `None` for guest checkouts; `guest_email`
347 /// fills that gap. `item_title` is denormalized on the transaction row so
348 /// renaming an item later doesn't strand the audit trail.
349 #[derive(Debug, sqlx::FromRow, serde::Serialize)]
350 pub struct PromoRedemption {
351 pub redeemed_at: chrono::DateTime<chrono::Utc>,
352 pub display_name: Option<String>,
353 pub username: Option<String>,
354 pub guest_email: Option<String>,
355 pub item_title: Option<String>,
356 pub amount_cents: i32,
357 }
358
359 /// List redemptions of a single promo code, newest first.
360 ///
361 /// Joins through to `users` for buyer identity but falls back to the
362 /// transaction's `guest_email` for unauthenticated checkouts. Capped at 500
363 /// rows — promo codes that exceed that bound are an outlier worth its own
364 /// CSV-export flow rather than a paginated UI.
365 #[tracing::instrument(skip_all)]
366 pub async fn list_redemptions(
367 pool: &PgPool,
368 id: PromoCodeId,
369 ) -> Result<Vec<PromoRedemption>> {
370 let rows = sqlx::query_as::<_, PromoRedemption>(
371 r#"
372 SELECT
373 COALESCE(t.completed_at, t.created_at) AS redeemed_at,
374 u.display_name AS display_name,
375 u.username AS username,
376 t.guest_email AS guest_email,
377 t.item_title AS item_title,
378 t.amount_cents AS amount_cents
379 FROM transactions t
380 LEFT JOIN users u ON u.id = t.buyer_id
381 WHERE t.promo_code_id = $1
382 AND t.status = 'completed'
383 ORDER BY redeemed_at DESC
384 LIMIT 500
385 "#,
386 )
387 .bind(id)
388 .fetch_all(pool)
389 .await?;
390
391 Ok(rows)
392 }
393
394 /// Create a platform-wide promo code (used for Fan+ monthly credits).
395 ///
396 /// Same as `create_promo_code` but sets `is_platform_wide = true`.
397 /// Platform-wide codes are not scoped to a specific creator's items.
398 #[allow(clippy::too_many_arguments)]
399 #[tracing::instrument(skip_all)]
400 pub async fn create_platform_promo_code(
401 pool: &PgPool,
402 creator_id: UserId,
403 code: &str,
404 code_purpose: super::CodePurpose,
405 discount_type: Option<DiscountType>,
406 discount_value: Option<i32>,
407 min_price_cents: i32,
408 trial_days: Option<i32>,
409 max_uses: Option<i32>,
410 expires_at: Option<chrono::DateTime<chrono::Utc>>,
411 ) -> Result<DbPromoCode> {
412 let promo_code = sqlx::query_as::<_, DbPromoCode>(
413 r#"
414 INSERT INTO promo_codes (creator_id, code, code_purpose, discount_type, discount_value,
415 min_price_cents, trial_days, max_uses, expires_at, is_platform_wide)
416 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, true)
417 RETURNING *
418 "#,
419 )
420 .bind(creator_id)
421 .bind(code)
422 .bind(code_purpose)
423 .bind(discount_type)
424 .bind(discount_value)
425 .bind(min_price_cents)
426 .bind(trial_days)
427 .bind(max_uses)
428 .bind(expires_at)
429 .fetch_one(pool)
430 .await?;
431
432 Ok(promo_code)
433 }
434
435 /// Look up a platform-wide promo code by user ID and code string (case-insensitive).
436 ///
437 /// Used at checkout to validate Fan+ credits: the buyer owns the code, and it
438 /// applies to any item on the platform.
439 #[tracing::instrument(skip_all)]
440 pub async fn get_platform_promo_code_by_user_and_code(
441 pool: &PgPool,
442 user_id: UserId,
443 code: &str,
444 ) -> Result<Option<DbPromoCode>> {
445 let promo_code = sqlx::query_as::<_, DbPromoCode>(
446 "SELECT * FROM promo_codes WHERE creator_id = $1 AND upper(code) = upper($2) AND is_platform_wide = true",
447 )
448 .bind(user_id)
449 .bind(code)
450 .fetch_optional(pool)
451 .await?;
452
453 Ok(promo_code)
454 }
455
456 /// Apply a discount to a price, returning the discounted price in cents (minimum 0).
457 /// Negative discount values are clamped to 0 to prevent price increases.
458 #[tracing::instrument(skip_all)]
459 pub fn apply_discount(price_cents: i32, discount_type: DiscountType, discount_value: i32) -> i32 {
460 let discount_value = discount_value.max(0);
461 match discount_type {
462 DiscountType::Percentage => {
463 let discount = (price_cents as i64 * discount_value as i64) / 100;
464 (price_cents as i64 - discount).max(0) as i32
465 }
466 // Subtract in i64 (like the Percentage arm) so a configuration where
467 // `discount_value > i32::MAX - price_cents` can't underflow before the
468 // `.max(0)` clamp catches it. discount_value is i32 so the sub is
469 // bounded; we cast for parity with the Percentage path.
470 DiscountType::Fixed => (price_cents as i64 - discount_value as i64).max(0) as i32,
471 }
472 }
473
474 // ── Shared checkout promo validation ─────────────────────────────────────────
475 //
476 // Every checkout path (single item, guest, cart ×2) needs the same promo logic:
477 // look the code up, run the code-level window/limit checks, then apply it to each
478 // item with scope + minimum-price + discount math. These two functions are that
479 // logic in one place, so a fix (the NULL-discount rejection, the min-price floor)
480 // can't land in three copies and miss the fourth.
481
482 /// A promo code that passed the code-level checks (exists, not a trial, inside
483 /// its active window, under its use limit). Apply it per item with
484 /// [`apply_promo_to_item`]; reserve it with [`try_increment_use_count`].
485 pub struct ValidatedPromo {
486 pub code: DbPromoCode,
487 /// A platform-wide Fan+ credit (valid on any seller's items) rather than a
488 /// seller-scoped code; gates the scope and minimum-price checks.
489 pub is_platform_wide: bool,
490 }
491
492 impl ValidatedPromo {
493 pub fn id(&self) -> PromoCodeId {
494 self.code.id
495 }
496 }
497
498 /// Look up and code-level-validate a checkout promo. Tries the seller's code
499 /// first; when `buyer_id` is `Some`, falls back to that buyer's platform-wide
500 /// Fan+ credit. Returns `Ok(None)` for a blank code, `Err` for an
501 /// unknown/not-yet-active/expired/exhausted/trial code. Per-item scope, minimum
502 /// price, and discount math are done by [`apply_promo_to_item`], not here.
503 #[tracing::instrument(skip_all)]
504 pub async fn lookup_and_validate_promo(
505 pool: &PgPool,
506 seller_id: UserId,
507 buyer_id: Option<UserId>,
508 raw_code: &str,
509 ) -> Result<Option<ValidatedPromo>> {
510 let code_str = raw_code.trim().to_uppercase();
511 if code_str.is_empty() {
512 return Ok(None);
513 }
514
515 let code = match get_promo_code_by_creator_and_code(pool, seller_id, &code_str).await? {
516 Some(pc) => pc,
517 None => match buyer_id {
518 Some(uid) => get_platform_promo_code_by_user_and_code(pool, uid, &code_str)
519 .await?
520 .ok_or_else(|| AppError::BadRequest("Invalid promo code".to_string()))?,
521 None => return Err(AppError::BadRequest("Invalid promo code".to_string())),
522 },
523 };
524
525 if code.code_purpose == CodePurpose::FreeTrial {
526 return Err(AppError::BadRequest(
527 "Trial codes can only be used for subscriptions".to_string(),
528 ));
529 }
530 let now = chrono::Utc::now();
531 if let Some(starts) = code.starts_at
532 && starts > now
533 {
534 return Err(AppError::BadRequest("This promo code is not yet active".to_string()));
535 }
536 if let Some(expires) = code.expires_at
537 && expires < now
538 {
539 return Err(AppError::BadRequest("This promo code has expired".to_string()));
540 }
541 if let Some(max) = code.max_uses
542 && code.use_count >= max
543 {
544 return Err(AppError::BadRequest("This promo code has reached its usage limit".to_string()));
545 }
546
547 let is_platform_wide = code.is_platform_wide;
548 Ok(Some(ValidatedPromo { code, is_platform_wide }))
549 }
550
551 /// Why a validated promo doesn't apply to a particular item (vs a hard error).
552 pub enum PromoIneligible {
553 /// The code is scoped to a different item or project.
554 ScopeMismatch,
555 /// The item's price is below the code's `min_price_cents` floor.
556 BelowMinPrice,
557 }
558
559 /// Result of applying a validated promo to one item.
560 pub enum PromoApplication {
561 /// The item's price after the code (`0` for free-access, discounted otherwise).
562 Apply(i32),
563 /// The code doesn't apply to this item — cart skips it, single-item rejects.
564 Ineligible(PromoIneligible),
565 }
566
567 /// Apply a validated promo to one item's base price. A misconfigured Discount
568 /// code (NULL type/value) is a hard `Err` (never reserve-and-charge-full);
569 /// scope or minimum-price ineligibility is `Ok(Ineligible(_))` so cart checkout
570 /// can skip the item while single-item checkout turns it into an error.
571 pub fn apply_promo_to_item(
572 validated: &ValidatedPromo,
573 item_id: ItemId,
574 project_id: ProjectId,
575 base_price_cents: i32,
576 ) -> Result<PromoApplication> {
577 let code = &validated.code;
578
579 // Scope checks apply to seller codes only; a platform-wide credit is valid
580 // on any item.
581 if !validated.is_platform_wide {
582 if let Some(scoped_item) = code.item_id
583 && scoped_item != item_id
584 {
585 return Ok(PromoApplication::Ineligible(PromoIneligible::ScopeMismatch));
586 }
587 if let Some(scoped_project) = code.project_id
588 && project_id != scoped_project
589 {
590 return Ok(PromoApplication::Ineligible(PromoIneligible::ScopeMismatch));
591 }
592 }
593
594 match code.code_purpose {
595 CodePurpose::FreeAccess => Ok(PromoApplication::Apply(0)),
596 CodePurpose::Discount => {
597 if !validated.is_platform_wide && base_price_cents < code.min_price_cents {
598 return Ok(PromoApplication::Ineligible(PromoIneligible::BelowMinPrice));
599 }
600 let (dt, dv) = match (code.discount_type, code.discount_value) {
601 (Some(dt), Some(dv)) => (dt, dv),
602 _ => {
603 return Err(AppError::BadRequest(
604 "This promo code is misconfigured. Please contact the creator.".to_string(),
605 ));
606 }
607 };
608 Ok(PromoApplication::Apply(apply_discount(base_price_cents, dt, dv)))
609 }
610 // Rejected up front in `lookup_and_validate_promo`.
611 CodePurpose::FreeTrial => Ok(PromoApplication::Apply(base_price_cents)),
612 }
613 }
614
615 #[cfg(test)]
616 mod tests {
617 use super::*;
618
619 #[test]
620 fn percentage_discount_50() {
621 assert_eq!(apply_discount(1000, DiscountType::Percentage, 50), 500);
622 }
623
624 #[test]
625 fn percentage_discount_100() {
626 assert_eq!(apply_discount(1000, DiscountType::Percentage, 100), 0);
627 }
628
629 #[test]
630 fn percentage_discount_10() {
631 // 999 * 10 / 100 = 99 (integer), 999 - 99 = 900
632 assert_eq!(apply_discount(999, DiscountType::Percentage, 10), 900);
633 }
634
635 #[test]
636 fn fixed_discount() {
637 assert_eq!(apply_discount(1000, DiscountType::Fixed, 300), 700);
638 }
639
640 #[test]
641 fn fixed_discount_exceeds_price() {
642 assert_eq!(apply_discount(100, DiscountType::Fixed, 500), 0);
643 }
644
645 // -- Percentage discount edge cases --
646
647 #[test]
648 fn percentage_discount_0() {
649 assert_eq!(apply_discount(1000, DiscountType::Percentage, 0), 1000);
650 }
651
652 #[test]
653 fn percentage_discount_over_100() {
654 // 150% discount should clamp to 0
655 assert_eq!(apply_discount(1000, DiscountType::Percentage, 150), 0);
656 }
657
658 #[test]
659 fn percentage_discount_1_percent() {
660 // 1000 * 1 / 100 = 10, result = 990
661 assert_eq!(apply_discount(1000, DiscountType::Percentage, 1), 990);
662 }
663
664 #[test]
665 fn percentage_discount_99_percent() {
666 // 1000 * 99 / 100 = 990, result = 10
667 assert_eq!(apply_discount(1000, DiscountType::Percentage, 99), 10);
668 }
669
670 #[test]
671 fn percentage_discount_rounding() {
672 // 1 cent * 50 / 100 = 0 (integer division), result = 1
673 assert_eq!(apply_discount(1, DiscountType::Percentage, 50), 1);
674 // 3 * 33 / 100 = 0 (integer), result = 3
675 assert_eq!(apply_discount(3, DiscountType::Percentage, 33), 3);
676 // 199 * 50 / 100 = 99, result = 100
677 assert_eq!(apply_discount(199, DiscountType::Percentage, 50), 100);
678 }
679
680 // -- Fixed discount edge cases --
681
682 #[test]
683 fn fixed_discount_exact_price() {
684 assert_eq!(apply_discount(500, DiscountType::Fixed, 500), 0);
685 }
686
687 #[test]
688 fn fixed_discount_zero_value() {
689 assert_eq!(apply_discount(1000, DiscountType::Fixed, 0), 1000);
690 }
691
692 #[test]
693 fn fixed_discount_one_cent() {
694 assert_eq!(apply_discount(1000, DiscountType::Fixed, 1), 999);
695 }
696
697 // -- Zero price --
698
699 #[test]
700 fn zero_price_percentage() {
701 assert_eq!(apply_discount(0, DiscountType::Percentage, 50), 0);
702 }
703
704 #[test]
705 fn zero_price_fixed() {
706 assert_eq!(apply_discount(0, DiscountType::Fixed, 100), 0);
707 }
708
709 // -- Negative values (defensive) --
710
711 #[test]
712 fn negative_discount_value_percentage() {
713 // Negative discount values are clamped to 0, so price is unchanged
714 assert_eq!(apply_discount(1000, DiscountType::Percentage, -50), 1000);
715 }
716
717 #[test]
718 fn negative_discount_value_fixed() {
719 // Negative discount values are clamped to 0, so price is unchanged
720 assert_eq!(apply_discount(1000, DiscountType::Fixed, -500), 1000);
721 }
722
723 #[test]
724 fn negative_price_percentage() {
725 // Negative price with percentage discount — documents current behavior
726 // -1000 * 50 / 100 = -500, -1000 - (-500) = -500, max(0) = 0
727 assert_eq!(apply_discount(-1000, DiscountType::Percentage, 50), 0);
728 }
729
730 #[test]
731 fn negative_price_fixed() {
732 // -1000 - 500 = -1500, max(0) = 0
733 assert_eq!(apply_discount(-1000, DiscountType::Fixed, 500), 0);
734 }
735
736 // -- Large values (overflow safety) --
737
738 #[test]
739 fn large_price_percentage_no_overflow() {
740 // The function uses i64 intermediate to avoid overflow
741 // i32::MAX = 2_147_483_647; 50% of that
742 let price = i32::MAX;
743 let result = apply_discount(price, DiscountType::Percentage, 50);
744 assert_eq!(result, 1_073_741_824); // (MAX - MAX*50/100)
745 }
746
747 // ── Adversarial (test-fuzz) ──
748
749 #[test]
750 fn adversarial_percentage_max_price_max_percentage() {
751 // i32::MAX price with 100% discount
752 let result = apply_discount(i32::MAX, DiscountType::Percentage, 100);
753 assert_eq!(result, 0, "100% discount on any price should be 0");
754 }
755
756 #[test]
757 fn adversarial_percentage_max_price_99_percent() {
758 let result = apply_discount(i32::MAX, DiscountType::Percentage, 99);
759 // i32::MAX * 99 / 100 via i64 = 2_125_999_810, remainder = 21_483_837
760 // Exact: 2_147_483_647 * 99 = 212_600_881_053 / 100 = 2_126_008_810
761 // 2_147_483_647 - 2_126_008_810 = 21_474_837
762 assert_eq!(result, 21_474_837);
763 assert!(result > 0, "99% discount should leave some remaining");
764 }
765
766 #[test]
767 fn adversarial_fixed_max_price_max_discount() {
768 let result = apply_discount(i32::MAX, DiscountType::Fixed, i32::MAX);
769 assert_eq!(result, 0);
770 }
771
772 #[test]
773 fn adversarial_both_negative() {
774 // Both negative price and negative discount
775 let result = apply_discount(-100, DiscountType::Fixed, -100);
776 // -100 - (-100) = 0
777 assert_eq!(result, 0);
778 }
779
780 #[test]
781 fn adversarial_percentage_discount_exactly_50_odd_price() {
782 // Rounding: 1 cent * 50% = 0 (integer division), so result = 1
783 assert_eq!(apply_discount(1, DiscountType::Percentage, 50), 1);
784 // 3 cents * 50% = 1 (via i64: 3*50/100=1), result = 2
785 assert_eq!(apply_discount(3, DiscountType::Percentage, 50), 2);
786 }
787
788 #[test]
789 fn adversarial_apply_discount_invariant() {
790 // For any valid (positive) price and percentage 0-100,
791 // result should be in [0, price]
792 for price in [1, 50, 100, 999, 10000, 1_000_000] {
793 for pct in [0, 1, 10, 25, 33, 50, 75, 99, 100] {
794 let result = apply_discount(price, DiscountType::Percentage, pct);
795 assert!(
796 result >= 0 && result <= price,
797 "Invariant violated: price={}, pct={}, result={}",
798 price, pct, result
799 );
800 }
801 }
802 }
803
804 #[test]
805 fn adversarial_fixed_discount_invariant() {
806 // For any positive price and positive discount, result should be in [0, price]
807 for price in [1, 50, 100, 999, 10000] {
808 for discount in [0, 1, 50, 100, 999, 10000, 999999] {
809 let result = apply_discount(price, DiscountType::Fixed, discount);
810 assert!(
811 result >= 0 && result <= price,
812 "Invariant violated: price={}, discount={}, result={}",
813 price, discount, result
814 );
815 }
816 }
817 }
818
819 // ── Property-based tests (proptest) ──
820
821 proptest::proptest! {
822 #[test]
823 fn prop_percentage_discount_in_range(price in 0..=1_000_000i32, pct in 0..=100i32) {
824 let result = apply_discount(price, DiscountType::Percentage, pct);
825 proptest::prop_assert!(result >= 0, "Result {} should be >= 0", result);
826 proptest::prop_assert!(result <= price, "Result {} should be <= price {}", result, price);
827 }
828
829 #[test]
830 fn prop_fixed_discount_in_range(price in 0..=1_000_000i32, discount in 0..=1_000_000i32) {
831 let result = apply_discount(price, DiscountType::Fixed, discount);
832 proptest::prop_assert!(result >= 0, "Result {} should be >= 0", result);
833 proptest::prop_assert!(result <= price, "Result {} should be <= price {}", result, price);
834 }
835
836 #[test]
837 fn prop_100_percent_discount_is_zero(price in 0..=1_000_000i32) {
838 proptest::prop_assert_eq!(apply_discount(price, DiscountType::Percentage, 100), 0);
839 }
840
841 #[test]
842 fn prop_0_percent_discount_is_identity(price in 0..=1_000_000i32) {
843 proptest::prop_assert_eq!(apply_discount(price, DiscountType::Percentage, 0), price);
844 }
845 }
846 }
847