| 1 |
|
| 2 |
|
| 3 |
use rand::Rng; |
| 4 |
use sqlx::PgPool; |
| 5 |
|
| 6 |
use super::models::DbInviteCode; |
| 7 |
use super::{InviteCodeId, UserId}; |
| 8 |
use crate::error::Result; |
| 9 |
|
| 10 |
|
| 11 |
const CODE_CHARSET: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789"; |
| 12 |
const CODE_LENGTH: usize = 12; |
| 13 |
|
| 14 |
|
| 15 |
#[tracing::instrument(skip_all)] |
| 16 |
pub fn generate_invite_code() -> String { |
| 17 |
let mut rng = rand::rng(); |
| 18 |
(0..CODE_LENGTH) |
| 19 |
.map(|_| CODE_CHARSET[rng.random_range(0..CODE_CHARSET.len())] as char) |
| 20 |
.collect() |
| 21 |
} |
| 22 |
|
| 23 |
|
| 24 |
#[tracing::instrument(skip_all)] |
| 25 |
pub fn format_invite_code(code: &str) -> String { |
| 26 |
let chars: Vec<char> = code.chars().collect(); |
| 27 |
if chars.len() != 12 { |
| 28 |
return code.to_string(); |
| 29 |
} |
| 30 |
format!( |
| 31 |
"{}-{}-{}", |
| 32 |
&code[..4], |
| 33 |
&code[4..8], |
| 34 |
&code[8..12] |
| 35 |
) |
| 36 |
} |
| 37 |
|
| 38 |
|
| 39 |
#[tracing::instrument(skip_all)] |
| 40 |
pub async fn create_invite_code( |
| 41 |
pool: &PgPool, |
| 42 |
creator_id: UserId, |
| 43 |
code: &str, |
| 44 |
) -> Result<DbInviteCode> { |
| 45 |
let invite = sqlx::query_as::<_, DbInviteCode>( |
| 46 |
r#" |
| 47 |
INSERT INTO invite_codes (creator_id, code) |
| 48 |
VALUES ($1, $2) |
| 49 |
RETURNING * |
| 50 |
"#, |
| 51 |
) |
| 52 |
.bind(creator_id) |
| 53 |
.bind(code) |
| 54 |
.fetch_one(pool) |
| 55 |
.await?; |
| 56 |
|
| 57 |
Ok(invite) |
| 58 |
} |
| 59 |
|
| 60 |
|
| 61 |
#[tracing::instrument(skip_all)] |
| 62 |
pub async fn count_active_invites(pool: &PgPool, creator_id: UserId) -> Result<i64> { |
| 63 |
let count: (i64,) = sqlx::query_as( |
| 64 |
"SELECT COUNT(*) FROM invite_codes WHERE creator_id = $1 AND redeemed_by_id IS NULL", |
| 65 |
) |
| 66 |
.bind(creator_id) |
| 67 |
.fetch_one(pool) |
| 68 |
.await?; |
| 69 |
|
| 70 |
Ok(count.0) |
| 71 |
} |
| 72 |
|
| 73 |
|
| 74 |
#[tracing::instrument(skip_all)] |
| 75 |
pub async fn get_invites_by_creator(pool: &PgPool, creator_id: UserId) -> Result<Vec<DbInviteCode>> { |
| 76 |
let invites = sqlx::query_as::<_, DbInviteCode>( |
| 77 |
"SELECT * FROM invite_codes WHERE creator_id = $1 ORDER BY created_at DESC", |
| 78 |
) |
| 79 |
.bind(creator_id) |
| 80 |
.fetch_all(pool) |
| 81 |
.await?; |
| 82 |
|
| 83 |
Ok(invites) |
| 84 |
} |
| 85 |
|
| 86 |
|
| 87 |
#[tracing::instrument(skip_all)] |
| 88 |
pub async fn get_valid_invite_code(pool: &PgPool, code: &str) -> Result<Option<DbInviteCode>> { |
| 89 |
let invite = sqlx::query_as::<_, DbInviteCode>( |
| 90 |
"SELECT * FROM invite_codes WHERE code = $1 AND redeemed_by_id IS NULL", |
| 91 |
) |
| 92 |
.bind(code) |
| 93 |
.fetch_optional(pool) |
| 94 |
.await?; |
| 95 |
|
| 96 |
Ok(invite) |
| 97 |
} |
| 98 |
|
| 99 |
|
| 100 |
#[tracing::instrument(skip_all)] |
| 101 |
pub async fn redeem_invite_code( |
| 102 |
pool: &PgPool, |
| 103 |
code_id: InviteCodeId, |
| 104 |
user_id: UserId, |
| 105 |
) -> Result<()> { |
| 106 |
sqlx::query( |
| 107 |
"UPDATE invite_codes SET redeemed_by_id = $2, redeemed_at = NOW() WHERE id = $1", |
| 108 |
) |
| 109 |
.bind(code_id) |
| 110 |
.bind(user_id) |
| 111 |
.execute(pool) |
| 112 |
.await?; |
| 113 |
|
| 114 |
Ok(()) |
| 115 |
} |
| 116 |
|
| 117 |
#[cfg(test)] |
| 118 |
mod tests { |
| 119 |
use super::*; |
| 120 |
|
| 121 |
#[test] |
| 122 |
fn generate_code_length() { |
| 123 |
let code = generate_invite_code(); |
| 124 |
assert_eq!(code.len(), 12); |
| 125 |
} |
| 126 |
|
| 127 |
#[test] |
| 128 |
fn generate_code_uses_valid_charset() { |
| 129 |
let code = generate_invite_code(); |
| 130 |
for c in code.chars() { |
| 131 |
assert!( |
| 132 |
CODE_CHARSET.contains(&(c as u8)), |
| 133 |
"Invalid char '{c}' in generated code" |
| 134 |
); |
| 135 |
} |
| 136 |
} |
| 137 |
|
| 138 |
#[test] |
| 139 |
fn generate_codes_are_unique() { |
| 140 |
let a = generate_invite_code(); |
| 141 |
let b = generate_invite_code(); |
| 142 |
assert_ne!(a, b); |
| 143 |
} |
| 144 |
|
| 145 |
#[test] |
| 146 |
fn format_code_xxxx_xxxx_xxxx() { |
| 147 |
assert_eq!(format_invite_code("ABCD1234EFGH"), "ABCD-1234-EFGH"); |
| 148 |
} |
| 149 |
|
| 150 |
#[test] |
| 151 |
fn format_code_passthrough_on_wrong_length() { |
| 152 |
assert_eq!(format_invite_code("SHORT"), "SHORT"); |
| 153 |
} |
| 154 |
} |
| 155 |
|