Skip to main content

max / audiofiles

7.5 KB · 210 lines History Blame Raw
1 //! PKCE helpers and localhost OAuth callback server.
2
3 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4 use base64::Engine;
5 use sha2::{Digest, Sha256};
6 use synckit_client::SyncKitClient;
7
8 use tracing::instrument;
9
10 use crate::error::{Result, SyncError};
11
12 /// Result of the localhost callback: authorization code + state.
13 pub struct CallbackResult {
14 pub code: String,
15 pub state: String,
16 }
17
18 /// Active auth session with callback receiver.
19 pub struct AuthSession {
20 pub auth_url: String,
21 pub code_verifier: String,
22 pub expected_state: String,
23 pub port: u16,
24 pub code_rx: tokio::sync::oneshot::Receiver<CallbackResult>,
25 }
26
27 /// Generate a random PKCE code verifier (32 bytes → base64url).
28 pub fn generate_code_verifier() -> String {
29 use rand::RngCore;
30 let mut buf = [0u8; 32];
31 rand::rng().fill_bytes(&mut buf);
32 URL_SAFE_NO_PAD.encode(buf)
33 }
34
35 /// Derive the PKCE code challenge from a verifier (SHA256 → base64url).
36 pub fn generate_code_challenge(verifier: &str) -> String {
37 let hash = Sha256::digest(verifier.as_bytes());
38 URL_SAFE_NO_PAD.encode(hash)
39 }
40
41 /// Generate a random CSRF state parameter (16 bytes → base64url).
42 pub fn generate_state() -> String {
43 use rand::RngCore;
44 let mut buf = [0u8; 16];
45 rand::rng().fill_bytes(&mut buf);
46 URL_SAFE_NO_PAD.encode(buf)
47 }
48
49 /// Start the OAuth2 PKCE flow: bind a localhost callback server, build the auth URL.
50 #[instrument(skip_all)]
51 pub fn start_auth(client: &SyncKitClient) -> Result<AuthSession> {
52 let code_verifier = generate_code_verifier();
53 let code_challenge = generate_code_challenge(&code_verifier);
54 let state = generate_state();
55
56 // Bind to a random available port
57 let listener = std::net::TcpListener::bind("127.0.0.1:0")
58 .map_err(SyncError::Io)?;
59 let port = listener
60 .local_addr()
61 .map_err(SyncError::Io)?
62 .port();
63
64 let auth_url = client.build_authorize_url(port, &state, &code_challenge);
65
66 let (tx, rx) = tokio::sync::oneshot::channel();
67 let thread_expected_state = state.clone();
68 let expected_state = state;
69
70 // Spawn blocking listener thread
71 std::thread::spawn(move || {
72 listener
73 .set_nonblocking(false)
74 .ok();
75
76 // 5-minute timeout
77 listener
78 .set_nonblocking(true)
79 .ok();
80
81 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(300);
82
83 loop {
84 if std::time::Instant::now() > deadline {
85 tracing::warn!("Auth callback timed out after 5 minutes");
86 break;
87 }
88
89 match listener.accept() {
90 Ok((mut stream, _)) => {
91 use std::io::{Read, Write};
92
93 let mut buf = [0u8; 2048];
94 let n = stream.read(&mut buf).unwrap_or(0);
95 let request = String::from_utf8_lossy(&buf[..n]);
96
97 // Parse GET /callback?code=X&state=Y or GET /?code=X&state=Y
98 if let Some(query_start) = request.find('?') {
99 let after_q = query_start + 1;
100 let query_end = request[query_start..]
101 .find(' ')
102 .unwrap_or(request.len().saturating_sub(query_start));
103 let end = query_start + query_end;
104
105 let query = match request.get(after_q..end) {
106 Some(q) => q,
107 None => {
108 tracing::warn!("Malformed OAuth callback request, could not parse query string");
109 break;
110 }
111 };
112
113 let mut code = None;
114 let mut cb_state = None;
115
116 for param in query.split('&') {
117 if let Some((key, value)) = param.split_once('=') {
118 match key {
119 "code" => code = Some(percent_decode(value)),
120 "state" => cb_state = Some(percent_decode(value)),
121 _ => {}
122 }
123 }
124 }
125
126 if let (Some(code), Some(state)) = (code, cb_state) {
127 if state != thread_expected_state {
128 tracing::warn!("OAuth callback state mismatch — rejecting");
129 let body = "<html><body><h1>Authentication failed</h1><p>CSRF state mismatch. Please try again.</p></body></html>";
130 let response = format!(
131 "HTTP/1.1 403 Forbidden\r\nContent-Type: text/html\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}",
132 body.len(),
133 body
134 );
135 let _ = stream.write_all(response.as_bytes());
136 break;
137 }
138
139 let body = "<html><body><h1>Authentication successful</h1><p>You can close this tab and return to audiofiles.</p></body></html>";
140 let response = format!(
141 "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}",
142 body.len(),
143 body
144 );
145 let _ = stream.write_all(response.as_bytes());
146 let _ = tx.send(CallbackResult { code, state });
147 }
148 }
149 break;
150 }
151 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
152 std::thread::sleep(std::time::Duration::from_millis(100));
153 continue;
154 }
155 Err(_) => break,
156 }
157 }
158 });
159
160 Ok(AuthSession {
161 auth_url,
162 code_verifier,
163 expected_state,
164 port,
165 code_rx: rx,
166 })
167 }
168
169 /// Decode percent-encoded UTF-8 strings (RFC 3986).
170 /// Passes through invalid sequences unchanged.
171 fn percent_decode(input: &str) -> String {
172 let mut bytes = Vec::with_capacity(input.len());
173 let mut chars = input.bytes();
174 while let Some(b) = chars.next() {
175 if b == b'%' {
176 let hi = chars.next();
177 let lo = chars.next();
178 if let (Some(hi), Some(lo)) = (hi, lo) {
179 if let (Some(h), Some(l)) = (hex_val(hi), hex_val(lo)) {
180 bytes.push(h << 4 | l);
181 continue;
182 }
183 // Invalid hex pair — emit literally
184 bytes.push(b'%');
185 bytes.push(hi);
186 bytes.push(lo);
187 } else {
188 bytes.push(b'%');
189 if let Some(hi) = hi {
190 bytes.push(hi);
191 }
192 }
193 } else if b == b'+' {
194 bytes.push(b' ');
195 } else {
196 bytes.push(b);
197 }
198 }
199 String::from_utf8(bytes).unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned())
200 }
201
202 fn hex_val(b: u8) -> Option<u8> {
203 match b {
204 b'0'..=b'9' => Some(b - b'0'),
205 b'a'..=b'f' => Some(b - b'a' + 10),
206 b'A'..=b'F' => Some(b - b'A' + 10),
207 _ => None,
208 }
209 }
210