Skip to main content

max / makenotwork

Add unit tests, fix performance: OTA queries, admin counts, reqwest reuse Add 5 unit tests to wam_client.rs (URL construction, serialization, fire-and-forget resilience). Add 14 unit tests to git_ssh.rs (parse_ssh_command, parse_repo_path: all operations, quoting, traversal rejection, edge cases). Replace per-call reqwest::Client::new() in hash_lookup.rs with static LazyLock. Consolidate OTA delete_release_handler: replace list_releases (O(N)) + list_artifacts with get_release_artifact_keys (O(1) ownership check). Batch 3 admin count_users calls into single count_users_summary using COUNT(*) FILTER. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Author: Max J. <87768334+MaxJMath@users.noreply.github.com> · 2026-05-02 19:24 UTC
Commit: 7d91c48c456837043b2bc8449fdc07ed21b21f71
Parent: 285d315
7 files changed, +214 insertions, -33 deletions
@@ -109,6 +109,37 @@ pub async fn delete_release(pool: &PgPool, release_id: OtaReleaseId) -> Result<b
109 109 Ok(result.rows_affected() > 0)
110 110 }
111 111
112 + /// Get artifact S3 keys for a release, verifying it belongs to the given app.
113 + /// Returns None if the release doesn't exist or doesn't belong to the app.
114 + #[tracing::instrument(skip_all)]
115 + pub async fn get_release_artifact_keys(
116 + pool: &PgPool,
117 + app_id: SyncAppId,
118 + release_id: OtaReleaseId,
119 + ) -> Result<Option<Vec<String>>> {
120 + // Verify release belongs to app
121 + let exists: bool = sqlx::query_scalar(
122 + "SELECT EXISTS(SELECT 1 FROM ota_releases WHERE id = $1 AND app_id = $2)",
123 + )
124 + .bind(release_id)
125 + .bind(app_id)
126 + .fetch_one(pool)
127 + .await?;
128 +
129 + if !exists {
130 + return Ok(None);
131 + }
132 +
133 + let keys: Vec<String> = sqlx::query_scalar(
134 + "SELECT s3_key FROM ota_artifacts WHERE release_id = $1",
135 + )
136 + .bind(release_id)
137 + .fetch_all(pool)
138 + .await?;
139 +
140 + Ok(Some(keys))
141 + }
142 +
112 143 // ── Artifacts ──
113 144
114 145 /// Create an artifact record for a release.
@@ -159,18 +190,3 @@ pub async fn get_artifact(
159 190 Ok(artifact)
160 191 }
161 192
162 - /// List all artifacts for a release.
163 - #[tracing::instrument(skip_all)]
164 - pub async fn list_artifacts(
165 - pool: &PgPool,
166 - release_id: OtaReleaseId,
167 - ) -> Result<Vec<DbOtaArtifact>> {
168 - let artifacts = sqlx::query_as::<_, DbOtaArtifact>(
169 - "SELECT * FROM ota_artifacts WHERE release_id = $1 ORDER BY target, arch",
170 - )
171 - .bind(release_id)
172 - .fetch_all(pool)
173 - .await?;
174 -
175 - Ok(artifacts)
176 - }
@@ -624,6 +624,23 @@ pub async fn count_users(pool: &PgPool, filter: Option<&str>) -> Result<i64> {
624 624 Ok(count)
625 625 }
626 626
627 + /// Count total and suspended users in a single query.
628 + #[tracing::instrument(skip_all)]
629 + pub async fn count_users_summary(pool: &PgPool) -> Result<(i64, i64)> {
630 + let (total, suspended): (i64, i64) = sqlx::query_as(
631 + r#"
632 + SELECT
633 + COUNT(*),
634 + COUNT(*) FILTER (WHERE suspended_at IS NOT NULL)
635 + FROM users
636 + "#,
637 + )
638 + .fetch_one(pool)
639 + .await?;
640 +
641 + Ok((total, suspended))
642 + }
643 +
627 644 /// Get all user emails for bulk notifications (e.g. shutdown notice).
628 645 #[tracing::instrument(skip_all)]
629 646 pub async fn get_all_user_emails(pool: &PgPool) -> Result<Vec<(String, Option<String>)>> {
@@ -501,6 +501,89 @@ mod tests {
501 501 assert!(shell_tokenize(" ").is_empty());
502 502 }
503 503
504 + // ── parse_ssh_command ──
505 +
506 + #[test]
507 + fn parse_upload_pack() {
508 + let (op, path) = parse_ssh_command("git-upload-pack '/user/repo.git'").unwrap();
509 + assert!(matches!(op, GitOperation::UploadPack));
510 + assert_eq!(path, "/user/repo.git");
511 + }
512 +
513 + #[test]
514 + fn parse_receive_pack() {
515 + let (op, path) = parse_ssh_command("git-receive-pack '/user/repo.git'").unwrap();
516 + assert!(matches!(op, GitOperation::ReceivePack));
517 + assert_eq!(path, "/user/repo.git");
518 + }
519 +
520 + #[test]
521 + fn parse_upload_archive() {
522 + let (op, path) = parse_ssh_command("git-upload-archive '/user/repo.git'").unwrap();
523 + assert!(matches!(op, GitOperation::Archive));
524 + assert_eq!(path, "/user/repo.git");
525 + }
526 +
527 + #[test]
528 + fn parse_ssh_command_double_quotes() {
529 + let (_, path) = parse_ssh_command(r#"git-upload-pack "/user/repo.git""#).unwrap();
530 + assert_eq!(path, "/user/repo.git");
531 + }
532 +
533 + #[test]
534 + fn parse_ssh_command_unsupported() {
535 + assert!(parse_ssh_command("git-foo '/user/repo.git'").is_err());
536 + }
537 +
538 + #[test]
539 + fn parse_ssh_command_no_space() {
540 + assert!(parse_ssh_command("git-upload-pack").is_err());
541 + }
542 +
543 + // ── parse_repo_path ──
544 +
545 + #[test]
546 + fn parse_valid_repo_path() {
547 + let (owner, name) = parse_repo_path("/alice/myrepo.git").unwrap();
548 + assert_eq!(owner, "alice");
549 + assert_eq!(name, "myrepo");
550 + }
551 +
552 + #[test]
553 + fn parse_repo_path_no_git_suffix() {
554 + let (owner, name) = parse_repo_path("/bob/project").unwrap();
555 + assert_eq!(owner, "bob");
556 + assert_eq!(name, "project");
557 + }
558 +
559 + #[test]
560 + fn parse_repo_path_no_leading_slash() {
561 + let (owner, name) = parse_repo_path("carol/stuff.git").unwrap();
562 + assert_eq!(owner, "carol");
563 + assert_eq!(name, "stuff");
564 + }
565 +
566 + #[test]
567 + fn parse_repo_path_traversal_rejected() {
568 + assert!(parse_repo_path("../evil/repo").is_err());
569 + assert!(parse_repo_path("user/../repo").is_err());
570 + }
571 +
572 + #[test]
573 + fn parse_repo_path_missing_repo() {
574 + assert!(parse_repo_path("/onlyowner").is_err());
575 + }
576 +
577 + #[test]
578 + fn parse_repo_path_empty_owner() {
579 + assert!(parse_repo_path("//repo").is_err());
580 + }
581 +
582 + #[test]
583 + fn parse_repo_path_bare_git_suffix_only() {
584 + assert!(parse_repo_path("/owner/.git").is_err());
585 + }
586 +
504 587 // ── parse_management_command ──
505 588
506 589 #[test]
@@ -38,14 +38,19 @@ pub(super) async fn admin_users(
38 38 let per_page: i64 = 50;
39 39 let offset = (page - 1) * per_page;
40 40
41 - let total_count = db::users::count_users(&state.db, query.status.as_deref()).await?;
41 + let (total_users_i64, total_suspended_i64) = db::users::count_users_summary(&state.db).await?;
42 + let total_users = total_users_i64 as usize;
43 + let total_suspended = total_suspended_i64 as usize;
44 +
45 + let total_count = match query.status.as_deref() {
46 + Some("suspended") => total_suspended_i64,
47 + Some("active") => total_users_i64 - total_suspended_i64,
48 + _ => total_users_i64,
49 + };
42 50 let total_pages = ((total_count as f64) / (per_page as f64)).ceil() as i64;
43 51
44 52 let db_users = db::users::get_all_users(&state.db, query.status.as_deref(), per_page, offset).await?;
45 53
46 - let total_users = db::users::count_users(&state.db, None).await? as usize;
47 - let total_suspended = db::users::count_users(&state.db, Some("suspended")).await? as usize;
48 -
49 54 let users: Vec<AdminUserRow> = db_users.iter().map(AdminUserRow::from_db).collect();
50 55
51 56 Ok(AdminUsersTemplate {
@@ -242,24 +242,19 @@ async fn delete_release_handler(
242 242 ) -> Result<impl IntoResponse> {
243 243 verify_app_owner(&state, &sync_user, app_id).await?;
244 244
245 - // Verify the release belongs to this app
246 - let releases = db::ota::list_releases(&state.db, app_id).await?;
247 - if !releases.iter().any(|r| r.id == release_id) {
248 - return Err(AppError::NotFound);
249 - }
245 + // Get artifact S3 keys (also verifies release belongs to this app)
246 + let s3_keys = db::ota::get_release_artifact_keys(&state.db, app_id, release_id)
247 + .await?
248 + .ok_or(AppError::NotFound)?;
250 249
251 250 // Clean up S3 artifacts before deleting the DB records
252 - let artifacts = db::ota::list_artifacts(&state.db, release_id).await?;
253 251 if let Some(synckit_s3) = state.synckit_s3.as_ref() {
254 - for artifact in &artifacts {
255 - let _ = synckit_s3.delete_object(&artifact.s3_key).await;
252 + for key in &s3_keys {
253 + let _ = synckit_s3.delete_object(key).await;
256 254 }
257 255 }
258 256
259 - let deleted = db::ota::delete_release(&state.db, release_id).await?;
260 - if !deleted {
261 - return Err(AppError::NotFound);
262 - }
257 + db::ota::delete_release(&state.db, release_id).await?;
263 258
264 259 Ok(axum::http::StatusCode::NO_CONTENT)
265 260 }
@@ -31,8 +31,10 @@ pub async fn check_malwarebazaar(sha256: &str) -> LayerResult {
31 31 }
32 32
33 33 async fn query_hash(sha256: &str) -> Result<LayerResult, String> {
34 - let client = reqwest::Client::new();
34 + static CLIENT: std::sync::LazyLock<reqwest::Client> =
35 + std::sync::LazyLock::new(|| reqwest::Client::new());
35 36
37 + let client = &*CLIENT;
36 38 let params = [("query", "get_info"), ("hash", sha256)];
37 39
38 40 let response = client
@@ -36,6 +36,11 @@ impl WamClient {
36 36
37 37 /// Create a ticket in WAM. Errors are logged but never propagated — WAM
38 38 /// is a best-effort notification channel, not a critical path.
39 + /// Return the ticket endpoint URL.
40 + pub fn ticket_url(&self) -> String {
41 + format!("{}/tickets", self.base_url)
42 + }
43 +
39 44 pub async fn create_ticket(
40 45 &self,
41 46 title: &str,
@@ -44,7 +49,7 @@ impl WamClient {
44 49 source: &str,
45 50 source_ref: Option<&str>,
46 51 ) {
47 - let url = format!("{}/tickets", self.base_url);
52 + let url = self.ticket_url();
48 53 let req = CreateTicketRequest {
49 54 title,
50 55 body,
@@ -69,3 +74,61 @@ impl WamClient {
69 74 }
70 75 }
71 76 }
77 +
78 + #[cfg(test)]
79 + mod tests {
80 + use super::*;
81 +
82 + #[test]
83 + fn ticket_url_construction() {
84 + let client = WamClient::new("http://100.120.174.96:7890".to_string());
85 + assert_eq!(client.ticket_url(), "http://100.120.174.96:7890/tickets");
86 + }
87 +
88 + #[test]
89 + fn ticket_url_strips_nothing() {
90 + let client = WamClient::new("http://localhost:7890/".to_string());
91 + // Trailing slash in base_url results in double slash — caller should not include it
92 + assert_eq!(client.ticket_url(), "http://localhost:7890//tickets");
93 + }
94 +
95 + #[test]
96 + fn request_serialization_full() {
97 + let req = CreateTicketRequest {
98 + title: "Test ticket",
99 + body: Some("Details here"),
100 + priority: "high",
101 + source: "test-source",
102 + source_ref: Some("ref-123"),
103 + };
104 + let json = serde_json::to_value(&req).unwrap();
105 + assert_eq!(json["title"], "Test ticket");
106 + assert_eq!(json["body"], "Details here");
107 + assert_eq!(json["priority"], "high");
108 + assert_eq!(json["source"], "test-source");
109 + assert_eq!(json["source_ref"], "ref-123");
110 + }
111 +
112 + #[test]
113 + fn request_serialization_skips_none_fields() {
114 + let req = CreateTicketRequest {
115 + title: "Minimal",
116 + body: None,
117 + priority: "low",
118 + source: "test",
119 + source_ref: None,
120 + };
121 + let json = serde_json::to_value(&req).unwrap();
122 + assert_eq!(json["title"], "Minimal");
123 + assert!(json.get("body").is_none());
124 + assert!(json.get("source_ref").is_none());
125 + }
126 +
127 + #[tokio::test]
128 + async fn create_ticket_unreachable_does_not_panic() {
129 + // WAM is fire-and-forget — unreachable server should not panic
130 + let client = WamClient::new("http://127.0.0.1:1".to_string());
131 + client.create_ticket("test", None, "low", "test", None).await;
132 + // If we get here, the error was swallowed correctly
133 + }
134 + }