| 1 |
|
| 2 |
|
| 3 |
use base64::engine::general_purpose::URL_SAFE_NO_PAD; |
| 4 |
use base64::Engine; |
| 5 |
use sha2::{Digest, Sha256}; |
| 6 |
use synckit_client::SyncKitClient; |
| 7 |
|
| 8 |
use tracing::instrument; |
| 9 |
|
| 10 |
use crate::error::{Result, SyncError}; |
| 11 |
|
| 12 |
|
| 13 |
pub struct CallbackResult { |
| 14 |
pub code: String, |
| 15 |
pub state: String, |
| 16 |
} |
| 17 |
|
| 18 |
|
| 19 |
pub struct AuthSession { |
| 20 |
pub auth_url: String, |
| 21 |
pub code_verifier: String, |
| 22 |
pub expected_state: String, |
| 23 |
pub port: u16, |
| 24 |
pub code_rx: tokio::sync::oneshot::Receiver<CallbackResult>, |
| 25 |
} |
| 26 |
|
| 27 |
|
| 28 |
pub fn generate_code_verifier() -> String { |
| 29 |
use rand::RngCore; |
| 30 |
let mut buf = [0u8; 32]; |
| 31 |
rand::rng().fill_bytes(&mut buf); |
| 32 |
URL_SAFE_NO_PAD.encode(buf) |
| 33 |
} |
| 34 |
|
| 35 |
|
| 36 |
pub fn generate_code_challenge(verifier: &str) -> String { |
| 37 |
let hash = Sha256::digest(verifier.as_bytes()); |
| 38 |
URL_SAFE_NO_PAD.encode(hash) |
| 39 |
} |
| 40 |
|
| 41 |
|
| 42 |
pub fn generate_state() -> String { |
| 43 |
use rand::RngCore; |
| 44 |
let mut buf = [0u8; 16]; |
| 45 |
rand::rng().fill_bytes(&mut buf); |
| 46 |
URL_SAFE_NO_PAD.encode(buf) |
| 47 |
} |
| 48 |
|
| 49 |
|
| 50 |
#[instrument(skip_all)] |
| 51 |
pub fn start_auth(client: &SyncKitClient) -> Result<AuthSession> { |
| 52 |
let code_verifier = generate_code_verifier(); |
| 53 |
let code_challenge = generate_code_challenge(&code_verifier); |
| 54 |
let state = generate_state(); |
| 55 |
|
| 56 |
|
| 57 |
let listener = std::net::TcpListener::bind("127.0.0.1:0") |
| 58 |
.map_err(SyncError::Io)?; |
| 59 |
let port = listener |
| 60 |
.local_addr() |
| 61 |
.map_err(SyncError::Io)? |
| 62 |
.port(); |
| 63 |
|
| 64 |
let auth_url = client.build_authorize_url(port, &state, &code_challenge); |
| 65 |
|
| 66 |
let (tx, rx) = tokio::sync::oneshot::channel(); |
| 67 |
let thread_expected_state = state.clone(); |
| 68 |
let expected_state = state; |
| 69 |
|
| 70 |
|
| 71 |
std::thread::spawn(move || { |
| 72 |
listener |
| 73 |
.set_nonblocking(false) |
| 74 |
.ok(); |
| 75 |
|
| 76 |
|
| 77 |
listener |
| 78 |
.set_nonblocking(true) |
| 79 |
.ok(); |
| 80 |
|
| 81 |
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(300); |
| 82 |
|
| 83 |
loop { |
| 84 |
if std::time::Instant::now() > deadline { |
| 85 |
tracing::warn!("Auth callback timed out after 5 minutes"); |
| 86 |
break; |
| 87 |
} |
| 88 |
|
| 89 |
match listener.accept() { |
| 90 |
Ok((mut stream, _)) => { |
| 91 |
use std::io::{Read, Write}; |
| 92 |
|
| 93 |
let mut buf = [0u8; 2048]; |
| 94 |
let n = stream.read(&mut buf).unwrap_or(0); |
| 95 |
let request = String::from_utf8_lossy(&buf[..n]); |
| 96 |
|
| 97 |
|
| 98 |
if let Some(query_start) = request.find('?') { |
| 99 |
let after_q = query_start + 1; |
| 100 |
let query_end = request[query_start..] |
| 101 |
.find(' ') |
| 102 |
.unwrap_or(request.len().saturating_sub(query_start)); |
| 103 |
let end = query_start + query_end; |
| 104 |
|
| 105 |
let query = match request.get(after_q..end) { |
| 106 |
Some(q) => q, |
| 107 |
None => { |
| 108 |
tracing::warn!("Malformed OAuth callback request, could not parse query string"); |
| 109 |
break; |
| 110 |
} |
| 111 |
}; |
| 112 |
|
| 113 |
let mut code = None; |
| 114 |
let mut cb_state = None; |
| 115 |
|
| 116 |
for param in query.split('&') { |
| 117 |
if let Some((key, value)) = param.split_once('=') { |
| 118 |
match key { |
| 119 |
"code" => code = Some(percent_decode(value)), |
| 120 |
"state" => cb_state = Some(percent_decode(value)), |
| 121 |
_ => {} |
| 122 |
} |
| 123 |
} |
| 124 |
} |
| 125 |
|
| 126 |
if let (Some(code), Some(state)) = (code, cb_state) { |
| 127 |
if state != thread_expected_state { |
| 128 |
tracing::warn!("OAuth callback state mismatch — rejecting"); |
| 129 |
let body = "<html><body><h1>Authentication failed</h1><p>CSRF state mismatch. Please try again.</p></body></html>"; |
| 130 |
let response = format!( |
| 131 |
"HTTP/1.1 403 Forbidden\r\nContent-Type: text/html\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}", |
| 132 |
body.len(), |
| 133 |
body |
| 134 |
); |
| 135 |
let _ = stream.write_all(response.as_bytes()); |
| 136 |
break; |
| 137 |
} |
| 138 |
|
| 139 |
let body = "<html><body><h1>Authentication successful</h1><p>You can close this tab and return to audiofiles.</p></body></html>"; |
| 140 |
let response = format!( |
| 141 |
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}", |
| 142 |
body.len(), |
| 143 |
body |
| 144 |
); |
| 145 |
let _ = stream.write_all(response.as_bytes()); |
| 146 |
let _ = tx.send(CallbackResult { code, state }); |
| 147 |
} |
| 148 |
} |
| 149 |
break; |
| 150 |
} |
| 151 |
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { |
| 152 |
std::thread::sleep(std::time::Duration::from_millis(100)); |
| 153 |
continue; |
| 154 |
} |
| 155 |
Err(_) => break, |
| 156 |
} |
| 157 |
} |
| 158 |
}); |
| 159 |
|
| 160 |
Ok(AuthSession { |
| 161 |
auth_url, |
| 162 |
code_verifier, |
| 163 |
expected_state, |
| 164 |
port, |
| 165 |
code_rx: rx, |
| 166 |
}) |
| 167 |
} |
| 168 |
|
| 169 |
|
| 170 |
|
| 171 |
fn percent_decode(input: &str) -> String { |
| 172 |
let mut bytes = Vec::with_capacity(input.len()); |
| 173 |
let mut chars = input.bytes(); |
| 174 |
while let Some(b) = chars.next() { |
| 175 |
if b == b'%' { |
| 176 |
let hi = chars.next(); |
| 177 |
let lo = chars.next(); |
| 178 |
if let (Some(hi), Some(lo)) = (hi, lo) { |
| 179 |
if let (Some(h), Some(l)) = (hex_val(hi), hex_val(lo)) { |
| 180 |
bytes.push(h << 4 | l); |
| 181 |
continue; |
| 182 |
} |
| 183 |
|
| 184 |
bytes.push(b'%'); |
| 185 |
bytes.push(hi); |
| 186 |
bytes.push(lo); |
| 187 |
} else { |
| 188 |
bytes.push(b'%'); |
| 189 |
if let Some(hi) = hi { |
| 190 |
bytes.push(hi); |
| 191 |
} |
| 192 |
} |
| 193 |
} else if b == b'+' { |
| 194 |
bytes.push(b' '); |
| 195 |
} else { |
| 196 |
bytes.push(b); |
| 197 |
} |
| 198 |
} |
| 199 |
String::from_utf8(bytes).unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned()) |
| 200 |
} |
| 201 |
|
| 202 |
fn hex_val(b: u8) -> Option<u8> { |
| 203 |
match b { |
| 204 |
b'0'..=b'9' => Some(b - b'0'), |
| 205 |
b'a'..=b'f' => Some(b - b'a' + 10), |
| 206 |
b'A'..=b'F' => Some(b - b'A' + 10), |
| 207 |
_ => None, |
| 208 |
} |
| 209 |
} |
| 210 |
|