| 1 |
|
| 2 |
|
| 3 |
|
| 4 |
|
| 5 |
|
| 6 |
|
| 7 |
|
| 8 |
|
| 9 |
|
| 10 |
|
| 11 |
|
| 12 |
|
| 13 |
|
| 14 |
|
| 15 |
|
| 16 |
|
| 17 |
|
| 18 |
|
| 19 |
|
| 20 |
use anyhow::{Context, Result}; |
| 21 |
use async_trait::async_trait; |
| 22 |
use std::process::{ExitStatus, Stdio}; |
| 23 |
use std::sync::Arc; |
| 24 |
use tokio::io::{AsyncRead, AsyncReadExt}; |
| 25 |
use tokio::process::Command; |
| 26 |
use tokio::sync::Mutex; |
| 27 |
|
| 28 |
|
| 29 |
|
| 30 |
pub const SSH_FLAGS: &[&str] = &[ |
| 31 |
"-o", |
| 32 |
"BatchMode=yes", |
| 33 |
"-o", |
| 34 |
"ConnectTimeout=10", |
| 35 |
"-o", |
| 36 |
"StrictHostKeyChecking=accept-new", |
| 37 |
]; |
| 38 |
|
| 39 |
|
| 40 |
|
| 41 |
|
| 42 |
|
| 43 |
|
| 44 |
|
| 45 |
|
| 46 |
|
| 47 |
#[async_trait] |
| 48 |
pub trait LogSink: Send { |
| 49 |
async fn write_chunk(&mut self, bytes: &[u8]); |
| 50 |
} |
| 51 |
|
| 52 |
|
| 53 |
|
| 54 |
#[derive(Debug, Clone)] |
| 55 |
pub struct RemoteHost { |
| 56 |
ssh_target: String, |
| 57 |
} |
| 58 |
|
| 59 |
|
| 60 |
|
| 61 |
#[derive(Debug)] |
| 62 |
pub struct RunOutput { |
| 63 |
pub status: ExitStatus, |
| 64 |
pub stdout: Vec<u8>, |
| 65 |
pub stderr: Vec<u8>, |
| 66 |
} |
| 67 |
|
| 68 |
impl RunOutput { |
| 69 |
pub fn success(&self) -> bool { |
| 70 |
self.status.success() |
| 71 |
} |
| 72 |
} |
| 73 |
|
| 74 |
impl RemoteHost { |
| 75 |
|
| 76 |
|
| 77 |
pub fn new(ssh_target: impl Into<String>) -> Self { |
| 78 |
Self { ssh_target: ssh_target.into() } |
| 79 |
} |
| 80 |
|
| 81 |
pub fn is_local(&self) -> bool { |
| 82 |
self.ssh_target == "local" || self.ssh_target.is_empty() |
| 83 |
} |
| 84 |
|
| 85 |
pub fn ssh_target(&self) -> &str { |
| 86 |
&self.ssh_target |
| 87 |
} |
| 88 |
|
| 89 |
|
| 90 |
|
| 91 |
|
| 92 |
|
| 93 |
pub(crate) fn command(&self, script: &str) -> Command { |
| 94 |
if self.is_local() { |
| 95 |
let mut cmd = Command::new("sh"); |
| 96 |
cmd.arg("-c").arg(script); |
| 97 |
cmd |
| 98 |
} else { |
| 99 |
let mut cmd = Command::new("ssh"); |
| 100 |
cmd.args(SSH_FLAGS).arg(&self.ssh_target).arg(script); |
| 101 |
cmd |
| 102 |
} |
| 103 |
} |
| 104 |
|
| 105 |
|
| 106 |
|
| 107 |
|
| 108 |
|
| 109 |
pub async fn run_streaming<S>(&self, script: &str, sink: Arc<Mutex<S>>) -> Result<RunOutput> |
| 110 |
where |
| 111 |
S: LogSink + Send + 'static, |
| 112 |
{ |
| 113 |
let mut child = self |
| 114 |
.command(script) |
| 115 |
.stdout(Stdio::piped()) |
| 116 |
.stderr(Stdio::piped()) |
| 117 |
.kill_on_drop(true) |
| 118 |
.spawn() |
| 119 |
.with_context(|| format!("spawning command on {}", self.ssh_target))?; |
| 120 |
|
| 121 |
let stdout_task = tokio::spawn(drain(child.stdout.take(), sink.clone())); |
| 122 |
let stderr_task = tokio::spawn(drain(child.stderr.take(), sink.clone())); |
| 123 |
let status = child.wait().await.context("waiting on child")?; |
| 124 |
let stdout = stdout_task.await.unwrap_or_default(); |
| 125 |
let stderr = stderr_task.await.unwrap_or_default(); |
| 126 |
Ok(RunOutput { status, stdout, stderr }) |
| 127 |
} |
| 128 |
} |
| 129 |
|
| 130 |
|
| 131 |
async fn drain<R, S>(stream: Option<R>, sink: Arc<Mutex<S>>) -> Vec<u8> |
| 132 |
where |
| 133 |
R: AsyncRead + Unpin + Send + 'static, |
| 134 |
S: LogSink + Send + 'static, |
| 135 |
{ |
| 136 |
let mut total = Vec::new(); |
| 137 |
let Some(mut s) = stream else { return total }; |
| 138 |
let mut buf = [0u8; 4096]; |
| 139 |
loop { |
| 140 |
match s.read(&mut buf).await { |
| 141 |
Ok(0) | Err(_) => break, |
| 142 |
Ok(n) => { |
| 143 |
total.extend_from_slice(&buf[..n]); |
| 144 |
sink.lock().await.write_chunk(&buf[..n]).await; |
| 145 |
} |
| 146 |
} |
| 147 |
} |
| 148 |
total |
| 149 |
} |
| 150 |
|
| 151 |
|
| 152 |
|
| 153 |
|
| 154 |
pub fn sh_quote(s: &str) -> String { |
| 155 |
let escaped = s.replace('\'', r"'\''"); |
| 156 |
format!("'{escaped}'") |
| 157 |
} |
| 158 |
|
| 159 |
#[cfg(test)] |
| 160 |
mod tests { |
| 161 |
use super::*; |
| 162 |
|
| 163 |
|
| 164 |
#[derive(Default)] |
| 165 |
pub(crate) struct VecSink(pub Vec<u8>); |
| 166 |
#[async_trait] |
| 167 |
impl LogSink for VecSink { |
| 168 |
async fn write_chunk(&mut self, bytes: &[u8]) { |
| 169 |
self.0.extend_from_slice(bytes); |
| 170 |
} |
| 171 |
} |
| 172 |
|
| 173 |
#[test] |
| 174 |
fn sh_quote_escapes() { |
| 175 |
assert_eq!(sh_quote("hello"), "'hello'"); |
| 176 |
assert_eq!(sh_quote("it's"), r"'it'\''s'"); |
| 177 |
} |
| 178 |
|
| 179 |
#[test] |
| 180 |
fn local_detection() { |
| 181 |
assert!(RemoteHost::new("local").is_local()); |
| 182 |
assert!(RemoteHost::new("").is_local()); |
| 183 |
assert!(!RemoteHost::new("mbp").is_local()); |
| 184 |
} |
| 185 |
|
| 186 |
#[tokio::test] |
| 187 |
async fn local_run_streams_stdout_and_captures_status() { |
| 188 |
let host = RemoteHost::new("local"); |
| 189 |
let sink = Arc::new(Mutex::new(VecSink::default())); |
| 190 |
let out = host |
| 191 |
.run_streaming("printf 'hello '; printf 'world'", sink.clone()) |
| 192 |
.await |
| 193 |
.unwrap(); |
| 194 |
assert!(out.success()); |
| 195 |
assert_eq!(out.stdout, b"hello world"); |
| 196 |
assert_eq!(sink.lock().await.0, b"hello world"); |
| 197 |
} |
| 198 |
|
| 199 |
#[tokio::test] |
| 200 |
async fn local_run_streams_stderr_too() { |
| 201 |
let host = RemoteHost::new("local"); |
| 202 |
let sink = Arc::new(Mutex::new(VecSink::default())); |
| 203 |
let out = host |
| 204 |
.run_streaming("echo out; echo err 1>&2", sink.clone()) |
| 205 |
.await |
| 206 |
.unwrap(); |
| 207 |
assert!(out.success()); |
| 208 |
assert_eq!(out.stdout, b"out\n"); |
| 209 |
assert_eq!(out.stderr, b"err\n"); |
| 210 |
let seen = sink.lock().await.0.clone(); |
| 211 |
assert!(seen.windows(4).any(|w| w == b"out\n")); |
| 212 |
assert!(seen.windows(4).any(|w| w == b"err\n")); |
| 213 |
} |
| 214 |
|
| 215 |
#[tokio::test] |
| 216 |
async fn local_run_reports_nonzero_exit() { |
| 217 |
let host = RemoteHost::new("local"); |
| 218 |
let sink = Arc::new(Mutex::new(VecSink::default())); |
| 219 |
let out = host.run_streaming("exit 3", sink).await.unwrap(); |
| 220 |
assert!(!out.success()); |
| 221 |
assert_eq!(out.status.code(), Some(3)); |
| 222 |
} |
| 223 |
} |
| 224 |
|