//! PKCE helpers and localhost OAuth callback server. use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; use sha2::{Digest, Sha256}; use synckit_client::SyncKitClient; use tracing::instrument; use crate::error::{Result, SyncError}; /// Result of the localhost callback: authorization code + state. pub struct CallbackResult { pub code: String, pub state: String, } /// Active auth session with callback receiver. pub struct AuthSession { pub auth_url: String, pub code_verifier: String, pub expected_state: String, pub port: u16, pub code_rx: tokio::sync::oneshot::Receiver, } /// Generate a random PKCE code verifier (32 bytes → base64url). pub fn generate_code_verifier() -> String { use rand::RngCore; let mut buf = [0u8; 32]; rand::rng().fill_bytes(&mut buf); URL_SAFE_NO_PAD.encode(buf) } /// Derive the PKCE code challenge from a verifier (SHA256 → base64url). pub fn generate_code_challenge(verifier: &str) -> String { let hash = Sha256::digest(verifier.as_bytes()); URL_SAFE_NO_PAD.encode(hash) } /// Generate a random CSRF state parameter (16 bytes → base64url). pub fn generate_state() -> String { use rand::RngCore; let mut buf = [0u8; 16]; rand::rng().fill_bytes(&mut buf); URL_SAFE_NO_PAD.encode(buf) } /// Start the OAuth2 PKCE flow: bind a localhost callback server, build the auth URL. #[instrument(skip_all)] pub fn start_auth(client: &SyncKitClient) -> Result { let code_verifier = generate_code_verifier(); let code_challenge = generate_code_challenge(&code_verifier); let state = generate_state(); // Bind to a random available port let listener = std::net::TcpListener::bind("127.0.0.1:0") .map_err(SyncError::Io)?; let port = listener .local_addr() .map_err(SyncError::Io)? .port(); let auth_url = client.build_authorize_url(port, &state, &code_challenge); let (tx, rx) = tokio::sync::oneshot::channel(); let thread_expected_state = state.clone(); let expected_state = state; // Spawn blocking listener thread std::thread::spawn(move || { listener .set_nonblocking(false) .ok(); // 5-minute timeout listener .set_nonblocking(true) .ok(); let deadline = std::time::Instant::now() + std::time::Duration::from_secs(300); loop { if std::time::Instant::now() > deadline { tracing::warn!("Auth callback timed out after 5 minutes"); break; } match listener.accept() { Ok((mut stream, _)) => { use std::io::{Read, Write}; let mut buf = [0u8; 2048]; let n = stream.read(&mut buf).unwrap_or(0); let request = String::from_utf8_lossy(&buf[..n]); // Parse GET /callback?code=X&state=Y or GET /?code=X&state=Y if let Some(query_start) = request.find('?') { let after_q = query_start + 1; let query_end = request[query_start..] .find(' ') .unwrap_or(request.len().saturating_sub(query_start)); let end = query_start + query_end; let query = match request.get(after_q..end) { Some(q) => q, None => { tracing::warn!("Malformed OAuth callback request, could not parse query string"); break; } }; let mut code = None; let mut cb_state = None; for param in query.split('&') { if let Some((key, value)) = param.split_once('=') { match key { "code" => code = Some(percent_decode(value)), "state" => cb_state = Some(percent_decode(value)), _ => {} } } } if let (Some(code), Some(state)) = (code, cb_state) { if state != thread_expected_state { tracing::warn!("OAuth callback state mismatch — rejecting"); let body = "

Authentication failed

CSRF state mismatch. Please try again.

"; let response = format!( "HTTP/1.1 403 Forbidden\r\nContent-Type: text/html\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}", body.len(), body ); let _ = stream.write_all(response.as_bytes()); break; } let body = "

Authentication successful

You can close this tab and return to audiofiles.

"; let response = format!( "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}", body.len(), body ); let _ = stream.write_all(response.as_bytes()); let _ = tx.send(CallbackResult { code, state }); } } break; } Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { std::thread::sleep(std::time::Duration::from_millis(100)); continue; } Err(_) => break, } } }); Ok(AuthSession { auth_url, code_verifier, expected_state, port, code_rx: rx, }) } /// Decode percent-encoded UTF-8 strings (RFC 3986). /// Passes through invalid sequences unchanged. fn percent_decode(input: &str) -> String { let mut bytes = Vec::with_capacity(input.len()); let mut chars = input.bytes(); while let Some(b) = chars.next() { if b == b'%' { let hi = chars.next(); let lo = chars.next(); if let (Some(hi), Some(lo)) = (hi, lo) { if let (Some(h), Some(l)) = (hex_val(hi), hex_val(lo)) { bytes.push(h << 4 | l); continue; } // Invalid hex pair — emit literally bytes.push(b'%'); bytes.push(hi); bytes.push(lo); } else { bytes.push(b'%'); if let Some(hi) = hi { bytes.push(hi); } } } else if b == b'+' { bytes.push(b' '); } else { bytes.push(b); } } String::from_utf8(bytes).unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned()) } fn hex_val(b: u8) -> Option { match b { b'0'..=b'9' => Some(b - b'0'), b'a'..=b'f' => Some(b - b'a' + 10), b'A'..=b'F' => Some(b - b'A' + 10), _ => None, } }