| 100 |
100 |
|
/// Shared flag to signal previous callback servers to stop.
|
| 101 |
101 |
|
static CALLBACK_CANCEL: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
|
| 102 |
102 |
|
|
|
103 |
+ |
/// Stored callback state for the /result polling endpoint.
|
|
104 |
+ |
#[derive(Clone)]
|
|
105 |
+ |
enum StoredCallback {
|
|
106 |
+ |
Pending,
|
|
107 |
+ |
Success { code: String, state: String },
|
|
108 |
+ |
Error { error: String },
|
|
109 |
+ |
}
|
|
110 |
+ |
|
| 103 |
111 |
|
/// Start a minimal HTTP server on a random port that waits for the OAuth redirect.
|
| 104 |
|
- |
/// Returns the port. The server accepts one connection, parses the query string,
|
| 105 |
|
- |
/// responds with a success/error page, then shuts down.
|
|
112 |
+ |
/// Returns the port. The server handles:
|
|
113 |
+ |
/// - The browser redirect with `?code=...&state=...` (stores result, returns HTML)
|
|
114 |
+ |
/// - `/result` polling endpoint (returns JSON: pending, success, or error)
|
| 106 |
115 |
|
/// Any previously running callback server is cancelled via the shared generation counter.
|
| 107 |
116 |
|
fn start_callback_server() -> Result<u16, ApiError> {
|
| 108 |
117 |
|
let listener = std::net::TcpListener::bind("127.0.0.1:0")
|
| 120 |
129 |
|
|
| 121 |
130 |
|
std::thread::spawn(move || {
|
| 122 |
131 |
|
use std::io::{Read, Write};
|
|
132 |
+ |
use std::sync::{Arc, Mutex};
|
| 123 |
133 |
|
|
| 124 |
|
- |
let timeout = std::time::Instant::now() + std::time::Duration::from_secs(300);
|
|
134 |
+ |
let stored = Arc::new(Mutex::new(StoredCallback::Pending));
|
|
135 |
+ |
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(300);
|
|
136 |
+ |
let mut callback_received = false;
|
| 125 |
137 |
|
|
| 126 |
|
- |
while std::time::Instant::now() < timeout {
|
| 127 |
|
- |
// Check if a newer callback server has been started
|
|
138 |
+ |
while std::time::Instant::now() < deadline {
|
| 128 |
139 |
|
if CALLBACK_CANCEL.load(std::sync::atomic::Ordering::Relaxed) != generation {
|
| 129 |
140 |
|
break;
|
| 130 |
141 |
|
}
|
| 134 |
145 |
|
let n = stream.read(&mut buf).unwrap_or(0);
|
| 135 |
146 |
|
let request = String::from_utf8_lossy(&buf[..n]);
|
| 136 |
147 |
|
|
| 137 |
|
- |
// Parse GET /callback?code=xxx&state=xxx
|
| 138 |
|
- |
let has_code = request
|
|
148 |
+ |
let path = request
|
| 139 |
149 |
|
.lines()
|
| 140 |
150 |
|
.next()
|
| 141 |
151 |
|
.and_then(|line| line.split_whitespace().nth(1))
|
| 142 |
|
- |
.and_then(|path| path.split('?').nth(1))
|
| 143 |
|
- |
.map(|query| query.split('&').any(|pair| pair.starts_with("code=")))
|
| 144 |
|
- |
.unwrap_or(false);
|
|
152 |
+ |
.unwrap_or("/");
|
|
153 |
+ |
|
|
154 |
+ |
let path_only = path.split('?').next().unwrap_or(path);
|
|
155 |
+ |
|
|
156 |
+ |
// Handle /result polling endpoint
|
|
157 |
+ |
if path_only == "/result" {
|
|
158 |
+ |
let json = match &*stored.lock().unwrap() {
|
|
159 |
+ |
StoredCallback::Pending => r#"{"status":"pending"}"#.to_string(),
|
|
160 |
+ |
StoredCallback::Success { code, state } => {
|
|
161 |
+ |
format!(
|
|
162 |
+ |
r#"{{"status":"success","code":"{}","state":"{}"}}"#,
|
|
163 |
+ |
code.replace('"', "\\\""),
|
|
164 |
+ |
state.replace('"', "\\\"")
|
|
165 |
+ |
)
|
|
166 |
+ |
}
|
|
167 |
+ |
StoredCallback::Error { error } => {
|
|
168 |
+ |
format!(
|
|
169 |
+ |
r#"{{"status":"error","error":"{}"}}"#,
|
|
170 |
+ |
error.replace('"', "\\\"")
|
|
171 |
+ |
)
|
|
172 |
+ |
}
|
|
173 |
+ |
};
|
|
174 |
+ |
let response = format!(
|
|
175 |
+ |
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}",
|
|
176 |
+ |
json.len(), json
|
|
177 |
+ |
);
|
|
178 |
+ |
let _ = stream.write_all(response.as_bytes());
|
|
179 |
+ |
let _ = stream.flush();
|
|
180 |
+ |
continue;
|
|
181 |
+ |
}
|
|
182 |
+ |
|
|
183 |
+ |
// Parse query parameters for the OAuth callback
|
|
184 |
+ |
let query = path.split('?').nth(1).unwrap_or("");
|
|
185 |
+ |
let mut code = None;
|
|
186 |
+ |
let mut cb_state = None;
|
|
187 |
+ |
let mut error = None;
|
|
188 |
+ |
|
|
189 |
+ |
for param in query.split('&') {
|
|
190 |
+ |
if let Some((key, value)) = param.split_once('=') {
|
|
191 |
+ |
match key {
|
|
192 |
+ |
"code" => code = Some(value.to_string()),
|
|
193 |
+ |
"state" => cb_state = Some(value.to_string()),
|
|
194 |
+ |
"error" => error = Some(value.to_string()),
|
|
195 |
+ |
_ => {}
|
|
196 |
+ |
}
|
|
197 |
+ |
}
|
|
198 |
+ |
}
|
| 145 |
199 |
|
|
| 146 |
|
- |
if has_code {
|
|
200 |
+ |
if let Some(err) = error {
|
|
201 |
+ |
*stored.lock().unwrap() = StoredCallback::Error { error: err };
|
|
202 |
+ |
let body = "<html><body><h1>Authentication failed</h1><p>You can close this tab.</p></body></html>";
|
|
203 |
+ |
let response = format!(
|
|
204 |
+ |
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}",
|
|
205 |
+ |
body.len(), body
|
|
206 |
+ |
);
|
|
207 |
+ |
let _ = stream.write_all(response.as_bytes());
|
|
208 |
+ |
let _ = stream.flush();
|
|
209 |
+ |
callback_received = true;
|
|
210 |
+ |
} else if let (Some(code), Some(state)) = (code, cb_state) {
|
|
211 |
+ |
*stored.lock().unwrap() = StoredCallback::Success {
|
|
212 |
+ |
code: code.clone(),
|
|
213 |
+ |
state: state.clone(),
|
|
214 |
+ |
};
|
| 147 |
215 |
|
let body = "<html><body><h1>Authenticated</h1><p>You can close this tab and return to Balanced Breakfast.</p></body></html>";
|
| 148 |
216 |
|
let response = format!(
|
| 149 |
|
- |
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
|
217 |
+ |
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}",
|
| 150 |
218 |
|
body.len(), body
|
| 151 |
219 |
|
);
|
| 152 |
220 |
|
let _ = stream.write_all(response.as_bytes());
|
| 153 |
221 |
|
let _ = stream.flush();
|
| 154 |
|
- |
break;
|
|
222 |
+ |
callback_received = true;
|
| 155 |
223 |
|
}
|
| 156 |
|
- |
|
| 157 |
|
- |
let body = "Waiting for authentication...";
|
| 158 |
|
- |
let response = format!(
|
| 159 |
|
- |
"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
| 160 |
|
- |
body.len(), body
|
| 161 |
|
- |
);
|
| 162 |
|
- |
let _ = stream.write_all(response.as_bytes());
|
| 163 |
224 |
|
}
|
| 164 |
225 |
|
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
|
| 165 |
226 |
|
std::thread::sleep(std::time::Duration::from_millis(100));
|
| 166 |
227 |
|
}
|
| 167 |
228 |
|
Err(_) => break,
|
| 168 |
229 |
|
}
|
|
230 |
+ |
|
|
231 |
+ |
// After callback, keep serving /result for 30s then exit
|
|
232 |
+ |
if callback_received {
|
|
233 |
+ |
let poll_deadline = std::time::Instant::now() + std::time::Duration::from_secs(30);
|
|
234 |
+ |
while std::time::Instant::now() < poll_deadline {
|
|
235 |
+ |
if CALLBACK_CANCEL.load(std::sync::atomic::Ordering::Relaxed) != generation {
|
|
236 |
+ |
break;
|
|
237 |
+ |
}
|
|
238 |
+ |
match listener.accept() {
|
|
239 |
+ |
Ok((mut stream, _)) => {
|
|
240 |
+ |
let mut buf = [0u8; 4096];
|
|
241 |
+ |
let n = stream.read(&mut buf).unwrap_or(0);
|
|
242 |
+ |
let request = String::from_utf8_lossy(&buf[..n]);
|
|
243 |
+ |
|
|
244 |
+ |
let path = request
|
|
245 |
+ |
.lines()
|
|
246 |
+ |
.next()
|
|
247 |
+ |
.and_then(|line| line.split_whitespace().nth(1))
|
|
248 |
+ |
.unwrap_or("/");
|
|
249 |
+ |
|
|
250 |
+ |
if path.starts_with("/result") {
|
|
251 |
+ |
let json = match &*stored.lock().unwrap() {
|
|
252 |
+ |
StoredCallback::Pending => r#"{"status":"pending"}"#.to_string(),
|
|
253 |
+ |
StoredCallback::Success { code, state } => {
|
|
254 |
+ |
format!(
|
|
255 |
+ |
r#"{{"status":"success","code":"{}","state":"{}"}}"#,
|
|
256 |
+ |
code.replace('"', "\\\""),
|
|
257 |
+ |
state.replace('"', "\\\"")
|
|
258 |
+ |
)
|
|
259 |
+ |
}
|
|
260 |
+ |
StoredCallback::Error { error } => {
|
|
261 |
+ |
format!(
|
|
262 |
+ |
r#"{{"status":"error","error":"{}"}}"#,
|
|
263 |
+ |
error.replace('"', "\\\"")
|
|
264 |
+ |
)
|
|
265 |
+ |
}
|
|
266 |
+ |
};
|
|
267 |
+ |
let response = format!(
|
|
268 |
+ |
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}",
|
|
269 |
+ |
json.len(), json
|
|
270 |
+ |
);
|
|
271 |
+ |
let _ = stream.write_all(response.as_bytes());
|
|
272 |
+ |
let _ = stream.flush();
|
|
273 |
+ |
}
|
|
274 |
+ |
}
|
|
275 |
+ |
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
|
|
276 |
+ |
std::thread::sleep(std::time::Duration::from_millis(100));
|
|
277 |
+ |
}
|
|
278 |
+ |
Err(_) => break,
|
|
279 |
+ |
}
|
|
280 |
+ |
}
|
|
281 |
+ |
break;
|
|
282 |
+ |
}
|
| 169 |
283 |
|
}
|
| 170 |
284 |
|
});
|
| 171 |
285 |
|
|