Skip to main content

max / multithreaded

5.1 KB · 154 lines History Blame Raw
1 use multithreaded::{config::Config, csrf, AppState};
2 use sqlx::PgPool;
3 use tokio::net::TcpListener;
4 use tower_http::services::ServeDir;
5 use tower_sessions::SessionManagerLayer;
6 use tower_sessions::cookie::SameSite;
7 use tower_sessions::ExpiredDeletion;
8 use tower_sessions_sqlx_store::PostgresStore;
9 use tracing_subscriber::EnvFilter;
10
11 #[tokio::main]
12 async fn main() {
13 tracing_subscriber::fmt()
14 .with_env_filter(EnvFilter::from_default_env())
15 .init();
16
17 dotenvy::dotenv().ok();
18
19 let database_url = std::env::var("DATABASE_URL")
20 .expect("DATABASE_URL must be set");
21
22 let pool = PgPool::connect(&database_url)
23 .await
24 .expect("failed to connect to database");
25
26 sqlx::migrate!()
27 .run(&pool)
28 .await
29 .expect("failed to run migrations");
30
31 tracing::info!("migrations applied");
32
33 // Seed initial data if --seed flag is passed, then exit
34 if std::env::args().any(|a| a == "--seed") {
35 multithreaded::seed::run(&pool).await;
36 tracing::info!("seed data inserted");
37 return;
38 }
39
40 let config = Config::from_env();
41
42 // Optional S3 storage for image uploads
43 let s3 = if let Some(ref s3_config) = config.s3 {
44 match multithreaded::storage::S3Storage::new(s3_config).await {
45 Ok(client) => {
46 tracing::info!("S3 storage configured (bucket: {})", s3_config.bucket);
47 Some(std::sync::Arc::new(client))
48 }
49 Err(e) => {
50 tracing::warn!("S3 storage unavailable: {e}");
51 None
52 }
53 }
54 } else {
55 tracing::info!("S3 storage not configured (image uploads disabled)");
56 None
57 };
58
59 let state = AppState {
60 db: pool.clone(),
61 config,
62 http: reqwest::Client::builder()
63 .timeout(std::time::Duration::from_secs(15))
64 .connect_timeout(std::time::Duration::from_secs(5))
65 .build()
66 .expect("failed to build HTTP client"),
67 preview_http: multithreaded::link_preview::build_preview_client(),
68 s3,
69 };
70
71 // Session store backed by PostgreSQL
72 let session_store = PostgresStore::new(pool);
73 session_store.migrate().await.expect("failed to migrate session store");
74
75 let deletion_task = tokio::task::spawn(
76 session_store
77 .clone()
78 .continuously_delete_expired(tokio::time::Duration::from_secs(3600)),
79 );
80
81 let session_layer = SessionManagerLayer::new(session_store)
82 .with_name("mt_session")
83 .with_same_site(SameSite::Lax)
84 .with_expiry(tower_sessions::Expiry::OnInactivity(
85 time::Duration::days(7),
86 ))
87 .with_secure(state.config.cookie_secure);
88
89 let app = multithreaded::routes::forum_routes(state.clone())
90 .layer(axum::middleware::from_fn(csrf::csrf_middleware))
91 .layer(session_layer)
92 .layer(tower_http::set_header::SetResponseHeaderLayer::overriding(
93 axum::http::header::CONTENT_SECURITY_POLICY,
94 axum::http::HeaderValue::from_static(
95 "default-src 'self'; img-src 'self'; style-src 'self' 'unsafe-inline'; frame-ancestors 'none'",
96 ),
97 ))
98 .layer(tower_http::set_header::SetResponseHeaderLayer::overriding(
99 axum::http::header::X_CONTENT_TYPE_OPTIONS,
100 axum::http::HeaderValue::from_static("nosniff"),
101 ))
102 .layer(tower_http::set_header::SetResponseHeaderLayer::overriding(
103 axum::http::header::X_FRAME_OPTIONS,
104 axum::http::HeaderValue::from_static("DENY"),
105 ))
106 .layer(tower_http::set_header::SetResponseHeaderLayer::if_not_present(
107 axum::http::header::CACHE_CONTROL,
108 axum::http::HeaderValue::from_static("private, no-cache"),
109 ))
110 // Internal API routes — HMAC auth only, no CSRF/session middleware
111 .merge(multithreaded::routes::internal::internal_routes(state))
112 .nest_service("/static", ServeDir::new("static"));
113
114 let host = std::env::var("HOST").unwrap_or_else(|_| "0.0.0.0".to_string());
115 let port = std::env::var("PORT").unwrap_or_else(|_| "3400".to_string());
116 let addr = format!("{host}:{port}");
117
118 let listener = TcpListener::bind(&addr)
119 .await
120 .expect("failed to bind");
121
122 tracing::info!("listening on {}", listener.local_addr().unwrap());
123
124 axum::serve(
125 listener,
126 app.into_make_service_with_connect_info::<std::net::SocketAddr>(),
127 )
128 .with_graceful_shutdown(shutdown_signal())
129 .await
130 .expect("server error");
131
132 deletion_task.abort();
133 let _ = deletion_task.await;
134 }
135
136 async fn shutdown_signal() {
137 use tokio::signal;
138 let ctrl_c = async { signal::ctrl_c().await.expect("failed to install Ctrl+C handler") };
139 #[cfg(unix)]
140 let terminate = async {
141 signal::unix::signal(signal::unix::SignalKind::terminate())
142 .expect("failed to install signal handler")
143 .recv()
144 .await;
145 };
146 #[cfg(not(unix))]
147 let terminate = std::future::pending::<()>();
148 tokio::select! {
149 _ = ctrl_c => {},
150 _ = terminate => {},
151 }
152 tracing::info!("Shutdown signal received");
153 }
154