Skip to main content

max / makenotwork

7.2 KB · 251 lines History Blame Raw
1 use sqlx::PgPool;
2
3 use crate::db::{SyncAppId, SyncDeviceId, UserId};
4 use crate::error::Result;
5
6 // ── Sync Keys ──
7
8 /// Upsert an encrypted master key for a user within an app with optimistic
9 /// concurrency control.
10 ///
11 /// `expected_version` is checked on UPDATE: if the current `key_version`
12 /// doesn't match, the update is skipped and `false` is returned (the caller
13 /// should return 409 Conflict). On INSERT (first key), `expected_version`
14 /// must be 0.
15 ///
16 /// Returns `true` if the key was inserted or updated, `false` on version mismatch.
17 #[tracing::instrument(skip_all)]
18 pub async fn upsert_sync_key(
19 pool: &PgPool,
20 app_id: SyncAppId,
21 user_id: UserId,
22 encrypted_key: &str,
23 expected_version: i32,
24 ) -> Result<bool> {
25 let result = sqlx::query(
26 r#"
27 INSERT INTO sync_keys (app_id, user_id, encrypted_key)
28 VALUES ($1, $2, $3)
29 ON CONFLICT (app_id, user_id)
30 DO UPDATE SET encrypted_key = EXCLUDED.encrypted_key,
31 key_version = sync_keys.key_version + 1,
32 updated_at = NOW()
33 WHERE sync_keys.key_version = $4
34 "#,
35 )
36 .bind(app_id)
37 .bind(user_id)
38 .bind(encrypted_key)
39 .bind(expected_version)
40 .execute(pool)
41 .await?;
42
43 Ok(result.rows_affected() > 0)
44 }
45
46 /// Encryption key info returned by `get_sync_key`.
47 pub struct SyncKeyInfo {
48 pub encrypted_key: String,
49 pub key_version: i32,
50 pub key_id: i32,
51 /// If a key rotation is in progress, the new key envelope and its key_id.
52 pub pending_key: Option<(String, i32)>,
53 }
54
55 /// Get the encrypted master key, version, key_id, and any pending rotation key.
56 #[tracing::instrument(skip_all)]
57 pub async fn get_sync_key(
58 pool: &PgPool,
59 app_id: SyncAppId,
60 user_id: UserId,
61 ) -> Result<Option<SyncKeyInfo>> {
62 let row: Option<(String, i32, i32)> = sqlx::query_as(
63 "SELECT encrypted_key, key_version, key_id FROM sync_keys WHERE app_id = $1 AND user_id = $2",
64 )
65 .bind(app_id)
66 .bind(user_id)
67 .fetch_optional(pool)
68 .await?;
69
70 let Some((encrypted_key, key_version, key_id)) = row else {
71 return Ok(None);
72 };
73
74 // Check for an active rotation
75 let pending: Option<(String, i32)> = sqlx::query_as(
76 "SELECT new_encrypted_key, new_key_id FROM sync_key_rotations WHERE app_id = $1 AND user_id = $2",
77 )
78 .bind(app_id)
79 .bind(user_id)
80 .fetch_optional(pool)
81 .await?;
82
83 Ok(Some(SyncKeyInfo {
84 encrypted_key,
85 key_version,
86 key_id,
87 pending_key: pending,
88 }))
89 }
90
91 /// Get device count and sync log entry count for all apps owned by a creator.
92 /// Returns Vec of (app_id, device_count, log_entry_count). Single query replaces N+1 loop.
93 #[tracing::instrument(skip_all)]
94 pub async fn get_sync_app_stats_batch(
95 pool: &PgPool,
96 creator_id: UserId,
97 ) -> Result<Vec<(SyncAppId, i64, i64)>> {
98 let rows: Vec<(SyncAppId, i64, i64)> = sqlx::query_as(
99 r#"
100 SELECT
101 a.id,
102 (SELECT COUNT(*) FROM sync_devices d WHERE d.app_id = a.id),
103 (SELECT COUNT(*) FROM sync_log l WHERE l.app_id = a.id)
104 FROM sync_apps a
105 WHERE a.creator_id = $1
106 "#,
107 )
108 .bind(creator_id)
109 .fetch_all(pool)
110 .await?;
111
112 Ok(rows)
113 }
114
115 /// Delete sync log entries older than the given number of days.
116 ///
117 /// `retain_days` must be positive. Returns 0 immediately for non-positive values
118 /// to prevent accidental deletion of all entries.
119 #[tracing::instrument(skip_all)]
120 pub async fn prune_sync_log(pool: &PgPool, retain_days: i64) -> Result<u64> {
121 if retain_days <= 0 {
122 tracing::warn!("prune_sync_log called with non-positive retain_days={retain_days}, skipping");
123 return Ok(0);
124 }
125 let result = sqlx::query(
126 "DELETE FROM sync_log WHERE created_at < NOW() - make_interval(days => $1)",
127 )
128 .bind(retain_days)
129 .execute(pool)
130 .await?;
131
132 Ok(result.rows_affected())
133 }
134
135 /// Update a device's last-pulled cursor position.
136 #[tracing::instrument(skip_all)]
137 pub async fn update_device_cursor(
138 pool: &PgPool,
139 device_id: SyncDeviceId,
140 seq: i64,
141 ) -> Result<()> {
142 sqlx::query(
143 "UPDATE sync_devices SET last_pulled_seq = GREATEST(last_pulled_seq, $1) WHERE id = $2",
144 )
145 .bind(seq)
146 .bind(device_id)
147 .execute(pool)
148 .await?;
149 Ok(())
150 }
151
152 /// Compact the sync log by removing entries that all devices for a given
153 /// (app_id, user_id) have already pulled. Keeps a safety margin of entries
154 /// newer than `min_age_days` regardless of cursor positions.
155 ///
156 /// Returns the number of entries deleted.
157 #[allow(dead_code)] // Public API for targeted per-user compaction
158 #[tracing::instrument(skip_all)]
159 pub async fn compact_sync_log(
160 pool: &PgPool,
161 app_id: SyncAppId,
162 user_id: UserId,
163 min_age_days: i64,
164 ) -> Result<u64> {
165 if min_age_days <= 0 {
166 return Ok(0);
167 }
168
169 // Find the lowest cursor across all devices for this user+app.
170 // Entries below this seq have been pulled by every device.
171 let min_cursor: Option<i64> = sqlx::query_scalar(
172 "SELECT MIN(last_pulled_seq) FROM sync_devices WHERE app_id = $1 AND user_id = $2",
173 )
174 .bind(app_id)
175 .bind(user_id)
176 .fetch_one(pool)
177 .await?;
178
179 let safe_seq = match min_cursor {
180 Some(seq) if seq > 0 => seq,
181 _ => return Ok(0), // No devices or no pulls yet
182 };
183
184 // Delete entries below the safe cursor AND older than the safety margin.
185 let result = sqlx::query(
186 r#"
187 DELETE FROM sync_log
188 WHERE app_id = $1 AND user_id = $2
189 AND seq <= $3
190 AND created_at < NOW() - make_interval(days => $4)
191 "#,
192 )
193 .bind(app_id)
194 .bind(user_id)
195 .bind(safe_seq)
196 .bind(min_age_days)
197 .execute(pool)
198 .await?;
199
200 Ok(result.rows_affected())
201 }
202
203 /// Compact sync logs for all user+app pairs that have compactable entries.
204 /// Finds pairs where MIN(last_pulled_seq) across devices > 0, then deletes
205 /// entries below that threshold (with age safety margin).
206 ///
207 /// Returns total entries deleted across all users.
208 #[tracing::instrument(skip_all)]
209 pub async fn compact_all_sync_logs(pool: &PgPool, min_age_days: i64) -> Result<u64> {
210 if min_age_days <= 0 {
211 return Ok(0);
212 }
213
214 // Find (app_id, user_id) pairs where compaction is possible:
215 // all devices have pulled past seq 0, and there are old entries to delete.
216 let pairs: Vec<(SyncAppId, UserId, i64)> = sqlx::query_as(
217 r#"
218 SELECT app_id, user_id, MIN(last_pulled_seq) AS min_seq
219 FROM sync_devices
220 WHERE last_pulled_seq > 0
221 GROUP BY app_id, user_id
222 HAVING MIN(last_pulled_seq) > 0
223 "#,
224 )
225 .fetch_all(pool)
226 .await?;
227
228 let mut total_deleted: u64 = 0;
229 for (app_id, user_id, safe_seq) in pairs {
230 let result = sqlx::query(
231 r#"
232 DELETE FROM sync_log
233 WHERE app_id = $1 AND user_id = $2
234 AND seq <= $3
235 AND created_at < NOW() - make_interval(days => $4)
236 "#,
237 )
238 .bind(app_id)
239 .bind(user_id)
240 .bind(safe_seq)
241 .bind(min_age_days)
242 .execute(pool)
243 .await?;
244
245 total_deleted += result.rows_affected();
246 }
247
248 Ok(total_deleted)
249 }
250
251