Skip to main content

max / makenotwork

3.8 KB · 97 lines History Blame Raw
1 use tokio::task::JoinHandle;
2 use tracing::info;
3
4 use pom::alerts::Alerter;
5 use pom::checks::dns;
6 use pom::config::Config;
7 use pom::db;
8
9 pub(crate) fn spawn_dns_tasks(
10 config: &Config,
11 pool: &sqlx::SqlitePool,
12 cancel: &tokio_util::sync::CancellationToken,
13 alerter: &Option<Alerter>,
14 ) -> Vec<JoinHandle<()>> {
15 let dns_interval_secs = config.serve.dns_check_interval_secs;
16 let mut handles = Vec::new();
17
18 for name in config.target_names() {
19 let target_config = config.get_target(&name).unwrap().clone();
20 if target_config.dns.is_empty() {
21 continue;
22 }
23 let dns_records = target_config.dns.clone();
24 let label = target_config.label.clone();
25 let pool = pool.clone();
26 let alerter = alerter.clone();
27 let cancel = cancel.clone();
28 let n = dns_records.len();
29
30 info!("{name}: DNS check every {dns_interval_secs}s ({n} records)");
31
32 handles.push(tokio::spawn(async move {
33 // Prune stale DNS data from DB (records removed from config)
34 let expected_dns_keys: Vec<(String, String)> = dns_records
35 .iter()
36 .map(|d| (d.name.clone(), d.record_type.to_string()))
37 .collect();
38 match db::prune_stale_dns(&pool, &name, &expected_dns_keys).await {
39 Ok(0) => {}
40 Ok(n) => info!("{name}: pruned {n} stale DNS check rows"),
41 Err(e) => tracing::error!("{name}: failed to prune stale DNS: {e}"),
42 }
43
44 let mut interval = tokio::time::interval(
45 std::time::Duration::from_secs(dns_interval_secs),
46 );
47 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
48 let mut prev_mismatched: std::collections::HashSet<(String, pom::types::DnsRecordType)> = std::collections::HashSet::new();
49
50 interval.tick().await; // consume immediate first tick
51 loop {
52 tokio::select! {
53 _ = cancel.cancelled() => break,
54 _ = interval.tick() => {}
55 }
56 let results = dns::check_dns(&name, &dns_records).await;
57
58 for result in &results {
59 if let Err(e) = db::insert_dns_check(&pool, result).await {
60 tracing::error!("{}: failed to store DNS check for {} {}: {e}", name, result.name, result.record_type);
61 }
62 }
63
64 let current_mismatched: std::collections::HashSet<(String, pom::types::DnsRecordType)> = results
65 .iter()
66 .filter(|r| !r.matches)
67 .map(|r| (r.name.clone(), r.record_type))
68 .collect();
69
70 let ok_count = results.iter().filter(|r| r.matches).count();
71 info!("{name}: DNS {ok_count}/{n} match");
72
73 if let Some(ref alerter) = alerter {
74 // New mismatches
75 let new_mismatches: Vec<&pom::types::DnsCheckResult> = results
76 .iter()
77 .filter(|r| !r.matches && !prev_mismatched.contains(&(r.name.clone(), r.record_type)))
78 .collect();
79 if !new_mismatches.is_empty() {
80 let owned: Vec<pom::types::DnsCheckResult> = new_mismatches.into_iter().cloned().collect();
81 alerter.send_dns_mismatch_alert(&name, &label, &owned).await;
82 }
83
84 // All recovered
85 if !prev_mismatched.is_empty() && current_mismatched.is_empty() {
86 alerter.send_dns_recovery_alert(&name, &label).await;
87 }
88 }
89
90 prev_mismatched = current_mismatched;
91 }
92 }));
93 }
94
95 handles
96 }
97