Skip to main content

max / makenotwork

7.0 KB · 185 lines History Blame Raw
1 //! Delegated login: "Sign in with Makenot.work" (OAuth client side).
2 //!
3 //! Used on the testnot.work staging mirror. Instead of a local password form,
4 //! the login page redirects to an upstream MNW provider (production), where the
5 //! user authenticates — so a password is only ever entered on the real site.
6 //! On callback we exchange the code for the provider's response, take the
7 //! verified `user_id`, look that user up in our own (mirrored) DB, and start a
8 //! local session. The provider's OAuth flow is the SyncKit one
9 //! (`src/routes/oauth.rs`); we discard its sync token and use only `user_id`.
10 //!
11 //! Active only when `[sso]` is configured (the three `SSO_*` vars). Routes are
12 //! allowlisted in the access gate so an unauthenticated visitor can reach them.
13
14 use axum::{
15 extract::{Query, State},
16 http::HeaderMap,
17 response::{IntoResponse, Redirect, Response},
18 routing::get,
19 Router,
20 };
21 use base64::Engine;
22 use rand::RngCore;
23 use serde::Deserialize;
24 use sha2::{Digest, Sha256};
25 use tower_sessions::Session;
26
27 use crate::{
28 auth::{login_user, track_session, SessionUser},
29 db::{self, UserId},
30 error::{AppError, Result},
31 AppState,
32 };
33
34 const SSO_STATE_KEY: &str = "sso_state";
35 const SSO_VERIFIER_KEY: &str = "sso_pkce_verifier";
36
37 fn b64url(data: &[u8]) -> String {
38 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
39 }
40
41 /// GET /sso/login — begin the delegated-login flow.
42 ///
43 /// Generates PKCE + state, stashes them in the session, and redirects to the
44 /// provider's authorize endpoint. No-op (404-ish redirect home) when SSO is off.
45 #[tracing::instrument(skip_all, name = "sso::login")]
46 async fn sso_login(State(state): State<AppState>, session: Session) -> Result<Response> {
47 let Some(sso) = state.config.sso.as_ref() else {
48 // SSO not configured — nothing to delegate to.
49 return Ok(Redirect::to("/login").into_response());
50 };
51
52 // PKCE verifier (43-char base64url of 32 random bytes) + S256 challenge.
53 let mut vbytes = [0u8; 32];
54 rand::rng().fill_bytes(&mut vbytes);
55 let verifier = b64url(&vbytes);
56 let challenge = b64url(Sha256::digest(verifier.as_bytes()).as_ref());
57
58 // CSRF-style state to bind the callback to this session.
59 let mut sbytes = [0u8; 16];
60 rand::rng().fill_bytes(&mut sbytes);
61 let state_param = b64url(&sbytes);
62
63 session.insert(SSO_STATE_KEY, &state_param).await.map_err(|e| AppError::Internal(e.into()))?;
64 session.insert(SSO_VERIFIER_KEY, &verifier).await.map_err(|e| AppError::Internal(e.into()))?;
65
66 let redirect_uri = format!("{}/sso/callback", state.config.host_url);
67 let authorize = format!(
68 "{}/oauth/authorize?response_type=code&client_id={}&redirect_uri={}&state={}&code_challenge={}&code_challenge_method=S256",
69 sso.provider_url,
70 urlencoding::encode(&sso.client_id),
71 urlencoding::encode(&redirect_uri),
72 urlencoding::encode(&state_param),
73 urlencoding::encode(&challenge),
74 );
75 Ok(Redirect::to(&authorize).into_response())
76 }
77
78 #[derive(Deserialize)]
79 struct CallbackQuery {
80 code: Option<String>,
81 state: Option<String>,
82 error: Option<String>,
83 }
84
85 /// Minimal view of the provider's token response — we only need the user id.
86 #[derive(Deserialize)]
87 struct TokenResponse {
88 user_id: UserId,
89 }
90
91 /// GET /sso/callback — provider redirected back with `code` + `state`.
92 #[tracing::instrument(skip_all, name = "sso::callback")]
93 async fn sso_callback(
94 State(state): State<AppState>,
95 session: Session,
96 headers: HeaderMap,
97 Query(q): Query<CallbackQuery>,
98 ) -> Result<Response> {
99 let Some(sso) = state.config.sso.as_ref() else {
100 return Ok(Redirect::to("/login").into_response());
101 };
102
103 let fail = |msg: &str| Ok(Redirect::to(&format!("/login?sso_error={}", urlencoding::encode(msg))).into_response());
104
105 if let Some(err) = q.error.as_deref() {
106 tracing::warn!(error = %err, "sso provider returned error");
107 return fail("Sign-in was cancelled or denied.");
108 }
109 let (Some(code), Some(returned_state)) = (q.code.as_deref(), q.state.as_deref()) else {
110 return fail("Sign-in response was incomplete. Please try again.");
111 };
112
113 // Validate state against the session, and consume the one-shot PKCE values.
114 let expected_state: Option<String> = session.get(SSO_STATE_KEY).await.ok().flatten();
115 let verifier: Option<String> = session.get(SSO_VERIFIER_KEY).await.ok().flatten();
116 let _ = session.remove::<String>(SSO_STATE_KEY).await;
117 let _ = session.remove::<String>(SSO_VERIFIER_KEY).await;
118
119 let (Some(expected_state), Some(verifier)) = (expected_state, verifier) else {
120 return fail("Your sign-in session expired. Please try again.");
121 };
122 if !crate::helpers::constant_time_compare(&expected_state, returned_state) {
123 tracing::warn!("sso state mismatch");
124 return fail("Sign-in could not be verified. Please try again.");
125 }
126
127 // Exchange the code at the provider's token endpoint.
128 let redirect_uri = format!("{}/sso/callback", state.config.host_url);
129 let resp = reqwest::Client::new()
130 .post(format!("{}/oauth/token", sso.provider_url))
131 .timeout(std::time::Duration::from_secs(10))
132 .form(&[
133 ("grant_type", "authorization_code"),
134 ("code", code),
135 ("redirect_uri", &redirect_uri),
136 ("code_verifier", &verifier),
137 ("client_id", &sso.client_id),
138 ("key", &sso.key),
139 ])
140 .send()
141 .await;
142
143 let resp = match resp {
144 Ok(r) if r.status().is_success() => r,
145 Ok(r) => {
146 tracing::warn!(status = %r.status(), "sso token exchange rejected");
147 return fail("Sign-in failed at the provider. Please try again.");
148 }
149 Err(e) => {
150 tracing::warn!(error = ?e, "sso token exchange request failed");
151 return fail("Could not reach the sign-in provider. Please try again.");
152 }
153 };
154
155 let token: TokenResponse = match resp.json().await {
156 Ok(t) => t,
157 Err(e) => {
158 tracing::warn!(error = ?e, "sso token response parse failed");
159 return fail("Sign-in failed at the provider. Please try again.");
160 }
161 };
162
163 // Map the verified provider user id onto our mirrored account.
164 let db_user = match db::users::get_user_by_id(&state.db, token.user_id).await? {
165 Some(u) => u,
166 None => return fail("Your account isn't in the preview yet — it syncs from production daily."),
167 };
168 if db_user.is_suspended() || db_user.is_deactivated() {
169 return fail("This account is not active.");
170 }
171
172 let user_id = db_user.id;
173 let session_user = SessionUser::from_db_user(db_user, &state.db, state.config.admin_user_id).await;
174 login_user(&session, session_user).await?;
175 track_session(&session, &state.db, user_id, &headers).await?;
176
177 Ok(Redirect::to("/").into_response())
178 }
179
180 pub fn sso_routes() -> Router<AppState> {
181 Router::new()
182 .route("/sso/login", get(sso_login))
183 .route("/sso/callback", get(sso_callback))
184 }
185