Skip to main content

max / makenotwork

20.3 KB · 579 lines History Blame Raw
1 //! OAuth client for "Log in with Makenot.work" and session user extraction.
2 //!
3 //! Perks (Fan+, creator tier, capabilities) come from MNW's `/oauth/userinfo`
4 //! `perks` object. We cache them in the session and refresh on three triggers:
5 //! (1) login, (2) session cycle, (3) on-demand via `POST /auth/refresh`. See
6 //! `MNW/server/docs/oauth_integration.md` for the contract.
7
8 use axum::{
9 extract::{FromRequestParts, Query, State},
10 http::{request::Parts, StatusCode},
11 response::{IntoResponse, Redirect},
12 Json,
13 };
14 use base64::Engine;
15 use rand::RngCore;
16 use serde::{Deserialize, Serialize};
17 use sha2::{Digest, Sha256};
18 use tokio::time::sleep;
19 use tower_sessions::Session;
20
21 use crate::AppState;
22
23 // ── PKCE helpers ──
24
25 fn generate_verifier() -> String {
26 let mut bytes = [0u8; 32];
27 rand::thread_rng().fill_bytes(&mut bytes);
28 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
29 }
30
31 fn pkce_challenge(verifier: &str) -> String {
32 let mut hasher = Sha256::new();
33 hasher.update(verifier.as_bytes());
34 let digest = hasher.finalize();
35 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
36 }
37
38 fn generate_state_nonce() -> String {
39 let mut bytes = [0u8; 16];
40 rand::thread_rng().fill_bytes(&mut bytes);
41 hex::encode(bytes)
42 }
43
44 // ── Session user ──
45
46 /// User info cached in the session after OAuth login.
47 ///
48 /// `perks` reflects MNW state at the last refresh (login, session cycle, or
49 /// explicit `POST /auth/refresh`). Use [`UserPerks::effective_plus`] for the
50 /// canonical Fan+ gate.
51 #[derive(Clone, Debug)]
52 pub struct SessionUser {
53 pub user_id: uuid::Uuid,
54 pub username: String,
55 pub display_name: Option<String>,
56 pub perks: UserPerks,
57 }
58
59 /// Capability snapshot from MNW's `/oauth/userinfo` `perks` object.
60 ///
61 /// Default = no perks; this is what unknown / not-yet-refreshed sessions see.
62 #[derive(Clone, Debug, Default, Serialize, Deserialize)]
63 pub struct UserPerks {
64 #[serde(default)]
65 pub fan_plus: bool,
66 #[serde(default)]
67 pub is_creator: bool,
68 #[serde(default)]
69 pub creator_tier: Option<CreatorTierInfo>,
70 }
71
72 #[derive(Clone, Debug, Serialize, Deserialize)]
73 pub struct CreatorTierInfo {
74 pub tier: String,
75 pub features: Vec<String>,
76 }
77
78 impl UserPerks {
79 /// Canonical "should this user see + features" check. True for active Fan+
80 /// subscribers and for any creator (auto-grant: creators get + perks without
81 /// paying for Fan+ separately).
82 pub fn effective_plus(&self) -> bool {
83 self.fan_plus || self.is_creator
84 }
85 }
86
87 const SESSION_USER_ID: &str = "user_id";
88 const SESSION_USERNAME: &str = "username";
89 const SESSION_DISPLAY_NAME: &str = "display_name";
90 const SESSION_PERKS: &str = "perks";
91 const SESSION_ACCESS_TOKEN: &str = "mnw_access_token";
92 const SESSION_OAUTH_STATE: &str = "oauth_state";
93 const SESSION_PKCE_VERIFIER: &str = "pkce_verifier";
94
95 impl SessionUser {
96 async fn from_session(session: &Session) -> Option<Self> {
97 let user_id: uuid::Uuid = match session.get(SESSION_USER_ID).await {
98 Ok(v) => v?,
99 Err(e) => {
100 tracing::warn!(error = %e, "failed to read user_id from session");
101 return None;
102 }
103 };
104 let username: String = match session.get(SESSION_USERNAME).await {
105 Ok(v) => v?,
106 Err(e) => {
107 tracing::warn!(error = %e, "failed to read username from session");
108 return None;
109 }
110 };
111 let display_name: Option<String> = match session.get(SESSION_DISPLAY_NAME).await {
112 Ok(v) => v,
113 Err(e) => {
114 tracing::warn!(error = %e, "failed to read display_name from session");
115 None
116 }
117 };
118 // Perks default to empty — sessions predating the perks change still load.
119 let perks: UserPerks = session
120 .get(SESSION_PERKS)
121 .await
122 .unwrap_or_default()
123 .unwrap_or_default();
124 Some(Self {
125 user_id,
126 username,
127 display_name,
128 perks,
129 })
130 }
131
132 async fn save_to_session(&self, session: &Session) {
133 if let Err(e) = session.insert(SESSION_USER_ID, self.user_id).await {
134 tracing::error!(error = %e, "failed to save user_id to session");
135 }
136 if let Err(e) = session.insert(SESSION_USERNAME, &self.username).await {
137 tracing::error!(error = %e, "failed to save username to session");
138 }
139 if let Err(e) = session.insert(SESSION_DISPLAY_NAME, &self.display_name).await {
140 tracing::error!(error = %e, "failed to save display_name to session");
141 }
142 if let Err(e) = session.insert(SESSION_PERKS, &self.perks).await {
143 tracing::error!(error = %e, "failed to save perks to session");
144 }
145 }
146 }
147
148 /// Axum extractor that yields `Option<SessionUser>`.
149 pub struct MaybeUser(pub Option<SessionUser>);
150
151 impl FromRequestParts<AppState> for MaybeUser {
152 type Rejection = std::convert::Infallible;
153
154 async fn from_request_parts(
155 parts: &mut Parts,
156 state: &AppState,
157 ) -> Result<Self, Self::Rejection> {
158 let session = Session::from_request_parts(parts, state)
159 .await
160 .expect("session layer missing");
161 Ok(MaybeUser(SessionUser::from_session(&session).await))
162 }
163 }
164
165 /// Axum extractor that requires the user to be the platform admin.
166 /// Returns 404 to non-admins (hides admin routes).
167 pub struct PlatformAdmin(pub SessionUser);
168
169 impl FromRequestParts<AppState> for PlatformAdmin {
170 type Rejection = StatusCode;
171
172 async fn from_request_parts(
173 parts: &mut Parts,
174 state: &AppState,
175 ) -> Result<Self, Self::Rejection> {
176 let session = Session::from_request_parts(parts, state)
177 .await
178 .expect("session layer missing");
179 let user = SessionUser::from_session(&session)
180 .await
181 .ok_or(StatusCode::NOT_FOUND)?;
182
183 let admin_id = state.config.platform_admin_id.ok_or(StatusCode::NOT_FOUND)?;
184 if user.user_id != admin_id {
185 return Err(StatusCode::NOT_FOUND);
186 }
187
188 Ok(PlatformAdmin(user))
189 }
190 }
191
192 // ── OAuth callback types ──
193
194 #[derive(Deserialize)]
195 pub struct CallbackQuery {
196 pub code: String,
197 pub state: String,
198 }
199
200 #[derive(Deserialize)]
201 struct TokenResponse {
202 access_token: String,
203 }
204
205 #[derive(Deserialize)]
206 struct UserinfoResponse {
207 user_id: uuid::Uuid,
208 username: String,
209 display_name: Option<String>,
210 avatar_url: Option<String>,
211 #[serde(default)]
212 perks: UserPerks,
213 }
214
215 #[derive(Debug)]
216 pub enum UserinfoError {
217 Unauthorized,
218 Transport,
219 BadResponse,
220 }
221
222 /// Single-attempt userinfo fetch against MNW. Callers decide retry policy.
223 ///
224 /// `Unauthorized` means the bearer token is invalid or the user is gone.
225 /// `Transport` covers network and 5xx. `BadResponse` covers other 4xx and parse
226 /// errors. The login callback retries on `Transport`; `refresh_session` does
227 /// not — the client can retry.
228 async fn fetch_userinfo(
229 http: &reqwest::Client,
230 base_url: &str,
231 access_token: &str,
232 ) -> Result<UserinfoResponse, UserinfoError> {
233 let url = format!("{}/oauth/userinfo", base_url);
234 let res = http
235 .get(&url)
236 .bearer_auth(access_token)
237 .send()
238 .await
239 .map_err(|e| {
240 tracing::warn!(error = %e, "userinfo transport error");
241 UserinfoError::Transport
242 })?;
243
244 let status = res.status();
245 if status == reqwest::StatusCode::UNAUTHORIZED {
246 return Err(UserinfoError::Unauthorized);
247 }
248 if status.is_server_error() {
249 return Err(UserinfoError::Transport);
250 }
251 if !status.is_success() {
252 let body = res.text().await.unwrap_or_default();
253 tracing::warn!(%status, %body, "userinfo non-success");
254 return Err(UserinfoError::BadResponse);
255 }
256
257 res.json::<UserinfoResponse>().await.map_err(|e| {
258 tracing::warn!(error = %e, "userinfo parse failed");
259 UserinfoError::BadResponse
260 })
261 }
262
263 /// Refresh the cached perks for the current session by re-hitting MNW.
264 ///
265 /// Caller must have a logged-in session (access token stored at login). On
266 /// `Unauthorized` the session is flushed — the access token is gone for good
267 /// and the user needs to log in again. Other errors leave the session intact.
268 pub async fn refresh_session(
269 state: &AppState,
270 session: &Session,
271 ) -> Result<UserPerks, UserinfoError> {
272 let token: String = session
273 .get(SESSION_ACCESS_TOKEN)
274 .await
275 .unwrap_or(None)
276 .ok_or(UserinfoError::Unauthorized)?;
277
278 match fetch_userinfo(&state.http, &state.config.mnw_base_url, &token).await {
279 Ok(info) => {
280 if let Err(e) = session.insert(SESSION_PERKS, &info.perks).await {
281 tracing::error!(error = %e, "failed to save refreshed perks");
282 }
283 // Username/display can drift on MNW too — sync them while we're here.
284 if let Err(e) = session.insert(SESSION_USERNAME, &info.username).await {
285 tracing::error!(error = %e, "failed to save refreshed username");
286 }
287 if let Err(e) = session.insert(SESSION_DISPLAY_NAME, &info.display_name).await {
288 tracing::error!(error = %e, "failed to save refreshed display_name");
289 }
290 // Mirror perks into users table so post rendering sees the change
291 // without consulting MNW per-post. Best-effort: rendering tolerates
292 // a stale row, so DB errors here are logged but non-fatal.
293 if let Err(e) = sqlx::query(
294 "UPDATE users SET is_fan_plus = $2, is_creator = $3 WHERE mnw_account_id = $1",
295 )
296 .bind(info.user_id)
297 .bind(info.perks.fan_plus)
298 .bind(info.perks.is_creator)
299 .execute(&state.db)
300 .await
301 {
302 tracing::warn!(error = %e, "failed to mirror refreshed perks to users table");
303 }
304 let _ = info.avatar_url; // not stored in session yet
305 Ok(info.perks)
306 }
307 Err(UserinfoError::Unauthorized) => {
308 // Token revoked, expired, or user deleted — drop the session.
309 if let Err(e) = session.flush().await {
310 tracing::warn!(error = %e, "failed to flush session after auth failure");
311 }
312 Err(UserinfoError::Unauthorized)
313 }
314 Err(e) => Err(e),
315 }
316 }
317
318 // ── Handlers ──
319
320 /// `GET /auth/login` — redirect to MNW OAuth authorize endpoint.
321 #[tracing::instrument(skip_all)]
322 pub async fn login(
323 State(state): State<AppState>,
324 session: Session,
325 ) -> impl IntoResponse {
326 let verifier = generate_verifier();
327 let challenge = pkce_challenge(&verifier);
328 let oauth_state = generate_state_nonce();
329
330 if let Err(e) = session.insert(SESSION_PKCE_VERIFIER, &verifier).await {
331 tracing::error!(error = %e, "failed to save PKCE verifier to session");
332 }
333 if let Err(e) = session.insert(SESSION_OAUTH_STATE, &oauth_state).await {
334 tracing::error!(error = %e, "failed to save OAuth state to session");
335 }
336
337 let url = format!(
338 "{}/oauth/authorize?response_type=code&client_id={}&redirect_uri={}&state={}&code_challenge={}&code_challenge_method=S256",
339 state.config.mnw_base_url,
340 urlencoding::encode(&state.config.oauth_client_id),
341 urlencoding::encode(&state.config.oauth_redirect_uri),
342 urlencoding::encode(&oauth_state),
343 urlencoding::encode(&challenge),
344 );
345
346 Redirect::to(&url)
347 }
348
349 /// `GET /auth/callback` — exchange code for token, fetch userinfo, create session.
350 #[tracing::instrument(skip_all)]
351 pub async fn callback(
352 State(state): State<AppState>,
353 session: Session,
354 Query(params): Query<CallbackQuery>,
355 ) -> impl IntoResponse {
356 tracing::info!("OAuth callback received");
357
358 // Verify state nonce
359 let stored_state: Option<String> = session.get(SESSION_OAUTH_STATE).await.unwrap_or(None);
360 if stored_state.as_deref() != Some(&params.state) {
361 tracing::warn!(stored = ?stored_state, received = %params.state, "state mismatch");
362 return Redirect::to("/?error=state_mismatch");
363 }
364
365 let verifier: String = match session.get(SESSION_PKCE_VERIFIER).await.unwrap_or(None) {
366 Some(v) => v,
367 None => {
368 tracing::warn!("missing PKCE verifier in session");
369 return Redirect::to("/?error=missing_verifier");
370 }
371 };
372
373 // Clean up OAuth session data
374 if let Err(e) = session.remove::<String>(SESSION_OAUTH_STATE).await {
375 tracing::warn!(error = %e, "failed to remove OAuth state from session");
376 }
377 if let Err(e) = session.remove::<String>(SESSION_PKCE_VERIFIER).await {
378 tracing::warn!(error = %e, "failed to remove PKCE verifier from session");
379 }
380
381 // Exchange code for token (retry up to 2 attempts on network/5xx errors)
382 let token_url = format!("{}/oauth/token", state.config.mnw_base_url);
383 tracing::info!(%token_url, "exchanging code for token");
384 let backoffs = [
385 std::time::Duration::from_millis(500),
386 std::time::Duration::from_millis(1000),
387 ];
388 let mut token_res = None;
389 for attempt in 0..=backoffs.len() {
390 let res = state
391 .http
392 .post(&token_url)
393 .json(&serde_json::json!({
394 "grant_type": "authorization_code",
395 "code": params.code,
396 "redirect_uri": state.config.oauth_redirect_uri,
397 "code_verifier": verifier,
398 "client_id": state.config.oauth_client_id,
399 }))
400 .send()
401 .await;
402
403 match res {
404 Ok(r) if r.status().is_server_error() => {
405 let status = r.status();
406 if attempt < backoffs.len() {
407 tracing::warn!(%status, attempt, "token exchange got 5xx, retrying");
408 sleep(backoffs[attempt]).await;
409 continue;
410 }
411 let body = r.text().await.unwrap_or_default();
412 tracing::error!(%status, %body, "token exchange failed after retries");
413 return Redirect::to("/?error=token_exchange_failed");
414 }
415 Ok(r) if !r.status().is_success() => {
416 let status = r.status();
417 let body = r.text().await.unwrap_or_default();
418 tracing::error!(%status, %body, "token exchange failed");
419 return Redirect::to("/?error=token_exchange_failed");
420 }
421 Ok(r) => {
422 token_res = Some(r);
423 break;
424 }
425 Err(e) => {
426 if attempt < backoffs.len() {
427 tracing::warn!(error = %e, attempt, "token request failed, retrying");
428 sleep(backoffs[attempt]).await;
429 continue;
430 }
431 tracing::error!(error = %e, "token request failed after retries");
432 return Redirect::to("/?error=token_request_failed");
433 }
434 }
435 }
436 // Safety: loop always either sets token_res or returns early
437 let token_res = token_res.unwrap();
438
439 let token: TokenResponse = match token_res.json().await {
440 Ok(t) => t,
441 Err(e) => {
442 tracing::error!(error = %e, "token parse failed");
443 return Redirect::to("/?error=token_parse_failed");
444 }
445 };
446
447 // Fetch userinfo (retry up to 2 attempts on transport / 5xx errors).
448 tracing::info!(base_url = %state.config.mnw_base_url, "fetching userinfo");
449 let mut info: Option<UserinfoResponse> = None;
450 for attempt in 0..=backoffs.len() {
451 match fetch_userinfo(&state.http, &state.config.mnw_base_url, &token.access_token).await {
452 Ok(i) => {
453 info = Some(i);
454 break;
455 }
456 Err(UserinfoError::Transport) if attempt < backoffs.len() => {
457 tracing::warn!(attempt, "userinfo transport error, retrying");
458 sleep(backoffs[attempt]).await;
459 continue;
460 }
461 Err(UserinfoError::Transport) => {
462 tracing::error!("userinfo transport failed after retries");
463 return Redirect::to("/?error=userinfo_fetch_failed");
464 }
465 Err(UserinfoError::Unauthorized) => {
466 tracing::error!("userinfo unauthorized — token rejected");
467 return Redirect::to("/?error=userinfo_fetch_failed");
468 }
469 Err(UserinfoError::BadResponse) => {
470 tracing::error!("userinfo bad response");
471 return Redirect::to("/?error=userinfo_parse_failed");
472 }
473 }
474 }
475 let info = info.expect("userinfo loop always sets value or returns");
476
477 tracing::info!(user_id = %info.user_id, username = %info.username, "OAuth login successful");
478
479 // Upsert local user. `is_fan_plus`/`is_creator` are denormalised here so
480 // post rendering can look up the post author's perks via JOIN — see
481 // migration 026.
482 let upsert_result = sqlx::query(
483 r#"
484 INSERT INTO users (mnw_account_id, username, display_name, avatar_url, is_fan_plus, is_creator)
485 VALUES ($1, $2, $3, $4, $5, $6)
486 ON CONFLICT (mnw_account_id) DO UPDATE
487 SET username = $2, display_name = $3, avatar_url = $4,
488 is_fan_plus = $5, is_creator = $6, updated_at = now()
489 "#,
490 )
491 .bind(info.user_id)
492 .bind(&info.username)
493 .bind(&info.display_name)
494 .bind(&info.avatar_url)
495 .bind(info.perks.fan_plus)
496 .bind(info.perks.is_creator)
497 .execute(&state.db)
498 .await;
499
500 if let Err(e) = upsert_result {
501 tracing::error!(error = %e, "user upsert failed");
502 return Redirect::to("/?error=user_upsert_failed");
503 }
504
505 // Check if user is suspended (fail-closed: DB errors block login)
506 let suspended: bool = match sqlx::query_scalar(
507 "SELECT suspended_at IS NOT NULL FROM users WHERE mnw_account_id = $1",
508 )
509 .bind(info.user_id)
510 .fetch_one(&state.db)
511 .await
512 {
513 Ok(v) => v,
514 Err(e) => {
515 tracing::error!(error = %e, "db error checking suspension status");
516 return Redirect::to("/?error=internal_error");
517 }
518 };
519
520 if suspended {
521 return Redirect::to("/?error=account_suspended");
522 }
523
524 // Save session — perks come from the same userinfo response, no second roundtrip.
525 let session_user = SessionUser {
526 user_id: info.user_id,
527 username: info.username,
528 display_name: info.display_name,
529 perks: info.perks,
530 };
531 session_user.save_to_session(&session).await;
532 // Stash the access token so `refresh_session` can re-hit userinfo without
533 // forcing the user through another OAuth round trip. Token lifetime is set
534 // by MNW (7d as of writing); after expiry, refresh returns Unauthorized and
535 // the session is flushed.
536 if let Err(e) = session.insert(SESSION_ACCESS_TOKEN, &token.access_token).await {
537 tracing::error!(error = %e, "failed to save access token to session");
538 }
539 if let Err(e) = session.cycle_id().await {
540 tracing::warn!(error = %e, "Failed to cycle session ID");
541 }
542 tracing::info!("session saved, redirecting to /");
543
544 Redirect::to("/")
545 }
546
547 /// `POST /auth/refresh` — re-fetch MNW userinfo and overwrite cached perks.
548 ///
549 /// Useful after the user takes an action that changed their MNW entitlements
550 /// (e.g., subscribing to Fan+, upgrading a creator tier) so they don't have to
551 /// log out and back in to see the new perks. Returns the refreshed perks as
552 /// JSON.
553 #[tracing::instrument(skip_all)]
554 pub async fn refresh(
555 State(state): State<AppState>,
556 session: Session,
557 ) -> Result<Json<RefreshResponse>, StatusCode> {
558 match refresh_session(&state, &session).await {
559 Ok(perks) => Ok(Json(RefreshResponse { perks })),
560 Err(UserinfoError::Unauthorized) => Err(StatusCode::UNAUTHORIZED),
561 Err(UserinfoError::Transport) => Err(StatusCode::BAD_GATEWAY),
562 Err(UserinfoError::BadResponse) => Err(StatusCode::BAD_GATEWAY),
563 }
564 }
565
566 #[derive(Serialize)]
567 pub struct RefreshResponse {
568 pub perks: UserPerks,
569 }
570
571 /// `POST /auth/logout` — flush session, redirect home.
572 #[tracing::instrument(skip_all)]
573 pub async fn logout(session: Session) -> impl IntoResponse {
574 if let Err(e) = session.flush().await {
575 tracing::warn!(error = %e, "failed to flush session on logout");
576 }
577 Redirect::to("/")
578 }
579