//! SSE push notification endpoint for SyncKit subscribers. //! //! Clients connect to `GET /api/sync/subscribe?app_id={uuid}` and receive //! zero-data "changed" events whenever another device pushes changes. //! No payload is included; the client should pull to get actual data. //! This preserves E2E encryption (server never sends plaintext). use std::convert::Infallible; use std::sync::atomic::Ordering; use std::time::Duration; use axum::{ extract::{Query, State}, response::sse::{Event, KeepAlive, Sse}, }; use serde::Deserialize; use tokio_stream::wrappers::BroadcastStream; use tokio_stream::StreamExt; use crate::{ constants, db::{SyncAppId, UserId}, error::{AppError, Result}, synckit_auth::SyncUser, AppState, }; /// Drop guard that decrements the per-user SSE connection counter. struct SseConnectionGuard { sse_connections: std::sync::Arc>, sync_notify: std::sync::Arc>>, user_id: UserId, app_id: SyncAppId, } impl Drop for SseConnectionGuard { fn drop(&mut self) { // Decrement connection counter, then remove entry if it hit zero. // Must drop the read guard before calling remove() to avoid deadlocking // on the same DashMap shard. let should_remove_connection = self.sse_connections.get(&self.user_id).map(|counter| { let prev = counter.value().fetch_sub(1, Ordering::AcqRel); prev <= 1 }).unwrap_or(false); if should_remove_connection { self.sse_connections.remove(&self.user_id); } // Prune sync_notify channel if no receivers remain. // Same pattern: read first, drop guard, then remove. let key = (self.app_id, self.user_id); let should_remove_notify = self.sync_notify.get(&key) .map(|entry| entry.value().receiver_count() == 0) .unwrap_or(false); if should_remove_notify { self.sync_notify.remove(&key); } } } /// Stream wrapper that holds an SSE connection guard. When the stream is /// dropped (client disconnects), the guard decrements the connection counter. struct GuardedStream { inner: S, _guard: SseConnectionGuard, } impl tokio_stream::Stream for GuardedStream { type Item = S::Item; fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { std::pin::Pin::new(&mut self.inner).poll_next(cx) } } #[derive(Deserialize)] pub struct SubscribeQuery { pub app_id: SyncAppId, } /// SSE endpoint for real-time sync push notifications. /// /// `GET /api/sync/subscribe?app_id={uuid}`: JWT auth required. /// /// Returns an SSE stream that emits `event: changed` (with empty data) whenever /// a push is made to the same app+user. The client should pull on each event. /// A keepalive comment is sent every 30 seconds to prevent proxy timeouts. #[tracing::instrument(skip_all, name = "synckit::subscribe")] pub(super) async fn sync_subscribe( State(state): State, sync_user: SyncUser, Query(query): Query, ) -> Result>>> { // Validate that the requested app_id matches the JWT's app_id if query.app_id != sync_user.app_id { return Err(AppError::BadRequest( "app_id does not match authenticated session".to_string(), )); } // Enforce per-user SSE connection limit let counter = state.sse_connections .entry(sync_user.user_id) .or_insert_with(|| std::sync::atomic::AtomicUsize::new(0)); let current = counter.value().fetch_add(1, Ordering::AcqRel); if current >= constants::SYNCKIT_MAX_SSE_CONNECTIONS_PER_USER { counter.value().fetch_sub(1, Ordering::AcqRel); return Err(AppError::BadRequest( "Too many concurrent SSE connections".to_string(), )); } let key = (sync_user.app_id, sync_user.user_id); let guard = SseConnectionGuard { sse_connections: state.sse_connections.clone(), sync_notify: state.sync_notify.clone(), user_id: sync_user.user_id, app_id: sync_user.app_id, }; // Get or create the broadcast channel for this app+user let rx = { let entry = state.sync_notify.entry(key).or_insert_with(|| { let (tx, _) = tokio::sync::broadcast::channel(16); tx }); entry.value().subscribe() }; let stream = BroadcastStream::new(rx).filter_map(|result| match result { Ok(()) => Some(Ok(Event::default().event("changed").data("{}"))), Err(_) => None, // Lagged — skip missed events, client will pull anyway }); // Wrap stream with the connection guard — when the client disconnects and // the stream is dropped, the guard's Drop impl decrements the counter. let stream = GuardedStream { inner: stream, _guard: guard }; Ok(Sse::new(stream).keep_alive( KeepAlive::new() .interval(Duration::from_secs(30)) .text("keepalive"), )) }