Skip to main content

max / makenotwork

15.9 KB · 489 lines History Blame Raw
1 use bytes::Bytes;
2 use tracing::instrument;
3 use uuid::Uuid;
4
5 use crate::{
6 crypto,
7 error::Result,
8 types::*,
9 };
10
11 use super::SyncKitClient;
12 use super::helpers::check_response;
13
14 impl SyncKitClient {
15 // ── Devices ──
16
17 /// Register a device for sync.
18 ///
19 /// If a device with the same name already exists for this user/app, the
20 /// server upserts: it updates the existing device's platform and
21 /// `last_seen_at` rather than creating a duplicate.
22 #[instrument(skip(self))]
23 pub async fn register_device(
24 &self,
25 device_name: &str,
26 platform: &str,
27 ) -> Result<Device> {
28 let token = self.require_token()?;
29
30 let body = Bytes::from(serde_json::to_vec(&RegisterDeviceRequest {
31 device_name: device_name.to_string(),
32 platform: platform.to_string(),
33 })?);
34
35 self.retry_request_json(|| {
36 let req = self
37 .http
38 .post(&self.endpoints.devices)
39 .bearer_auth(&token)
40 .header("content-type", "application/json")
41 .body(body.clone());
42 async move { check_response(req.send().await?).await }
43 })
44 .await
45 }
46
47 /// List all devices for the current user.
48 #[instrument(skip(self))]
49 pub async fn list_devices(&self) -> Result<Vec<Device>> {
50 let token = self.require_token()?;
51
52 self.retry_request_json(|| {
53 let req = self.http.get(&self.endpoints.devices).bearer_auth(&token);
54 async move { check_response(req.send().await?).await }
55 })
56 .await
57 }
58
59 // ── Push / Pull ──
60
61 /// Push changes to the server. Encrypts `data` fields automatically.
62 /// Returns the server cursor after the push.
63 ///
64 /// Retries on transient failures (network errors, 5xx, 429) with exponential backoff.
65 #[instrument(skip(self, changes))]
66 pub async fn push(
67 &self,
68 device_id: Uuid,
69 changes: Vec<ChangeEntry>,
70 ) -> Result<i64> {
71 let token = self.require_token()?;
72
73 // Extract key once for the entire batch (only needed if any entry has data)
74 let has_data = changes.iter().any(|c| c.data.is_some());
75 let key_holder = if has_data {
76 self.require_master_key()?
77 } else {
78 crypto::ZeroizeOnDrop([0u8; 32])
79 };
80 let master_key: &[u8; 32] = &key_holder;
81 let wire_changes = changes
82 .into_iter()
83 .map(|c| Self::encrypt_change_with_key(c, master_key))
84 .collect::<Result<Vec<_>>>()?;
85
86 let body = Bytes::from(serde_json::to_vec(&WirePushRequest {
87 device_id,
88 batch_id: Uuid::new_v4(),
89 changes: wire_changes,
90 })?);
91
92 let push_resp: PushResponse = self
93 .retry_request_json(|| {
94 let req = self
95 .http
96 .post(&self.endpoints.push)
97 .bearer_auth(&token)
98 .header("content-type", "application/json")
99 .body(body.clone());
100 async move { check_response(req.send().await?).await }
101 })
102 .await?;
103 Ok(push_resp.cursor)
104 }
105
106 /// Pull changes from the server since the given cursor.
107 /// Decrypts `data` fields automatically.
108 /// Returns (changes, new_cursor, has_more).
109 ///
110 /// Retries on transient failures (network errors, 5xx, 429) with exponential backoff.
111 #[instrument(skip(self))]
112 pub async fn pull(
113 &self,
114 device_id: Uuid,
115 cursor: i64,
116 ) -> Result<(Vec<ChangeEntry>, i64, bool)> {
117 let body = Bytes::from(serde_json::to_vec(&PullRequest { device_id, cursor })?);
118 self.pull_inner(body, Self::decrypt_change_with_key).await
119 }
120
121 /// Pull changes from the server with optional table and timestamp filters.
122 /// Decrypts `data` fields automatically.
123 /// Returns (changes, new_cursor, has_more).
124 ///
125 /// Identical to [`pull`](SyncKitClient::pull) when the filter is empty/default.
126 #[instrument(skip(self, filter))]
127 pub async fn pull_filtered(
128 &self,
129 device_id: Uuid,
130 cursor: i64,
131 filter: PullFilter,
132 ) -> Result<(Vec<ChangeEntry>, i64, bool)> {
133 let body = Bytes::from(serde_json::to_vec(&FilteredPullRequest {
134 device_id,
135 cursor,
136 tables: filter.tables,
137 since: filter.since,
138 })?);
139 self.pull_inner(body, Self::decrypt_change_with_key).await
140 }
141
142 /// Pull changes from the server, preserving `device_id` and `seq` metadata.
143 ///
144 /// Same HTTP call and decryption as [`pull`](SyncKitClient::pull), but returns
145 /// [`PulledChange`] wrappers that retain server metadata needed for conflict
146 /// detection. Returns (changes, new_cursor, has_more).
147 #[instrument(skip(self))]
148 pub async fn pull_rich(
149 &self,
150 device_id: Uuid,
151 cursor: i64,
152 ) -> Result<(Vec<PulledChange>, i64, bool)> {
153 let body = Bytes::from(serde_json::to_vec(&PullRequest { device_id, cursor })?);
154 self.pull_inner(body, Self::decrypt_change_to_pulled).await
155 }
156
157 /// Pull changes with filters, preserving `device_id` and `seq` metadata.
158 ///
159 /// Same as [`pull_rich`](SyncKitClient::pull_rich) but with table/timestamp
160 /// filtering support. Returns (changes, new_cursor, has_more).
161 #[instrument(skip(self, filter))]
162 pub async fn pull_filtered_rich(
163 &self,
164 device_id: Uuid,
165 cursor: i64,
166 filter: PullFilter,
167 ) -> Result<(Vec<PulledChange>, i64, bool)> {
168 let body = Bytes::from(serde_json::to_vec(&FilteredPullRequest {
169 device_id,
170 cursor,
171 tables: filter.tables,
172 since: filter.since,
173 })?);
174 self.pull_inner(body, Self::decrypt_change_to_pulled).await
175 }
176
177 /// Shared pull implementation: sends the request, extracts the master key,
178 /// and decrypts each change using the provided function.
179 ///
180 /// If a pending rotation key is cached on the client, entries are decrypted
181 /// with key selection based on each entry's `key_id` field.
182 async fn pull_inner<T, F>(
183 &self,
184 body: Bytes,
185 decrypt_fn: F,
186 ) -> Result<(Vec<T>, i64, bool)>
187 where
188 F: Fn(PullChangeEntry, &[u8; 32]) -> Result<T>,
189 {
190 let token = self.require_token()?;
191
192 let pull_resp: PullResponse = self
193 .retry_request_json(|| {
194 let req = self
195 .http
196 .post(&self.endpoints.pull)
197 .bearer_auth(&token)
198 .header("content-type", "application/json")
199 .body(body.clone());
200 async move { check_response(req.send().await?).await }
201 })
202 .await?;
203
204 // Extract key once for the entire batch (only needed if any entry has data)
205 let has_data = pull_resp.changes.iter().any(|c| c.data.is_some());
206 let key_holder = if has_data {
207 self.require_master_key()?
208 } else {
209 crypto::ZeroizeOnDrop([0u8; 32])
210 };
211 let master_key: &[u8; 32] = &key_holder;
212
213 // Check for pending rotation key (multi-key decryption)
214 let pending_guard = self.pending_key.read();
215 let has_pending = pending_guard.is_some() && has_data;
216
217 let changes = if has_pending {
218 let pending = pending_guard.as_ref().unwrap();
219 let _primary_key_id = *self.master_key_id.read();
220 pull_resp
221 .changes
222 .into_iter()
223 .map(|c| {
224 let effective_key_id = c.key_id.unwrap_or(1);
225 if effective_key_id == pending.key_id {
226 decrypt_fn(c, &pending.key)
227 } else {
228 decrypt_fn(c, master_key)
229 }
230 })
231 .collect::<Result<Vec<_>>>()?
232 } else {
233 drop(pending_guard);
234 pull_resp
235 .changes
236 .into_iter()
237 .map(|c| decrypt_fn(c, master_key))
238 .collect::<Result<Vec<_>>>()?
239 };
240
241 Ok((changes, pull_resp.cursor, pull_resp.has_more))
242 }
243
244 /// Get sync status (total changes, latest cursor).
245 #[instrument(skip(self))]
246 pub async fn status(&self) -> Result<SyncStatus> {
247 let token = self.require_token()?;
248
249 self.retry_request_json(|| {
250 let req = self.http.get(&self.endpoints.status).bearer_auth(&token);
251 async move { check_response(req.send().await?).await }
252 })
253 .await
254 }
255 }
256
257 #[cfg(test)]
258 mod tests {
259 use chrono::Utc;
260 use uuid::Uuid;
261
262 use crate::types::*;
263
264 // ── Type serialization / deserialization ──
265
266 #[test]
267 fn change_entry_serialization_roundtrip() {
268 let entry = ChangeEntry {
269 table: "tasks".to_string(),
270 op: ChangeOp::Insert,
271 row_id: Uuid::new_v4().to_string(),
272 timestamp: Utc::now(),
273 data: Some(serde_json::json!({"title": "Test task", "done": false})),
274 };
275
276 let json = serde_json::to_string(&entry).unwrap();
277 let deserialized: ChangeEntry = serde_json::from_str(&json).unwrap();
278
279 assert_eq!(deserialized.table, entry.table);
280 assert_eq!(deserialized.op, entry.op);
281 assert_eq!(deserialized.row_id, entry.row_id);
282 assert_eq!(deserialized.data, entry.data);
283 }
284
285 #[test]
286 fn change_entry_with_none_data_omits_field() {
287 let entry = ChangeEntry {
288 table: "tasks".to_string(),
289 op: ChangeOp::Delete,
290 row_id: "abc-123".to_string(),
291 timestamp: Utc::now(),
292 data: None,
293 };
294
295 let json = serde_json::to_string(&entry).unwrap();
296 assert!(!json.contains("\"data\""));
297 }
298
299 #[test]
300 fn change_entry_deserialization_with_missing_data() {
301 let json = r#"{
302 "table": "events",
303 "op": "DELETE",
304 "row_id": "evt-1",
305 "timestamp": "2025-01-15T10:00:00Z"
306 }"#;
307
308 let entry: ChangeEntry = serde_json::from_str(json).unwrap();
309 assert_eq!(entry.table, "events");
310 assert_eq!(entry.op, ChangeOp::Delete);
311 assert!(entry.data.is_none());
312 }
313
314 #[test]
315 fn device_serialization_roundtrip() {
316 let device = Device {
317 id: Uuid::new_v4(),
318 app_id: Uuid::new_v4(),
319 user_id: Uuid::new_v4(),
320 device_name: "MacBook Pro".to_string(),
321 platform: "macos".to_string(),
322 last_seen_at: Utc::now(),
323 created_at: Utc::now(),
324 };
325
326 let json = serde_json::to_string(&device).unwrap();
327 let deserialized: Device = serde_json::from_str(&json).unwrap();
328
329 assert_eq!(deserialized.id, device.id);
330 assert_eq!(deserialized.device_name, device.device_name);
331 assert_eq!(deserialized.platform, device.platform);
332 }
333
334 #[test]
335 fn sync_status_deserialization() {
336 let json = r#"{"total_changes": 42, "latest_cursor": 100}"#;
337 let status: SyncStatus = serde_json::from_str(json).unwrap();
338 assert_eq!(status.total_changes, 42);
339 assert_eq!(status.latest_cursor, Some(100));
340 }
341
342 #[test]
343 fn sync_status_with_null_cursor() {
344 let json = r#"{"total_changes": 0, "latest_cursor": null}"#;
345 let status: SyncStatus = serde_json::from_str(json).unwrap();
346 assert_eq!(status.total_changes, 0);
347 assert_eq!(status.latest_cursor, None);
348 }
349
350 #[test]
351 fn register_device_request_serialization() {
352 let req = RegisterDeviceRequest {
353 device_name: "iPhone 15".to_string(),
354 platform: "ios".to_string(),
355 };
356
357 let json = serde_json::to_string(&req).unwrap();
358 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
359 assert_eq!(parsed["device_name"], "iPhone 15");
360 assert_eq!(parsed["platform"], "ios");
361 }
362
363 // ── Wire types ──
364
365 #[test]
366 fn wire_push_request_serialization() {
367 let device_id = Uuid::new_v4();
368 let req = WirePushRequest {
369 device_id,
370 batch_id: Uuid::new_v4(),
371 changes: vec![WireChangeEntry {
372 table: "tasks".to_string(),
373 op: ChangeOp::Insert,
374 row_id: "r1".to_string(),
375 timestamp: Utc::now(),
376 data: Some(serde_json::json!("encrypted-blob")),
377 }],
378 };
379
380 let json = serde_json::to_string(&req).unwrap();
381 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
382 assert_eq!(parsed["device_id"], device_id.to_string());
383 assert_eq!(parsed["changes"].as_array().unwrap().len(), 1);
384 }
385
386 #[test]
387 fn pull_request_serialization() {
388 let device_id = Uuid::new_v4();
389 let req = PullRequest {
390 device_id,
391 cursor: 42,
392 };
393
394 let json = serde_json::to_string(&req).unwrap();
395 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
396 assert_eq!(parsed["device_id"], device_id.to_string());
397 assert_eq!(parsed["cursor"], 42);
398 }
399
400 #[test]
401 fn pull_response_deserialization() {
402 let device_id = Uuid::new_v4();
403 let json = format!(
404 r#"{{
405 "changes": [
406 {{
407 "seq": 1,
408 "device_id": "{}",
409 "table": "tasks",
410 "op": "INSERT",
411 "row_id": "r1",
412 "timestamp": "2025-06-01T12:00:00Z",
413 "data": "encrypted"
414 }}
415 ],
416 "cursor": 5,
417 "has_more": true
418 }}"#,
419 device_id
420 );
421
422 let resp: PullResponse = serde_json::from_str(&json).unwrap();
423 assert_eq!(resp.changes.len(), 1);
424 assert_eq!(resp.cursor, 5);
425 assert!(resp.has_more);
426 assert_eq!(resp.changes[0].seq, 1);
427 assert_eq!(resp.changes[0].table, "tasks");
428 }
429
430 #[test]
431 fn pull_response_empty_changes() {
432 let json = r#"{"changes": [], "cursor": 0, "has_more": false}"#;
433 let resp: PullResponse = serde_json::from_str(json).unwrap();
434 assert!(resp.changes.is_empty());
435 assert_eq!(resp.cursor, 0);
436 assert!(!resp.has_more);
437 }
438
439 #[test]
440 fn push_response_deserialization() {
441 let json = r#"{"cursor": 99}"#;
442 let resp: PushResponse = serde_json::from_str(json).unwrap();
443 assert_eq!(resp.cursor, 99);
444 }
445
446 // ── ChangeOp display and parsing ──
447
448 #[test]
449 fn change_op_display() {
450 assert_eq!(ChangeOp::Insert.to_string(), "INSERT");
451 assert_eq!(ChangeOp::Update.to_string(), "UPDATE");
452 assert_eq!(ChangeOp::Delete.to_string(), "DELETE");
453 }
454
455 #[test]
456 fn change_op_from_str_valid() {
457 assert_eq!(ChangeOp::from_str_opt("INSERT"), Some(ChangeOp::Insert));
458 assert_eq!(ChangeOp::from_str_opt("UPDATE"), Some(ChangeOp::Update));
459 assert_eq!(ChangeOp::from_str_opt("DELETE"), Some(ChangeOp::Delete));
460 }
461
462 #[test]
463 fn change_op_from_str_invalid() {
464 assert_eq!(ChangeOp::from_str_opt("insert"), None);
465 assert_eq!(ChangeOp::from_str_opt("UPSERT"), None);
466 assert_eq!(ChangeOp::from_str_opt(""), None);
467 }
468
469 // ── Malformed response types ──
470
471 #[test]
472 fn pull_response_missing_changes_fails() {
473 let json = r#"{"cursor": 0, "has_more": false}"#;
474 assert!(serde_json::from_str::<PullResponse>(json).is_err());
475 }
476
477 #[test]
478 fn pull_response_missing_cursor_fails() {
479 let json = r#"{"changes": [], "has_more": false}"#;
480 assert!(serde_json::from_str::<PullResponse>(json).is_err());
481 }
482
483 #[test]
484 fn push_response_missing_cursor_fails() {
485 let json = r#"{}"#;
486 assert!(serde_json::from_str::<PushResponse>(json).is_err());
487 }
488 }
489