Skip to main content

max / makenotwork

3.9 KB · 155 lines History Blame Raw
1 //! Invite code generation and redemption queries.
2
3 use rand::Rng;
4 use sqlx::PgPool;
5
6 use super::models::DbInviteCode;
7 use super::{InviteCodeId, UserId};
8 use crate::error::Result;
9
10 /// Charset for invite codes: uppercase alphanumeric minus ambiguous chars (I/O/0/1).
11 const CODE_CHARSET: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789";
12 const CODE_LENGTH: usize = 12;
13
14 /// Generate a random 12-character invite code from the unambiguous charset.
15 #[tracing::instrument(skip_all)]
16 pub fn generate_invite_code() -> String {
17 let mut rng = rand::rng();
18 (0..CODE_LENGTH)
19 .map(|_| CODE_CHARSET[rng.random_range(0..CODE_CHARSET.len())] as char)
20 .collect()
21 }
22
23 /// Format a raw 12-char code as `XXXX-XXXX-XXXX` for display.
24 #[tracing::instrument(skip_all)]
25 pub fn format_invite_code(code: &str) -> String {
26 let chars: Vec<char> = code.chars().collect();
27 if chars.len() != 12 {
28 return code.to_string();
29 }
30 format!(
31 "{}-{}-{}",
32 &code[..4],
33 &code[4..8],
34 &code[8..12]
35 )
36 }
37
38 /// Insert a new invite code for a creator.
39 #[tracing::instrument(skip_all)]
40 pub async fn create_invite_code(
41 pool: &PgPool,
42 creator_id: UserId,
43 code: &str,
44 ) -> Result<DbInviteCode> {
45 let invite = sqlx::query_as::<_, DbInviteCode>(
46 r#"
47 INSERT INTO invite_codes (creator_id, code)
48 VALUES ($1, $2)
49 RETURNING *
50 "#,
51 )
52 .bind(creator_id)
53 .bind(code)
54 .fetch_one(pool)
55 .await?;
56
57 Ok(invite)
58 }
59
60 /// Count unredeemed (active) invite codes for a creator.
61 #[tracing::instrument(skip_all)]
62 pub async fn count_active_invites(pool: &PgPool, creator_id: UserId) -> Result<i64> {
63 let count: (i64,) = sqlx::query_as(
64 "SELECT COUNT(*) FROM invite_codes WHERE creator_id = $1 AND redeemed_by_id IS NULL",
65 )
66 .bind(creator_id)
67 .fetch_one(pool)
68 .await?;
69
70 Ok(count.0)
71 }
72
73 /// List all invite codes created by a specific creator, newest first.
74 #[tracing::instrument(skip_all)]
75 pub async fn get_invites_by_creator(pool: &PgPool, creator_id: UserId) -> Result<Vec<DbInviteCode>> {
76 let invites = sqlx::query_as::<_, DbInviteCode>(
77 "SELECT * FROM invite_codes WHERE creator_id = $1 ORDER BY created_at DESC",
78 )
79 .bind(creator_id)
80 .fetch_all(pool)
81 .await?;
82
83 Ok(invites)
84 }
85
86 /// Look up an unredeemed invite code by its raw code string.
87 #[tracing::instrument(skip_all)]
88 pub async fn get_valid_invite_code(pool: &PgPool, code: &str) -> Result<Option<DbInviteCode>> {
89 let invite = sqlx::query_as::<_, DbInviteCode>(
90 "SELECT * FROM invite_codes WHERE code = $1 AND redeemed_by_id IS NULL",
91 )
92 .bind(code)
93 .fetch_optional(pool)
94 .await?;
95
96 Ok(invite)
97 }
98
99 /// Mark an invite code as redeemed by a specific user.
100 #[tracing::instrument(skip_all)]
101 pub async fn redeem_invite_code(
102 pool: &PgPool,
103 code_id: InviteCodeId,
104 user_id: UserId,
105 ) -> Result<()> {
106 sqlx::query(
107 "UPDATE invite_codes SET redeemed_by_id = $2, redeemed_at = NOW() WHERE id = $1",
108 )
109 .bind(code_id)
110 .bind(user_id)
111 .execute(pool)
112 .await?;
113
114 Ok(())
115 }
116
117 #[cfg(test)]
118 mod tests {
119 use super::*;
120
121 #[test]
122 fn generate_code_length() {
123 let code = generate_invite_code();
124 assert_eq!(code.len(), 12);
125 }
126
127 #[test]
128 fn generate_code_uses_valid_charset() {
129 let code = generate_invite_code();
130 for c in code.chars() {
131 assert!(
132 CODE_CHARSET.contains(&(c as u8)),
133 "Invalid char '{c}' in generated code"
134 );
135 }
136 }
137
138 #[test]
139 fn generate_codes_are_unique() {
140 let a = generate_invite_code();
141 let b = generate_invite_code();
142 assert_ne!(a, b);
143 }
144
145 #[test]
146 fn format_code_xxxx_xxxx_xxxx() {
147 assert_eq!(format_invite_code("ABCD1234EFGH"), "ABCD-1234-EFGH");
148 }
149
150 #[test]
151 fn format_code_passthrough_on_wrong_length() {
152 assert_eq!(format_invite_code("SHORT"), "SHORT");
153 }
154 }
155