//! Local HTTP server for OAuth2 redirect callbacks. //! //! Runs a minimal HTTP server on localhost to receive the OAuth callback //! after the user authorizes in their browser. use std::io::{Read, Write}; use std::net::{TcpListener, TcpStream}; use std::sync::mpsc::{self, Receiver, Sender}; use std::sync::{Arc, Mutex}; use std::thread; use std::time::Duration; /// Escape a string for safe inclusion in HTML content. fn html_escape(s: &str) -> String { s.replace('&', "&") .replace('<', "<") .replace('>', ">") .replace('"', """) .replace('\'', "'") } /// Result of an OAuth callback. #[derive(Debug, Clone)] pub struct CallbackResult { /// Authorization code from the OAuth provider. pub code: String, /// State parameter to verify CSRF. pub state: String, } /// Error that occurred during OAuth callback. #[derive(Debug, Clone)] pub struct CallbackError { /// Error code from OAuth provider. pub error: String, /// Human-readable error description. pub error_description: Option, } /// Stored callback data for polling. #[derive(Debug, Clone)] enum StoredCallback { Pending, Success { code: String, state: String }, Error { error: String, description: Option }, } /// A local HTTP server that handles OAuth2 redirects. pub struct OAuthCallbackServer { port: u16, receiver: Receiver>, } impl OAuthCallbackServer { /// Starts a new callback server on a random available port. /// /// Returns the server and the port it's listening on. pub fn start() -> Result { // Bind to port 0 to get a random available port let listener = TcpListener::bind("127.0.0.1:0") .map_err(|e| format!("Failed to bind callback server: {}", e))?; let port = listener .local_addr() .map_err(|e| format!("Failed to get server port: {}", e))? .port(); // Set non-blocking so we can timeout listener .set_nonblocking(true) .map_err(|e| format!("Failed to set non-blocking: {}", e))?; let (sender, receiver) = mpsc::channel(); // Shared storage for callback result (for polling) let stored = Arc::new(Mutex::new(StoredCallback::Pending)); // Spawn a thread to handle the callback let stored_clone = stored.clone(); thread::spawn(move || { Self::run_server(listener, sender, stored_clone); }); Ok(Self { port, receiver }) } /// Returns the port this server is listening on. pub fn port(&self) -> u16 { self.port } /// Waits for the OAuth callback with a timeout. /// /// # Arguments /// * `timeout` - Maximum time to wait for the callback. pub fn wait_for_callback( &self, timeout: Duration, ) -> Result, String> { self.receiver .recv_timeout(timeout) .map_err(|e| format!("Timeout waiting for OAuth callback: {}", e)) } fn run_server( listener: TcpListener, sender: Sender>, stored: Arc>, ) { // Accept connections for up to 5 minutes let deadline = std::time::Instant::now() + Duration::from_secs(300); let mut callback_received = false; while std::time::Instant::now() < deadline { match listener.accept() { Ok((stream, _)) => { let result = Self::handle_request(stream, &stored, callback_received); if let Some(callback_result) = result { let _ = sender.send(callback_result); callback_received = true; // Continue running for a bit to serve /result requests } } Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { thread::sleep(Duration::from_millis(100)); } Err(_) => { thread::sleep(Duration::from_millis(100)); } } // If we received a callback, keep serving for 30 more seconds for polling if callback_received { let poll_deadline = std::time::Instant::now() + Duration::from_secs(30); while std::time::Instant::now() < poll_deadline { match listener.accept() { Ok((stream, _)) => { Self::handle_request(stream, &stored, true); } Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { thread::sleep(Duration::from_millis(100)); } Err(_) => { thread::sleep(Duration::from_millis(100)); } } } return; } } } fn handle_request( mut stream: TcpStream, stored: &Arc>, _callback_received: bool, ) -> Option> { let mut buffer = [0; 16384]; let n = stream.read(&mut buffer).ok()?; let request = String::from_utf8_lossy(&buffer[..n]); // Parse the GET request line let first_line = request.lines().next()?; if !first_line.starts_with("GET ") { Self::send_response(&mut stream, "405 Method Not Allowed", "text/plain", "Only GET is supported"); return None; } // Extract path and query string let path = first_line .strip_prefix("GET ")? .split_whitespace() .next()?; let path_only = path.split('?').next().unwrap_or(path); // Handle /result endpoint for polling if path_only == "/result" { let stored_guard = stored.lock().ok()?; let json = match &*stored_guard { StoredCallback::Pending => r#"{"status":"pending"}"#.to_string(), StoredCallback::Success { code, state } => { serde_json::json!({ "status": "success", "code": code, "state": state, }).to_string() } StoredCallback::Error { error, description } => { serde_json::json!({ "status": "error", "error": error, "description": description.as_deref().unwrap_or(""), }).to_string() } }; Self::send_json_response(&mut stream, &json); return None; } // Parse query parameters for the callback let query = path.split('?').nth(1).unwrap_or(""); let params: std::collections::HashMap<&str, &str> = query .split('&') .filter_map(|pair| { let mut parts = pair.splitn(2, '='); Some((parts.next()?, parts.next().unwrap_or(""))) }) .collect(); // Check for error if let Some(error) = params.get("error") { let error_description = params.get("error_description").map(|s| { urlencoding::decode(s).unwrap_or_else(|_| s.to_string()) }); // Store the error if let Ok(mut stored_guard) = stored.lock() { *stored_guard = StoredCallback::Error { error: error.to_string(), description: error_description.clone(), }; } let safe_msg = html_escape(error_description.as_deref().unwrap_or(error)); Self::send_response( &mut stream, "200 OK", "text/html; charset=utf-8", &format!( r#" Authorization Failed

Authorization Failed

{}

You can close this window.

"#, safe_msg ), ); return Some(Err(CallbackError { error: error.to_string(), error_description, })); } // Extract code and state, URL-decoding to handle encoded chars let code = params.get("code")?; let state = params.get("state")?; let code = urlencoding::decode(code).unwrap_or_else(|_| code.to_string()); let state = urlencoding::decode(state).unwrap_or_else(|_| state.to_string()); // Store the success result if let Ok(mut stored_guard) = stored.lock() { *stored_guard = StoredCallback::Success { code: code.clone(), state: state.clone(), }; } Self::send_response( &mut stream, "200 OK", "text/html; charset=utf-8", r#" Authorization Successful

Authorization Successful

Your email account has been connected.

You can close this window and return to GoingsOn.

"#, ); Some(Ok(CallbackResult { code, state, })) } fn send_response(stream: &mut TcpStream, status: &str, content_type: &str, body: &str) { let response = format!( "HTTP/1.1 {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}", status, content_type, body.len(), body ); let _ = stream.write_all(response.as_bytes()); let _ = stream.flush(); } fn send_json_response(stream: &mut TcpStream, json: &str) { Self::send_response(stream, "200 OK", "application/json", json); } } /// Simple URL decoding with proper multi-byte UTF-8 support. mod urlencoding { pub fn decode(s: &str) -> Result { let mut bytes = Vec::with_capacity(s.len()); let mut chars = s.chars(); while let Some(c) = chars.next() { if c == '%' { let hex: String = chars.by_ref().take(2).collect(); if hex.len() == 2 { if let Ok(byte) = u8::from_str_radix(&hex, 16) { bytes.push(byte); continue; } } return Err(()); } else if c == '+' { bytes.push(b' '); } else { let mut buf = [0u8; 4]; let encoded = c.encode_utf8(&mut buf); bytes.extend_from_slice(encoded.as_bytes()); } } String::from_utf8(bytes).map_err(|_| ()) } } #[cfg(test)] mod tests { use super::*; use std::net::TcpStream; use std::sync::Mutex; use std::time::Duration; // Serialize server integration tests to avoid port/timing interference // when multiple tests spawn callback servers simultaneously. // Uses unwrap_or_else to recover from a poisoned mutex (if a prior test panicked). static SERVER_TEST_LOCK: Mutex<()> = Mutex::new(()); fn lock_server_tests() -> std::sync::MutexGuard<'static, ()> { SERVER_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()) } // ============ URL Decoding Tests ============ #[test] fn url_decode_plain_string() { assert_eq!(urlencoding::decode("hello").unwrap(), "hello"); } #[test] fn url_decode_percent_encoded_spaces() { assert_eq!(urlencoding::decode("hello%20world").unwrap(), "hello world"); } #[test] fn url_decode_plus_as_space() { assert_eq!(urlencoding::decode("hello+world").unwrap(), "hello world"); } #[test] fn url_decode_special_characters() { assert_eq!(urlencoding::decode("a%3Db%26c%3Dd").unwrap(), "a=b&c=d"); } #[test] fn url_decode_slash() { assert_eq!(urlencoding::decode("%2F").unwrap(), "/"); } #[test] fn url_decode_mixed_encoded_and_plain() { assert_eq!( urlencoding::decode("access%20denied%3A+invalid+scope").unwrap(), "access denied: invalid scope" ); } #[test] fn url_decode_empty_string() { assert_eq!(urlencoding::decode("").unwrap(), ""); } #[test] fn url_decode_truncated_percent_sequence() { // "%A" has only one hex digit instead of two assert!(urlencoding::decode("%A").is_err()); } #[test] fn url_decode_invalid_hex_after_percent() { assert!(urlencoding::decode("%ZZ").is_err()); } #[test] fn url_decode_percent_at_end() { // "%" with nothing after it assert!(urlencoding::decode("hello%").is_err()); } // ============ CallbackResult / CallbackError struct tests ============ #[test] fn callback_result_stores_code_and_state() { let result = CallbackResult { code: "auth_code_123".to_string(), state: "csrf_state_abc".to_string(), }; assert_eq!(result.code, "auth_code_123"); assert_eq!(result.state, "csrf_state_abc"); } #[test] fn callback_error_with_description() { let err = CallbackError { error: "access_denied".to_string(), error_description: Some("User denied access".to_string()), }; assert_eq!(err.error, "access_denied"); assert_eq!(err.error_description.as_deref(), Some("User denied access")); } #[test] fn callback_error_without_description() { let err = CallbackError { error: "server_error".to_string(), error_description: None, }; assert_eq!(err.error, "server_error"); assert!(err.error_description.is_none()); } // ============ Server Integration Tests ============ /// Helper: send a raw HTTP GET request and read the full response. /// /// Reads headers first, extracts Content-Length, then reads the exact body. /// This avoids blocking on `read_to_string` when the server doesn't close /// the connection immediately. fn send_request(port: u16, path: &str) -> String { let request = format!("GET {} HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\n\r\n", path); send_and_read(port, &request) } /// Core send-and-read implementation. /// /// The callback server uses a non-blocking accept loop (100ms poll interval), /// so under heavy CPU load the server thread may not be scheduled immediately. /// A small initial delay helps avoid connecting before the server is ready. fn send_and_read(port: u16, request: &str) -> String { // Give the server thread time to enter its accept loop. // The server polls every 100ms, so 150ms ensures at least one accept cycle. std::thread::sleep(Duration::from_millis(150)); let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).unwrap(); stream .set_read_timeout(Some(Duration::from_secs(5))) .unwrap(); stream.write_all(request.as_bytes()).unwrap(); stream.flush().unwrap(); // Read byte-by-byte until we find \r\n\r\n (end of headers) let mut raw = Vec::new(); let mut found_end = false; loop { let mut byte = [0u8; 1]; match stream.read(&mut byte) { Ok(0) => break, Ok(_) => { raw.push(byte[0]); if raw.len() >= 4 && &raw[raw.len() - 4..] == b"\r\n\r\n" { found_end = true; break; } } Err(_) => break, } } if !found_end { return String::from_utf8_lossy(&raw).to_string(); } let headers_str = String::from_utf8_lossy(&raw).to_string(); // Parse Content-Length from headers let content_length: usize = headers_str .lines() .find_map(|line| { let lower = line.to_lowercase(); if lower.starts_with("content-length:") { lower .strip_prefix("content-length:") .and_then(|v| v.trim().parse().ok()) } else { None } }) .unwrap_or(0); // Read exactly content_length bytes for the body let mut body_buf = vec![0u8; content_length]; if content_length > 0 { let mut read_so_far = 0; while read_so_far < content_length { match stream.read(&mut body_buf[read_so_far..]) { Ok(0) => break, Ok(n) => read_so_far += n, Err(_) => break, } } } let body_str = String::from_utf8_lossy(&body_buf); format!("{}{}", headers_str, body_str) } /// Helper: send a raw HTTP request with a custom method and read the response. fn send_raw_request(port: u16, request: &str) -> String { send_and_read(port, request) } /// Helper: extract the HTTP body from a raw response (everything after \r\n\r\n). fn extract_body(response: &str) -> &str { response .split("\r\n\r\n") .nth(1) .unwrap_or("") } /// Helper: extract the HTTP status line from a raw response. fn extract_status(response: &str) -> &str { response.lines().next().unwrap_or("") } #[test] fn server_starts_on_random_port() { let _lock = lock_server_tests(); let server = OAuthCallbackServer::start().unwrap(); assert!(server.port() > 0); } #[test] fn server_returns_success_on_valid_callback() { let _lock = lock_server_tests(); let server = OAuthCallbackServer::start().unwrap(); let port = server.port(); let response = send_request(port, "/?code=test_code_abc&state=test_state_xyz"); let body = extract_body(&response); // Should return success HTML page assert!(body.contains("Authorization Successful")); assert!(body.contains("email account has been connected")); // Should deliver result via channel let result = server.wait_for_callback(Duration::from_secs(2)); let callback = result.unwrap().unwrap(); assert_eq!(callback.code, "test_code_abc"); assert_eq!(callback.state, "test_state_xyz"); } #[test] fn server_returns_error_on_oauth_error_callback() { let _lock = lock_server_tests(); let server = OAuthCallbackServer::start().unwrap(); let port = server.port(); let response = send_request( port, "/?error=access_denied&error_description=User%20denied%20access", ); let body = extract_body(&response); // Should return error HTML page assert!(body.contains("Authorization Failed")); assert!(body.contains("User denied access")); // Should deliver error via channel let result = server.wait_for_callback(Duration::from_secs(2)); let err = result.unwrap().unwrap_err(); assert_eq!(err.error, "access_denied"); assert_eq!(err.error_description.as_deref(), Some("User denied access")); } #[test] fn server_returns_error_without_description() { let _lock = lock_server_tests(); let server = OAuthCallbackServer::start().unwrap(); let port = server.port(); let response = send_request(port, "/?error=server_error"); let body = extract_body(&response); assert!(body.contains("Authorization Failed")); let result = server.wait_for_callback(Duration::from_secs(2)); let err = result.unwrap().unwrap_err(); assert_eq!(err.error, "server_error"); assert!(err.error_description.is_none()); } #[test] fn server_rejects_non_get_methods() { let _lock = lock_server_tests(); let server = OAuthCallbackServer::start().unwrap(); let port = server.port(); let response = send_raw_request(port, "POST / HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n"); assert!(extract_status(&response).contains("405")); assert!(extract_body(&response).contains("Only GET is supported")); } #[test] fn server_result_endpoint_returns_pending_initially() { let _lock = lock_server_tests(); let server = OAuthCallbackServer::start().unwrap(); let port = server.port(); let response = send_request(port, "/result"); let body = extract_body(&response); assert!(body.contains(r#""status":"pending"#)); } #[test] fn server_result_endpoint_returns_success_after_callback() { let _lock = lock_server_tests(); let server = OAuthCallbackServer::start().unwrap(); let port = server.port(); // First, trigger the callback send_request(port, "/?code=mycode&state=mystate"); // Small delay to let the server process std::thread::sleep(Duration::from_millis(200)); // Then poll /result let response = send_request(port, "/result"); let body = extract_body(&response); assert!(body.contains(r#""status":"success"#)); assert!(body.contains(r#""code":"mycode"#)); assert!(body.contains(r#""state":"mystate"#)); } #[test] fn server_result_endpoint_returns_error_after_error_callback() { let _lock = lock_server_tests(); let server = OAuthCallbackServer::start().unwrap(); let port = server.port(); // Trigger an error callback send_request(port, "/?error=invalid_grant&error_description=Expired"); std::thread::sleep(Duration::from_millis(200)); // Then poll /result let response = send_request(port, "/result"); let body = extract_body(&response); assert!(body.contains(r#""status":"error"#)); assert!(body.contains(r#""error":"invalid_grant"#)); assert!(body.contains(r#""description":"Expired"#)); } #[test] fn server_timeout_returns_error() { let _lock = lock_server_tests(); let server = OAuthCallbackServer::start().unwrap(); // Don't send any request, just wait with a very short timeout let result = server.wait_for_callback(Duration::from_millis(50)); assert!(result.is_err()); assert!(result.unwrap_err().contains("Timeout")); } #[test] fn server_ignores_request_with_no_code_or_error() { let _lock = lock_server_tests(); let server = OAuthCallbackServer::start().unwrap(); let port = server.port(); // Send a request to root with no query params. // The handler returns None early (no code/error), so no HTTP response is sent. // We just connect and send the request without expecting a response. let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).unwrap(); stream.set_write_timeout(Some(Duration::from_secs(2))).unwrap(); let request = "GET / HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n"; stream.write_all(request.as_bytes()).unwrap(); stream.flush().unwrap(); drop(stream); // Give the server a moment to process std::thread::sleep(Duration::from_millis(100)); // The channel should still be empty (no callback delivered) let result = server.wait_for_callback(Duration::from_millis(100)); assert!(result.is_err()); } #[test] fn server_handles_code_without_state() { let _lock = lock_server_tests(); let server = OAuthCallbackServer::start().unwrap(); let port = server.port(); // Code present but state missing -- should not produce a callback send_request(port, "/?code=only_code"); let result = server.wait_for_callback(Duration::from_millis(100)); assert!(result.is_err()); } #[test] fn server_handles_state_without_code() { let _lock = lock_server_tests(); let server = OAuthCallbackServer::start().unwrap(); let port = server.port(); // State present but code missing -- should not produce a callback send_request(port, "/?state=only_state"); let result = server.wait_for_callback(Duration::from_millis(100)); assert!(result.is_err()); } }