Skip to main content

max / balanced_breakfast

20.1 KB · 581 lines History Blame Raw
1 //! SyncKit cloud sync commands.
2 //!
3 //! Provides Tauri commands for authenticating with the MNW sync service
4 //! via OAuth2 PKCE flow, managing encryption, manual sync, and settings.
5
6 use super::error::ApiError;
7 use crate::state::AppState;
8 use crate::sync_service;
9 use serde::{Deserialize, Serialize};
10 use std::sync::Arc;
11 use tauri::{Emitter, State};
12 use tracing::instrument;
13
14 // ── Helpers ──
15
16 /// Extract the sync client from state (clones the Arc for use across await points).
17 fn get_sync_client(state: &AppState) -> Option<Arc<synckit_client::SyncKitClient>> {
18 state.sync_client.read().clone()
19 }
20
21 fn require_sync_client(state: &AppState) -> Result<Arc<synckit_client::SyncKitClient>, ApiError> {
22 get_sync_client(state).ok_or_else(|| ApiError::bad_request("Sync is not configured"))
23 }
24
25 // ── Types ──
26
27 #[derive(Debug, Serialize)]
28 #[serde(rename_all = "camelCase")]
29 pub struct SyncStatusResponse {
30 pub configured: bool,
31 pub authenticated: bool,
32 pub encryption_ready: bool,
33 pub has_server_key: Option<bool>,
34 pub device_id: Option<String>,
35 pub auto_sync_enabled: bool,
36 pub sync_interval_minutes: u32,
37 pub last_sync_at: Option<String>,
38 pub pending_changes: i64,
39 }
40
41 #[derive(Debug, Serialize)]
42 #[serde(rename_all = "camelCase")]
43 pub struct SyncAuthStartResponse {
44 pub auth_url: String,
45 pub state: String,
46 pub code_verifier: String,
47 pub port: u16,
48 }
49
50 #[derive(Debug, Deserialize)]
51 #[serde(rename_all = "camelCase")]
52 pub struct SyncAuthCompleteInput {
53 pub code: String,
54 pub state: String,
55 pub expected_state: String,
56 pub code_verifier: String,
57 pub port: u16,
58 }
59
60 #[derive(Debug, Deserialize)]
61 #[serde(rename_all = "camelCase")]
62 pub struct SyncSettingsInput {
63 pub auto_sync_enabled: Option<bool>,
64 pub sync_interval_minutes: Option<u32>,
65 }
66
67 // ── PKCE helpers ──
68
69 fn generate_code_verifier() -> String {
70 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
71 use base64::Engine;
72 use rand::RngCore;
73
74 let mut bytes = [0u8; 32];
75 rand::thread_rng().fill_bytes(&mut bytes);
76 URL_SAFE_NO_PAD.encode(bytes)
77 }
78
79 fn generate_code_challenge(verifier: &str) -> String {
80 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
81 use base64::Engine;
82 use sha2::{Digest, Sha256};
83
84 let hash = Sha256::digest(verifier.as_bytes());
85 URL_SAFE_NO_PAD.encode(hash)
86 }
87
88 fn generate_state() -> String {
89 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
90 use base64::Engine;
91 use rand::RngCore;
92
93 let mut bytes = [0u8; 16];
94 rand::thread_rng().fill_bytes(&mut bytes);
95 URL_SAFE_NO_PAD.encode(bytes)
96 }
97
98 // ── Callback server ──
99
100 /// Shared flag to signal previous callback servers to stop.
101 static CALLBACK_CANCEL: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
102
103 /// Stored callback state for the /result polling endpoint.
104 #[derive(Clone)]
105 enum StoredCallback {
106 Pending,
107 Success { code: String, state: String },
108 Error { error: String },
109 }
110
111 /// Start a minimal HTTP server on a random port that waits for the OAuth redirect.
112 /// Returns the port. The server handles:
113 /// - The browser redirect with `?code=...&state=...` (stores result, returns HTML)
114 /// - `/result` polling endpoint (returns JSON: pending, success, or error)
115 /// Any previously running callback server is cancelled via the shared generation counter.
116 fn start_callback_server() -> Result<u16, ApiError> {
117 let listener = std::net::TcpListener::bind("127.0.0.1:0")
118 .map_err(|e| ApiError::internal(format!("Failed to bind callback server: {}", e)))?;
119 let port = listener
120 .local_addr()
121 .map_err(|e| ApiError::internal(format!("Failed to get callback port: {}", e)))?
122 .port();
123 listener
124 .set_nonblocking(true)
125 .map_err(|e| ApiError::internal(format!("Failed to set non-blocking: {}", e)))?;
126
127 // Increment generation to cancel any previous callback server thread
128 let generation = CALLBACK_CANCEL.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;
129
130 std::thread::spawn(move || {
131 use std::io::{Read, Write};
132 use std::sync::{Arc, Mutex};
133
134 let stored = Arc::new(Mutex::new(StoredCallback::Pending));
135 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(300);
136 let mut callback_received = false;
137
138 while std::time::Instant::now() < deadline {
139 if CALLBACK_CANCEL.load(std::sync::atomic::Ordering::Relaxed) != generation {
140 break;
141 }
142 match listener.accept() {
143 Ok((mut stream, _)) => {
144 let mut buf = [0u8; 4096];
145 let n = stream.read(&mut buf).unwrap_or(0);
146 let request = String::from_utf8_lossy(&buf[..n]);
147
148 let path = request
149 .lines()
150 .next()
151 .and_then(|line| line.split_whitespace().nth(1))
152 .unwrap_or("/");
153
154 let path_only = path.split('?').next().unwrap_or(path);
155
156 // Handle /result polling endpoint
157 if path_only == "/result" {
158 let json = match &*stored.lock().unwrap() {
159 StoredCallback::Pending => r#"{"status":"pending"}"#.to_string(),
160 StoredCallback::Success { code, state } => {
161 format!(
162 r#"{{"status":"success","code":"{}","state":"{}"}}"#,
163 code.replace('"', "\\\""),
164 state.replace('"', "\\\"")
165 )
166 }
167 StoredCallback::Error { error } => {
168 format!(
169 r#"{{"status":"error","error":"{}"}}"#,
170 error.replace('"', "\\\"")
171 )
172 }
173 };
174 let response = format!(
175 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}",
176 json.len(), json
177 );
178 let _ = stream.write_all(response.as_bytes());
179 let _ = stream.flush();
180 continue;
181 }
182
183 // Parse query parameters for the OAuth callback
184 let query = path.split('?').nth(1).unwrap_or("");
185 let mut code = None;
186 let mut cb_state = None;
187 let mut error = None;
188
189 for param in query.split('&') {
190 if let Some((key, value)) = param.split_once('=') {
191 match key {
192 "code" => code = Some(value.to_string()),
193 "state" => cb_state = Some(value.to_string()),
194 "error" => error = Some(value.to_string()),
195 _ => {}
196 }
197 }
198 }
199
200 if let Some(err) = error {
201 *stored.lock().unwrap() = StoredCallback::Error { error: err };
202 let body = "<html><body><h1>Authentication failed</h1><p>You can close this tab.</p></body></html>";
203 let response = format!(
204 "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}",
205 body.len(), body
206 );
207 let _ = stream.write_all(response.as_bytes());
208 let _ = stream.flush();
209 callback_received = true;
210 } else if let (Some(code), Some(state)) = (code, cb_state) {
211 *stored.lock().unwrap() = StoredCallback::Success {
212 code: code.clone(),
213 state: state.clone(),
214 };
215 let body = "<html><body><h1>Authenticated</h1><p>You can close this tab and return to Balanced Breakfast.</p></body></html>";
216 let response = format!(
217 "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}",
218 body.len(), body
219 );
220 let _ = stream.write_all(response.as_bytes());
221 let _ = stream.flush();
222 callback_received = true;
223 }
224 }
225 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
226 std::thread::sleep(std::time::Duration::from_millis(100));
227 }
228 Err(_) => break,
229 }
230
231 // After callback, keep serving /result for 30s then exit
232 if callback_received {
233 let poll_deadline = std::time::Instant::now() + std::time::Duration::from_secs(30);
234 while std::time::Instant::now() < poll_deadline {
235 if CALLBACK_CANCEL.load(std::sync::atomic::Ordering::Relaxed) != generation {
236 break;
237 }
238 match listener.accept() {
239 Ok((mut stream, _)) => {
240 let mut buf = [0u8; 4096];
241 let n = stream.read(&mut buf).unwrap_or(0);
242 let request = String::from_utf8_lossy(&buf[..n]);
243
244 let path = request
245 .lines()
246 .next()
247 .and_then(|line| line.split_whitespace().nth(1))
248 .unwrap_or("/");
249
250 if path.starts_with("/result") {
251 let json = match &*stored.lock().unwrap() {
252 StoredCallback::Pending => r#"{"status":"pending"}"#.to_string(),
253 StoredCallback::Success { code, state } => {
254 format!(
255 r#"{{"status":"success","code":"{}","state":"{}"}}"#,
256 code.replace('"', "\\\""),
257 state.replace('"', "\\\"")
258 )
259 }
260 StoredCallback::Error { error } => {
261 format!(
262 r#"{{"status":"error","error":"{}"}}"#,
263 error.replace('"', "\\\"")
264 )
265 }
266 };
267 let response = format!(
268 "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}",
269 json.len(), json
270 );
271 let _ = stream.write_all(response.as_bytes());
272 let _ = stream.flush();
273 }
274 }
275 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
276 std::thread::sleep(std::time::Duration::from_millis(100));
277 }
278 Err(_) => break,
279 }
280 }
281 break;
282 }
283 }
284 });
285
286 Ok(port)
287 }
288
289 // ── Commands ──
290
291 /// Fetch available pricing tiers for this app (no auth required, uses API key).
292 #[tauri::command]
293 #[instrument(skip_all)]
294 pub async fn sync_get_tiers(
295 state: State<'_, Arc<AppState>>,
296 ) -> Result<Vec<synckit_client::TierInfo>, ApiError> {
297 let client = require_sync_client(&state)?;
298 let tiers = client.get_available_tiers()
299 .await
300 .map_err(|e| ApiError::internal(format!("Failed to fetch tiers: {e}")))?;
301 Ok(tiers)
302 }
303
304 #[tauri::command]
305 #[instrument(skip_all)]
306 pub async fn sync_status(
307 state: State<'_, Arc<AppState>>,
308 ) -> Result<SyncStatusResponse, ApiError> {
309 let (configured, encryption_ready, has_server_key) = match get_sync_client(&state) {
310 Some(client) => {
311 let enc_ready = client.has_master_key();
312 let authenticated = client.session_info().is_some();
313 let server_key = if authenticated {
314 client.has_server_key().await.ok()
315 } else {
316 None
317 };
318 (true, enc_ready, server_key)
319 }
320 None => (false, false, None),
321 };
322
323 let authenticated = get_sync_client(&state)
324 .is_some_and(|c| c.session_info().is_some());
325
326 let pool = state.orchestrator.database().pool();
327
328 let device_id = sync_service::get_sync_state(pool, "device_id")
329 .await
330 .ok()
331 .filter(|s| !s.is_empty());
332
333 let auto_sync_enabled = sync_service::get_sync_state(pool, "auto_sync_enabled")
334 .await
335 .map(|v| v == "1")
336 .unwrap_or(false);
337
338 let sync_interval_minutes = sync_service::get_sync_state(pool, "sync_interval_minutes")
339 .await
340 .ok()
341 .and_then(|v| v.parse().ok())
342 .unwrap_or(15);
343
344 let last_sync_at = sync_service::get_sync_state(pool, "last_sync_at")
345 .await
346 .ok()
347 .filter(|s| !s.is_empty());
348
349 let pending_changes = sync_service::count_pending_changes(pool)
350 .await
351 .unwrap_or(0);
352
353 Ok(SyncStatusResponse {
354 configured,
355 authenticated,
356 encryption_ready,
357 has_server_key,
358 device_id,
359 auto_sync_enabled,
360 sync_interval_minutes,
361 last_sync_at,
362 pending_changes,
363 })
364 }
365
366 #[tauri::command]
367 #[instrument(skip_all)]
368 pub async fn sync_start_auth(
369 state: State<'_, Arc<AppState>>,
370 ) -> Result<SyncAuthStartResponse, ApiError> {
371 let client = require_sync_client(&state)?;
372
373 let code_verifier = generate_code_verifier();
374 let code_challenge = generate_code_challenge(&code_verifier);
375 let csrf_state = generate_state();
376
377 let port = start_callback_server()?;
378
379 let auth_url = client.build_authorize_url(port, &csrf_state, &code_challenge);
380
381 Ok(SyncAuthStartResponse {
382 auth_url,
383 state: csrf_state,
384 code_verifier,
385 port,
386 })
387 }
388
389 #[tauri::command]
390 #[instrument(skip_all)]
391 pub async fn sync_complete_auth(
392 state: State<'_, Arc<AppState>>,
393 input: SyncAuthCompleteInput,
394 ) -> Result<bool, ApiError> {
395 let client = require_sync_client(&state)?;
396
397 if input.state != input.expected_state {
398 return Err(ApiError::bad_request("OAuth state mismatch"));
399 }
400
401 client
402 .authenticate_with_code(&input.code, &input.code_verifier, input.port, "__internal__")
403 .await
404 .map_err(|e| ApiError::internal(format!("Token exchange failed: {}", e)))?;
405
406 // Try to load encryption key from keychain (may already exist on this device)
407 match client.try_load_key_from_keychain() {
408 Ok(true) => tracing::info!("Sync encryption key loaded from keychain"),
409 Ok(false) => tracing::debug!("No sync encryption key in keychain yet"),
410 Err(e) => tracing::warn!(error = %e, "Failed to load sync encryption key"),
411 }
412
413 Ok(true)
414 }
415
416 #[tauri::command]
417 #[instrument(skip_all)]
418 pub async fn sync_disconnect(
419 state: State<'_, Arc<AppState>>,
420 ) -> Result<bool, ApiError> {
421 // Clear in-memory session and master key
422 if let Some(client) = get_sync_client(&state) {
423 client.clear_session();
424 }
425
426 // Clear persisted sync state (device_id, cursors, flags, changelog)
427 let pool = state.orchestrator.database().pool();
428 sync_service::clear_all_sync_state(pool).await?;
429
430 Ok(true)
431 }
432
433 #[tauri::command]
434 #[instrument(skip_all)]
435 pub async fn sync_now(
436 state: State<'_, Arc<AppState>>,
437 app: tauri::AppHandle,
438 ) -> Result<sync_service::SyncResult, ApiError> {
439 let client = require_sync_client(&state)?;
440
441 if client.session_info().is_none() {
442 return Err(ApiError::bad_request("Not authenticated"));
443 }
444
445 if !client.has_master_key() {
446 return Err(ApiError::bad_request("Encryption not set up"));
447 }
448
449 let pool = state.orchestrator.database().pool();
450
451 // Prevent concurrent sync operations (manual + scheduler).
452 let _sync_guard = state.sync_mutex.lock().await;
453
454 // Create initial snapshot if needed
455 let snapshot_done = sync_service::get_sync_state(pool, "initial_snapshot_done")
456 .await
457 .unwrap_or_default();
458 if snapshot_done != "1" {
459 sync_service::create_initial_snapshot(pool)
460 .await
461 .map_err(|e| ApiError::internal(format!("Failed to create initial snapshot: {}", e)))?;
462 }
463
464 let result = sync_service::perform_sync(pool, &client).await?;
465
466 if result.pulled > 0 {
467 let _ = app.emit("sync:changes-applied", ());
468 }
469
470 // Cleanup after manual sync too
471 if let Err(e) = sync_service::cleanup_changelog(pool).await {
472 tracing::warn!(error = %e, "Sync changelog cleanup failed after manual sync");
473 }
474
475 Ok(result)
476 }
477
478 #[tauri::command]
479 #[instrument(skip_all)]
480 pub async fn sync_setup_encryption_new(
481 state: State<'_, Arc<AppState>>,
482 password: String,
483 ) -> Result<bool, ApiError> {
484 let client = require_sync_client(&state)?;
485
486 client
487 .setup_encryption_new(&password)
488 .await
489 .map_err(|e| ApiError::internal(format!("Encryption setup failed: {}", e)))?;
490
491 Ok(true)
492 }
493
494 #[tauri::command]
495 #[instrument(skip_all)]
496 pub async fn sync_setup_encryption_existing(
497 state: State<'_, Arc<AppState>>,
498 password: String,
499 ) -> Result<bool, ApiError> {
500 let client = require_sync_client(&state)?;
501
502 client
503 .setup_encryption_existing(&password)
504 .await
505 .map_err(|e| ApiError::internal(format!("Encryption setup failed: {}", e)))?;
506
507 Ok(true)
508 }
509
510 #[tauri::command]
511 #[instrument(skip_all)]
512 pub async fn sync_update_settings(
513 state: State<'_, Arc<AppState>>,
514 input: SyncSettingsInput,
515 ) -> Result<bool, ApiError> {
516 let pool = state.orchestrator.database().pool();
517
518 if let Some(enabled) = input.auto_sync_enabled {
519 sync_service::set_sync_state(
520 pool,
521 "auto_sync_enabled",
522 if enabled { "1" } else { "0" },
523 )
524 .await?;
525 }
526
527 if let Some(minutes) = input.sync_interval_minutes {
528 sync_service::set_sync_state(pool, "sync_interval_minutes", &minutes.to_string())
529 .await?;
530 }
531
532 Ok(true)
533 }
534
535 // ── Subscription Commands ──
536
537 /// Check subscription status for this user + app.
538 #[tauri::command]
539 #[instrument(skip_all)]
540 pub async fn sync_subscription_status(
541 state: State<'_, Arc<AppState>>,
542 ) -> Result<synckit_client::SubscriptionStatus, ApiError> {
543 let client = require_sync_client(&state)?;
544
545 if client.session_info().is_none() {
546 return Err(ApiError::bad_request("Not authenticated"));
547 }
548
549 client
550 .get_subscription_status()
551 .await
552 .map_err(|e| ApiError::internal(e.to_string()))
553 }
554
555 /// Create a Stripe Checkout session for subscribing to cloud sync.
556 /// Opens the checkout URL in the user's default browser.
557 #[tauri::command]
558 #[instrument(skip_all)]
559 pub async fn sync_subscribe(
560 state: State<'_, Arc<AppState>>,
561 interval: String,
562 ) -> Result<String, ApiError> {
563 let client = require_sync_client(&state)?;
564
565 if client.session_info().is_none() {
566 return Err(ApiError::bad_request("Not authenticated"));
567 }
568
569 let response = client
570 .create_subscription_checkout("standard", &interval)
571 .await
572 .map_err(|e| ApiError::internal(e.to_string()))?;
573
574 // Open in default browser
575 if let Err(e) = open::that(&response.checkout_url) {
576 tracing::warn!(error = %e, "Failed to open browser, returning URL");
577 }
578
579 Ok(response.checkout_url)
580 }
581