| 1 |
|
| 2 |
|
| 3 |
use crate::harness::TestHarness; |
| 4 |
use makenotwork::db::{SyncAppId, UserId}; |
| 5 |
use serde::Deserialize; |
| 6 |
use sha2::{Digest, Sha256}; |
| 7 |
use sqlx::PgPool; |
| 8 |
|
| 9 |
|
| 10 |
|
| 11 |
#[derive(Deserialize)] |
| 12 |
struct TokenResponse { |
| 13 |
access_token: String, |
| 14 |
token_type: String, |
| 15 |
expires_in: i64, |
| 16 |
user_id: UserId, |
| 17 |
app_id: SyncAppId, |
| 18 |
} |
| 19 |
|
| 20 |
|
| 21 |
|
| 22 |
|
| 23 |
async fn create_sync_app(pool: &PgPool, user_id: UserId) -> (SyncAppId, String) { |
| 24 |
let api_key = "test-oauth-client-id"; |
| 25 |
let key_hash = crate::harness::hash_api_key(api_key); |
| 26 |
let key_prefix = &api_key[..8]; |
| 27 |
let app_id: SyncAppId = sqlx::query_scalar( |
| 28 |
"INSERT INTO sync_apps (creator_id, name, api_key_hash, api_key_prefix) VALUES ($1, 'OAuth Test App', $2, $3) RETURNING id", |
| 29 |
) |
| 30 |
.bind(user_id) |
| 31 |
.bind(&key_hash) |
| 32 |
.bind(key_prefix) |
| 33 |
.fetch_one(pool) |
| 34 |
.await |
| 35 |
.expect("Failed to create sync app"); |
| 36 |
|
| 37 |
(app_id, api_key.to_string()) |
| 38 |
} |
| 39 |
|
| 40 |
|
| 41 |
fn generate_pkce() -> (String, String) { |
| 42 |
|
| 43 |
let verifier: String = (0u32..64) |
| 44 |
.map(|i| { |
| 45 |
let idx = ((i * 7 + 3) % 26) as u8; |
| 46 |
(b'A' + idx) as char |
| 47 |
}) |
| 48 |
.collect(); |
| 49 |
|
| 50 |
let mut hasher = Sha256::new(); |
| 51 |
hasher.update(verifier.as_bytes()); |
| 52 |
let digest = hasher.finalize(); |
| 53 |
|
| 54 |
use base64::Engine; |
| 55 |
let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest); |
| 56 |
|
| 57 |
(verifier, challenge) |
| 58 |
} |
| 59 |
|
| 60 |
|
| 61 |
fn extract_code_from_redirect(location: &str) -> (String, String) { |
| 62 |
let url = url::Url::parse(location).expect("Invalid redirect URL"); |
| 63 |
let mut code = String::new(); |
| 64 |
let mut state = String::new(); |
| 65 |
|
| 66 |
for (key, value) in url.query_pairs() { |
| 67 |
match key.as_ref() { |
| 68 |
"code" => code = value.to_string(), |
| 69 |
"state" => state = value.to_string(), |
| 70 |
_ => {} |
| 71 |
} |
| 72 |
} |
| 73 |
|
| 74 |
assert!(!code.is_empty(), "No code in redirect: {}", location); |
| 75 |
(code, state) |
| 76 |
} |
| 77 |
|
| 78 |
|
| 79 |
async fn authorize( |
| 80 |
h: &mut TestHarness, |
| 81 |
client_id: &str, |
| 82 |
code_challenge: &str, |
| 83 |
username: &str, |
| 84 |
password: &str, |
| 85 |
) -> (String, String) { |
| 86 |
let state_param = "test-state-12345"; |
| 87 |
let redirect_uri = "http://127.0.0.1:9999/callback"; |
| 88 |
|
| 89 |
|
| 90 |
let resp = h |
| 91 |
.client |
| 92 |
.get(&format!( |
| 93 |
"/oauth/authorize?response_type=code&client_id={}&redirect_uri={}&state={}&code_challenge={}&code_challenge_method=S256", |
| 94 |
urlencoding::encode(client_id), |
| 95 |
urlencoding::encode(redirect_uri), |
| 96 |
state_param, |
| 97 |
code_challenge, |
| 98 |
)) |
| 99 |
.await; |
| 100 |
assert_eq!(resp.status.as_u16(), 200, "Authorize page failed: {}", resp.text); |
| 101 |
|
| 102 |
|
| 103 |
let csrf = h |
| 104 |
.client |
| 105 |
.csrf_token() |
| 106 |
.expect("No CSRF token after loading authorize page") |
| 107 |
.to_string(); |
| 108 |
|
| 109 |
|
| 110 |
let body = format!( |
| 111 |
"client_id={}&redirect_uri={}&state={}&code_challenge={}&code_challenge_method=S256&login={}&password={}&_csrf={}", |
| 112 |
urlencoding::encode(client_id), |
| 113 |
urlencoding::encode(redirect_uri), |
| 114 |
state_param, |
| 115 |
code_challenge, |
| 116 |
urlencoding::encode(username), |
| 117 |
urlencoding::encode(password), |
| 118 |
urlencoding::encode(&csrf), |
| 119 |
); |
| 120 |
|
| 121 |
let resp = h.client.post_form("/oauth/authorize", &body).await; |
| 122 |
assert!( |
| 123 |
resp.status.is_redirection(), |
| 124 |
"Expected redirect after authorize POST, got {}: {}", |
| 125 |
resp.status, |
| 126 |
resp.text |
| 127 |
); |
| 128 |
|
| 129 |
let location = resp.header("location").expect("No Location header on redirect"); |
| 130 |
extract_code_from_redirect(location) |
| 131 |
} |
| 132 |
|
| 133 |
|
| 134 |
|
| 135 |
#[tokio::test] |
| 136 |
async fn oauth_full_flow() { |
| 137 |
let mut h = TestHarness::new().await; |
| 138 |
let user_id = h.signup("oauthuser", "oauthuser@test.com", "Password1!").await; |
| 139 |
|
| 140 |
h.client.post_form("/logout", "").await; |
| 141 |
|
| 142 |
let (app_id, client_id) = create_sync_app(&h.db, user_id).await; |
| 143 |
let (verifier, challenge) = generate_pkce(); |
| 144 |
|
| 145 |
let (code, state) = authorize(&mut h, &client_id, &challenge, "oauthuser", "Password1!").await; |
| 146 |
assert_eq!(state, "test-state-12345"); |
| 147 |
|
| 148 |
|
| 149 |
let resp = h |
| 150 |
.client |
| 151 |
.post_form( |
| 152 |
"/oauth/token", |
| 153 |
&format!( |
| 154 |
"grant_type=authorization_code&code={}&redirect_uri={}&code_verifier={}&client_id={}&key=test-session-key", |
| 155 |
code, |
| 156 |
urlencoding::encode("http://127.0.0.1:9999/callback"), |
| 157 |
verifier, |
| 158 |
client_id, |
| 159 |
), |
| 160 |
) |
| 161 |
.await; |
| 162 |
assert_eq!(resp.status.as_u16(), 200, "Token exchange failed: {}", resp.text); |
| 163 |
|
| 164 |
let token: TokenResponse = resp.json(); |
| 165 |
assert!(!token.access_token.is_empty()); |
| 166 |
assert_eq!(token.token_type, "Bearer"); |
| 167 |
assert!(token.expires_in > 0); |
| 168 |
assert_eq!(token.user_id, user_id); |
| 169 |
assert_eq!(token.app_id, app_id); |
| 170 |
} |
| 171 |
|
| 172 |
#[tokio::test] |
| 173 |
async fn oauth_pkce_wrong_verifier() { |
| 174 |
let mut h = TestHarness::new().await; |
| 175 |
let user_id = h.signup("oauthpkce", "oauthpkce@test.com", "Password1!").await; |
| 176 |
h.client.post_form("/logout", "").await; |
| 177 |
|
| 178 |
let (_app_id, client_id) = create_sync_app(&h.db, user_id).await; |
| 179 |
let (_verifier, challenge) = generate_pkce(); |
| 180 |
|
| 181 |
let (code, _) = authorize(&mut h, &client_id, &challenge, "oauthpkce", "Password1!").await; |
| 182 |
|
| 183 |
|
| 184 |
let resp = h |
| 185 |
.client |
| 186 |
.post_form( |
| 187 |
"/oauth/token", |
| 188 |
&format!( |
| 189 |
"grant_type=authorization_code&code={}&redirect_uri={}&code_verifier=this-is-the-wrong-verifier-and-should-fail&client_id={}&key=test-session-key", |
| 190 |
code, |
| 191 |
urlencoding::encode("http://127.0.0.1:9999/callback"), |
| 192 |
client_id, |
| 193 |
), |
| 194 |
) |
| 195 |
.await; |
| 196 |
assert_eq!(resp.status.as_u16(), 400, "Wrong PKCE verifier should be rejected"); |
| 197 |
} |
| 198 |
|
| 199 |
#[tokio::test] |
| 200 |
async fn oauth_code_single_use() { |
| 201 |
let mut h = TestHarness::new().await; |
| 202 |
let user_id = h.signup("oauthonce", "oauthonce@test.com", "Password1!").await; |
| 203 |
h.client.post_form("/logout", "").await; |
| 204 |
|
| 205 |
let (_app_id, client_id) = create_sync_app(&h.db, user_id).await; |
| 206 |
let (verifier, challenge) = generate_pkce(); |
| 207 |
|
| 208 |
let (code, _) = authorize(&mut h, &client_id, &challenge, "oauthonce", "Password1!").await; |
| 209 |
|
| 210 |
let token_body = format!( |
| 211 |
"grant_type=authorization_code&code={}&redirect_uri={}&code_verifier={}&client_id={}&key=test-session-key", |
| 212 |
code, |
| 213 |
urlencoding::encode("http://127.0.0.1:9999/callback"), |
| 214 |
verifier, |
| 215 |
client_id, |
| 216 |
); |
| 217 |
|
| 218 |
|
| 219 |
let resp = h.client.post_form("/oauth/token", &token_body).await; |
| 220 |
assert_eq!(resp.status.as_u16(), 200, "First token exchange failed: {}", resp.text); |
| 221 |
|
| 222 |
|
| 223 |
let resp = h.client.post_form("/oauth/token", &token_body).await; |
| 224 |
assert_eq!(resp.status.as_u16(), 400, "Reused auth code should be rejected"); |
| 225 |
} |
| 226 |
|
| 227 |
#[tokio::test] |
| 228 |
async fn oauth_invalid_client_id() { |
| 229 |
let mut h = TestHarness::new().await; |
| 230 |
h.signup("oauthbad", "oauthbad@test.com", "Password1!").await; |
| 231 |
|
| 232 |
let resp = h |
| 233 |
.client |
| 234 |
.get("/oauth/authorize?response_type=code&client_id=nonexistent-app&redirect_uri=http://127.0.0.1:9999/callback&state=x&code_challenge=abc&code_challenge_method=S256") |
| 235 |
.await; |
| 236 |
assert_eq!(resp.status.as_u16(), 400, "Invalid client_id should return 400"); |
| 237 |
} |
| 238 |
|
| 239 |
#[tokio::test] |
| 240 |
async fn oauth_invalid_credentials() { |
| 241 |
let mut h = TestHarness::new().await; |
| 242 |
let user_id = h.signup("oauthcred", "oauthcred@test.com", "Password1!").await; |
| 243 |
h.client.post_form("/logout", "").await; |
| 244 |
|
| 245 |
let (_app_id, client_id) = create_sync_app(&h.db, user_id).await; |
| 246 |
let (_verifier, challenge) = generate_pkce(); |
| 247 |
|
| 248 |
let state_param = "test-state-12345"; |
| 249 |
let redirect_uri = "http://127.0.0.1:9999/callback"; |
| 250 |
|
| 251 |
|
| 252 |
let resp = h |
| 253 |
.client |
| 254 |
.get(&format!( |
| 255 |
"/oauth/authorize?response_type=code&client_id={}&redirect_uri={}&state={}&code_challenge={}&code_challenge_method=S256", |
| 256 |
urlencoding::encode(&client_id), |
| 257 |
urlencoding::encode(redirect_uri), |
| 258 |
state_param, |
| 259 |
challenge, |
| 260 |
)) |
| 261 |
.await; |
| 262 |
assert_eq!(resp.status.as_u16(), 200); |
| 263 |
|
| 264 |
let csrf = h |
| 265 |
.client |
| 266 |
.csrf_token() |
| 267 |
.expect("No CSRF token") |
| 268 |
.to_string(); |
| 269 |
|
| 270 |
|
| 271 |
let body = format!( |
| 272 |
"client_id={}&redirect_uri={}&state={}&code_challenge={}&code_challenge_method=S256&login={}&password={}&_csrf={}", |
| 273 |
urlencoding::encode(&client_id), |
| 274 |
urlencoding::encode(redirect_uri), |
| 275 |
state_param, |
| 276 |
challenge, |
| 277 |
"oauthcred", |
| 278 |
"WrongPassword1%21", |
| 279 |
urlencoding::encode(&csrf), |
| 280 |
); |
| 281 |
|
| 282 |
let resp = h.client.post_form("/oauth/authorize", &body).await; |
| 283 |
|
| 284 |
|
| 285 |
assert_eq!( |
| 286 |
resp.status.as_u16(), |
| 287 |
200, |
| 288 |
"Invalid credentials should re-render form, got {}", |
| 289 |
resp.status |
| 290 |
); |
| 291 |
assert!( |
| 292 |
resp.text.contains("Invalid") || resp.text.contains("invalid") || resp.text.contains("password"), |
| 293 |
"Should show error message: {}", |
| 294 |
resp.text |
| 295 |
); |
| 296 |
} |
| 297 |
|
| 298 |
|
| 299 |
|
| 300 |
|
| 301 |
|
| 302 |
|
| 303 |
|
| 304 |
|
| 305 |
async fn obtain_access_token(h: &mut TestHarness, username: &str, password: &str) -> String { |
| 306 |
let user_id = sqlx::query_scalar::<_, UserId>("SELECT id FROM users WHERE username = $1") |
| 307 |
.bind(username) |
| 308 |
.fetch_one(&h.db) |
| 309 |
.await |
| 310 |
.expect("user lookup"); |
| 311 |
|
| 312 |
let (_app_id, client_id) = create_sync_app(&h.db, user_id).await; |
| 313 |
let (verifier, challenge) = generate_pkce(); |
| 314 |
|
| 315 |
h.client.post_form("/logout", "").await; |
| 316 |
let (code, _state) = authorize(h, &client_id, &challenge, username, password).await; |
| 317 |
|
| 318 |
let resp = h |
| 319 |
.client |
| 320 |
.post_form( |
| 321 |
"/oauth/token", |
| 322 |
&format!( |
| 323 |
"grant_type=authorization_code&code={}&redirect_uri={}&code_verifier={}&client_id={}&key=test-session-key", |
| 324 |
code, |
| 325 |
urlencoding::encode("http://127.0.0.1:9999/callback"), |
| 326 |
verifier, |
| 327 |
client_id, |
| 328 |
), |
| 329 |
) |
| 330 |
.await; |
| 331 |
assert_eq!(resp.status.as_u16(), 200, "Token exchange failed: {}", resp.text); |
| 332 |
let token: TokenResponse = resp.json(); |
| 333 |
token.access_token |
| 334 |
} |
| 335 |
|
| 336 |
#[derive(Deserialize)] |
| 337 |
struct UserinfoResp { |
| 338 |
user_id: UserId, |
| 339 |
username: String, |
| 340 |
display_name: Option<String>, |
| 341 |
avatar_url: Option<String>, |
| 342 |
perks: PerksResp, |
| 343 |
} |
| 344 |
|
| 345 |
#[derive(Deserialize)] |
| 346 |
struct PerksResp { |
| 347 |
fan_plus: bool, |
| 348 |
is_creator: bool, |
| 349 |
creator_tier: Option<CreatorTierResp>, |
| 350 |
} |
| 351 |
|
| 352 |
#[derive(Deserialize)] |
| 353 |
struct CreatorTierResp { |
| 354 |
tier: String, |
| 355 |
features: Vec<String>, |
| 356 |
} |
| 357 |
|
| 358 |
#[tokio::test] |
| 359 |
async fn oauth_userinfo_default() { |
| 360 |
let mut h = TestHarness::new().await; |
| 361 |
let user_id = h.signup("uinfo_def", "uinfo_def@test.com", "Password1!").await; |
| 362 |
|
| 363 |
let token = obtain_access_token(&mut h, "uinfo_def", "Password1!").await; |
| 364 |
h.client.set_bearer_token(&token); |
| 365 |
let resp = h.client.get("/oauth/userinfo").await; |
| 366 |
assert_eq!(resp.status.as_u16(), 200, "userinfo failed: {}", resp.text); |
| 367 |
|
| 368 |
let info: UserinfoResp = resp.json(); |
| 369 |
assert_eq!(info.user_id, user_id); |
| 370 |
assert_eq!(info.username, "uinfo_def"); |
| 371 |
assert!(info.display_name.is_none() || info.display_name.as_deref() == Some("")); |
| 372 |
let _ = info.avatar_url; |
| 373 |
assert!(!info.perks.fan_plus); |
| 374 |
assert!(!info.perks.is_creator); |
| 375 |
assert!(info.perks.creator_tier.is_none()); |
| 376 |
} |
| 377 |
|
| 378 |
#[tokio::test] |
| 379 |
async fn oauth_userinfo_creator_tier() { |
| 380 |
let mut h = TestHarness::new().await; |
| 381 |
let user_id = h.signup("uinfo_creator", "uinfo_creator@test.com", "Password1!").await; |
| 382 |
sqlx::query("UPDATE users SET creator_tier = 'big_files' WHERE id = $1") |
| 383 |
.bind(user_id) |
| 384 |
.execute(&h.db) |
| 385 |
.await |
| 386 |
.expect("set tier"); |
| 387 |
|
| 388 |
let token = obtain_access_token(&mut h, "uinfo_creator", "Password1!").await; |
| 389 |
h.client.set_bearer_token(&token); |
| 390 |
let resp = h.client.get("/oauth/userinfo").await; |
| 391 |
assert_eq!(resp.status.as_u16(), 200); |
| 392 |
|
| 393 |
let info: UserinfoResp = resp.json(); |
| 394 |
assert!(info.perks.is_creator); |
| 395 |
assert!(!info.perks.fan_plus); |
| 396 |
let tier = info.perks.creator_tier.expect("creator_tier populated"); |
| 397 |
assert_eq!(tier.tier, "big_files"); |
| 398 |
assert!(tier.features.iter().any(|f| f == "file_uploads")); |
| 399 |
assert!(tier.features.iter().any(|f| f == "large_files")); |
| 400 |
} |
| 401 |
|
| 402 |
#[tokio::test] |
| 403 |
async fn oauth_userinfo_fan_plus() { |
| 404 |
let mut h = TestHarness::new().await; |
| 405 |
let user_id = h.signup("uinfo_fp", "uinfo_fp@test.com", "Password1!").await; |
| 406 |
sqlx::query( |
| 407 |
"INSERT INTO fan_plus_subscriptions (user_id, stripe_subscription_id, stripe_customer_id, status) \ |
| 408 |
VALUES ($1, 'sub_uinfo_fp', 'cus_uinfo_fp', 'active')", |
| 409 |
) |
| 410 |
.bind(user_id) |
| 411 |
.execute(&h.db) |
| 412 |
.await |
| 413 |
.expect("seed fan_plus"); |
| 414 |
|
| 415 |
let token = obtain_access_token(&mut h, "uinfo_fp", "Password1!").await; |
| 416 |
h.client.set_bearer_token(&token); |
| 417 |
let resp = h.client.get("/oauth/userinfo").await; |
| 418 |
assert_eq!(resp.status.as_u16(), 200); |
| 419 |
|
| 420 |
let info: UserinfoResp = resp.json(); |
| 421 |
assert!(info.perks.fan_plus); |
| 422 |
assert!(!info.perks.is_creator); |
| 423 |
assert!(info.perks.creator_tier.is_none()); |
| 424 |
} |
| 425 |
|
| 426 |
#[tokio::test] |
| 427 |
async fn oauth_userinfo_unauthorized() { |
| 428 |
let mut h = TestHarness::new().await; |
| 429 |
|
| 430 |
let resp = h.client.get("/oauth/userinfo").await; |
| 431 |
assert_eq!(resp.status.as_u16(), 401); |
| 432 |
} |
| 433 |
|