Skip to main content

max / synckit-client

22.3 KB · 662 lines History Blame Raw
1 use base64::Engine;
2
3 use crate::{
4 crypto,
5 error::{Result, SyncKitError},
6 types::*,
7 };
8
9 use super::{BASE_DELAY, MAX_RETRIES, SyncKitClient};
10
11 impl SyncKitClient {
12 /// Retry an async HTTP operation with exponential backoff.
13 ///
14 /// Retries on transient errors (network failures, 5xx, 429) up to [`MAX_RETRIES`]
15 /// times with delays of 1s, 2s, 4s. Returns the last error if all attempts fail.
16 /// Client errors (4xx except 429) are considered permanent and returned immediately.
17 pub(super) async fn retry_request<F, Fut>(&self, mut operation: F) -> Result<reqwest::Response>
18 where
19 F: FnMut() -> Fut,
20 Fut: std::future::Future<Output = Result<reqwest::Response>>,
21 {
22 let mut last_err = None;
23
24 for attempt in 0..=MAX_RETRIES {
25 match operation().await {
26 Ok(resp) => return Ok(resp),
27 Err(err) => {
28 if !is_transient(&err) {
29 return Err(err);
30 }
31
32 if attempt < MAX_RETRIES {
33 let delay = BASE_DELAY * 2u32.pow(attempt);
34 tracing::debug!(
35 attempt = attempt + 1,
36 max_retries = MAX_RETRIES,
37 delay_ms = delay.as_millis() as u64,
38 error = %err,
39 "Transient error, retrying after backoff",
40 );
41 tokio::time::sleep(delay).await;
42 }
43
44 last_err = Some(err);
45 }
46 }
47 }
48
49 Err(last_err.expect("loop ran at least once"))
50 }
51
52 /// Encrypt the data field of a change entry for the wire.
53 #[cfg(test)]
54 pub(super) fn encrypt_change(&self, entry: ChangeEntry) -> Result<WireChangeEntry> {
55 if entry.data.is_some() {
56 let master_key = self.require_master_key()?;
57 Self::encrypt_change_with_key(entry, &master_key)
58 } else {
59 Self::encrypt_change_with_key(entry, &[0u8; 32])
60 }
61 }
62
63 /// Encrypt with a pre-loaded key. Used by `push()` to avoid per-entry lock acquisition.
64 pub(super) fn encrypt_change_with_key(entry: ChangeEntry, master_key: &[u8; 32]) -> Result<WireChangeEntry> {
65 let encrypted_data = match entry.data {
66 Some(ref value) => Some(crypto::encrypt_json(value, master_key)?),
67 None => None,
68 };
69
70 Ok(WireChangeEntry {
71 table: entry.table,
72 op: entry.op,
73 row_id: entry.row_id,
74 timestamp: entry.timestamp,
75 data: encrypted_data,
76 })
77 }
78
79 /// Decrypt the data field of a pulled change entry.
80 #[cfg(test)]
81 pub(super) fn decrypt_change(&self, entry: PullChangeEntry) -> Result<ChangeEntry> {
82 if entry.data.is_some() {
83 let master_key = self.require_master_key()?;
84 Self::decrypt_change_with_key(entry, &master_key)
85 } else {
86 Self::decrypt_change_with_key(entry, &[0u8; 32])
87 }
88 }
89
90 /// Decrypt with a pre-loaded key, preserving `device_id` and `seq` in a [`PulledChange`].
91 ///
92 /// Used by `pull_rich()` to produce conflict-detection-ready results.
93 pub(super) fn decrypt_change_to_pulled(entry: PullChangeEntry, master_key: &[u8; 32]) -> Result<crate::types::PulledChange> {
94 let device_id = entry.device_id;
95 let seq = entry.seq;
96 let decrypted = Self::decrypt_change_with_key(entry, master_key)?;
97 Ok(crate::types::PulledChange {
98 entry: decrypted,
99 device_id,
100 seq,
101 })
102 }
103
104 /// Decrypt with a pre-loaded key. Used by `pull()` to avoid per-entry lock acquisition.
105 pub(super) fn decrypt_change_with_key(entry: PullChangeEntry, master_key: &[u8; 32]) -> Result<ChangeEntry> {
106 let decrypted_data = match entry.data {
107 Some(ref value) => Some(crypto::decrypt_json(value, master_key)?),
108 None => None,
109 };
110
111 Ok(ChangeEntry {
112 table: entry.table,
113 op: entry.op,
114 row_id: entry.row_id,
115 timestamp: entry.timestamp,
116 data: decrypted_data,
117 })
118 }
119 }
120
121 /// Extract the `exp` claim from a JWT without verifying the signature.
122 ///
123 /// JWTs are `header.payload.signature` where the payload is base64url-encoded JSON.
124 /// We decode the payload segment and read the `exp` field. Returns `None` if
125 /// the token is malformed or `exp` is missing.
126 pub(super) fn jwt_exp(token: &str) -> Option<i64> {
127 let payload = token.split('.').nth(1)?;
128 let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
129 .decode(payload)
130 .ok()?;
131 let claims: serde_json::Value = serde_json::from_slice(&bytes).ok()?;
132 claims["exp"].as_i64()
133 }
134
135 /// Returns `true` if the token's `exp` claim is within [`TOKEN_EXPIRY_BUFFER_SECS`]
136 /// of the current time (or already past). Returns `false` if the token cannot
137 /// be decoded — in that case, let the server decide.
138 #[cfg(test)]
139 pub(super) fn token_is_expired(token: &str) -> bool {
140 let Some(exp) = jwt_exp(token) else {
141 return false;
142 };
143 let now = chrono::Utc::now().timestamp();
144 now >= exp - super::TOKEN_EXPIRY_BUFFER_SECS
145 }
146
147 /// Check an HTTP response for errors, returning the response on success.
148 pub(super) async fn check_response(resp: reqwest::Response) -> Result<reqwest::Response> {
149 let status = resp.status().as_u16();
150 if status >= 400 {
151 let message = resp.text().await.unwrap_or_default();
152 return Err(SyncKitError::Server { status, message });
153 }
154 Ok(resp)
155 }
156
157 /// Returns true if the error is transient and worth retrying.
158 ///
159 /// Transient errors:
160 /// - Network-level failures (connection refused, timeout, DNS, etc.)
161 /// - Server errors (5xx)
162 /// - Rate limiting (429)
163 ///
164 /// Permanent errors (not retried):
165 /// - Client errors (4xx except 429) — bad request, auth failure, not found, etc.
166 /// - Serialization errors, encryption errors, missing session, etc.
167 pub(super) fn is_transient(err: &SyncKitError) -> bool {
168 match err {
169 SyncKitError::Http(e) => {
170 // All reqwest transport errors are transient (timeout, connect, DNS, etc.)
171 // except for builder errors which indicate programming mistakes.
172 !e.is_builder()
173 }
174 SyncKitError::Server { status, .. } => {
175 // 5xx = server error (transient), 429 = rate limited (transient)
176 *status >= 500 || *status == 429
177 }
178 // Everything else (auth, crypto, serialization) is permanent
179 _ => false,
180 }
181 }
182
183 #[cfg(test)]
184 mod tests {
185 use super::*;
186 use base64::Engine;
187 use chrono::Utc;
188 use std::time::Duration;
189
190 use super::super::TOKEN_EXPIRY_BUFFER_SECS;
191
192 fn test_config() -> super::super::SyncKitConfig {
193 super::super::SyncKitConfig {
194 server_url: "https://example.com".to_string(),
195 api_key: "test-api-key-123".to_string(),
196 }
197 }
198
199 /// Build a fake JWT with the given `exp` claim (no real signature).
200 fn fake_jwt(exp: i64) -> String {
201 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD
202 .encode(r#"{"alg":"HS256","typ":"JWT"}"#);
203 let payload_json = serde_json::json!({
204 "sub": "550e8400-e29b-41d4-a716-446655440000",
205 "app": "6ba7b810-9dad-11d1-80b4-00c04fd430c8",
206 "exp": exp,
207 "iat": exp - 3600,
208 });
209 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
210 .encode(payload_json.to_string().as_bytes());
211 let signature = base64::engine::general_purpose::URL_SAFE_NO_PAD
212 .encode(b"fake-signature");
213 format!("{header}.{payload}.{signature}")
214 }
215
216 // ── encrypt_change / decrypt_change ──
217
218 #[test]
219 fn encrypt_change_with_no_data() {
220 let client = SyncKitClient::new(test_config());
221 let entry = ChangeEntry {
222 table: "tasks".to_string(),
223 op: ChangeOp::Delete,
224 row_id: "row-1".to_string(),
225 timestamp: Utc::now(),
226 data: None,
227 };
228
229 let wire = client.encrypt_change(entry.clone()).unwrap();
230 assert_eq!(wire.table, "tasks");
231 assert_eq!(wire.op, ChangeOp::Delete);
232 assert_eq!(wire.row_id, "row-1");
233 assert!(wire.data.is_none());
234 }
235
236 #[test]
237 fn encrypt_change_fails_without_master_key() {
238 let client = SyncKitClient::new(test_config());
239 let entry = ChangeEntry {
240 table: "tasks".to_string(),
241 op: ChangeOp::Insert,
242 row_id: "row-1".to_string(),
243 timestamp: Utc::now(),
244 data: Some(serde_json::json!({"title": "test"})),
245 };
246
247 let err = client.encrypt_change(entry).unwrap_err();
248 assert!(matches!(err, SyncKitError::NoMasterKey));
249 }
250
251 #[test]
252 fn encrypt_change_produces_encrypted_data() {
253 let client = SyncKitClient::new(test_config());
254 let key = crypto::generate_master_key();
255 *client.master_key.write() = Some(crypto::ZeroizeOnDrop(key));
256
257 let original_data = serde_json::json!({"title": "Buy milk", "priority": 3});
258 let entry = ChangeEntry {
259 table: "tasks".to_string(),
260 op: ChangeOp::Insert,
261 row_id: "row-1".to_string(),
262 timestamp: Utc::now(),
263 data: Some(original_data.clone()),
264 };
265
266 let wire = client.encrypt_change(entry).unwrap();
267 assert!(wire.data.is_some());
268 let encrypted = wire.data.unwrap();
269 assert!(encrypted.is_string());
270 assert_ne!(encrypted, original_data);
271 }
272
273 #[test]
274 fn encrypt_decrypt_roundtrip() {
275 let client = SyncKitClient::new(test_config());
276 let key = crypto::generate_master_key();
277 *client.master_key.write() = Some(crypto::ZeroizeOnDrop(key));
278
279 let original_data = serde_json::json!({
280 "title": "Buy milk",
281 "tags": ["groceries", "urgent"],
282 "count": 42
283 });
284 let ts = Utc::now();
285 let entry = ChangeEntry {
286 table: "tasks".to_string(),
287 op: ChangeOp::Update,
288 row_id: "row-abc".to_string(),
289 timestamp: ts,
290 data: Some(original_data.clone()),
291 };
292
293 let wire = client.encrypt_change(entry).unwrap();
294 let pull_entry = PullChangeEntry {
295 seq: 1,
296 device_id: uuid::Uuid::new_v4(),
297 table: wire.table,
298 op: wire.op,
299 row_id: wire.row_id,
300 timestamp: wire.timestamp,
301 data: wire.data,
302 };
303
304 let decrypted = client.decrypt_change(pull_entry).unwrap();
305 assert_eq!(decrypted.table, "tasks");
306 assert_eq!(decrypted.op, ChangeOp::Update);
307 assert_eq!(decrypted.row_id, "row-abc");
308 assert_eq!(decrypted.data.unwrap(), original_data);
309 }
310
311 #[test]
312 fn decrypt_change_with_no_data() {
313 let client = SyncKitClient::new(test_config());
314 let pull_entry = PullChangeEntry {
315 seq: 5,
316 device_id: uuid::Uuid::new_v4(),
317 table: "events".to_string(),
318 op: ChangeOp::Delete,
319 row_id: "evt-1".to_string(),
320 timestamp: Utc::now(),
321 data: None,
322 };
323
324 let decrypted = client.decrypt_change(pull_entry).unwrap();
325 assert_eq!(decrypted.table, "events");
326 assert_eq!(decrypted.op, ChangeOp::Delete);
327 assert!(decrypted.data.is_none());
328 }
329
330 #[test]
331 fn decrypt_change_fails_without_master_key() {
332 let client = SyncKitClient::new(test_config());
333 let pull_entry = PullChangeEntry {
334 seq: 1,
335 device_id: uuid::Uuid::new_v4(),
336 table: "tasks".to_string(),
337 op: ChangeOp::Insert,
338 row_id: "row-1".to_string(),
339 timestamp: Utc::now(),
340 data: Some(serde_json::json!("some-encrypted-string")),
341 };
342
343 let err = client.decrypt_change(pull_entry).unwrap_err();
344 assert!(matches!(err, SyncKitError::NoMasterKey));
345 }
346
347 // ── is_transient error classification ──
348
349 #[test]
350 fn is_transient_server_5xx() {
351 let err = SyncKitError::Server { status: 500, message: "Internal Server Error".to_string() };
352 assert!(is_transient(&err));
353 let err = SyncKitError::Server { status: 502, message: "Bad Gateway".to_string() };
354 assert!(is_transient(&err));
355 let err = SyncKitError::Server { status: 503, message: "Service Unavailable".to_string() };
356 assert!(is_transient(&err));
357 let err = SyncKitError::Server { status: 504, message: "Gateway Timeout".to_string() };
358 assert!(is_transient(&err));
359 }
360
361 #[test]
362 fn is_transient_rate_limited_429() {
363 let err = SyncKitError::Server { status: 429, message: "Too Many Requests".to_string() };
364 assert!(is_transient(&err));
365 }
366
367 #[test]
368 fn is_not_transient_client_4xx() {
369 let err = SyncKitError::Server { status: 400, message: "Bad Request".to_string() };
370 assert!(!is_transient(&err));
371 let err = SyncKitError::Server { status: 401, message: "Unauthorized".to_string() };
372 assert!(!is_transient(&err));
373 let err = SyncKitError::Server { status: 403, message: "Forbidden".to_string() };
374 assert!(!is_transient(&err));
375 let err = SyncKitError::Server { status: 404, message: "Not Found".to_string() };
376 assert!(!is_transient(&err));
377 let err = SyncKitError::Server { status: 409, message: "Conflict".to_string() };
378 assert!(!is_transient(&err));
379 let err = SyncKitError::Server { status: 422, message: "Unprocessable Entity".to_string() };
380 assert!(!is_transient(&err));
381 }
382
383 #[test]
384 fn is_not_transient_not_authenticated() {
385 assert!(!is_transient(&SyncKitError::NotAuthenticated));
386 }
387
388 #[test]
389 fn is_not_transient_no_master_key() {
390 assert!(!is_transient(&SyncKitError::NoMasterKey));
391 }
392
393 #[test]
394 fn is_not_transient_decryption_failed() {
395 assert!(!is_transient(&SyncKitError::DecryptionFailed));
396 }
397
398 #[test]
399 fn is_not_transient_invalid_envelope() {
400 assert!(!is_transient(&SyncKitError::InvalidEnvelope("bad version".to_string())));
401 }
402
403 #[test]
404 fn is_not_transient_crypto() {
405 assert!(!is_transient(&SyncKitError::Crypto("encrypt failed".to_string())));
406 }
407
408 #[test]
409 fn is_not_transient_json() {
410 let err: SyncKitError = serde_json::from_str::<serde_json::Value>("not json")
411 .unwrap_err()
412 .into();
413 assert!(!is_transient(&err));
414 }
415
416 #[test]
417 fn is_not_transient_base64() {
418 let err: SyncKitError = base64::engine::general_purpose::STANDARD
419 .decode("!!!invalid!!!")
420 .unwrap_err()
421 .into();
422 assert!(!is_transient(&err));
423 }
424
425 #[test]
426 fn is_not_transient_token_expired() {
427 assert!(!is_transient(&SyncKitError::TokenExpired));
428 }
429
430 #[test]
431 fn is_not_transient_internal() {
432 assert!(!is_transient(&SyncKitError::Internal("lock poisoned".to_string())));
433 }
434
435 // ── Retry constants ──
436
437 #[test]
438 fn retry_constants_are_sensible() {
439 assert_eq!(MAX_RETRIES, 3);
440 assert_eq!(BASE_DELAY, Duration::from_secs(1));
441 }
442
443 #[test]
444 fn backoff_delays_are_exponential() {
445 let delay_0 = BASE_DELAY * 2u32.pow(0);
446 let delay_1 = BASE_DELAY * 2u32.pow(1);
447 let delay_2 = BASE_DELAY * 2u32.pow(2);
448
449 assert_eq!(delay_0, Duration::from_secs(1));
450 assert_eq!(delay_1, Duration::from_secs(2));
451 assert_eq!(delay_2, Duration::from_secs(4));
452 }
453
454 // ── is_transient boundary: 429 vs 428, 499 vs 500 ──
455
456 #[test]
457 fn is_transient_boundary_values() {
458 assert!(!is_transient(&SyncKitError::Server { status: 428, message: String::new() }));
459 assert!(is_transient(&SyncKitError::Server { status: 429, message: String::new() }));
460 assert!(!is_transient(&SyncKitError::Server { status: 430, message: String::new() }));
461 assert!(!is_transient(&SyncKitError::Server { status: 499, message: String::new() }));
462 assert!(is_transient(&SyncKitError::Server { status: 500, message: String::new() }));
463 }
464
465 // ── Token expiry detection ──
466
467 #[test]
468 fn jwt_exp_extracts_expiry() {
469 let exp = Utc::now().timestamp() + 3600;
470 let token = fake_jwt(exp);
471 assert_eq!(jwt_exp(&token), Some(exp));
472 }
473
474 #[test]
475 fn jwt_exp_returns_none_for_garbage() {
476 assert_eq!(jwt_exp("not-a-jwt"), None);
477 assert_eq!(jwt_exp("a.b.c"), None);
478 assert_eq!(jwt_exp(""), None);
479 }
480
481 #[test]
482 fn token_is_expired_for_past_exp() {
483 let token = fake_jwt(Utc::now().timestamp() - 3600);
484 assert!(token_is_expired(&token));
485 }
486
487 #[test]
488 fn token_is_expired_within_buffer() {
489 let token = fake_jwt(Utc::now().timestamp() + 10);
490 assert!(token_is_expired(&token));
491 }
492
493 #[test]
494 fn token_is_not_expired_when_fresh() {
495 let token = fake_jwt(Utc::now().timestamp() + 3600);
496 assert!(!token_is_expired(&token));
497 }
498
499 #[test]
500 fn token_is_not_expired_for_garbage() {
501 assert!(!token_is_expired("garbage"));
502 }
503
504 #[test]
505 fn token_expires_exactly_at_buffer_boundary() {
506 let token = fake_jwt(Utc::now().timestamp() + TOKEN_EXPIRY_BUFFER_SECS);
507 assert!(token_is_expired(&token));
508 }
509
510 #[test]
511 fn token_expires_just_past_buffer() {
512 let token = fake_jwt(Utc::now().timestamp() + TOKEN_EXPIRY_BUFFER_SECS + 1);
513 assert!(!token_is_expired(&token));
514 }
515
516 // ── encrypt_change preserves metadata ──
517
518 #[test]
519 fn encrypt_change_preserves_all_metadata() {
520 let client = SyncKitClient::new(test_config());
521 let key = crypto::generate_master_key();
522 client.set_master_key_raw(key);
523
524 let ts = Utc::now();
525 let entry = ChangeEntry {
526 table: "contacts".to_string(),
527 op: ChangeOp::Update,
528 row_id: "unique-row-id".to_string(),
529 timestamp: ts,
530 data: Some(serde_json::json!({"name": "Alice"})),
531 };
532
533 let wire = client.encrypt_change(entry).unwrap();
534 assert_eq!(wire.table, "contacts");
535 assert_eq!(wire.op, ChangeOp::Update);
536 assert_eq!(wire.row_id, "unique-row-id");
537 assert_eq!(wire.timestamp, ts);
538 }
539
540 // ── Multiple entries encrypt/decrypt ──
541
542 #[test]
543 fn multiple_entries_encrypt_decrypt_roundtrip() {
544 let client = SyncKitClient::new(test_config());
545 let key = crypto::generate_master_key();
546 client.set_master_key_raw(key);
547
548 let entries = [
549 ChangeEntry {
550 table: "tasks".to_string(),
551 op: ChangeOp::Insert,
552 row_id: "r1".to_string(),
553 timestamp: Utc::now(),
554 data: Some(serde_json::json!({"title": "Task 1"})),
555 },
556 ChangeEntry {
557 table: "tasks".to_string(),
558 op: ChangeOp::Update,
559 row_id: "r2".to_string(),
560 timestamp: Utc::now(),
561 data: Some(serde_json::json!({"title": "Task 2", "done": true})),
562 },
563 ChangeEntry {
564 table: "events".to_string(),
565 op: ChangeOp::Delete,
566 row_id: "r3".to_string(),
567 timestamp: Utc::now(),
568 data: None,
569 },
570 ];
571
572 let wire_entries: Vec<_> = entries
573 .iter()
574 .cloned()
575 .map(|e| client.encrypt_change(e).unwrap())
576 .collect();
577
578 assert_eq!(wire_entries.len(), 3);
579 assert!(wire_entries[0].data.is_some());
580 assert!(wire_entries[1].data.is_some());
581 assert!(wire_entries[2].data.is_none());
582
583 for (i, wire) in wire_entries.into_iter().enumerate() {
584 let pull = PullChangeEntry {
585 seq: i as i64,
586 device_id: uuid::Uuid::new_v4(),
587 table: wire.table,
588 op: wire.op,
589 row_id: wire.row_id,
590 timestamp: wire.timestamp,
591 data: wire.data,
592 };
593 let decrypted = client.decrypt_change(pull).unwrap();
594 assert_eq!(decrypted.table, entries[i].table);
595 assert_eq!(decrypted.op, entries[i].op);
596 assert_eq!(decrypted.data, entries[i].data);
597 }
598 }
599
600 // ── Unicode and edge-case roundtrips ──
601
602 #[test]
603 fn encrypt_decrypt_roundtrip_unicode_table() {
604 let client = SyncKitClient::new(test_config());
605 let key = crypto::generate_master_key();
606 client.set_master_key_raw(key);
607
608 let entry = ChangeEntry {
609 table: "\u{65E5}\u{672C}\u{8A9E}\u{30C6}\u{30FC}\u{30D6}\u{30EB}".into(),
610 op: ChangeOp::Insert,
611 row_id: "row-1".into(),
612 timestamp: Utc::now(),
613 data: Some(serde_json::json!({"name": "\u{30C6}\u{30B9}\u{30C8}"})),
614 };
615
616 let wire = client.encrypt_change(entry).unwrap();
617 let pull = PullChangeEntry {
618 seq: 1,
619 device_id: uuid::Uuid::new_v4(),
620 table: wire.table,
621 op: wire.op,
622 row_id: wire.row_id,
623 timestamp: wire.timestamp,
624 data: wire.data,
625 };
626 let decrypted = client.decrypt_change(pull).unwrap();
627 assert_eq!(
628 decrypted.table,
629 "\u{65E5}\u{672C}\u{8A9E}\u{30C6}\u{30FC}\u{30D6}\u{30EB}"
630 );
631 }
632
633 #[test]
634 fn encrypt_decrypt_roundtrip_empty_row_id() {
635 let client = SyncKitClient::new(test_config());
636 let key = crypto::generate_master_key();
637 client.set_master_key_raw(key);
638
639 let entry = ChangeEntry {
640 table: "t".into(),
641 op: ChangeOp::Insert,
642 row_id: "".into(),
643 timestamp: Utc::now(),
644 data: Some(serde_json::json!(42)),
645 };
646
647 let wire = client.encrypt_change(entry).unwrap();
648 let pull = PullChangeEntry {
649 seq: 1,
650 device_id: uuid::Uuid::new_v4(),
651 table: wire.table,
652 op: wire.op,
653 row_id: wire.row_id,
654 timestamp: wire.timestamp,
655 data: wire.data,
656 };
657 let decrypted = client.decrypt_change(pull).unwrap();
658 assert_eq!(decrypted.row_id, "");
659 assert_eq!(decrypted.data.unwrap(), serde_json::json!(42));
660 }
661 }
662