Skip to main content

max / makenotwork

5.2 KB · 151 lines History Blame Raw
1 //! SSE push notification endpoint for SyncKit subscribers.
2 //!
3 //! Clients connect to `GET /api/sync/subscribe?app_id={uuid}` and receive
4 //! zero-data "changed" events whenever another device pushes changes.
5 //! No payload is included; the client should pull to get actual data.
6 //! This preserves E2E encryption (server never sends plaintext).
7
8 use std::convert::Infallible;
9 use std::sync::atomic::Ordering;
10 use std::time::Duration;
11
12 use axum::{
13 extract::{Query, State},
14 response::sse::{Event, KeepAlive, Sse},
15 };
16 use serde::Deserialize;
17 use tokio_stream::wrappers::BroadcastStream;
18 use tokio_stream::StreamExt;
19
20 use crate::{
21 constants,
22 db::{SyncAppId, UserId},
23 error::{AppError, Result},
24 synckit_auth::SyncUser,
25 AppState,
26 };
27
28 /// Drop guard that decrements the per-user SSE connection counter.
29 struct SseConnectionGuard {
30 sse_connections: std::sync::Arc<dashmap::DashMap<UserId, std::sync::atomic::AtomicUsize>>,
31 sync_notify: std::sync::Arc<dashmap::DashMap<(SyncAppId, UserId), tokio::sync::broadcast::Sender<()>>>,
32 user_id: UserId,
33 app_id: SyncAppId,
34 }
35
36 impl Drop for SseConnectionGuard {
37 fn drop(&mut self) {
38 // Decrement connection counter, then remove entry if it hit zero.
39 // Must drop the read guard before calling remove() to avoid deadlocking
40 // on the same DashMap shard.
41 let should_remove_connection = self.sse_connections.get(&self.user_id).map(|counter| {
42 let prev = counter.value().fetch_sub(1, Ordering::AcqRel);
43 prev <= 1
44 }).unwrap_or(false);
45
46 if should_remove_connection {
47 self.sse_connections.remove(&self.user_id);
48 }
49
50 // Prune sync_notify channel if no receivers remain.
51 // Same pattern: read first, drop guard, then remove.
52 let key = (self.app_id, self.user_id);
53 let should_remove_notify = self.sync_notify.get(&key)
54 .map(|entry| entry.value().receiver_count() == 0)
55 .unwrap_or(false);
56
57 if should_remove_notify {
58 self.sync_notify.remove(&key);
59 }
60 }
61 }
62
63 /// Stream wrapper that holds an SSE connection guard. When the stream is
64 /// dropped (client disconnects), the guard decrements the connection counter.
65 struct GuardedStream<S> {
66 inner: S,
67 _guard: SseConnectionGuard,
68 }
69
70 impl<S: tokio_stream::Stream + Unpin> tokio_stream::Stream for GuardedStream<S> {
71 type Item = S::Item;
72
73 fn poll_next(
74 mut self: std::pin::Pin<&mut Self>,
75 cx: &mut std::task::Context<'_>,
76 ) -> std::task::Poll<Option<Self::Item>> {
77 std::pin::Pin::new(&mut self.inner).poll_next(cx)
78 }
79 }
80
81 #[derive(Deserialize)]
82 pub struct SubscribeQuery {
83 pub app_id: SyncAppId,
84 }
85
86 /// SSE endpoint for real-time sync push notifications.
87 ///
88 /// `GET /api/sync/subscribe?app_id={uuid}`: JWT auth required.
89 ///
90 /// Returns an SSE stream that emits `event: changed` (with empty data) whenever
91 /// a push is made to the same app+user. The client should pull on each event.
92 /// A keepalive comment is sent every 30 seconds to prevent proxy timeouts.
93 #[tracing::instrument(skip_all, name = "synckit::subscribe")]
94 pub(super) async fn sync_subscribe(
95 State(state): State<AppState>,
96 sync_user: SyncUser,
97 Query(query): Query<SubscribeQuery>,
98 ) -> Result<Sse<impl tokio_stream::Stream<Item = std::result::Result<Event, Infallible>>>> {
99 // Validate that the requested app_id matches the JWT's app_id
100 if query.app_id != sync_user.app_id {
101 return Err(AppError::BadRequest(
102 "app_id does not match authenticated session".to_string(),
103 ));
104 }
105
106 // Enforce per-user SSE connection limit
107 let counter = state.sse_connections
108 .entry(sync_user.user_id)
109 .or_insert_with(|| std::sync::atomic::AtomicUsize::new(0));
110 let current = counter.value().fetch_add(1, Ordering::AcqRel);
111 if current >= constants::SYNCKIT_MAX_SSE_CONNECTIONS_PER_USER {
112 counter.value().fetch_sub(1, Ordering::AcqRel);
113 return Err(AppError::BadRequest(
114 "Too many concurrent SSE connections".to_string(),
115 ));
116 }
117
118 let key = (sync_user.app_id, sync_user.user_id);
119
120 let guard = SseConnectionGuard {
121 sse_connections: state.sse_connections.clone(),
122 sync_notify: state.sync_notify.clone(),
123 user_id: sync_user.user_id,
124 app_id: sync_user.app_id,
125 };
126
127 // Get or create the broadcast channel for this app+user
128 let rx = {
129 let entry = state.sync_notify.entry(key).or_insert_with(|| {
130 let (tx, _) = tokio::sync::broadcast::channel(16);
131 tx
132 });
133 entry.value().subscribe()
134 };
135
136 let stream = BroadcastStream::new(rx).filter_map(|result| match result {
137 Ok(()) => Some(Ok(Event::default().event("changed").data("{}"))),
138 Err(_) => None, // Lagged — skip missed events, client will pull anyway
139 });
140
141 // Wrap stream with the connection guard — when the client disconnects and
142 // the stream is dropped, the guard's Drop impl decrements the counter.
143 let stream = GuardedStream { inner: stream, _guard: guard };
144
145 Ok(Sse::new(stream).keep_alive(
146 KeepAlive::new()
147 .interval(Duration::from_secs(30))
148 .text("keepalive"),
149 ))
150 }
151