Skip to main content

max / multithreaded

10.0 KB · 329 lines History Blame Raw
1 //! OAuth client for "Log in with Makenot.work" and session user extraction.
2
3 use axum::{
4 extract::{FromRequestParts, Query, State},
5 http::{request::Parts, StatusCode},
6 response::{IntoResponse, Redirect},
7 };
8 use base64::Engine;
9 use rand::RngCore;
10 use serde::Deserialize;
11 use sha2::{Digest, Sha256};
12 use tower_sessions::Session;
13
14 use crate::AppState;
15
16 // ── PKCE helpers ──
17
18 fn generate_verifier() -> String {
19 let mut bytes = [0u8; 32];
20 rand::thread_rng().fill_bytes(&mut bytes);
21 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
22 }
23
24 fn pkce_challenge(verifier: &str) -> String {
25 let mut hasher = Sha256::new();
26 hasher.update(verifier.as_bytes());
27 let digest = hasher.finalize();
28 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
29 }
30
31 fn generate_state_nonce() -> String {
32 let mut bytes = [0u8; 16];
33 rand::thread_rng().fill_bytes(&mut bytes);
34 hex::encode(bytes)
35 }
36
37 // ── Session user ──
38
39 /// Minimal user info stored in the session after OAuth login.
40 #[derive(Clone, Debug)]
41 pub struct SessionUser {
42 pub user_id: uuid::Uuid,
43 pub username: String,
44 pub display_name: Option<String>,
45 }
46
47 const SESSION_USER_ID: &str = "user_id";
48 const SESSION_USERNAME: &str = "username";
49 const SESSION_DISPLAY_NAME: &str = "display_name";
50 const SESSION_OAUTH_STATE: &str = "oauth_state";
51 const SESSION_PKCE_VERIFIER: &str = "pkce_verifier";
52
53 impl SessionUser {
54 async fn from_session(session: &Session) -> Option<Self> {
55 let user_id: uuid::Uuid = session.get(SESSION_USER_ID).await.ok()??;
56 let username: String = session.get(SESSION_USERNAME).await.ok()??;
57 let display_name: Option<String> = session.get(SESSION_DISPLAY_NAME).await.ok().flatten();
58 Some(Self {
59 user_id,
60 username,
61 display_name,
62 })
63 }
64
65 async fn save_to_session(&self, session: &Session) {
66 let _ = session.insert(SESSION_USER_ID, self.user_id).await;
67 let _ = session.insert(SESSION_USERNAME, &self.username).await;
68 let _ = session
69 .insert(SESSION_DISPLAY_NAME, &self.display_name)
70 .await;
71 }
72 }
73
74 /// Axum extractor that yields `Option<SessionUser>`.
75 pub struct MaybeUser(pub Option<SessionUser>);
76
77 impl FromRequestParts<AppState> for MaybeUser {
78 type Rejection = std::convert::Infallible;
79
80 async fn from_request_parts(
81 parts: &mut Parts,
82 state: &AppState,
83 ) -> Result<Self, Self::Rejection> {
84 let session = Session::from_request_parts(parts, state)
85 .await
86 .expect("session layer missing");
87 Ok(MaybeUser(SessionUser::from_session(&session).await))
88 }
89 }
90
91 /// Axum extractor that requires the user to be the platform admin.
92 /// Returns 404 to non-admins (hides admin routes).
93 pub struct PlatformAdmin(pub SessionUser);
94
95 impl FromRequestParts<AppState> for PlatformAdmin {
96 type Rejection = StatusCode;
97
98 async fn from_request_parts(
99 parts: &mut Parts,
100 state: &AppState,
101 ) -> Result<Self, Self::Rejection> {
102 let session = Session::from_request_parts(parts, state)
103 .await
104 .expect("session layer missing");
105 let user = SessionUser::from_session(&session)
106 .await
107 .ok_or(StatusCode::NOT_FOUND)?;
108
109 let admin_id = state.config.platform_admin_id.ok_or(StatusCode::NOT_FOUND)?;
110 if user.user_id != admin_id {
111 return Err(StatusCode::NOT_FOUND);
112 }
113
114 Ok(PlatformAdmin(user))
115 }
116 }
117
118 // ── OAuth callback types ──
119
120 #[derive(Deserialize)]
121 pub struct CallbackQuery {
122 pub code: String,
123 pub state: String,
124 }
125
126 #[derive(Deserialize)]
127 struct TokenResponse {
128 access_token: String,
129 }
130
131 #[derive(Deserialize)]
132 struct UserinfoResponse {
133 user_id: uuid::Uuid,
134 username: String,
135 display_name: Option<String>,
136 avatar_url: Option<String>,
137 }
138
139 // ── Handlers ──
140
141 /// `GET /auth/login` — redirect to MNW OAuth authorize endpoint.
142 #[tracing::instrument(skip_all)]
143 pub async fn login(
144 State(state): State<AppState>,
145 session: Session,
146 ) -> impl IntoResponse {
147 let verifier = generate_verifier();
148 let challenge = pkce_challenge(&verifier);
149 let oauth_state = generate_state_nonce();
150
151 let _ = session.insert(SESSION_PKCE_VERIFIER, &verifier).await;
152 let _ = session.insert(SESSION_OAUTH_STATE, &oauth_state).await;
153
154 let url = format!(
155 "{}/oauth/authorize?response_type=code&client_id={}&redirect_uri={}&state={}&code_challenge={}&code_challenge_method=S256",
156 state.config.mnw_base_url,
157 urlencoding::encode(&state.config.oauth_client_id),
158 urlencoding::encode(&state.config.oauth_redirect_uri),
159 urlencoding::encode(&oauth_state),
160 urlencoding::encode(&challenge),
161 );
162
163 Redirect::to(&url)
164 }
165
166 /// `GET /auth/callback` — exchange code for token, fetch userinfo, create session.
167 #[tracing::instrument(skip_all)]
168 pub async fn callback(
169 State(state): State<AppState>,
170 session: Session,
171 Query(params): Query<CallbackQuery>,
172 ) -> impl IntoResponse {
173 tracing::info!("OAuth callback received");
174
175 // Verify state nonce
176 let stored_state: Option<String> = session.get(SESSION_OAUTH_STATE).await.unwrap_or(None);
177 if stored_state.as_deref() != Some(&params.state) {
178 tracing::warn!(stored = ?stored_state, received = %params.state, "state mismatch");
179 return Redirect::to("/?error=state_mismatch");
180 }
181
182 let verifier: String = match session.get(SESSION_PKCE_VERIFIER).await.unwrap_or(None) {
183 Some(v) => v,
184 None => {
185 tracing::warn!("missing PKCE verifier in session");
186 return Redirect::to("/?error=missing_verifier");
187 }
188 };
189
190 // Clean up OAuth session data
191 let _ = session.remove::<String>(SESSION_OAUTH_STATE).await;
192 let _ = session.remove::<String>(SESSION_PKCE_VERIFIER).await;
193
194 // Exchange code for token
195 let token_url = format!("{}/oauth/token", state.config.mnw_base_url);
196 tracing::info!(%token_url, "exchanging code for token");
197 let token_res = state
198 .http
199 .post(&token_url)
200 .json(&serde_json::json!({
201 "grant_type": "authorization_code",
202 "code": params.code,
203 "redirect_uri": state.config.oauth_redirect_uri,
204 "code_verifier": verifier,
205 "client_id": state.config.oauth_client_id,
206 }))
207 .send()
208 .await;
209
210 let token_res = match token_res {
211 Ok(r) => r,
212 Err(e) => {
213 tracing::error!(error = %e, "token request failed");
214 return Redirect::to("/?error=token_request_failed");
215 }
216 };
217
218 if !token_res.status().is_success() {
219 let status = token_res.status();
220 let body = token_res.text().await.unwrap_or_default();
221 tracing::error!(%status, %body, "token exchange failed");
222 return Redirect::to("/?error=token_exchange_failed");
223 }
224
225 let token: TokenResponse = match token_res.json().await {
226 Ok(t) => t,
227 Err(e) => {
228 tracing::error!(error = %e, "token parse failed");
229 return Redirect::to("/?error=token_parse_failed");
230 }
231 };
232
233 // Fetch userinfo
234 let userinfo_url = format!("{}/oauth/userinfo", state.config.mnw_base_url);
235 tracing::info!(%userinfo_url, "fetching userinfo");
236 let userinfo_res = state
237 .http
238 .get(&userinfo_url)
239 .bearer_auth(&token.access_token)
240 .send()
241 .await;
242
243 let userinfo_res = match userinfo_res {
244 Ok(r) => r,
245 Err(e) => {
246 tracing::error!(error = %e, "userinfo request failed");
247 return Redirect::to("/?error=userinfo_request_failed");
248 }
249 };
250
251 if !userinfo_res.status().is_success() {
252 let status = userinfo_res.status();
253 let body = userinfo_res.text().await.unwrap_or_default();
254 tracing::error!(%status, %body, "userinfo fetch failed");
255 return Redirect::to("/?error=userinfo_fetch_failed");
256 }
257
258 let info: UserinfoResponse = match userinfo_res.json().await {
259 Ok(i) => i,
260 Err(e) => {
261 tracing::error!(error = %e, "userinfo parse failed");
262 return Redirect::to("/?error=userinfo_parse_failed");
263 }
264 };
265
266 tracing::info!(user_id = %info.user_id, username = %info.username, "OAuth login successful");
267
268 // Upsert local user
269 let upsert_result = sqlx::query(
270 r#"
271 INSERT INTO users (mnw_account_id, username, display_name, avatar_url)
272 VALUES ($1, $2, $3, $4)
273 ON CONFLICT (mnw_account_id) DO UPDATE
274 SET username = $2, display_name = $3, avatar_url = $4, updated_at = now()
275 "#,
276 )
277 .bind(info.user_id)
278 .bind(&info.username)
279 .bind(&info.display_name)
280 .bind(&info.avatar_url)
281 .execute(&state.db)
282 .await;
283
284 if let Err(e) = upsert_result {
285 tracing::error!(error = %e, "user upsert failed");
286 return Redirect::to("/?error=user_upsert_failed");
287 }
288
289 // Check if user is suspended (fail-closed: DB errors block login)
290 let suspended: bool = match sqlx::query_scalar(
291 "SELECT suspended_at IS NOT NULL FROM users WHERE mnw_account_id = $1",
292 )
293 .bind(info.user_id)
294 .fetch_one(&state.db)
295 .await
296 {
297 Ok(v) => v,
298 Err(e) => {
299 tracing::error!(error = %e, "db error checking suspension status");
300 return Redirect::to("/?error=internal_error");
301 }
302 };
303
304 if suspended {
305 return Redirect::to("/?error=account_suspended");
306 }
307
308 // Save session
309 let session_user = SessionUser {
310 user_id: info.user_id,
311 username: info.username,
312 display_name: info.display_name,
313 };
314 session_user.save_to_session(&session).await;
315 if let Err(e) = session.cycle_id().await {
316 tracing::warn!(error = %e, "Failed to cycle session ID");
317 }
318 tracing::info!("session saved, redirecting to /");
319
320 Redirect::to("/")
321 }
322
323 /// `POST /auth/logout` — flush session, redirect home.
324 #[tracing::instrument(skip_all)]
325 pub async fn logout(session: Session) -> impl IntoResponse {
326 let _ = session.flush().await;
327 Redirect::to("/")
328 }
329