use sqlx::PgPool; use crate::db::{SyncAppId, SyncDeviceId, UserId}; use crate::error::Result; // ── Sync Keys ── /// Upsert an encrypted master key for a user within an app with optimistic /// concurrency control. /// /// `expected_version` is checked on UPDATE: if the current `key_version` /// doesn't match, the update is skipped and `false` is returned (the caller /// should return 409 Conflict). On INSERT (first key), `expected_version` /// must be 0. /// /// Returns `true` if the key was inserted or updated, `false` on version mismatch. #[tracing::instrument(skip_all)] pub async fn upsert_sync_key( pool: &PgPool, app_id: SyncAppId, user_id: UserId, encrypted_key: &str, expected_version: i32, ) -> Result { let result = sqlx::query( r#" INSERT INTO sync_keys (app_id, user_id, encrypted_key) VALUES ($1, $2, $3) ON CONFLICT (app_id, user_id) DO UPDATE SET encrypted_key = EXCLUDED.encrypted_key, key_version = sync_keys.key_version + 1, updated_at = NOW() WHERE sync_keys.key_version = $4 "#, ) .bind(app_id) .bind(user_id) .bind(encrypted_key) .bind(expected_version) .execute(pool) .await?; Ok(result.rows_affected() > 0) } /// Encryption key info returned by `get_sync_key`. pub struct SyncKeyInfo { pub encrypted_key: String, pub key_version: i32, pub key_id: i32, /// If a key rotation is in progress, the new key envelope and its key_id. pub pending_key: Option<(String, i32)>, } /// Get the encrypted master key, version, key_id, and any pending rotation key. #[tracing::instrument(skip_all)] pub async fn get_sync_key( pool: &PgPool, app_id: SyncAppId, user_id: UserId, ) -> Result> { let row: Option<(String, i32, i32)> = sqlx::query_as( "SELECT encrypted_key, key_version, key_id FROM sync_keys WHERE app_id = $1 AND user_id = $2", ) .bind(app_id) .bind(user_id) .fetch_optional(pool) .await?; let Some((encrypted_key, key_version, key_id)) = row else { return Ok(None); }; // Check for an active rotation let pending: Option<(String, i32)> = sqlx::query_as( "SELECT new_encrypted_key, new_key_id FROM sync_key_rotations WHERE app_id = $1 AND user_id = $2", ) .bind(app_id) .bind(user_id) .fetch_optional(pool) .await?; Ok(Some(SyncKeyInfo { encrypted_key, key_version, key_id, pending_key: pending, })) } /// Get device count and sync log entry count for all apps owned by a creator. /// Returns Vec of (app_id, device_count, log_entry_count). Single query replaces N+1 loop. #[tracing::instrument(skip_all)] pub async fn get_sync_app_stats_batch( pool: &PgPool, creator_id: UserId, ) -> Result> { let rows: Vec<(SyncAppId, i64, i64)> = sqlx::query_as( r#" SELECT a.id, (SELECT COUNT(*) FROM sync_devices d WHERE d.app_id = a.id), (SELECT COUNT(*) FROM sync_log l WHERE l.app_id = a.id) FROM sync_apps a WHERE a.creator_id = $1 "#, ) .bind(creator_id) .fetch_all(pool) .await?; Ok(rows) } /// Delete sync log entries older than the given number of days. /// /// `retain_days` must be positive. Returns 0 immediately for non-positive values /// to prevent accidental deletion of all entries. #[tracing::instrument(skip_all)] pub async fn prune_sync_log(pool: &PgPool, retain_days: i64) -> Result { if retain_days <= 0 { tracing::warn!("prune_sync_log called with non-positive retain_days={retain_days}, skipping"); return Ok(0); } let result = sqlx::query( "DELETE FROM sync_log WHERE created_at < NOW() - make_interval(days => $1)", ) .bind(retain_days) .execute(pool) .await?; Ok(result.rows_affected()) } /// Update a device's last-pulled cursor position. #[tracing::instrument(skip_all)] pub async fn update_device_cursor( pool: &PgPool, device_id: SyncDeviceId, seq: i64, ) -> Result<()> { sqlx::query( "UPDATE sync_devices SET last_pulled_seq = GREATEST(last_pulled_seq, $1) WHERE id = $2", ) .bind(seq) .bind(device_id) .execute(pool) .await?; Ok(()) } /// Compact the sync log by removing entries that all devices for a given /// (app_id, user_id) have already pulled. Keeps a safety margin of entries /// newer than `min_age_days` regardless of cursor positions. /// /// Returns the number of entries deleted. #[allow(dead_code)] // Public API for targeted per-user compaction #[tracing::instrument(skip_all)] pub async fn compact_sync_log( pool: &PgPool, app_id: SyncAppId, user_id: UserId, min_age_days: i64, ) -> Result { if min_age_days <= 0 { return Ok(0); } // Find the lowest cursor across all devices for this user+app. // Entries below this seq have been pulled by every device. let min_cursor: Option = sqlx::query_scalar( "SELECT MIN(last_pulled_seq) FROM sync_devices WHERE app_id = $1 AND user_id = $2", ) .bind(app_id) .bind(user_id) .fetch_one(pool) .await?; let safe_seq = match min_cursor { Some(seq) if seq > 0 => seq, _ => return Ok(0), // No devices or no pulls yet }; // Delete entries below the safe cursor AND older than the safety margin. let result = sqlx::query( r#" DELETE FROM sync_log WHERE app_id = $1 AND user_id = $2 AND seq <= $3 AND created_at < NOW() - make_interval(days => $4) "#, ) .bind(app_id) .bind(user_id) .bind(safe_seq) .bind(min_age_days) .execute(pool) .await?; Ok(result.rows_affected()) } /// Compact sync logs for all user+app pairs that have compactable entries. /// Finds pairs where MIN(last_pulled_seq) across devices > 0, then deletes /// entries below that threshold (with age safety margin). /// /// Returns total entries deleted across all users. #[tracing::instrument(skip_all)] pub async fn compact_all_sync_logs(pool: &PgPool, min_age_days: i64) -> Result { if min_age_days <= 0 { return Ok(0); } // Find (app_id, user_id) pairs where compaction is possible: // all devices have pulled past seq 0, and there are old entries to delete. let pairs: Vec<(SyncAppId, UserId, i64)> = sqlx::query_as( r#" SELECT app_id, user_id, MIN(last_pulled_seq) AS min_seq FROM sync_devices WHERE last_pulled_seq > 0 GROUP BY app_id, user_id HAVING MIN(last_pulled_seq) > 0 "#, ) .fetch_all(pool) .await?; let mut total_deleted: u64 = 0; for (app_id, user_id, safe_seq) in pairs { let result = sqlx::query( r#" DELETE FROM sync_log WHERE app_id = $1 AND user_id = $2 AND seq <= $3 AND created_at < NOW() - make_interval(days => $4) "#, ) .bind(app_id) .bind(user_id) .bind(safe_seq) .bind(min_age_days) .execute(pool) .await?; total_deleted += result.rows_affected(); } Ok(total_deleted) }