Skip to main content

max / makenotwork

3.2 KB · 105 lines History Blame Raw
1 //! OAuth 2.0 authorization code storage and retrieval.
2
3 use chrono::{DateTime, Utc};
4 use sqlx::PgPool;
5
6 use super::models::DbOAuthCode;
7 use super::{SyncAppId, UserId};
8 use crate::error::Result;
9
10 /// Store a new OAuth authorization code.
11 #[allow(clippy::too_many_arguments)]
12 #[tracing::instrument(skip_all)]
13 pub async fn create_oauth_code(
14 pool: &PgPool,
15 code: &str,
16 app_id: SyncAppId,
17 user_id: UserId,
18 code_challenge: &str,
19 code_challenge_method: &str,
20 redirect_uri: &str,
21 expires_at: DateTime<Utc>,
22 ) -> Result<DbOAuthCode> {
23 let row = sqlx::query_as::<_, DbOAuthCode>(
24 r#"
25 INSERT INTO oauth_authorization_codes
26 (code, app_id, user_id, code_challenge, code_challenge_method, redirect_uri, expires_at)
27 VALUES ($1, $2, $3, $4, $5, $6, $7)
28 RETURNING *
29 "#,
30 )
31 .bind(code)
32 .bind(app_id)
33 .bind(user_id)
34 .bind(code_challenge)
35 .bind(code_challenge_method)
36 .bind(redirect_uri)
37 .bind(expires_at)
38 .fetch_one(pool)
39 .await?;
40
41 Ok(row)
42 }
43
44 /// Atomically consume an authorization code: mark it used and return it in one step.
45 ///
46 /// Returns `Some(code)` if the code was valid and successfully consumed,
47 /// or `None` if the code was already used, expired, or does not exist.
48 /// Because this is a single UPDATE with `used_at IS NULL` in the WHERE clause,
49 /// concurrent requests for the same code will never both succeed.
50 #[tracing::instrument(skip_all)]
51 pub async fn consume_oauth_code(pool: &PgPool, code: &str) -> Result<Option<DbOAuthCode>> {
52 let row = sqlx::query_as::<_, DbOAuthCode>(
53 r#"
54 UPDATE oauth_authorization_codes
55 SET used_at = NOW()
56 WHERE code = $1
57 AND used_at IS NULL
58 AND expires_at > NOW()
59 RETURNING *
60 "#,
61 )
62 .bind(code)
63 .fetch_optional(pool)
64 .await?;
65
66 Ok(row)
67 }
68
69 /// Delete expired or used authorization codes older than 1 hour.
70 /// Called opportunistically from the health monitor loop.
71 #[tracing::instrument(skip_all)]
72 pub async fn cleanup_expired_oauth_codes(pool: &PgPool) -> Result<u64> {
73 let result = sqlx::query(
74 "DELETE FROM oauth_authorization_codes WHERE expires_at < NOW() - INTERVAL '1 hour' OR (used_at IS NOT NULL AND used_at < NOW() - INTERVAL '1 hour')",
75 )
76 .execute(pool)
77 .await?;
78
79 Ok(result.rows_affected())
80 }
81
82 /// Check if a redirect URI is registered for a given sync app.
83 ///
84 /// Returns `Ok(false)` when the app row doesn't exist or is inactive — never
85 /// surfaces a "no rows" error to the caller. Matching is **exact-string** on
86 /// the registered `redirect_uris` array; trailing slashes are significant
87 /// (`https://x/cb` and `https://x/cb/` are distinct registrations), so apps
88 /// must register every variant they intend to redirect to.
89 #[tracing::instrument(skip_all)]
90 pub async fn is_registered_redirect_uri(
91 pool: &PgPool,
92 app_id: SyncAppId,
93 uri: &str,
94 ) -> Result<bool> {
95 let row: Option<(bool,)> = sqlx::query_as(
96 "SELECT $2 = ANY(redirect_uris) FROM sync_apps WHERE id = $1 AND is_active = true",
97 )
98 .bind(app_id)
99 .bind(uri)
100 .fetch_optional(pool)
101 .await?;
102
103 Ok(row.map(|r| r.0).unwrap_or(false))
104 }
105