Skip to main content

max / makenotwork

8.7 KB · 269 lines History Blame Raw
1 //! Session tracking for remote revocation.
2
3 use sqlx::PgPool;
4
5 use super::{DbUserSession, UserId, UserSessionId};
6 use crate::error::Result;
7
8 /// Insert a tracked session row and return its ID.
9 #[tracing::instrument(skip_all)]
10 pub async fn create_user_session(
11 pool: &PgPool,
12 user_id: UserId,
13 user_agent: Option<&str>,
14 ip_address: Option<&str>,
15 ) -> Result<UserSessionId> {
16 let row = sqlx::query_scalar::<_, UserSessionId>(
17 "INSERT INTO user_sessions (user_id, user_agent, ip_address) VALUES ($1, $2, $3) RETURNING id",
18 )
19 .bind(user_id)
20 .bind(user_agent)
21 .bind(ip_address)
22 .fetch_one(pool)
23 .await?;
24
25 Ok(row)
26 }
27
28 /// Insert a `kind='pending_2fa'` row for an intermediate session held between
29 /// the password step and the TOTP/backup-code step. Exposed to
30 /// `delete_all_sessions_for_user` so "log out everywhere" sweeps a phisher
31 /// who has the password but not the second factor.
32 #[tracing::instrument(skip_all)]
33 pub async fn create_pending_2fa_session(
34 pool: &PgPool,
35 user_id: UserId,
36 user_agent: Option<&str>,
37 ip_address: Option<&str>,
38 ) -> Result<UserSessionId> {
39 let row = sqlx::query_scalar::<_, UserSessionId>(
40 "INSERT INTO user_sessions (user_id, user_agent, ip_address, kind)
41 VALUES ($1, $2, $3, 'pending_2fa') RETURNING id",
42 )
43 .bind(user_id)
44 .bind(user_agent)
45 .bind(ip_address)
46 .fetch_one(pool)
47 .await?;
48
49 Ok(row)
50 }
51
52 /// Confirm the pending_2fa tracking row is still present (i.e. wasn't swept
53 /// by `delete_all_sessions_for_user` while the user was at the TOTP prompt).
54 #[tracing::instrument(skip_all)]
55 pub async fn pending_2fa_session_exists(
56 pool: &PgPool,
57 id: UserSessionId,
58 user_id: UserId,
59 ) -> Result<bool> {
60 let exists: bool = sqlx::query_scalar(
61 "SELECT EXISTS(SELECT 1 FROM user_sessions WHERE id = $1 AND user_id = $2 AND kind = 'pending_2fa')",
62 )
63 .bind(id)
64 .bind(user_id)
65 .fetch_one(pool)
66 .await?;
67 Ok(exists)
68 }
69
70 /// Delete a pending_2fa tracking row. Called when 2FA succeeds (the caller
71 /// then `track_session`s a fresh 'active' row) or when the pending state is
72 /// cleared (expiry, account lockout, navigation away).
73 #[tracing::instrument(skip_all)]
74 pub async fn delete_pending_2fa_session(pool: &PgPool, id: UserSessionId) -> Result<()> {
75 sqlx::query("DELETE FROM user_sessions WHERE id = $1 AND kind = 'pending_2fa'")
76 .bind(id)
77 .execute(pool)
78 .await?;
79 Ok(())
80 }
81
82 /// Result of touching a session: whether it exists and the user's current
83 /// suspended status (live from the `users` table, not cached in the session).
84 pub struct TouchResult {
85 /// `false` if the session row was deleted (revoked).
86 pub valid: bool,
87 /// Current `suspended_at IS NOT NULL` from the users table.
88 /// Only meaningful when `valid` is `true`.
89 pub suspended: bool,
90 /// Current `can_create_projects` from the users table.
91 /// Only meaningful when `valid` is `true`.
92 pub can_create_projects: bool,
93 /// Whether the user has an active Fan+ subscription.
94 pub is_fan_plus: bool,
95 /// Active creator tier name (e.g. "SmallFiles"), or None.
96 pub creator_tier: Option<String>,
97 }
98
99 /// Update `last_active_at`, confirm the session still exists, and return
100 /// the user's current `suspended` status from the `users` table.
101 ///
102 /// This ensures suspension takes effect immediately even if the session
103 /// was created before the admin suspended the user.
104 #[tracing::instrument(skip_all)]
105 pub async fn touch_session(pool: &PgPool, session_id: UserSessionId) -> Result<TouchResult> {
106 // Single query: update last_active_at, join users for live status, and check
107 // fan_plus + creator_tier via subqueries (avoids 2 extra round-trips in auth extractor).
108 let row = sqlx::query_as::<_, (bool, bool, bool, Option<String>)>(
109 r#"
110 UPDATE user_sessions us
111 SET last_active_at = NOW()
112 FROM users u
113 WHERE us.id = $1 AND u.id = us.user_id
114 RETURNING
115 u.suspended_at IS NOT NULL,
116 u.can_create_projects,
117 EXISTS(SELECT 1 FROM fan_plus_subscriptions fps WHERE fps.user_id = u.id AND fps.status = 'active'),
118 (SELECT cs.tier FROM creator_subscriptions cs WHERE cs.user_id = u.id AND cs.status = 'active')
119 "#,
120 )
121 .bind(session_id)
122 .fetch_optional(pool)
123 .await?;
124
125 match row {
126 Some((suspended, can_create_projects, is_fan_plus, creator_tier)) => Ok(TouchResult {
127 valid: true,
128 suspended,
129 can_create_projects,
130 is_fan_plus,
131 creator_tier,
132 }),
133 None => Ok(TouchResult { valid: false, suspended: false, can_create_projects: false, is_fan_plus: false, creator_tier: None }),
134 }
135 }
136
137 /// List all active sessions for a user, newest first.
138 #[tracing::instrument(skip_all)]
139 pub async fn get_user_sessions(pool: &PgPool, user_id: UserId) -> Result<Vec<DbUserSession>> {
140 let sessions = sqlx::query_as::<_, DbUserSession>(
141 "SELECT id, user_id, created_at, last_active_at, user_agent, ip_address
142 FROM user_sessions
143 WHERE user_id = $1
144 ORDER BY last_active_at DESC
145 LIMIT 100",
146 )
147 .bind(user_id)
148 .fetch_all(pool)
149 .await?;
150
151 Ok(sessions)
152 }
153
154 /// Count active sessions for a user.
155 #[tracing::instrument(skip_all)]
156 pub async fn count_user_sessions(pool: &PgPool, user_id: UserId) -> Result<i64> {
157 let count = sqlx::query_scalar::<_, i64>(
158 "SELECT COUNT(*) FROM user_sessions WHERE user_id = $1",
159 )
160 .bind(user_id)
161 .fetch_one(pool)
162 .await?;
163
164 Ok(count)
165 }
166
167 /// Delete a single session, scoped to the owning user. Returns `true` if deleted.
168 #[tracing::instrument(skip_all)]
169 pub async fn delete_user_session(
170 pool: &PgPool,
171 session_id: UserSessionId,
172 user_id: UserId,
173 ) -> Result<bool> {
174 let rows = sqlx::query(
175 "DELETE FROM user_sessions WHERE id = $1 AND user_id = $2",
176 )
177 .bind(session_id)
178 .bind(user_id)
179 .execute(pool)
180 .await?;
181
182 Ok(rows.rows_affected() > 0)
183 }
184
185 /// Delete a session row, scoped to a specific user.
186 ///
187 /// The user scoping isn't strictly required for correctness in the current
188 /// caller (logout reads its own tracking ID out of the session and we
189 /// trust that), but the unscoped signature was an easy footgun — anyone
190 /// who later wired this up with an attacker-controllable session_id could
191 /// delete arbitrary rows. Requiring user_id in the signature keeps the
192 /// SQL pinned to "this user, this row" so that misuse fails fast.
193 #[tracing::instrument(skip_all)]
194 pub async fn delete_session_by_id(
195 pool: &PgPool,
196 session_id: UserSessionId,
197 user_id: UserId,
198 ) -> Result<bool> {
199 let rows = sqlx::query("DELETE FROM user_sessions WHERE id = $1 AND user_id = $2")
200 .bind(session_id)
201 .bind(user_id)
202 .execute(pool)
203 .await?;
204
205 Ok(rows.rows_affected() > 0)
206 }
207
208 /// Delete expired session records (inactive longer than the given threshold).
209 /// Returns the number of rows removed.
210 #[tracing::instrument(skip_all)]
211 pub async fn prune_expired_sessions(pool: &PgPool, stale_threshold: chrono::DateTime<chrono::Utc>) -> Result<u64> {
212 let result = sqlx::query(
213 "DELETE FROM user_sessions WHERE last_active_at < $1",
214 )
215 .bind(stale_threshold)
216 .execute(pool)
217 .await?;
218
219 Ok(result.rows_affected())
220 }
221
222 /// Delete all sessions for a user except the current one. Returns count deleted.
223 #[tracing::instrument(skip_all)]
224 pub async fn delete_other_sessions(
225 pool: &PgPool,
226 current_session_id: UserSessionId,
227 user_id: UserId,
228 ) -> Result<Vec<UserSessionId>> {
229 let ids: Vec<UserSessionId> = sqlx::query_scalar(
230 "DELETE FROM user_sessions WHERE user_id = $1 AND id != $2 RETURNING id",
231 )
232 .bind(user_id)
233 .bind(current_session_id)
234 .fetch_all(pool)
235 .await?;
236
237 Ok(ids)
238 }
239
240 /// Delete ALL sessions for a user. Returns the deleted session IDs (for cache eviction).
241 ///
242 /// Also bumps `users.jwt_invalidated_at` — without this, a stolen SyncKit JWT
243 /// would survive a "log out everywhere" until its natural expiry. Both writes
244 /// run in a single transaction so a partial failure can't leave the JWTs alive
245 /// after the session rows are gone.
246 #[tracing::instrument(skip_all)]
247 pub async fn delete_all_sessions_for_user(
248 pool: &PgPool,
249 user_id: UserId,
250 ) -> Result<Vec<UserSessionId>> {
251 let mut tx = pool.begin().await?;
252
253 let ids: Vec<UserSessionId> = sqlx::query_scalar(
254 "DELETE FROM user_sessions WHERE user_id = $1 RETURNING id",
255 )
256 .bind(user_id)
257 .fetch_all(&mut *tx)
258 .await?;
259
260 sqlx::query("UPDATE users SET jwt_invalidated_at = NOW() WHERE id = $1")
261 .bind(user_id)
262 .execute(&mut *tx)
263 .await?;
264
265 tx.commit().await?;
266
267 Ok(ids)
268 }
269