Skip to main content

max / makenotwork

6.7 KB · 200 lines History Blame Raw
1 //! Two-factor authentication queries: TOTP secrets, backup codes.
2
3 use sqlx::PgPool;
4
5 use super::UserId;
6 use crate::error::Result;
7
8 /// Get the stored TOTP secret for a user (None if not set up).
9 #[tracing::instrument(skip_all)]
10 pub async fn get_totp_secret(pool: &PgPool, user_id: UserId) -> Result<Option<String>> {
11 let secret: Option<String> =
12 sqlx::query_scalar("SELECT totp_secret FROM users WHERE id = $1")
13 .bind(user_id)
14 .fetch_one(pool)
15 .await?;
16
17 Ok(secret)
18 }
19
20 /// Store a TOTP secret for a user (does not enable 2FA yet).
21 #[tracing::instrument(skip_all)]
22 pub async fn set_totp_secret(pool: &PgPool, user_id: UserId, secret: &str) -> Result<()> {
23 // Clear the replay step alongside the secret. Without this, a user who
24 // disables and re-enables TOTP (potentially with a new secret) inherits a
25 // stale `totp_last_used_step` and any first-attempt code in a lower step
26 // window is false-rejected as a replay.
27 sqlx::query("UPDATE users SET totp_secret = $2, totp_last_used_step = NULL WHERE id = $1")
28 .bind(user_id)
29 .bind(secret)
30 .execute(pool)
31 .await?;
32
33 Ok(())
34 }
35
36 /// Enable TOTP 2FA for a user (called after first successful code verification).
37 #[tracing::instrument(skip_all)]
38 pub async fn enable_totp(pool: &PgPool, user_id: UserId) -> Result<()> {
39 sqlx::query("UPDATE users SET totp_enabled = true WHERE id = $1")
40 .bind(user_id)
41 .execute(pool)
42 .await?;
43
44 Ok(())
45 }
46
47 /// Disable TOTP 2FA: clear the secret, set enabled to false, delete backup codes.
48 #[tracing::instrument(skip_all)]
49 pub async fn disable_totp(pool: &PgPool, user_id: UserId) -> Result<()> {
50 sqlx::query("UPDATE users SET totp_secret = NULL, totp_enabled = false, totp_last_used_step = NULL WHERE id = $1")
51 .bind(user_id)
52 .execute(pool)
53 .await?;
54
55 sqlx::query("DELETE FROM backup_codes WHERE user_id = $1")
56 .bind(user_id)
57 .execute(pool)
58 .await?;
59
60 Ok(())
61 }
62
63 /// Get the last accepted TOTP time step for a user (for replay prevention).
64 #[tracing::instrument(skip_all)]
65 pub async fn get_totp_last_used_step(pool: &PgPool, user_id: UserId) -> Result<Option<i64>> {
66 let step: Option<i64> =
67 sqlx::query_scalar("SELECT totp_last_used_step FROM users WHERE id = $1")
68 .bind(user_id)
69 .fetch_one(pool)
70 .await?;
71
72 Ok(step)
73 }
74
75 /// Update the last accepted TOTP time step (set after successful verification).
76 #[tracing::instrument(skip_all)]
77 pub async fn set_totp_last_used_step(pool: &PgPool, user_id: UserId, step: i64) -> Result<()> {
78 sqlx::query("UPDATE users SET totp_last_used_step = $2 WHERE id = $1")
79 .bind(user_id)
80 .bind(step)
81 .execute(pool)
82 .await?;
83
84 Ok(())
85 }
86
87 /// Check if a user has TOTP 2FA enabled.
88 #[tracing::instrument(skip_all)]
89 pub async fn is_totp_enabled(pool: &PgPool, user_id: UserId) -> Result<bool> {
90 let enabled: bool =
91 sqlx::query_scalar("SELECT totp_enabled FROM users WHERE id = $1")
92 .bind(user_id)
93 .fetch_one(pool)
94 .await?;
95
96 Ok(enabled)
97 }
98
99 /// Delete existing backup codes and insert new ones (atomic replacement).
100 #[tracing::instrument(skip_all)]
101 pub async fn create_backup_codes(
102 pool: &PgPool,
103 user_id: UserId,
104 code_hashes: &[String],
105 ) -> Result<()> {
106 let mut tx = pool.begin().await?;
107
108 // Delete any existing codes
109 sqlx::query("DELETE FROM backup_codes WHERE user_id = $1")
110 .bind(user_id)
111 .execute(&mut *tx)
112 .await?;
113
114 // Batch insert all codes in a single query
115 sqlx::query(
116 "INSERT INTO backup_codes (user_id, code_hash) SELECT $1, UNNEST($2::text[])",
117 )
118 .bind(user_id)
119 .bind(code_hashes)
120 .execute(&mut *tx)
121 .await?;
122
123 tx.commit().await?;
124 Ok(())
125 }
126
127 /// Verify a backup code and mark it as used if found.
128 ///
129 /// `code` is the raw 8-char token the user typed; `legacy_hmac` is the
130 /// HMAC-SHA256 of the same code (passed in pre-computed by the caller so the
131 /// secret stays in route-layer scope). Returns `Ok(true)` when a matching
132 /// unused code is consumed.
133 ///
134 /// Dual-read window: rows hashed under the old HMAC scheme remain valid
135 /// until the user regenerates their backup codes (each regeneration writes
136 /// fresh Argon2 hashes). Argon2 PHC strings begin with `$argon2`; anything
137 /// else is treated as a legacy 64-char hex HMAC.
138 #[tracing::instrument(skip_all)]
139 pub async fn verify_and_consume_backup_code(
140 pool: &PgPool,
141 user_id: UserId,
142 code: &str,
143 legacy_hmac: &str,
144 ) -> Result<bool> {
145 use argon2::{password_hash::PasswordVerifier, Argon2, PasswordHash};
146
147 let rows: Vec<(uuid::Uuid, String)> = sqlx::query_as(
148 "SELECT id, code_hash FROM backup_codes WHERE user_id = $1 AND used_at IS NULL",
149 )
150 .bind(user_id)
151 .fetch_all(pool)
152 .await?;
153
154 // Timing note: the loop `break`s on the first match, which is NOT a usable
155 // timing oracle. A wrong guess (the attacker's case) matches nothing, so the
156 // loop always runs to completion and scans every row in constant time —
157 // independent of code ordering. The early exit fires only on a *successful*
158 // verify, by which point the caller already supplied a valid code and has
159 // nothing left to learn. Retaining the break also avoids forcing N Argon2
160 // verifications (each ~46 MiB) on every attempt, which would hand an attacker
161 // a memory-amplification lever on the 2FA endpoint. Brute force is bounded
162 // separately by the shared failed-attempt lockout.
163 let mut matched_id: Option<uuid::Uuid> = None;
164 for (id, stored) in &rows {
165 let is_match = if stored.starts_with("$argon2") {
166 match PasswordHash::new(stored) {
167 Ok(parsed) => Argon2::default()
168 .verify_password(code.as_bytes(), &parsed)
169 .is_ok(),
170 Err(e) => {
171 tracing::warn!(error = %e, "malformed argon2 backup code hash in DB; skipping");
172 false
173 }
174 }
175 } else {
176 // Legacy HMAC-SHA256 hex. Length-equality short-circuits before
177 // the constant-time compare, matching the existing behavior of
178 // `crypto::constant_time_compare`.
179 crate::crypto::constant_time_compare(stored, legacy_hmac)
180 };
181 if is_match {
182 matched_id = Some(*id);
183 break;
184 }
185 }
186
187 let Some(id) = matched_id else {
188 return Ok(false);
189 };
190
191 let result = sqlx::query(
192 "UPDATE backup_codes SET used_at = NOW() WHERE id = $1 AND used_at IS NULL",
193 )
194 .bind(id)
195 .execute(pool)
196 .await?;
197
198 Ok(result.rows_affected() > 0)
199 }
200