use serde_json::Value as JsonValue; use sqlx::PgPool; use crate::db::{SyncAppId, SyncDeviceId, UserId}; use crate::error::Result; // ── Key Rotation ── /// Begin a key rotation. Returns the rotation row if created, or the existing /// row if this device already has an active rotation (resume support). /// /// Returns `None` if the key_version doesn't match (caller should 409). /// Returns `Err` if a different device has an active rotation. #[tracing::instrument(skip_all)] pub async fn begin_key_rotation( pool: &PgPool, app_id: SyncAppId, user_id: UserId, device_id: SyncDeviceId, new_encrypted_key: &str, expected_key_version: i32, ) -> Result> { // Verify key_version matches let key_row: Option<(i32, i32)> = sqlx::query_as( "SELECT key_version, key_id FROM sync_keys WHERE app_id = $1 AND user_id = $2", ) .bind(app_id) .bind(user_id) .fetch_optional(pool) .await?; let Some((current_version, current_key_id)) = key_row else { return Ok(Err("no encryption key exists")); }; if current_version != expected_key_version { return Ok(Err("key version mismatch")); } // Check for existing rotation let existing = sqlx::query_as::<_, crate::db::models::DbSyncKeyRotation>( "SELECT * FROM sync_key_rotations WHERE app_id = $1 AND user_id = $2", ) .bind(app_id) .bind(user_id) .fetch_optional(pool) .await?; if let Some(rotation) = existing { if rotation.device_id == device_id { // Same device resuming — return existing rotation return Ok(Ok(rotation)); } return Ok(Err("rotation already in progress by another device")); } // Get target_seq (max seq for this user) let target_seq: i64 = sqlx::query_scalar( "SELECT COALESCE(MAX(seq), 0) FROM sync_log WHERE app_id = $1 AND user_id = $2", ) .bind(app_id) .bind(user_id) .fetch_one(pool) .await?; let new_key_id = current_key_id + 1; let rotation = sqlx::query_as::<_, crate::db::models::DbSyncKeyRotation>( r#" INSERT INTO sync_key_rotations (app_id, user_id, device_id, new_encrypted_key, old_key_version, new_key_id, target_seq) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING * "#, ) .bind(app_id) .bind(user_id) .bind(device_id) .bind(new_encrypted_key) .bind(current_version) .bind(new_key_id) .bind(target_seq) .fetch_one(pool) .await?; Ok(Ok(rotation)) } /// Get the active rotation for a user, if any. #[tracing::instrument(skip_all)] pub async fn get_key_rotation( pool: &PgPool, app_id: SyncAppId, user_id: UserId, ) -> Result> { let rotation = sqlx::query_as::<_, crate::db::models::DbSyncKeyRotation>( "SELECT * FROM sync_key_rotations WHERE app_id = $1 AND user_id = $2", ) .bind(app_id) .bind(user_id) .fetch_optional(pool) .await?; Ok(rotation) } /// Pull sync log entries that need re-encryption (key_id != new_key_id). /// Returns entries ordered by seq, paginated by after_seq. #[tracing::instrument(skip_all)] pub async fn get_rotation_entries( pool: &PgPool, app_id: SyncAppId, user_id: UserId, new_key_id: i32, after_seq: i64, limit: i64, ) -> Result)>> { let entries: Vec<(i64, Option)> = sqlx::query_as( r#" SELECT seq, data FROM sync_log WHERE app_id = $1 AND user_id = $2 AND seq > $3 AND (key_id IS NULL OR key_id != $5) ORDER BY seq ASC LIMIT $4 "#, ) .bind(app_id) .bind(user_id) .bind(after_seq) .bind(limit) .bind(new_key_id) .fetch_all(pool) .await?; Ok(entries) } /// Batch-update re-encrypted sync log entries during key rotation. /// Sets data and key_id for each (seq) in the batch. /// Returns the number of rows updated. #[tracing::instrument(skip_all)] pub async fn submit_rotation_batch( pool: &PgPool, app_id: SyncAppId, user_id: UserId, rotation_id: uuid::Uuid, new_key_id: i32, entries: &[(i64, Option)], ) -> Result { if entries.is_empty() { return Ok(0); } let mut seqs: Vec = Vec::with_capacity(entries.len()); let mut data_values: Vec = Vec::with_capacity(entries.len()); for (seq, data) in entries { seqs.push(*seq); data_values.push(data.clone().unwrap_or(JsonValue::Null)); } let mut tx = pool.begin().await?; let updated = sqlx::query( r#" UPDATE sync_log AS sl SET data = CASE WHEN batch.new_data = 'null'::jsonb THEN NULL ELSE batch.new_data END, key_id = $3 FROM UNNEST($4::bigint[], $5::jsonb[]) AS batch(seq, new_data) WHERE sl.seq = batch.seq AND sl.app_id = $1 AND sl.user_id = $2 "#, ) .bind(app_id) .bind(user_id) .bind(new_key_id) .bind(&seqs) .bind(&data_values) .execute(&mut *tx) .await?; // Update progress marker if let Some(&max_seq) = seqs.iter().max() { sqlx::query( r#" UPDATE sync_key_rotations SET re_encrypted_through_seq = GREATEST(re_encrypted_through_seq, $1), updated_at = NOW() WHERE id = $2 "#, ) .bind(max_seq) .bind(rotation_id) .execute(&mut *tx) .await?; } tx.commit().await?; Ok(updated.rows_affected()) } /// Complete a key rotation: swap the new key into sync_keys and delete the rotation. /// Returns `Err("entries remain")` if un-rotated entries still exist, with the count. #[tracing::instrument(skip_all)] pub async fn complete_key_rotation( pool: &PgPool, app_id: SyncAppId, user_id: UserId, rotation_id: uuid::Uuid, ) -> Result> { let rotation = sqlx::query_as::<_, crate::db::models::DbSyncKeyRotation>( "SELECT * FROM sync_key_rotations WHERE id = $1 AND app_id = $2 AND user_id = $3", ) .bind(rotation_id) .bind(app_id) .bind(user_id) .fetch_optional(pool) .await?; let Some(rotation) = rotation else { return Ok(Err(0)); // No rotation found }; // Check for remaining un-rotated entries up to the target_seq captured at // rotation start. Entries arriving after rotation began are excluded — they // will use the new key once rotation completes. let remaining: i64 = sqlx::query_scalar( r#" SELECT COUNT(*) FROM sync_log WHERE app_id = $1 AND user_id = $2 AND seq <= $4 AND (key_id IS NULL OR key_id != $3) "#, ) .bind(app_id) .bind(user_id) .bind(rotation.new_key_id) .bind(rotation.target_seq) .fetch_one(pool) .await?; if remaining > 0 { return Ok(Err(remaining)); } // Atomically swap key and delete rotation let mut tx = pool.begin().await?; sqlx::query( r#" UPDATE sync_keys SET encrypted_key = $3, key_version = key_version + 1, key_id = $4, updated_at = NOW() WHERE app_id = $1 AND user_id = $2 "#, ) .bind(app_id) .bind(user_id) .bind(&rotation.new_encrypted_key) .bind(rotation.new_key_id) .execute(&mut *tx) .await?; sqlx::query("DELETE FROM sync_key_rotations WHERE id = $1") .bind(rotation_id) .execute(&mut *tx) .await?; tx.commit().await?; Ok(Ok(rotation.new_key_id)) } /// Cancel a stale rotation (only if older than the stale threshold). /// Returns true if cancelled, false if not found or not stale. #[tracing::instrument(skip_all)] pub async fn cancel_stale_rotation( pool: &PgPool, app_id: SyncAppId, user_id: UserId, stale_hours: i64, ) -> Result { let result = sqlx::query( r#" DELETE FROM sync_key_rotations WHERE app_id = $1 AND user_id = $2 AND updated_at < NOW() - make_interval(hours => $3) "#, ) .bind(app_id) .bind(user_id) .bind(stale_hours) .execute(pool) .await?; Ok(result.rows_affected() > 0) } /// Get sync status: total changes and latest seq for a user within an app. #[tracing::instrument(skip_all)] pub async fn get_sync_status( pool: &PgPool, app_id: SyncAppId, user_id: UserId, ) -> Result<(i64, Option)> { let row: (i64, Option) = sqlx::query_as( r#" SELECT COUNT(*), MAX(seq) FROM sync_log WHERE app_id = $1 AND user_id = $2 "#, ) .bind(app_id) .bind(user_id) .fetch_one(pool) .await?; Ok(row) }