Skip to main content

max / makenotwork

3.1 KB · 107 lines History Blame Raw
1 //! Per-test database isolation.
2 //!
3 //! Each test gets its own PostgreSQL database created from scratch with all
4 //! migrations applied. Dropped automatically when `TestDb` goes out of scope.
5
6 use sqlx::postgres::PgPoolOptions;
7 use sqlx::{Connection, Executor, PgConnection, PgPool};
8 use std::time::Duration;
9 use uuid::Uuid;
10
11 pub struct TestDb {
12 pub pool: PgPool,
13 db_name: String,
14 admin_url: String,
15 }
16
17 impl TestDb {
18 pub async fn new() -> Self {
19 let admin_url = std::env::var("TEST_DATABASE_URL")
20 .unwrap_or_else(|_| "postgres://localhost/postgres".to_string());
21
22 let db_name = format!("mt_test_{}", Uuid::new_v4().simple());
23
24 let mut admin_conn = PgConnection::connect(&admin_url)
25 .await
26 .expect("Failed to connect to admin database");
27
28 admin_conn
29 .execute(format!("CREATE DATABASE \"{}\"", db_name).as_str())
30 .await
31 .expect("Failed to create test database");
32
33 let test_url = Self::replace_db_name(&admin_url, &db_name);
34
35 let pool = PgPoolOptions::new()
36 .max_connections(5)
37 .acquire_timeout(Duration::from_secs(5))
38 .connect(&test_url)
39 .await
40 .expect("Failed to connect to test database");
41
42 sqlx::migrate!("./migrations")
43 .run(&pool)
44 .await
45 .expect("Failed to run migrations on test database");
46
47 TestDb {
48 pool,
49 db_name,
50 admin_url,
51 }
52 }
53
54 fn replace_db_name(url: &str, new_db: &str) -> String {
55 if let Some(pos) = url.rfind('/') {
56 let base = &url[..pos];
57 let query = url[pos + 1..]
58 .find('?')
59 .map(|q| &url[pos + 1 + q..])
60 .unwrap_or("");
61 if query.is_empty() {
62 format!("{}/{}", base, new_db)
63 } else {
64 format!("{}/{}{}", base, new_db, query)
65 }
66 } else {
67 panic!("Invalid database URL: no '/' found");
68 }
69 }
70 }
71
72 impl Drop for TestDb {
73 fn drop(&mut self) {
74 let admin_url = self.admin_url.clone();
75 let db_name = self.db_name.clone();
76
77 self.pool.close_event();
78
79 std::thread::spawn(move || {
80 let rt = tokio::runtime::Builder::new_current_thread()
81 .enable_all()
82 .build()
83 .expect("Failed to build cleanup runtime");
84
85 rt.block_on(async {
86 if let Ok(mut conn) = PgConnection::connect(&admin_url).await {
87 let _ = conn
88 .execute(
89 format!(
90 "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}'",
91 db_name
92 )
93 .as_str(),
94 )
95 .await;
96
97 let _ = conn
98 .execute(format!("DROP DATABASE IF EXISTS \"{}\"", db_name).as_str())
99 .await;
100 }
101 });
102 })
103 .join()
104 .ok();
105 }
106 }
107