//! SyncKit cloud sync commands. //! //! Provides Tauri commands for authenticating with the MNW sync service //! via OAuth2 PKCE flow, managing encryption, manual sync, and settings. use super::error::ApiError; use crate::state::AppState; use crate::sync_service; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tauri::{Emitter, State}; use tracing::instrument; // ── Helpers ── /// Extract the sync client from state (clones the Arc for use across await points). fn get_sync_client(state: &AppState) -> Option> { state.sync_client.read().clone() } fn require_sync_client(state: &AppState) -> Result, ApiError> { get_sync_client(state).ok_or_else(|| ApiError::bad_request("Sync is not configured")) } // ── Types ── #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub struct SyncStatusResponse { pub configured: bool, pub authenticated: bool, pub encryption_ready: bool, pub has_server_key: Option, pub device_id: Option, pub auto_sync_enabled: bool, pub sync_interval_minutes: u32, pub last_sync_at: Option, pub pending_changes: i64, } #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] pub struct SyncAuthStartResponse { pub auth_url: String, pub state: String, pub code_verifier: String, pub port: u16, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SyncAuthCompleteInput { pub code: String, pub state: String, pub expected_state: String, pub code_verifier: String, pub port: u16, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SyncSettingsInput { pub auto_sync_enabled: Option, pub sync_interval_minutes: Option, } // ── PKCE helpers ── fn generate_code_verifier() -> String { use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; use rand::RngCore; let mut bytes = [0u8; 32]; rand::thread_rng().fill_bytes(&mut bytes); URL_SAFE_NO_PAD.encode(bytes) } fn generate_code_challenge(verifier: &str) -> String { use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; use sha2::{Digest, Sha256}; let hash = Sha256::digest(verifier.as_bytes()); URL_SAFE_NO_PAD.encode(hash) } fn generate_state() -> String { use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; use rand::RngCore; let mut bytes = [0u8; 16]; rand::thread_rng().fill_bytes(&mut bytes); URL_SAFE_NO_PAD.encode(bytes) } // ── Callback server ── /// Shared flag to signal previous callback servers to stop. static CALLBACK_CANCEL: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); /// Stored callback state for the /result polling endpoint. #[derive(Clone)] enum StoredCallback { Pending, Success { code: String, state: String }, Error { error: String }, } /// Start a minimal HTTP server on a random port that waits for the OAuth redirect. /// Returns the port. The server handles: /// - The browser redirect with `?code=...&state=...` (stores result, returns HTML) /// - `/result` polling endpoint (returns JSON: pending, success, or error) /// Any previously running callback server is cancelled via the shared generation counter. fn start_callback_server() -> Result { let listener = std::net::TcpListener::bind("127.0.0.1:0") .map_err(|e| ApiError::internal(format!("Failed to bind callback server: {}", e)))?; let port = listener .local_addr() .map_err(|e| ApiError::internal(format!("Failed to get callback port: {}", e)))? .port(); listener .set_nonblocking(true) .map_err(|e| ApiError::internal(format!("Failed to set non-blocking: {}", e)))?; // Increment generation to cancel any previous callback server thread let generation = CALLBACK_CANCEL.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1; std::thread::spawn(move || { use std::io::{Read, Write}; use std::sync::{Arc, Mutex}; let stored = Arc::new(Mutex::new(StoredCallback::Pending)); let deadline = std::time::Instant::now() + std::time::Duration::from_secs(300); let mut callback_received = false; while std::time::Instant::now() < deadline { if CALLBACK_CANCEL.load(std::sync::atomic::Ordering::Relaxed) != generation { break; } match listener.accept() { Ok((mut stream, _)) => { let mut buf = [0u8; 4096]; let n = stream.read(&mut buf).unwrap_or(0); let request = String::from_utf8_lossy(&buf[..n]); let path = request .lines() .next() .and_then(|line| line.split_whitespace().nth(1)) .unwrap_or("/"); let path_only = path.split('?').next().unwrap_or(path); // Handle /result polling endpoint if path_only == "/result" { let json = match &*stored.lock().unwrap() { StoredCallback::Pending => r#"{"status":"pending"}"#.to_string(), StoredCallback::Success { code, state } => { format!( r#"{{"status":"success","code":"{}","state":"{}"}}"#, code.replace('"', "\\\""), state.replace('"', "\\\"") ) } StoredCallback::Error { error } => { format!( r#"{{"status":"error","error":"{}"}}"#, error.replace('"', "\\\"") ) } }; let response = format!( "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}", json.len(), json ); let _ = stream.write_all(response.as_bytes()); let _ = stream.flush(); continue; } // Parse query parameters for the OAuth callback let query = path.split('?').nth(1).unwrap_or(""); let mut code = None; let mut cb_state = None; let mut error = None; for param in query.split('&') { if let Some((key, value)) = param.split_once('=') { match key { "code" => code = Some(value.to_string()), "state" => cb_state = Some(value.to_string()), "error" => error = Some(value.to_string()), _ => {} } } } if let Some(err) = error { *stored.lock().unwrap() = StoredCallback::Error { error: err }; let body = "

Authentication failed

You can close this tab.

"; let response = format!( "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}", body.len(), body ); let _ = stream.write_all(response.as_bytes()); let _ = stream.flush(); callback_received = true; } else if let (Some(code), Some(state)) = (code, cb_state) { *stored.lock().unwrap() = StoredCallback::Success { code: code.clone(), state: state.clone(), }; let body = "

Authenticated

You can close this tab and return to Balanced Breakfast.

"; let response = format!( "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}", body.len(), body ); let _ = stream.write_all(response.as_bytes()); let _ = stream.flush(); callback_received = true; } } Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { std::thread::sleep(std::time::Duration::from_millis(100)); } Err(_) => break, } // After callback, keep serving /result for 30s then exit if callback_received { let poll_deadline = std::time::Instant::now() + std::time::Duration::from_secs(30); while std::time::Instant::now() < poll_deadline { if CALLBACK_CANCEL.load(std::sync::atomic::Ordering::Relaxed) != generation { break; } match listener.accept() { Ok((mut stream, _)) => { let mut buf = [0u8; 4096]; let n = stream.read(&mut buf).unwrap_or(0); let request = String::from_utf8_lossy(&buf[..n]); let path = request .lines() .next() .and_then(|line| line.split_whitespace().nth(1)) .unwrap_or("/"); if path.starts_with("/result") { let json = match &*stored.lock().unwrap() { StoredCallback::Pending => r#"{"status":"pending"}"#.to_string(), StoredCallback::Success { code, state } => { format!( r#"{{"status":"success","code":"{}","state":"{}"}}"#, code.replace('"', "\\\""), state.replace('"', "\\\"") ) } StoredCallback::Error { error } => { format!( r#"{{"status":"error","error":"{}"}}"#, error.replace('"', "\\\"") ) } }; let response = format!( "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nAccess-Control-Allow-Origin: *\r\nConnection: close\r\n\r\n{}", json.len(), json ); let _ = stream.write_all(response.as_bytes()); let _ = stream.flush(); } } Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { std::thread::sleep(std::time::Duration::from_millis(100)); } Err(_) => break, } } break; } } }); Ok(port) } // ── Commands ── /// Fetch available pricing tiers for this app (no auth required, uses API key). #[tauri::command] #[instrument(skip_all)] pub async fn sync_get_tiers( state: State<'_, Arc>, ) -> Result, ApiError> { let client = require_sync_client(&state)?; let tiers = client.get_available_tiers() .await .map_err(|e| ApiError::internal(format!("Failed to fetch tiers: {e}")))?; Ok(tiers) } #[tauri::command] #[instrument(skip_all)] pub async fn sync_status( state: State<'_, Arc>, ) -> Result { let (configured, encryption_ready, has_server_key) = match get_sync_client(&state) { Some(client) => { let enc_ready = client.has_master_key(); let authenticated = client.session_info().is_some(); let server_key = if authenticated { client.has_server_key().await.ok() } else { None }; (true, enc_ready, server_key) } None => (false, false, None), }; let authenticated = get_sync_client(&state) .is_some_and(|c| c.session_info().is_some()); let pool = state.orchestrator.database().pool(); let device_id = sync_service::get_sync_state(pool, "device_id") .await .ok() .filter(|s| !s.is_empty()); let auto_sync_enabled = sync_service::get_sync_state(pool, "auto_sync_enabled") .await .map(|v| v == "1") .unwrap_or(false); let sync_interval_minutes = sync_service::get_sync_state(pool, "sync_interval_minutes") .await .ok() .and_then(|v| v.parse().ok()) .unwrap_or(15); let last_sync_at = sync_service::get_sync_state(pool, "last_sync_at") .await .ok() .filter(|s| !s.is_empty()); let pending_changes = sync_service::count_pending_changes(pool) .await .unwrap_or(0); Ok(SyncStatusResponse { configured, authenticated, encryption_ready, has_server_key, device_id, auto_sync_enabled, sync_interval_minutes, last_sync_at, pending_changes, }) } #[tauri::command] #[instrument(skip_all)] pub async fn sync_start_auth( state: State<'_, Arc>, ) -> Result { let client = require_sync_client(&state)?; let code_verifier = generate_code_verifier(); let code_challenge = generate_code_challenge(&code_verifier); let csrf_state = generate_state(); let port = start_callback_server()?; let auth_url = client.build_authorize_url(port, &csrf_state, &code_challenge); Ok(SyncAuthStartResponse { auth_url, state: csrf_state, code_verifier, port, }) } #[tauri::command] #[instrument(skip_all)] pub async fn sync_complete_auth( state: State<'_, Arc>, input: SyncAuthCompleteInput, ) -> Result { let client = require_sync_client(&state)?; if input.state != input.expected_state { return Err(ApiError::bad_request("OAuth state mismatch")); } client .authenticate_with_code(&input.code, &input.code_verifier, input.port, "__internal__") .await .map_err(|e| ApiError::internal(format!("Token exchange failed: {}", e)))?; // Try to load encryption key from keychain (may already exist on this device) match client.try_load_key_from_keychain() { Ok(true) => tracing::info!("Sync encryption key loaded from keychain"), Ok(false) => tracing::debug!("No sync encryption key in keychain yet"), Err(e) => tracing::warn!(error = %e, "Failed to load sync encryption key"), } Ok(true) } #[tauri::command] #[instrument(skip_all)] pub async fn sync_disconnect( state: State<'_, Arc>, ) -> Result { // Clear in-memory session and master key if let Some(client) = get_sync_client(&state) { client.clear_session(); } // Clear persisted sync state (device_id, cursors, flags, changelog) let pool = state.orchestrator.database().pool(); sync_service::clear_all_sync_state(pool).await?; Ok(true) } #[tauri::command] #[instrument(skip_all)] pub async fn sync_now( state: State<'_, Arc>, app: tauri::AppHandle, ) -> Result { let client = require_sync_client(&state)?; if client.session_info().is_none() { return Err(ApiError::bad_request("Not authenticated")); } if !client.has_master_key() { return Err(ApiError::bad_request("Encryption not set up")); } let pool = state.orchestrator.database().pool(); // Prevent concurrent sync operations (manual + scheduler). let _sync_guard = state.sync_mutex.lock().await; // Create initial snapshot if needed let snapshot_done = sync_service::get_sync_state(pool, "initial_snapshot_done") .await .unwrap_or_default(); if snapshot_done != "1" { sync_service::create_initial_snapshot(pool) .await .map_err(|e| ApiError::internal(format!("Failed to create initial snapshot: {}", e)))?; } let result = sync_service::perform_sync(pool, &client).await?; if result.pulled > 0 { let _ = app.emit("sync:changes-applied", ()); } // Cleanup after manual sync too if let Err(e) = sync_service::cleanup_changelog(pool).await { tracing::warn!(error = %e, "Sync changelog cleanup failed after manual sync"); } Ok(result) } #[tauri::command] #[instrument(skip_all)] pub async fn sync_setup_encryption_new( state: State<'_, Arc>, password: String, ) -> Result { let client = require_sync_client(&state)?; client .setup_encryption_new(&password) .await .map_err(|e| ApiError::internal(format!("Encryption setup failed: {}", e)))?; Ok(true) } #[tauri::command] #[instrument(skip_all)] pub async fn sync_setup_encryption_existing( state: State<'_, Arc>, password: String, ) -> Result { let client = require_sync_client(&state)?; client .setup_encryption_existing(&password) .await .map_err(|e| ApiError::internal(format!("Encryption setup failed: {}", e)))?; Ok(true) } #[tauri::command] #[instrument(skip_all)] pub async fn sync_update_settings( state: State<'_, Arc>, input: SyncSettingsInput, ) -> Result { let pool = state.orchestrator.database().pool(); if let Some(enabled) = input.auto_sync_enabled { sync_service::set_sync_state( pool, "auto_sync_enabled", if enabled { "1" } else { "0" }, ) .await?; } if let Some(minutes) = input.sync_interval_minutes { sync_service::set_sync_state(pool, "sync_interval_minutes", &minutes.to_string()) .await?; } Ok(true) } // ── Subscription Commands ── /// Check subscription status for this user + app. #[tauri::command] #[instrument(skip_all)] pub async fn sync_subscription_status( state: State<'_, Arc>, ) -> Result { let client = require_sync_client(&state)?; if client.session_info().is_none() { return Err(ApiError::bad_request("Not authenticated")); } client .get_subscription_status() .await .map_err(|e| ApiError::internal(e.to_string())) } /// Create a Stripe Checkout session for subscribing to cloud sync. /// Opens the checkout URL in the user's default browser. #[tauri::command] #[instrument(skip_all)] pub async fn sync_subscribe( state: State<'_, Arc>, interval: String, ) -> Result { let client = require_sync_client(&state)?; if client.session_info().is_none() { return Err(ApiError::bad_request("Not authenticated")); } let response = client .create_subscription_checkout("standard", &interval) .await .map_err(|e| ApiError::internal(e.to_string()))?; // Open in default browser if let Err(e) = open::that(&response.checkout_url) { tracing::warn!(error = %e, "Failed to open browser, returning URL"); } Ok(response.checkout_url) }