//! OAuth client for "Log in with Makenot.work" and session user extraction. use axum::{ extract::{FromRequestParts, Query, State}, http::{request::Parts, StatusCode}, response::{IntoResponse, Redirect}, }; use base64::Engine; use rand::RngCore; use serde::Deserialize; use sha2::{Digest, Sha256}; use tower_sessions::Session; use crate::AppState; // ── PKCE helpers ── fn generate_verifier() -> String { let mut bytes = [0u8; 32]; rand::thread_rng().fill_bytes(&mut bytes); base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) } fn pkce_challenge(verifier: &str) -> String { let mut hasher = Sha256::new(); hasher.update(verifier.as_bytes()); let digest = hasher.finalize(); base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest) } fn generate_state_nonce() -> String { let mut bytes = [0u8; 16]; rand::thread_rng().fill_bytes(&mut bytes); hex::encode(bytes) } // ── Session user ── /// Minimal user info stored in the session after OAuth login. #[derive(Clone, Debug)] pub struct SessionUser { pub user_id: uuid::Uuid, pub username: String, pub display_name: Option, } const SESSION_USER_ID: &str = "user_id"; const SESSION_USERNAME: &str = "username"; const SESSION_DISPLAY_NAME: &str = "display_name"; const SESSION_OAUTH_STATE: &str = "oauth_state"; const SESSION_PKCE_VERIFIER: &str = "pkce_verifier"; impl SessionUser { async fn from_session(session: &Session) -> Option { let user_id: uuid::Uuid = session.get(SESSION_USER_ID).await.ok()??; let username: String = session.get(SESSION_USERNAME).await.ok()??; let display_name: Option = session.get(SESSION_DISPLAY_NAME).await.ok().flatten(); Some(Self { user_id, username, display_name, }) } async fn save_to_session(&self, session: &Session) { let _ = session.insert(SESSION_USER_ID, self.user_id).await; let _ = session.insert(SESSION_USERNAME, &self.username).await; let _ = session .insert(SESSION_DISPLAY_NAME, &self.display_name) .await; } } /// Axum extractor that yields `Option`. pub struct MaybeUser(pub Option); impl FromRequestParts for MaybeUser { type Rejection = std::convert::Infallible; async fn from_request_parts( parts: &mut Parts, state: &AppState, ) -> Result { let session = Session::from_request_parts(parts, state) .await .expect("session layer missing"); Ok(MaybeUser(SessionUser::from_session(&session).await)) } } /// Axum extractor that requires the user to be the platform admin. /// Returns 404 to non-admins (hides admin routes). pub struct PlatformAdmin(pub SessionUser); impl FromRequestParts for PlatformAdmin { type Rejection = StatusCode; async fn from_request_parts( parts: &mut Parts, state: &AppState, ) -> Result { let session = Session::from_request_parts(parts, state) .await .expect("session layer missing"); let user = SessionUser::from_session(&session) .await .ok_or(StatusCode::NOT_FOUND)?; let admin_id = state.config.platform_admin_id.ok_or(StatusCode::NOT_FOUND)?; if user.user_id != admin_id { return Err(StatusCode::NOT_FOUND); } Ok(PlatformAdmin(user)) } } // ── OAuth callback types ── #[derive(Deserialize)] pub struct CallbackQuery { pub code: String, pub state: String, } #[derive(Deserialize)] struct TokenResponse { access_token: String, } #[derive(Deserialize)] struct UserinfoResponse { user_id: uuid::Uuid, username: String, display_name: Option, avatar_url: Option, } // ── Handlers ── /// `GET /auth/login` — redirect to MNW OAuth authorize endpoint. #[tracing::instrument(skip_all)] pub async fn login( State(state): State, session: Session, ) -> impl IntoResponse { let verifier = generate_verifier(); let challenge = pkce_challenge(&verifier); let oauth_state = generate_state_nonce(); let _ = session.insert(SESSION_PKCE_VERIFIER, &verifier).await; let _ = session.insert(SESSION_OAUTH_STATE, &oauth_state).await; let url = format!( "{}/oauth/authorize?response_type=code&client_id={}&redirect_uri={}&state={}&code_challenge={}&code_challenge_method=S256", state.config.mnw_base_url, urlencoding::encode(&state.config.oauth_client_id), urlencoding::encode(&state.config.oauth_redirect_uri), urlencoding::encode(&oauth_state), urlencoding::encode(&challenge), ); Redirect::to(&url) } /// `GET /auth/callback` — exchange code for token, fetch userinfo, create session. #[tracing::instrument(skip_all)] pub async fn callback( State(state): State, session: Session, Query(params): Query, ) -> impl IntoResponse { tracing::info!("OAuth callback received"); // Verify state nonce let stored_state: Option = session.get(SESSION_OAUTH_STATE).await.unwrap_or(None); if stored_state.as_deref() != Some(¶ms.state) { tracing::warn!(stored = ?stored_state, received = %params.state, "state mismatch"); return Redirect::to("/?error=state_mismatch"); } let verifier: String = match session.get(SESSION_PKCE_VERIFIER).await.unwrap_or(None) { Some(v) => v, None => { tracing::warn!("missing PKCE verifier in session"); return Redirect::to("/?error=missing_verifier"); } }; // Clean up OAuth session data let _ = session.remove::(SESSION_OAUTH_STATE).await; let _ = session.remove::(SESSION_PKCE_VERIFIER).await; // Exchange code for token let token_url = format!("{}/oauth/token", state.config.mnw_base_url); tracing::info!(%token_url, "exchanging code for token"); let token_res = state .http .post(&token_url) .json(&serde_json::json!({ "grant_type": "authorization_code", "code": params.code, "redirect_uri": state.config.oauth_redirect_uri, "code_verifier": verifier, "client_id": state.config.oauth_client_id, })) .send() .await; let token_res = match token_res { Ok(r) => r, Err(e) => { tracing::error!(error = %e, "token request failed"); return Redirect::to("/?error=token_request_failed"); } }; if !token_res.status().is_success() { let status = token_res.status(); let body = token_res.text().await.unwrap_or_default(); tracing::error!(%status, %body, "token exchange failed"); return Redirect::to("/?error=token_exchange_failed"); } let token: TokenResponse = match token_res.json().await { Ok(t) => t, Err(e) => { tracing::error!(error = %e, "token parse failed"); return Redirect::to("/?error=token_parse_failed"); } }; // Fetch userinfo let userinfo_url = format!("{}/oauth/userinfo", state.config.mnw_base_url); tracing::info!(%userinfo_url, "fetching userinfo"); let userinfo_res = state .http .get(&userinfo_url) .bearer_auth(&token.access_token) .send() .await; let userinfo_res = match userinfo_res { Ok(r) => r, Err(e) => { tracing::error!(error = %e, "userinfo request failed"); return Redirect::to("/?error=userinfo_request_failed"); } }; if !userinfo_res.status().is_success() { let status = userinfo_res.status(); let body = userinfo_res.text().await.unwrap_or_default(); tracing::error!(%status, %body, "userinfo fetch failed"); return Redirect::to("/?error=userinfo_fetch_failed"); } let info: UserinfoResponse = match userinfo_res.json().await { Ok(i) => i, Err(e) => { tracing::error!(error = %e, "userinfo parse failed"); return Redirect::to("/?error=userinfo_parse_failed"); } }; tracing::info!(user_id = %info.user_id, username = %info.username, "OAuth login successful"); // Upsert local user let upsert_result = sqlx::query( r#" INSERT INTO users (mnw_account_id, username, display_name, avatar_url) VALUES ($1, $2, $3, $4) ON CONFLICT (mnw_account_id) DO UPDATE SET username = $2, display_name = $3, avatar_url = $4, updated_at = now() "#, ) .bind(info.user_id) .bind(&info.username) .bind(&info.display_name) .bind(&info.avatar_url) .execute(&state.db) .await; if let Err(e) = upsert_result { tracing::error!(error = %e, "user upsert failed"); return Redirect::to("/?error=user_upsert_failed"); } // Check if user is suspended (fail-closed: DB errors block login) let suspended: bool = match sqlx::query_scalar( "SELECT suspended_at IS NOT NULL FROM users WHERE mnw_account_id = $1", ) .bind(info.user_id) .fetch_one(&state.db) .await { Ok(v) => v, Err(e) => { tracing::error!(error = %e, "db error checking suspension status"); return Redirect::to("/?error=internal_error"); } }; if suspended { return Redirect::to("/?error=account_suspended"); } // Save session let session_user = SessionUser { user_id: info.user_id, username: info.username, display_name: info.display_name, }; session_user.save_to_session(&session).await; if let Err(e) = session.cycle_id().await { tracing::warn!(error = %e, "Failed to cycle session ID"); } tracing::info!("session saved, redirecting to /"); Redirect::to("/") } /// `POST /auth/logout` — flush session, redirect home. #[tracing::instrument(skip_all)] pub async fn logout(session: Session) -> impl IntoResponse { let _ = session.flush().await; Redirect::to("/") }