Skip to main content

max / makenotwork

7.5 KB · 235 lines History Blame Raw
1 //! Per-test database isolation using PostgreSQL template databases.
2 //!
3 //! A shared template database is created once (with all migrations) and each
4 //! test gets a cheap `CREATE DATABASE ... TEMPLATE` clone. Dropped
5 //! automatically when `TestDb` goes out of scope.
6
7 use sqlx::postgres::PgPoolOptions;
8 use sqlx::{Connection, Executor, PgConnection, PgPool};
9 use std::sync::Once;
10 use std::time::Duration;
11 use uuid::Uuid;
12
13 /// Name of the shared template database, created once per test suite run.
14 const TEMPLATE_DB_NAME: &str = "mnw_test_template";
15
16 /// Ensures template creation runs exactly once, across all threads and runtimes.
17 static TEMPLATE_INIT: Once = Once::new();
18
19 fn admin_url() -> String {
20 std::env::var("TEST_DATABASE_URL")
21 .unwrap_or_else(|_| "postgres://localhost/postgres".to_string())
22 }
23
24 /// Create the template database with all migrations. Runs in a dedicated
25 /// single-threaded tokio runtime so it works from any context (including
26 /// inside `#[tokio::test]` and plain `#[test]`).
27 fn ensure_template() {
28 TEMPLATE_INIT.call_once(|| {
29 let rt = tokio::runtime::Builder::new_current_thread()
30 .enable_all()
31 .build()
32 .expect("build template setup runtime");
33
34 rt.block_on(async {
35 let t0 = std::time::Instant::now();
36 let admin = admin_url();
37 let mut conn = PgConnection::connect(&admin)
38 .await
39 .expect("connect to admin DB for template setup");
40
41 // Drop stale template if it exists (migrations may have changed)
42 let _ = conn.execute(format!(
43 "DROP DATABASE IF EXISTS \"{TEMPLATE_DB_NAME}\" WITH (FORCE)"
44 ).as_str()).await;
45
46 conn.execute(format!(
47 "CREATE DATABASE \"{TEMPLATE_DB_NAME}\""
48 ).as_str())
49 .await
50 .expect("create template database");
51
52 // Connect to the template and run all migrations
53 let tpl_url = replace_db_name(&admin, TEMPLATE_DB_NAME);
54 let tpl_pool = PgPoolOptions::new()
55 .max_connections(2)
56 .acquire_timeout(Duration::from_secs(10))
57 .connect(&tpl_url)
58 .await
59 .expect("connect to template database");
60
61 let t_migrate = std::time::Instant::now();
62 sqlx::migrate!("./migrations")
63 .run(&tpl_pool)
64 .await
65 .expect("run migrations on template");
66 let migrate_ms = t_migrate.elapsed().as_millis();
67
68 // Also create the session store table
69 let session_store = tower_sessions_sqlx_store::PostgresStore::new(tpl_pool.clone());
70 session_store.migrate().await.expect("session store migration on template");
71
72 tpl_pool.close().await;
73
74 let total_ms = t0.elapsed().as_millis();
75 eprintln!(
76 "[test-harness] Template DB created in {}ms (migrations: {}ms)",
77 total_ms, migrate_ms
78 );
79 });
80 });
81 }
82
83 /// An isolated test database that cleans up after itself.
84 pub struct TestDb {
85 pub pool: PgPool,
86 db_name: String,
87 admin_url: String,
88 #[allow(dead_code)]
89 test_url: String,
90 /// Whether the session store table already exists (from template).
91 pub session_migrated: bool,
92 }
93
94 impl TestDb {
95 /// Create a fresh database cloned from the shared template.
96 pub async fn new() -> Self {
97 // ensure_template uses std::sync::Once + its own runtime, safe from any context.
98 // When called from an async context, we run it on a blocking thread to avoid
99 // nesting runtimes.
100 tokio::task::spawn_blocking(ensure_template)
101 .await
102 .expect("template setup panicked");
103
104 let t0 = std::time::Instant::now();
105 let admin = admin_url();
106 let db_name = format!("mnw_test_{}", Uuid::new_v4().simple());
107
108 let mut admin_conn = PgConnection::connect(&admin)
109 .await
110 .expect("Failed to connect to admin database");
111
112 admin_conn
113 .execute(
114 format!(
115 "CREATE DATABASE \"{db_name}\" TEMPLATE \"{TEMPLATE_DB_NAME}\""
116 )
117 .as_str(),
118 )
119 .await
120 .expect("Failed to create test database from template");
121
122 let test_url = replace_db_name(&admin, &db_name);
123
124 let pool = PgPoolOptions::new()
125 .max_connections(5)
126 .acquire_timeout(Duration::from_secs(5))
127 .connect(&test_url)
128 .await
129 .expect("Failed to connect to test database");
130
131 let clone_ms = t0.elapsed().as_millis();
132 if clone_ms > 500 {
133 eprintln!(
134 "[test-harness] SLOW DB clone: {}ms for {}",
135 clone_ms, db_name
136 );
137 }
138
139 TestDb {
140 pool,
141 db_name,
142 admin_url: admin,
143 test_url,
144 session_migrated: true,
145 }
146 }
147
148 /// The connection URL for this test database.
149 #[allow(dead_code)]
150 pub fn url(&self) -> &str {
151 &self.test_url
152 }
153 }
154
155 impl Drop for TestDb {
156 fn drop(&mut self) {
157 let admin_url = self.admin_url.clone();
158 let db_name = self.db_name.clone();
159
160 self.pool.close_event();
161
162 std::thread::spawn(move || {
163 let rt = tokio::runtime::Builder::new_current_thread()
164 .enable_all()
165 .build()
166 .expect("Failed to build cleanup runtime");
167
168 rt.block_on(async {
169 if let Ok(mut conn) = PgConnection::connect(&admin_url).await {
170 let _ = conn
171 .execute(
172 format!(
173 "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}'",
174 db_name
175 )
176 .as_str(),
177 )
178 .await;
179
180 let _ = conn
181 .execute(format!("DROP DATABASE IF EXISTS \"{}\"", db_name).as_str())
182 .await;
183 }
184 });
185 })
186 .join()
187 .ok();
188 }
189 }
190
191 /// Replace the database name in a PostgreSQL connection URL.
192 fn replace_db_name(url: &str, new_db: &str) -> String {
193 if let Some(pos) = url.rfind('/') {
194 let base = &url[..pos];
195 let query = url[pos + 1..]
196 .find('?')
197 .map(|q| &url[pos + 1 + q..])
198 .unwrap_or("");
199 if query.is_empty() {
200 format!("{}/{}", base, new_db)
201 } else {
202 format!("{}/{}{}", base, new_db, query)
203 }
204 } else {
205 panic!("Invalid database URL: no '/' found");
206 }
207 }
208
209 #[cfg(test)]
210 pub mod tests {
211 use super::*;
212
213 #[test]
214 fn replace_db_name_simple() {
215 let result = replace_db_name("postgres://localhost/postgres", "test_db");
216 assert_eq!(result, "postgres://localhost/test_db");
217 }
218
219 #[test]
220 fn replace_db_name_with_auth() {
221 let result =
222 replace_db_name("postgres://user:pass@localhost:5432/mydb", "test_db");
223 assert_eq!(result, "postgres://user:pass@localhost:5432/test_db");
224 }
225
226 #[test]
227 fn replace_db_name_with_query() {
228 let result = replace_db_name(
229 "postgres://localhost/postgres?sslmode=disable",
230 "test_db",
231 );
232 assert_eq!(result, "postgres://localhost/test_db?sslmode=disable");
233 }
234 }
235