Skip to main content

max / makenotwork

8.7 KB · 321 lines History Blame Raw
1 use serde_json::Value as JsonValue;
2 use sqlx::PgPool;
3
4 use crate::db::{SyncAppId, SyncDeviceId, UserId};
5 use crate::error::Result;
6
7 // ── Key Rotation ──
8
9 /// Begin a key rotation. Returns the rotation row if created, or the existing
10 /// row if this device already has an active rotation (resume support).
11 ///
12 /// Returns `None` if the key_version doesn't match (caller should 409).
13 /// Returns `Err` if a different device has an active rotation.
14 #[tracing::instrument(skip_all)]
15 pub async fn begin_key_rotation(
16 pool: &PgPool,
17 app_id: SyncAppId,
18 user_id: UserId,
19 device_id: SyncDeviceId,
20 new_encrypted_key: &str,
21 expected_key_version: i32,
22 ) -> Result<std::result::Result<crate::db::models::DbSyncKeyRotation, &'static str>> {
23 // Verify key_version matches
24 let key_row: Option<(i32, i32)> = sqlx::query_as(
25 "SELECT key_version, key_id FROM sync_keys WHERE app_id = $1 AND user_id = $2",
26 )
27 .bind(app_id)
28 .bind(user_id)
29 .fetch_optional(pool)
30 .await?;
31
32 let Some((current_version, current_key_id)) = key_row else {
33 return Ok(Err("no encryption key exists"));
34 };
35
36 if current_version != expected_key_version {
37 return Ok(Err("key version mismatch"));
38 }
39
40 // Check for existing rotation
41 let existing = sqlx::query_as::<_, crate::db::models::DbSyncKeyRotation>(
42 "SELECT * FROM sync_key_rotations WHERE app_id = $1 AND user_id = $2",
43 )
44 .bind(app_id)
45 .bind(user_id)
46 .fetch_optional(pool)
47 .await?;
48
49 if let Some(rotation) = existing {
50 if rotation.device_id == device_id {
51 // Same device resuming — return existing rotation
52 return Ok(Ok(rotation));
53 }
54 return Ok(Err("rotation already in progress by another device"));
55 }
56
57 // Get target_seq (max seq for this user)
58 let target_seq: i64 = sqlx::query_scalar(
59 "SELECT COALESCE(MAX(seq), 0) FROM sync_log WHERE app_id = $1 AND user_id = $2",
60 )
61 .bind(app_id)
62 .bind(user_id)
63 .fetch_one(pool)
64 .await?;
65
66 let new_key_id = current_key_id + 1;
67
68 let rotation = sqlx::query_as::<_, crate::db::models::DbSyncKeyRotation>(
69 r#"
70 INSERT INTO sync_key_rotations (app_id, user_id, device_id, new_encrypted_key, old_key_version, new_key_id, target_seq)
71 VALUES ($1, $2, $3, $4, $5, $6, $7)
72 RETURNING *
73 "#,
74 )
75 .bind(app_id)
76 .bind(user_id)
77 .bind(device_id)
78 .bind(new_encrypted_key)
79 .bind(current_version)
80 .bind(new_key_id)
81 .bind(target_seq)
82 .fetch_one(pool)
83 .await?;
84
85 Ok(Ok(rotation))
86 }
87
88 /// Get the active rotation for a user, if any.
89 #[tracing::instrument(skip_all)]
90 pub async fn get_key_rotation(
91 pool: &PgPool,
92 app_id: SyncAppId,
93 user_id: UserId,
94 ) -> Result<Option<crate::db::models::DbSyncKeyRotation>> {
95 let rotation = sqlx::query_as::<_, crate::db::models::DbSyncKeyRotation>(
96 "SELECT * FROM sync_key_rotations WHERE app_id = $1 AND user_id = $2",
97 )
98 .bind(app_id)
99 .bind(user_id)
100 .fetch_optional(pool)
101 .await?;
102
103 Ok(rotation)
104 }
105
106 /// Pull sync log entries that need re-encryption (key_id != new_key_id).
107 /// Returns entries ordered by seq, paginated by after_seq.
108 #[tracing::instrument(skip_all)]
109 pub async fn get_rotation_entries(
110 pool: &PgPool,
111 app_id: SyncAppId,
112 user_id: UserId,
113 new_key_id: i32,
114 after_seq: i64,
115 limit: i64,
116 ) -> Result<Vec<(i64, Option<JsonValue>)>> {
117 let entries: Vec<(i64, Option<JsonValue>)> = sqlx::query_as(
118 r#"
119 SELECT seq, data FROM sync_log
120 WHERE app_id = $1 AND user_id = $2
121 AND seq > $3
122 AND (key_id IS NULL OR key_id != $5)
123 ORDER BY seq ASC
124 LIMIT $4
125 "#,
126 )
127 .bind(app_id)
128 .bind(user_id)
129 .bind(after_seq)
130 .bind(limit)
131 .bind(new_key_id)
132 .fetch_all(pool)
133 .await?;
134
135 Ok(entries)
136 }
137
138 /// Batch-update re-encrypted sync log entries during key rotation.
139 /// Sets data and key_id for each (seq) in the batch.
140 /// Returns the number of rows updated.
141 #[tracing::instrument(skip_all)]
142 pub async fn submit_rotation_batch(
143 pool: &PgPool,
144 app_id: SyncAppId,
145 user_id: UserId,
146 rotation_id: uuid::Uuid,
147 new_key_id: i32,
148 entries: &[(i64, Option<JsonValue>)],
149 ) -> Result<u64> {
150 if entries.is_empty() {
151 return Ok(0);
152 }
153
154 let mut seqs: Vec<i64> = Vec::with_capacity(entries.len());
155 let mut data_values: Vec<JsonValue> = Vec::with_capacity(entries.len());
156 for (seq, data) in entries {
157 seqs.push(*seq);
158 data_values.push(data.clone().unwrap_or(JsonValue::Null));
159 }
160
161 let mut tx = pool.begin().await?;
162
163 let updated = sqlx::query(
164 r#"
165 UPDATE sync_log AS sl
166 SET data = CASE WHEN batch.new_data = 'null'::jsonb THEN NULL ELSE batch.new_data END,
167 key_id = $3
168 FROM UNNEST($4::bigint[], $5::jsonb[]) AS batch(seq, new_data)
169 WHERE sl.seq = batch.seq AND sl.app_id = $1 AND sl.user_id = $2
170 "#,
171 )
172 .bind(app_id)
173 .bind(user_id)
174 .bind(new_key_id)
175 .bind(&seqs)
176 .bind(&data_values)
177 .execute(&mut *tx)
178 .await?;
179
180 // Update progress marker
181 if let Some(&max_seq) = seqs.iter().max() {
182 sqlx::query(
183 r#"
184 UPDATE sync_key_rotations
185 SET re_encrypted_through_seq = GREATEST(re_encrypted_through_seq, $1),
186 updated_at = NOW()
187 WHERE id = $2
188 "#,
189 )
190 .bind(max_seq)
191 .bind(rotation_id)
192 .execute(&mut *tx)
193 .await?;
194 }
195
196 tx.commit().await?;
197
198 Ok(updated.rows_affected())
199 }
200
201 /// Complete a key rotation: swap the new key into sync_keys and delete the rotation.
202 /// Returns `Err("entries remain")` if un-rotated entries still exist, with the count.
203 #[tracing::instrument(skip_all)]
204 pub async fn complete_key_rotation(
205 pool: &PgPool,
206 app_id: SyncAppId,
207 user_id: UserId,
208 rotation_id: uuid::Uuid,
209 ) -> Result<std::result::Result<i32, i64>> {
210 let rotation = sqlx::query_as::<_, crate::db::models::DbSyncKeyRotation>(
211 "SELECT * FROM sync_key_rotations WHERE id = $1 AND app_id = $2 AND user_id = $3",
212 )
213 .bind(rotation_id)
214 .bind(app_id)
215 .bind(user_id)
216 .fetch_optional(pool)
217 .await?;
218
219 let Some(rotation) = rotation else {
220 return Ok(Err(0)); // No rotation found
221 };
222
223 // Check for remaining un-rotated entries up to the target_seq captured at
224 // rotation start. Entries arriving after rotation began are excluded — they
225 // will use the new key once rotation completes.
226 let remaining: i64 = sqlx::query_scalar(
227 r#"
228 SELECT COUNT(*) FROM sync_log
229 WHERE app_id = $1 AND user_id = $2
230 AND seq <= $4
231 AND (key_id IS NULL OR key_id != $3)
232 "#,
233 )
234 .bind(app_id)
235 .bind(user_id)
236 .bind(rotation.new_key_id)
237 .bind(rotation.target_seq)
238 .fetch_one(pool)
239 .await?;
240
241 if remaining > 0 {
242 return Ok(Err(remaining));
243 }
244
245 // Atomically swap key and delete rotation
246 let mut tx = pool.begin().await?;
247
248 sqlx::query(
249 r#"
250 UPDATE sync_keys
251 SET encrypted_key = $3,
252 key_version = key_version + 1,
253 key_id = $4,
254 updated_at = NOW()
255 WHERE app_id = $1 AND user_id = $2
256 "#,
257 )
258 .bind(app_id)
259 .bind(user_id)
260 .bind(&rotation.new_encrypted_key)
261 .bind(rotation.new_key_id)
262 .execute(&mut *tx)
263 .await?;
264
265 sqlx::query("DELETE FROM sync_key_rotations WHERE id = $1")
266 .bind(rotation_id)
267 .execute(&mut *tx)
268 .await?;
269
270 tx.commit().await?;
271
272 Ok(Ok(rotation.new_key_id))
273 }
274
275 /// Cancel a stale rotation (only if older than the stale threshold).
276 /// Returns true if cancelled, false if not found or not stale.
277 #[tracing::instrument(skip_all)]
278 pub async fn cancel_stale_rotation(
279 pool: &PgPool,
280 app_id: SyncAppId,
281 user_id: UserId,
282 stale_hours: i64,
283 ) -> Result<bool> {
284 let result = sqlx::query(
285 r#"
286 DELETE FROM sync_key_rotations
287 WHERE app_id = $1 AND user_id = $2
288 AND updated_at < NOW() - make_interval(hours => $3)
289 "#,
290 )
291 .bind(app_id)
292 .bind(user_id)
293 .bind(stale_hours)
294 .execute(pool)
295 .await?;
296
297 Ok(result.rows_affected() > 0)
298 }
299
300 /// Get sync status: total changes and latest seq for a user within an app.
301 #[tracing::instrument(skip_all)]
302 pub async fn get_sync_status(
303 pool: &PgPool,
304 app_id: SyncAppId,
305 user_id: UserId,
306 ) -> Result<(i64, Option<i64>)> {
307 let row: (i64, Option<i64>) = sqlx::query_as(
308 r#"
309 SELECT COUNT(*), MAX(seq)
310 FROM sync_log
311 WHERE app_id = $1 AND user_id = $2
312 "#,
313 )
314 .bind(app_id)
315 .bind(user_id)
316 .fetch_one(pool)
317 .await?;
318
319 Ok(row)
320 }
321