//! Invite code generation and redemption queries. use rand::Rng; use sqlx::PgPool; use super::models::DbInviteCode; use super::{InviteCodeId, UserId}; use crate::error::Result; /// Charset for invite codes: uppercase alphanumeric minus ambiguous chars (I/O/0/1). const CODE_CHARSET: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789"; const CODE_LENGTH: usize = 12; /// Generate a random 12-character invite code from the unambiguous charset. #[tracing::instrument(skip_all)] pub fn generate_invite_code() -> String { let mut rng = rand::rng(); (0..CODE_LENGTH) .map(|_| CODE_CHARSET[rng.random_range(0..CODE_CHARSET.len())] as char) .collect() } /// Format a raw 12-char code as `XXXX-XXXX-XXXX` for display. #[tracing::instrument(skip_all)] pub fn format_invite_code(code: &str) -> String { let chars: Vec = code.chars().collect(); if chars.len() != 12 { return code.to_string(); } format!( "{}-{}-{}", &code[..4], &code[4..8], &code[8..12] ) } /// Insert a new invite code for a creator. #[tracing::instrument(skip_all)] pub async fn create_invite_code( pool: &PgPool, creator_id: UserId, code: &str, ) -> Result { let invite = sqlx::query_as::<_, DbInviteCode>( r#" INSERT INTO invite_codes (creator_id, code) VALUES ($1, $2) RETURNING * "#, ) .bind(creator_id) .bind(code) .fetch_one(pool) .await?; Ok(invite) } /// Count unredeemed (active) invite codes for a creator. #[tracing::instrument(skip_all)] pub async fn count_active_invites(pool: &PgPool, creator_id: UserId) -> Result { let count: (i64,) = sqlx::query_as( "SELECT COUNT(*) FROM invite_codes WHERE creator_id = $1 AND redeemed_by_id IS NULL", ) .bind(creator_id) .fetch_one(pool) .await?; Ok(count.0) } /// List all invite codes created by a specific creator, newest first. #[tracing::instrument(skip_all)] pub async fn get_invites_by_creator(pool: &PgPool, creator_id: UserId) -> Result> { let invites = sqlx::query_as::<_, DbInviteCode>( "SELECT * FROM invite_codes WHERE creator_id = $1 ORDER BY created_at DESC", ) .bind(creator_id) .fetch_all(pool) .await?; Ok(invites) } /// Look up an unredeemed invite code by its raw code string. #[tracing::instrument(skip_all)] pub async fn get_valid_invite_code(pool: &PgPool, code: &str) -> Result> { let invite = sqlx::query_as::<_, DbInviteCode>( "SELECT * FROM invite_codes WHERE code = $1 AND redeemed_by_id IS NULL", ) .bind(code) .fetch_optional(pool) .await?; Ok(invite) } /// Mark an invite code as redeemed by a specific user. #[tracing::instrument(skip_all)] pub async fn redeem_invite_code( pool: &PgPool, code_id: InviteCodeId, user_id: UserId, ) -> Result<()> { sqlx::query( "UPDATE invite_codes SET redeemed_by_id = $2, redeemed_at = NOW() WHERE id = $1", ) .bind(code_id) .bind(user_id) .execute(pool) .await?; Ok(()) } #[cfg(test)] mod tests { use super::*; #[test] fn generate_code_length() { let code = generate_invite_code(); assert_eq!(code.len(), 12); } #[test] fn generate_code_uses_valid_charset() { let code = generate_invite_code(); for c in code.chars() { assert!( CODE_CHARSET.contains(&(c as u8)), "Invalid char '{c}' in generated code" ); } } #[test] fn generate_codes_are_unique() { let a = generate_invite_code(); let b = generate_invite_code(); assert_ne!(a, b); } #[test] fn format_code_xxxx_xxxx_xxxx() { assert_eq!(format_invite_code("ABCD1234EFGH"), "ABCD-1234-EFGH"); } #[test] fn format_code_passthrough_on_wrong_length() { assert_eq!(format_invite_code("SHORT"), "SHORT"); } }