//! OAuth client for "Log in with Makenot.work" and session user extraction. //! //! Perks (Fan+, creator tier, capabilities) come from MNW's `/oauth/userinfo` //! `perks` object. We cache them in the session and refresh on three triggers: //! (1) login, (2) session cycle, (3) on-demand via `POST /auth/refresh`. See //! `MNW/server/docs/oauth_integration.md` for the contract. use axum::{ extract::{FromRequestParts, Query, State}, http::{request::Parts, StatusCode}, response::{IntoResponse, Redirect}, Json, }; use base64::Engine; use rand::RngCore; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use tokio::time::sleep; use tower_sessions::Session; use crate::AppState; // ── PKCE helpers ── fn generate_verifier() -> String { let mut bytes = [0u8; 32]; rand::thread_rng().fill_bytes(&mut bytes); base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) } fn pkce_challenge(verifier: &str) -> String { let mut hasher = Sha256::new(); hasher.update(verifier.as_bytes()); let digest = hasher.finalize(); base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest) } fn generate_state_nonce() -> String { let mut bytes = [0u8; 16]; rand::thread_rng().fill_bytes(&mut bytes); hex::encode(bytes) } // ── Session user ── /// User info cached in the session after OAuth login. /// /// `perks` reflects MNW state at the last refresh (login, session cycle, or /// explicit `POST /auth/refresh`). Use [`UserPerks::effective_plus`] for the /// canonical Fan+ gate. #[derive(Clone, Debug)] pub struct SessionUser { pub user_id: uuid::Uuid, pub username: String, pub display_name: Option, pub perks: UserPerks, } /// Capability snapshot from MNW's `/oauth/userinfo` `perks` object. /// /// Default = no perks; this is what unknown / not-yet-refreshed sessions see. #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct UserPerks { #[serde(default)] pub fan_plus: bool, #[serde(default)] pub is_creator: bool, #[serde(default)] pub creator_tier: Option, } #[derive(Clone, Debug, Serialize, Deserialize)] pub struct CreatorTierInfo { pub tier: String, pub features: Vec, } impl UserPerks { /// Canonical "should this user see + features" check. True for active Fan+ /// subscribers and for any creator (auto-grant: creators get + perks without /// paying for Fan+ separately). pub fn effective_plus(&self) -> bool { self.fan_plus || self.is_creator } } const SESSION_USER_ID: &str = "user_id"; const SESSION_USERNAME: &str = "username"; const SESSION_DISPLAY_NAME: &str = "display_name"; const SESSION_PERKS: &str = "perks"; const SESSION_ACCESS_TOKEN: &str = "mnw_access_token"; const SESSION_OAUTH_STATE: &str = "oauth_state"; const SESSION_PKCE_VERIFIER: &str = "pkce_verifier"; impl SessionUser { async fn from_session(session: &Session) -> Option { let user_id: uuid::Uuid = match session.get(SESSION_USER_ID).await { Ok(v) => v?, Err(e) => { tracing::warn!(error = %e, "failed to read user_id from session"); return None; } }; let username: String = match session.get(SESSION_USERNAME).await { Ok(v) => v?, Err(e) => { tracing::warn!(error = %e, "failed to read username from session"); return None; } }; let display_name: Option = match session.get(SESSION_DISPLAY_NAME).await { Ok(v) => v, Err(e) => { tracing::warn!(error = %e, "failed to read display_name from session"); None } }; // Perks default to empty — sessions predating the perks change still load. let perks: UserPerks = session .get(SESSION_PERKS) .await .unwrap_or_default() .unwrap_or_default(); Some(Self { user_id, username, display_name, perks, }) } async fn save_to_session(&self, session: &Session) { if let Err(e) = session.insert(SESSION_USER_ID, self.user_id).await { tracing::error!(error = %e, "failed to save user_id to session"); } if let Err(e) = session.insert(SESSION_USERNAME, &self.username).await { tracing::error!(error = %e, "failed to save username to session"); } if let Err(e) = session.insert(SESSION_DISPLAY_NAME, &self.display_name).await { tracing::error!(error = %e, "failed to save display_name to session"); } if let Err(e) = session.insert(SESSION_PERKS, &self.perks).await { tracing::error!(error = %e, "failed to save perks to session"); } } } /// Axum extractor that yields `Option`. pub struct MaybeUser(pub Option); impl FromRequestParts for MaybeUser { type Rejection = std::convert::Infallible; async fn from_request_parts( parts: &mut Parts, state: &AppState, ) -> Result { let session = Session::from_request_parts(parts, state) .await .expect("session layer missing"); Ok(MaybeUser(SessionUser::from_session(&session).await)) } } /// Axum extractor that requires the user to be the platform admin. /// Returns 404 to non-admins (hides admin routes). pub struct PlatformAdmin(pub SessionUser); impl FromRequestParts for PlatformAdmin { type Rejection = StatusCode; async fn from_request_parts( parts: &mut Parts, state: &AppState, ) -> Result { let session = Session::from_request_parts(parts, state) .await .expect("session layer missing"); let user = SessionUser::from_session(&session) .await .ok_or(StatusCode::NOT_FOUND)?; let admin_id = state.config.platform_admin_id.ok_or(StatusCode::NOT_FOUND)?; if user.user_id != admin_id { return Err(StatusCode::NOT_FOUND); } Ok(PlatformAdmin(user)) } } // ── OAuth callback types ── #[derive(Deserialize)] pub struct CallbackQuery { pub code: String, pub state: String, } #[derive(Deserialize)] struct TokenResponse { access_token: String, } #[derive(Deserialize)] struct UserinfoResponse { user_id: uuid::Uuid, username: String, display_name: Option, avatar_url: Option, #[serde(default)] perks: UserPerks, } #[derive(Debug)] pub enum UserinfoError { Unauthorized, Transport, BadResponse, } /// Single-attempt userinfo fetch against MNW. Callers decide retry policy. /// /// `Unauthorized` means the bearer token is invalid or the user is gone. /// `Transport` covers network and 5xx. `BadResponse` covers other 4xx and parse /// errors. The login callback retries on `Transport`; `refresh_session` does /// not — the client can retry. async fn fetch_userinfo( http: &reqwest::Client, base_url: &str, access_token: &str, ) -> Result { let url = format!("{}/oauth/userinfo", base_url); let res = http .get(&url) .bearer_auth(access_token) .send() .await .map_err(|e| { tracing::warn!(error = %e, "userinfo transport error"); UserinfoError::Transport })?; let status = res.status(); if status == reqwest::StatusCode::UNAUTHORIZED { return Err(UserinfoError::Unauthorized); } if status.is_server_error() { return Err(UserinfoError::Transport); } if !status.is_success() { let body = res.text().await.unwrap_or_default(); tracing::warn!(%status, %body, "userinfo non-success"); return Err(UserinfoError::BadResponse); } res.json::().await.map_err(|e| { tracing::warn!(error = %e, "userinfo parse failed"); UserinfoError::BadResponse }) } /// Refresh the cached perks for the current session by re-hitting MNW. /// /// Caller must have a logged-in session (access token stored at login). On /// `Unauthorized` the session is flushed — the access token is gone for good /// and the user needs to log in again. Other errors leave the session intact. pub async fn refresh_session( state: &AppState, session: &Session, ) -> Result { let token: String = session .get(SESSION_ACCESS_TOKEN) .await .unwrap_or(None) .ok_or(UserinfoError::Unauthorized)?; match fetch_userinfo(&state.http, &state.config.mnw_base_url, &token).await { Ok(info) => { if let Err(e) = session.insert(SESSION_PERKS, &info.perks).await { tracing::error!(error = %e, "failed to save refreshed perks"); } // Username/display can drift on MNW too — sync them while we're here. if let Err(e) = session.insert(SESSION_USERNAME, &info.username).await { tracing::error!(error = %e, "failed to save refreshed username"); } if let Err(e) = session.insert(SESSION_DISPLAY_NAME, &info.display_name).await { tracing::error!(error = %e, "failed to save refreshed display_name"); } // Mirror perks into users table so post rendering sees the change // without consulting MNW per-post. Best-effort: rendering tolerates // a stale row, so DB errors here are logged but non-fatal. if let Err(e) = sqlx::query( "UPDATE users SET is_fan_plus = $2, is_creator = $3 WHERE mnw_account_id = $1", ) .bind(info.user_id) .bind(info.perks.fan_plus) .bind(info.perks.is_creator) .execute(&state.db) .await { tracing::warn!(error = %e, "failed to mirror refreshed perks to users table"); } let _ = info.avatar_url; // not stored in session yet Ok(info.perks) } Err(UserinfoError::Unauthorized) => { // Token revoked, expired, or user deleted — drop the session. if let Err(e) = session.flush().await { tracing::warn!(error = %e, "failed to flush session after auth failure"); } Err(UserinfoError::Unauthorized) } Err(e) => Err(e), } } // ── Handlers ── /// `GET /auth/login` — redirect to MNW OAuth authorize endpoint. #[tracing::instrument(skip_all)] pub async fn login( State(state): State, session: Session, ) -> impl IntoResponse { let verifier = generate_verifier(); let challenge = pkce_challenge(&verifier); let oauth_state = generate_state_nonce(); if let Err(e) = session.insert(SESSION_PKCE_VERIFIER, &verifier).await { tracing::error!(error = %e, "failed to save PKCE verifier to session"); } if let Err(e) = session.insert(SESSION_OAUTH_STATE, &oauth_state).await { tracing::error!(error = %e, "failed to save OAuth state to session"); } let url = format!( "{}/oauth/authorize?response_type=code&client_id={}&redirect_uri={}&state={}&code_challenge={}&code_challenge_method=S256", state.config.mnw_base_url, urlencoding::encode(&state.config.oauth_client_id), urlencoding::encode(&state.config.oauth_redirect_uri), urlencoding::encode(&oauth_state), urlencoding::encode(&challenge), ); Redirect::to(&url) } /// `GET /auth/callback` — exchange code for token, fetch userinfo, create session. #[tracing::instrument(skip_all)] pub async fn callback( State(state): State, session: Session, Query(params): Query, ) -> impl IntoResponse { tracing::info!("OAuth callback received"); // Verify state nonce let stored_state: Option = session.get(SESSION_OAUTH_STATE).await.unwrap_or(None); if stored_state.as_deref() != Some(¶ms.state) { tracing::warn!(stored = ?stored_state, received = %params.state, "state mismatch"); return Redirect::to("/?error=state_mismatch"); } let verifier: String = match session.get(SESSION_PKCE_VERIFIER).await.unwrap_or(None) { Some(v) => v, None => { tracing::warn!("missing PKCE verifier in session"); return Redirect::to("/?error=missing_verifier"); } }; // Clean up OAuth session data if let Err(e) = session.remove::(SESSION_OAUTH_STATE).await { tracing::warn!(error = %e, "failed to remove OAuth state from session"); } if let Err(e) = session.remove::(SESSION_PKCE_VERIFIER).await { tracing::warn!(error = %e, "failed to remove PKCE verifier from session"); } // Exchange code for token (retry up to 2 attempts on network/5xx errors) let token_url = format!("{}/oauth/token", state.config.mnw_base_url); tracing::info!(%token_url, "exchanging code for token"); let backoffs = [ std::time::Duration::from_millis(500), std::time::Duration::from_millis(1000), ]; let mut token_res = None; for attempt in 0..=backoffs.len() { let res = state .http .post(&token_url) .json(&serde_json::json!({ "grant_type": "authorization_code", "code": params.code, "redirect_uri": state.config.oauth_redirect_uri, "code_verifier": verifier, "client_id": state.config.oauth_client_id, })) .send() .await; match res { Ok(r) if r.status().is_server_error() => { let status = r.status(); if attempt < backoffs.len() { tracing::warn!(%status, attempt, "token exchange got 5xx, retrying"); sleep(backoffs[attempt]).await; continue; } let body = r.text().await.unwrap_or_default(); tracing::error!(%status, %body, "token exchange failed after retries"); return Redirect::to("/?error=token_exchange_failed"); } Ok(r) if !r.status().is_success() => { let status = r.status(); let body = r.text().await.unwrap_or_default(); tracing::error!(%status, %body, "token exchange failed"); return Redirect::to("/?error=token_exchange_failed"); } Ok(r) => { token_res = Some(r); break; } Err(e) => { if attempt < backoffs.len() { tracing::warn!(error = %e, attempt, "token request failed, retrying"); sleep(backoffs[attempt]).await; continue; } tracing::error!(error = %e, "token request failed after retries"); return Redirect::to("/?error=token_request_failed"); } } } // Safety: loop always either sets token_res or returns early let token_res = token_res.unwrap(); let token: TokenResponse = match token_res.json().await { Ok(t) => t, Err(e) => { tracing::error!(error = %e, "token parse failed"); return Redirect::to("/?error=token_parse_failed"); } }; // Fetch userinfo (retry up to 2 attempts on transport / 5xx errors). tracing::info!(base_url = %state.config.mnw_base_url, "fetching userinfo"); let mut info: Option = None; for attempt in 0..=backoffs.len() { match fetch_userinfo(&state.http, &state.config.mnw_base_url, &token.access_token).await { Ok(i) => { info = Some(i); break; } Err(UserinfoError::Transport) if attempt < backoffs.len() => { tracing::warn!(attempt, "userinfo transport error, retrying"); sleep(backoffs[attempt]).await; continue; } Err(UserinfoError::Transport) => { tracing::error!("userinfo transport failed after retries"); return Redirect::to("/?error=userinfo_fetch_failed"); } Err(UserinfoError::Unauthorized) => { tracing::error!("userinfo unauthorized — token rejected"); return Redirect::to("/?error=userinfo_fetch_failed"); } Err(UserinfoError::BadResponse) => { tracing::error!("userinfo bad response"); return Redirect::to("/?error=userinfo_parse_failed"); } } } let info = info.expect("userinfo loop always sets value or returns"); tracing::info!(user_id = %info.user_id, username = %info.username, "OAuth login successful"); // Upsert local user. `is_fan_plus`/`is_creator` are denormalised here so // post rendering can look up the post author's perks via JOIN — see // migration 026. let upsert_result = sqlx::query( r#" INSERT INTO users (mnw_account_id, username, display_name, avatar_url, is_fan_plus, is_creator) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (mnw_account_id) DO UPDATE SET username = $2, display_name = $3, avatar_url = $4, is_fan_plus = $5, is_creator = $6, updated_at = now() "#, ) .bind(info.user_id) .bind(&info.username) .bind(&info.display_name) .bind(&info.avatar_url) .bind(info.perks.fan_plus) .bind(info.perks.is_creator) .execute(&state.db) .await; if let Err(e) = upsert_result { tracing::error!(error = %e, "user upsert failed"); return Redirect::to("/?error=user_upsert_failed"); } // Check if user is suspended (fail-closed: DB errors block login) let suspended: bool = match sqlx::query_scalar( "SELECT suspended_at IS NOT NULL FROM users WHERE mnw_account_id = $1", ) .bind(info.user_id) .fetch_one(&state.db) .await { Ok(v) => v, Err(e) => { tracing::error!(error = %e, "db error checking suspension status"); return Redirect::to("/?error=internal_error"); } }; if suspended { return Redirect::to("/?error=account_suspended"); } // Save session — perks come from the same userinfo response, no second roundtrip. let session_user = SessionUser { user_id: info.user_id, username: info.username, display_name: info.display_name, perks: info.perks, }; session_user.save_to_session(&session).await; // Stash the access token so `refresh_session` can re-hit userinfo without // forcing the user through another OAuth round trip. Token lifetime is set // by MNW (7d as of writing); after expiry, refresh returns Unauthorized and // the session is flushed. if let Err(e) = session.insert(SESSION_ACCESS_TOKEN, &token.access_token).await { tracing::error!(error = %e, "failed to save access token to session"); } if let Err(e) = session.cycle_id().await { tracing::warn!(error = %e, "Failed to cycle session ID"); } tracing::info!("session saved, redirecting to /"); Redirect::to("/") } /// `POST /auth/refresh` — re-fetch MNW userinfo and overwrite cached perks. /// /// Useful after the user takes an action that changed their MNW entitlements /// (e.g., subscribing to Fan+, upgrading a creator tier) so they don't have to /// log out and back in to see the new perks. Returns the refreshed perks as /// JSON. #[tracing::instrument(skip_all)] pub async fn refresh( State(state): State, session: Session, ) -> Result, StatusCode> { match refresh_session(&state, &session).await { Ok(perks) => Ok(Json(RefreshResponse { perks })), Err(UserinfoError::Unauthorized) => Err(StatusCode::UNAUTHORIZED), Err(UserinfoError::Transport) => Err(StatusCode::BAD_GATEWAY), Err(UserinfoError::BadResponse) => Err(StatusCode::BAD_GATEWAY), } } #[derive(Serialize)] pub struct RefreshResponse { pub perks: UserPerks, } /// `POST /auth/logout` — flush session, redirect home. #[tracing::instrument(skip_all)] pub async fn logout(session: Session) -> impl IntoResponse { if let Err(e) = session.flush().await { tracing::warn!(error = %e, "failed to flush session on logout"); } Redirect::to("/") }