Skip to main content

max / goingson

24.6 KB · 733 lines History Blame Raw
1 //! Local HTTP server for OAuth2 redirect callbacks.
2 //!
3 //! Runs a minimal HTTP server on localhost to receive the OAuth callback
4 //! after the user authorizes in their browser.
5
6 use std::io::{Read, Write};
7 use std::net::{TcpListener, TcpStream};
8 use std::sync::mpsc::{self, Receiver, Sender};
9 use std::sync::{Arc, Mutex};
10 use std::thread;
11 use std::time::Duration;
12
13 /// Escape a string for safe inclusion in HTML content.
14 fn html_escape(s: &str) -> String {
15 s.replace('&', "&")
16 .replace('<', "&lt;")
17 .replace('>', "&gt;")
18 .replace('"', "&quot;")
19 .replace('\'', "&#x27;")
20 }
21
22 /// Result of an OAuth callback.
23 #[derive(Debug, Clone)]
24 pub struct CallbackResult {
25 /// Authorization code from the OAuth provider.
26 pub code: String,
27 /// State parameter to verify CSRF.
28 pub state: String,
29 }
30
31 /// Error that occurred during OAuth callback.
32 #[derive(Debug, Clone)]
33 pub struct CallbackError {
34 /// Error code from OAuth provider.
35 pub error: String,
36 /// Human-readable error description.
37 pub error_description: Option<String>,
38 }
39
40 /// Stored callback data for polling.
41 #[derive(Debug, Clone)]
42 enum StoredCallback {
43 Pending,
44 Success { code: String, state: String },
45 Error { error: String, description: Option<String> },
46 }
47
48 /// A local HTTP server that handles OAuth2 redirects.
49 pub struct OAuthCallbackServer {
50 port: u16,
51 receiver: Receiver<Result<CallbackResult, CallbackError>>,
52 }
53
54 impl OAuthCallbackServer {
55 /// Starts a new callback server on a random available port.
56 ///
57 /// Returns the server and the port it's listening on.
58 pub fn start() -> Result<Self, String> {
59 // Bind to port 0 to get a random available port
60 let listener = TcpListener::bind("127.0.0.1:0")
61 .map_err(|e| format!("Failed to bind callback server: {}", e))?;
62
63 let port = listener
64 .local_addr()
65 .map_err(|e| format!("Failed to get server port: {}", e))?
66 .port();
67
68 // Set non-blocking so we can timeout
69 listener
70 .set_nonblocking(true)
71 .map_err(|e| format!("Failed to set non-blocking: {}", e))?;
72
73 let (sender, receiver) = mpsc::channel();
74
75 // Shared storage for callback result (for polling)
76 let stored = Arc::new(Mutex::new(StoredCallback::Pending));
77
78 // Spawn a thread to handle the callback
79 let stored_clone = stored.clone();
80 thread::spawn(move || {
81 Self::run_server(listener, sender, stored_clone);
82 });
83
84 Ok(Self { port, receiver })
85 }
86
87 /// Returns the port this server is listening on.
88 pub fn port(&self) -> u16 {
89 self.port
90 }
91
92 /// Waits for the OAuth callback with a timeout.
93 ///
94 /// # Arguments
95 /// * `timeout` - Maximum time to wait for the callback.
96 pub fn wait_for_callback(
97 &self,
98 timeout: Duration,
99 ) -> Result<Result<CallbackResult, CallbackError>, String> {
100 self.receiver
101 .recv_timeout(timeout)
102 .map_err(|e| format!("Timeout waiting for OAuth callback: {}", e))
103 }
104
105 fn run_server(
106 listener: TcpListener,
107 sender: Sender<Result<CallbackResult, CallbackError>>,
108 stored: Arc<Mutex<StoredCallback>>,
109 ) {
110 // Accept connections for up to 5 minutes
111 let deadline = std::time::Instant::now() + Duration::from_secs(300);
112 let mut callback_received = false;
113
114 while std::time::Instant::now() < deadline {
115 match listener.accept() {
116 Ok((stream, _)) => {
117 let result = Self::handle_request(stream, &stored, callback_received);
118 if let Some(callback_result) = result {
119 let _ = sender.send(callback_result);
120 callback_received = true;
121 // Continue running for a bit to serve /result requests
122 }
123 }
124 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
125 thread::sleep(Duration::from_millis(100));
126 }
127 Err(_) => {
128 thread::sleep(Duration::from_millis(100));
129 }
130 }
131
132 // If we received a callback, keep serving for 30 more seconds for polling
133 if callback_received {
134 let poll_deadline = std::time::Instant::now() + Duration::from_secs(30);
135 while std::time::Instant::now() < poll_deadline {
136 match listener.accept() {
137 Ok((stream, _)) => {
138 Self::handle_request(stream, &stored, true);
139 }
140 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
141 thread::sleep(Duration::from_millis(100));
142 }
143 Err(_) => {
144 thread::sleep(Duration::from_millis(100));
145 }
146 }
147 }
148 return;
149 }
150 }
151 }
152
153 fn handle_request(
154 mut stream: TcpStream,
155 stored: &Arc<Mutex<StoredCallback>>,
156 _callback_received: bool,
157 ) -> Option<Result<CallbackResult, CallbackError>> {
158 let mut buffer = [0; 16384];
159 let n = stream.read(&mut buffer).ok()?;
160 let request = String::from_utf8_lossy(&buffer[..n]);
161
162 // Parse the GET request line
163 let first_line = request.lines().next()?;
164 if !first_line.starts_with("GET ") {
165 Self::send_response(&mut stream, "405 Method Not Allowed", "text/plain", "Only GET is supported");
166 return None;
167 }
168
169 // Extract path and query string
170 let path = first_line
171 .strip_prefix("GET ")?
172 .split_whitespace()
173 .next()?;
174
175 let path_only = path.split('?').next().unwrap_or(path);
176
177 // Handle /result endpoint for polling
178 if path_only == "/result" {
179 let stored_guard = stored.lock().ok()?;
180 let json = match &*stored_guard {
181 StoredCallback::Pending => r#"{"status":"pending"}"#.to_string(),
182 StoredCallback::Success { code, state } => {
183 serde_json::json!({
184 "status": "success",
185 "code": code,
186 "state": state,
187 }).to_string()
188 }
189 StoredCallback::Error { error, description } => {
190 serde_json::json!({
191 "status": "error",
192 "error": error,
193 "description": description.as_deref().unwrap_or(""),
194 }).to_string()
195 }
196 };
197 Self::send_json_response(&mut stream, &json);
198 return None;
199 }
200
201 // Parse query parameters for the callback
202 let query = path.split('?').nth(1).unwrap_or("");
203 let params: std::collections::HashMap<&str, &str> = query
204 .split('&')
205 .filter_map(|pair| {
206 let mut parts = pair.splitn(2, '=');
207 Some((parts.next()?, parts.next().unwrap_or("")))
208 })
209 .collect();
210
211 // Check for error
212 if let Some(error) = params.get("error") {
213 let error_description = params.get("error_description").map(|s| {
214 urlencoding::decode(s).unwrap_or_else(|_| s.to_string())
215 });
216
217 // Store the error
218 if let Ok(mut stored_guard) = stored.lock() {
219 *stored_guard = StoredCallback::Error {
220 error: error.to_string(),
221 description: error_description.clone(),
222 };
223 }
224
225 let safe_msg = html_escape(error_description.as_deref().unwrap_or(error));
226 Self::send_response(
227 &mut stream,
228 "200 OK",
229 "text/html; charset=utf-8",
230 &format!(
231 r#"<!DOCTYPE html>
232 <html>
233 <head><title>Authorization Failed</title></head>
234 <body style="font-family: system-ui; padding: 2rem; text-align: center;">
235 <h1 style="color: #d33;">Authorization Failed</h1>
236 <p>{}</p>
237 <p style="color: #666;">You can close this window.</p>
238 </body>
239 </html>"#,
240 safe_msg
241 ),
242 );
243
244 return Some(Err(CallbackError {
245 error: error.to_string(),
246 error_description,
247 }));
248 }
249
250 // Extract code and state, URL-decoding to handle encoded chars
251 let code = params.get("code")?;
252 let state = params.get("state")?;
253 let code = urlencoding::decode(code).unwrap_or_else(|_| code.to_string());
254 let state = urlencoding::decode(state).unwrap_or_else(|_| state.to_string());
255
256 // Store the success result
257 if let Ok(mut stored_guard) = stored.lock() {
258 *stored_guard = StoredCallback::Success {
259 code: code.clone(),
260 state: state.clone(),
261 };
262 }
263
264 Self::send_response(
265 &mut stream,
266 "200 OK",
267 "text/html; charset=utf-8",
268 r#"<!DOCTYPE html>
269 <html>
270 <head><title>Authorization Successful</title></head>
271 <body style="font-family: system-ui; padding: 2rem; text-align: center;">
272 <h1 style="color: #090;">Authorization Successful</h1>
273 <p>Your email account has been connected.</p>
274 <p style="color: #666;">You can close this window and return to GoingsOn.</p>
275 <script>setTimeout(function() { window.close(); }, 2000);</script>
276 </body>
277 </html>"#,
278 );
279
280 Some(Ok(CallbackResult {
281 code,
282 state,
283 }))
284 }
285
286 fn send_response(stream: &mut TcpStream, status: &str, content_type: &str, body: &str) {
287 let response = format!(
288 "HTTP/1.1 {}\r\nContent-Type: {}\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}",
289 status,
290 content_type,
291 body.len(),
292 body
293 );
294 let _ = stream.write_all(response.as_bytes());
295 let _ = stream.flush();
296 }
297
298 fn send_json_response(stream: &mut TcpStream, json: &str) {
299 Self::send_response(stream, "200 OK", "application/json", json);
300 }
301 }
302
303 /// Simple URL decoding with proper multi-byte UTF-8 support.
304 mod urlencoding {
305 pub fn decode(s: &str) -> Result<String, ()> {
306 let mut bytes = Vec::with_capacity(s.len());
307 let mut chars = s.chars();
308
309 while let Some(c) = chars.next() {
310 if c == '%' {
311 let hex: String = chars.by_ref().take(2).collect();
312 if hex.len() == 2 {
313 if let Ok(byte) = u8::from_str_radix(&hex, 16) {
314 bytes.push(byte);
315 continue;
316 }
317 }
318 return Err(());
319 } else if c == '+' {
320 bytes.push(b' ');
321 } else {
322 let mut buf = [0u8; 4];
323 let encoded = c.encode_utf8(&mut buf);
324 bytes.extend_from_slice(encoded.as_bytes());
325 }
326 }
327
328 String::from_utf8(bytes).map_err(|_| ())
329 }
330 }
331
332 #[cfg(test)]
333 mod tests {
334 use super::*;
335 use std::net::TcpStream;
336 use std::sync::Mutex;
337 use std::time::Duration;
338
339 // Serialize server integration tests to avoid port/timing interference
340 // when multiple tests spawn callback servers simultaneously.
341 // Uses unwrap_or_else to recover from a poisoned mutex (if a prior test panicked).
342 static SERVER_TEST_LOCK: Mutex<()> = Mutex::new(());
343
344 fn lock_server_tests() -> std::sync::MutexGuard<'static, ()> {
345 SERVER_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner())
346 }
347
348 // ============ URL Decoding Tests ============
349
350 #[test]
351 fn url_decode_plain_string() {
352 assert_eq!(urlencoding::decode("hello").unwrap(), "hello");
353 }
354
355 #[test]
356 fn url_decode_percent_encoded_spaces() {
357 assert_eq!(urlencoding::decode("hello%20world").unwrap(), "hello world");
358 }
359
360 #[test]
361 fn url_decode_plus_as_space() {
362 assert_eq!(urlencoding::decode("hello+world").unwrap(), "hello world");
363 }
364
365 #[test]
366 fn url_decode_special_characters() {
367 assert_eq!(urlencoding::decode("a%3Db%26c%3Dd").unwrap(), "a=b&c=d");
368 }
369
370 #[test]
371 fn url_decode_slash() {
372 assert_eq!(urlencoding::decode("%2F").unwrap(), "/");
373 }
374
375 #[test]
376 fn url_decode_mixed_encoded_and_plain() {
377 assert_eq!(
378 urlencoding::decode("access%20denied%3A+invalid+scope").unwrap(),
379 "access denied: invalid scope"
380 );
381 }
382
383 #[test]
384 fn url_decode_empty_string() {
385 assert_eq!(urlencoding::decode("").unwrap(), "");
386 }
387
388 #[test]
389 fn url_decode_truncated_percent_sequence() {
390 // "%A" has only one hex digit instead of two
391 assert!(urlencoding::decode("%A").is_err());
392 }
393
394 #[test]
395 fn url_decode_invalid_hex_after_percent() {
396 assert!(urlencoding::decode("%ZZ").is_err());
397 }
398
399 #[test]
400 fn url_decode_percent_at_end() {
401 // "%" with nothing after it
402 assert!(urlencoding::decode("hello%").is_err());
403 }
404
405 // ============ CallbackResult / CallbackError struct tests ============
406
407 #[test]
408 fn callback_result_stores_code_and_state() {
409 let result = CallbackResult {
410 code: "auth_code_123".to_string(),
411 state: "csrf_state_abc".to_string(),
412 };
413 assert_eq!(result.code, "auth_code_123");
414 assert_eq!(result.state, "csrf_state_abc");
415 }
416
417 #[test]
418 fn callback_error_with_description() {
419 let err = CallbackError {
420 error: "access_denied".to_string(),
421 error_description: Some("User denied access".to_string()),
422 };
423 assert_eq!(err.error, "access_denied");
424 assert_eq!(err.error_description.as_deref(), Some("User denied access"));
425 }
426
427 #[test]
428 fn callback_error_without_description() {
429 let err = CallbackError {
430 error: "server_error".to_string(),
431 error_description: None,
432 };
433 assert_eq!(err.error, "server_error");
434 assert!(err.error_description.is_none());
435 }
436
437 // ============ Server Integration Tests ============
438
439 /// Helper: send a raw HTTP GET request and read the full response.
440 ///
441 /// Reads headers first, extracts Content-Length, then reads the exact body.
442 /// This avoids blocking on `read_to_string` when the server doesn't close
443 /// the connection immediately.
444 fn send_request(port: u16, path: &str) -> String {
445 let request = format!("GET {} HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\n\r\n", path);
446 send_and_read(port, &request)
447 }
448
449 /// Core send-and-read implementation.
450 ///
451 /// The callback server uses a non-blocking accept loop (100ms poll interval),
452 /// so under heavy CPU load the server thread may not be scheduled immediately.
453 /// A small initial delay helps avoid connecting before the server is ready.
454 fn send_and_read(port: u16, request: &str) -> String {
455 // Give the server thread time to enter its accept loop.
456 // The server polls every 100ms, so 150ms ensures at least one accept cycle.
457 std::thread::sleep(Duration::from_millis(150));
458
459 let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).unwrap();
460 stream
461 .set_read_timeout(Some(Duration::from_secs(5)))
462 .unwrap();
463 stream.write_all(request.as_bytes()).unwrap();
464 stream.flush().unwrap();
465
466 // Read byte-by-byte until we find \r\n\r\n (end of headers)
467 let mut raw = Vec::new();
468 let mut found_end = false;
469 loop {
470 let mut byte = [0u8; 1];
471 match stream.read(&mut byte) {
472 Ok(0) => break,
473 Ok(_) => {
474 raw.push(byte[0]);
475 if raw.len() >= 4 && &raw[raw.len() - 4..] == b"\r\n\r\n" {
476 found_end = true;
477 break;
478 }
479 }
480 Err(_) => break,
481 }
482 }
483
484 if !found_end {
485 return String::from_utf8_lossy(&raw).to_string();
486 }
487
488 let headers_str = String::from_utf8_lossy(&raw).to_string();
489
490 // Parse Content-Length from headers
491 let content_length: usize = headers_str
492 .lines()
493 .find_map(|line| {
494 let lower = line.to_lowercase();
495 if lower.starts_with("content-length:") {
496 lower
497 .strip_prefix("content-length:")
498 .and_then(|v| v.trim().parse().ok())
499 } else {
500 None
501 }
502 })
503 .unwrap_or(0);
504
505 // Read exactly content_length bytes for the body
506 let mut body_buf = vec![0u8; content_length];
507 if content_length > 0 {
508 let mut read_so_far = 0;
509 while read_so_far < content_length {
510 match stream.read(&mut body_buf[read_so_far..]) {
511 Ok(0) => break,
512 Ok(n) => read_so_far += n,
513 Err(_) => break,
514 }
515 }
516 }
517
518 let body_str = String::from_utf8_lossy(&body_buf);
519 format!("{}{}", headers_str, body_str)
520 }
521
522 /// Helper: send a raw HTTP request with a custom method and read the response.
523 fn send_raw_request(port: u16, request: &str) -> String {
524 send_and_read(port, request)
525 }
526
527 /// Helper: extract the HTTP body from a raw response (everything after \r\n\r\n).
528 fn extract_body(response: &str) -> &str {
529 response
530 .split("\r\n\r\n")
531 .nth(1)
532 .unwrap_or("")
533 }
534
535 /// Helper: extract the HTTP status line from a raw response.
536 fn extract_status(response: &str) -> &str {
537 response.lines().next().unwrap_or("")
538 }
539
540 #[test]
541 fn server_starts_on_random_port() {
542 let _lock = lock_server_tests();
543 let server = OAuthCallbackServer::start().unwrap();
544 assert!(server.port() > 0);
545 }
546
547 #[test]
548 fn server_returns_success_on_valid_callback() {
549 let _lock = lock_server_tests();
550 let server = OAuthCallbackServer::start().unwrap();
551 let port = server.port();
552
553 let response = send_request(port, "/?code=test_code_abc&state=test_state_xyz");
554 let body = extract_body(&response);
555
556 // Should return success HTML page
557 assert!(body.contains("Authorization Successful"));
558 assert!(body.contains("email account has been connected"));
559
560 // Should deliver result via channel
561 let result = server.wait_for_callback(Duration::from_secs(2));
562 let callback = result.unwrap().unwrap();
563 assert_eq!(callback.code, "test_code_abc");
564 assert_eq!(callback.state, "test_state_xyz");
565 }
566
567 #[test]
568 fn server_returns_error_on_oauth_error_callback() {
569 let _lock = lock_server_tests();
570 let server = OAuthCallbackServer::start().unwrap();
571 let port = server.port();
572
573 let response = send_request(
574 port,
575 "/?error=access_denied&error_description=User%20denied%20access",
576 );
577 let body = extract_body(&response);
578
579 // Should return error HTML page
580 assert!(body.contains("Authorization Failed"));
581 assert!(body.contains("User denied access"));
582
583 // Should deliver error via channel
584 let result = server.wait_for_callback(Duration::from_secs(2));
585 let err = result.unwrap().unwrap_err();
586 assert_eq!(err.error, "access_denied");
587 assert_eq!(err.error_description.as_deref(), Some("User denied access"));
588 }
589
590 #[test]
591 fn server_returns_error_without_description() {
592 let _lock = lock_server_tests();
593 let server = OAuthCallbackServer::start().unwrap();
594 let port = server.port();
595
596 let response = send_request(port, "/?error=server_error");
597 let body = extract_body(&response);
598
599 assert!(body.contains("Authorization Failed"));
600
601 let result = server.wait_for_callback(Duration::from_secs(2));
602 let err = result.unwrap().unwrap_err();
603 assert_eq!(err.error, "server_error");
604 assert!(err.error_description.is_none());
605 }
606
607 #[test]
608 fn server_rejects_non_get_methods() {
609 let _lock = lock_server_tests();
610 let server = OAuthCallbackServer::start().unwrap();
611 let port = server.port();
612
613 let response = send_raw_request(port, "POST / HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n");
614
615 assert!(extract_status(&response).contains("405"));
616 assert!(extract_body(&response).contains("Only GET is supported"));
617 }
618
619 #[test]
620 fn server_result_endpoint_returns_pending_initially() {
621 let _lock = lock_server_tests();
622 let server = OAuthCallbackServer::start().unwrap();
623 let port = server.port();
624
625 let response = send_request(port, "/result");
626 let body = extract_body(&response);
627
628 assert!(body.contains(r#""status":"pending"#));
629 }
630
631 #[test]
632 fn server_result_endpoint_returns_success_after_callback() {
633 let _lock = lock_server_tests();
634 let server = OAuthCallbackServer::start().unwrap();
635 let port = server.port();
636
637 // First, trigger the callback
638 send_request(port, "/?code=mycode&state=mystate");
639
640 // Small delay to let the server process
641 std::thread::sleep(Duration::from_millis(200));
642
643 // Then poll /result
644 let response = send_request(port, "/result");
645 let body = extract_body(&response);
646
647 assert!(body.contains(r#""status":"success"#));
648 assert!(body.contains(r#""code":"mycode"#));
649 assert!(body.contains(r#""state":"mystate"#));
650 }
651
652 #[test]
653 fn server_result_endpoint_returns_error_after_error_callback() {
654 let _lock = lock_server_tests();
655 let server = OAuthCallbackServer::start().unwrap();
656 let port = server.port();
657
658 // Trigger an error callback
659 send_request(port, "/?error=invalid_grant&error_description=Expired");
660
661 std::thread::sleep(Duration::from_millis(200));
662
663 // Then poll /result
664 let response = send_request(port, "/result");
665 let body = extract_body(&response);
666
667 assert!(body.contains(r#""status":"error"#));
668 assert!(body.contains(r#""error":"invalid_grant"#));
669 assert!(body.contains(r#""description":"Expired"#));
670 }
671
672 #[test]
673 fn server_timeout_returns_error() {
674 let _lock = lock_server_tests();
675 let server = OAuthCallbackServer::start().unwrap();
676
677 // Don't send any request, just wait with a very short timeout
678 let result = server.wait_for_callback(Duration::from_millis(50));
679 assert!(result.is_err());
680 assert!(result.unwrap_err().contains("Timeout"));
681 }
682
683 #[test]
684 fn server_ignores_request_with_no_code_or_error() {
685 let _lock = lock_server_tests();
686 let server = OAuthCallbackServer::start().unwrap();
687 let port = server.port();
688
689 // Send a request to root with no query params.
690 // The handler returns None early (no code/error), so no HTTP response is sent.
691 // We just connect and send the request without expecting a response.
692 let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)).unwrap();
693 stream.set_write_timeout(Some(Duration::from_secs(2))).unwrap();
694 let request = "GET / HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n";
695 stream.write_all(request.as_bytes()).unwrap();
696 stream.flush().unwrap();
697 drop(stream);
698
699 // Give the server a moment to process
700 std::thread::sleep(Duration::from_millis(100));
701
702 // The channel should still be empty (no callback delivered)
703 let result = server.wait_for_callback(Duration::from_millis(100));
704 assert!(result.is_err());
705 }
706
707 #[test]
708 fn server_handles_code_without_state() {
709 let _lock = lock_server_tests();
710 let server = OAuthCallbackServer::start().unwrap();
711 let port = server.port();
712
713 // Code present but state missing -- should not produce a callback
714 send_request(port, "/?code=only_code");
715
716 let result = server.wait_for_callback(Duration::from_millis(100));
717 assert!(result.is_err());
718 }
719
720 #[test]
721 fn server_handles_state_without_code() {
722 let _lock = lock_server_tests();
723 let server = OAuthCallbackServer::start().unwrap();
724 let port = server.port();
725
726 // State present but code missing -- should not produce a callback
727 send_request(port, "/?state=only_state");
728
729 let result = server.wait_for_callback(Duration::from_millis(100));
730 assert!(result.is_err());
731 }
732 }
733