Skip to main content

max / goingson

13.6 KB · 466 lines History Blame Raw
1 //! Rhai script engine with safety configuration.
2 //!
3 //! Provides a sandboxed Rhai engine for executing plugin scripts.
4
5 use rhai::{Dynamic, Engine, AST};
6 use std::sync::Arc;
7
8 use crate::api::create_goingson_module;
9 use crate::error::{PluginError, Result};
10
11 /// Safety limits for script execution.
12 #[derive(Debug, Clone)]
13 pub struct SafetyLimits {
14 /// Maximum number of operations before termination.
15 pub max_operations: u64,
16 /// Maximum call stack depth (recursion limit).
17 pub max_call_levels: usize,
18 /// Maximum string size in bytes.
19 pub max_string_size: usize,
20 /// Maximum array elements.
21 pub max_array_size: usize,
22 /// Maximum map entries.
23 pub max_map_size: usize,
24 }
25
26 impl Default for SafetyLimits {
27 fn default() -> Self {
28 Self {
29 max_operations: 100_000,
30 max_call_levels: 32,
31 max_string_size: 1_048_576, // 1MB
32 max_array_size: 10_000,
33 max_map_size: 10_000,
34 }
35 }
36 }
37
38 /// Plugin script engine with sandboxing.
39 pub struct PluginEngine {
40 engine: Engine,
41 limits: SafetyLimits,
42 }
43
44 impl PluginEngine {
45 /// Creates a new plugin engine with default safety limits.
46 #[tracing::instrument(skip_all)]
47 pub fn new() -> Self {
48 Self::with_limits(SafetyLimits::default())
49 }
50
51 /// Creates a new plugin engine with custom safety limits.
52 #[tracing::instrument(skip_all)]
53 pub fn with_limits(limits: SafetyLimits) -> Self {
54 let mut engine = Engine::new();
55
56 // Apply safety limits
57 engine.set_max_operations(limits.max_operations);
58 engine.set_max_call_levels(limits.max_call_levels);
59 engine.set_max_string_size(limits.max_string_size);
60 engine.set_max_array_size(limits.max_array_size);
61 engine.set_max_map_size(limits.max_map_size);
62
63 // Disable dangerous operations
64 engine.disable_symbol("eval");
65
66 // Register the goingson:: API module
67 let goingson_module = create_goingson_module();
68 engine.register_static_module("goingson", Arc::new(goingson_module));
69
70 Self { engine, limits }
71 }
72
73 /// Compiles a script into an AST for caching.
74 #[tracing::instrument(skip_all)]
75 pub fn compile(&self, script: &str) -> Result<AST> {
76 self.engine
77 .compile(script)
78 .map_err(|e| PluginError::ScriptError {
79 plugin: "unknown".to_string(),
80 message: e.to_string(),
81 })
82 }
83
84 /// Compiles a script with a plugin name for error messages.
85 #[tracing::instrument(skip_all)]
86 pub fn compile_plugin(&self, plugin_id: &str, script: &str) -> Result<AST> {
87 self.engine
88 .compile(script)
89 .map_err(|e| PluginError::script(plugin_id, e.to_string()))
90 }
91
92 /// Checks if a function exists in the compiled AST.
93 #[tracing::instrument(skip_all)]
94 pub fn has_function(&self, ast: &AST, name: &str, arity: usize) -> bool {
95 ast.iter_functions()
96 .any(|f| f.name == name && f.params.len() == arity)
97 }
98
99 /// Calls a function in the script with no arguments.
100 #[tracing::instrument(skip_all)]
101 pub fn call_fn<T: Clone + Send + Sync + 'static>(
102 &self,
103 ast: &AST,
104 plugin_id: &str,
105 fn_name: &str,
106 ) -> Result<T> {
107 self.engine
108 .call_fn::<T>(&mut rhai::Scope::new(), ast, fn_name, ())
109 .map_err(|e| self.map_rhai_error(plugin_id, e))
110 }
111
112 /// Calls a function with one argument.
113 #[tracing::instrument(skip_all)]
114 pub fn call_fn_1<A, T>(&self, ast: &AST, plugin_id: &str, fn_name: &str, arg: A) -> Result<T>
115 where
116 A: Clone + Send + Sync + 'static,
117 T: Clone + Send + Sync + 'static,
118 {
119 self.engine
120 .call_fn::<T>(&mut rhai::Scope::new(), ast, fn_name, (arg,))
121 .map_err(|e| self.map_rhai_error(plugin_id, e))
122 }
123
124 /// Calls a function with two arguments.
125 #[tracing::instrument(skip_all)]
126 pub fn call_fn_2<A, B, T>(
127 &self,
128 ast: &AST,
129 plugin_id: &str,
130 fn_name: &str,
131 arg1: A,
132 arg2: B,
133 ) -> Result<T>
134 where
135 A: Clone + Send + Sync + 'static,
136 B: Clone + Send + Sync + 'static,
137 T: Clone + Send + Sync + 'static,
138 {
139 self.engine
140 .call_fn::<T>(&mut rhai::Scope::new(), ast, fn_name, (arg1, arg2))
141 .map_err(|e| self.map_rhai_error(plugin_id, e))
142 }
143
144 /// Calls a function and returns a Dynamic result.
145 #[tracing::instrument(skip_all)]
146 pub fn call_fn_dynamic(
147 &self,
148 ast: &AST,
149 plugin_id: &str,
150 fn_name: &str,
151 args: impl rhai::FuncArgs,
152 ) -> Result<Dynamic> {
153 self.engine
154 .call_fn::<Dynamic>(&mut rhai::Scope::new(), ast, fn_name, args)
155 .map_err(|e| self.map_rhai_error(plugin_id, e))
156 }
157
158 /// Maps a Rhai error to a PluginError.
159 fn map_rhai_error(&self, plugin_id: &str, err: Box<rhai::EvalAltResult>) -> PluginError {
160 let message = err.to_string();
161
162 // Check for specific error types
163 if message.contains("Too many operations") {
164 PluginError::safety_limit(plugin_id, "Maximum operations exceeded")
165 } else if message.contains("Stack overflow") {
166 PluginError::safety_limit(plugin_id, "Maximum call depth exceeded")
167 } else {
168 PluginError::script(plugin_id, message)
169 }
170 }
171
172 /// Returns the current safety limits.
173 #[tracing::instrument(skip_all)]
174 pub fn limits(&self) -> &SafetyLimits {
175 &self.limits
176 }
177
178 /// Returns a reference to the underlying engine.
179 #[tracing::instrument(skip_all)]
180 pub fn inner(&self) -> &Engine {
181 &self.engine
182 }
183 }
184
185 impl Default for PluginEngine {
186 fn default() -> Self {
187 Self::new()
188 }
189 }
190
191 #[cfg(test)]
192 mod tests {
193 use super::*;
194
195 #[test]
196 fn test_compile_valid_script() {
197 let engine = PluginEngine::new();
198 let result = engine.compile("fn hello() { 42 }");
199 assert!(result.is_ok());
200 }
201
202 #[test]
203 fn test_compile_invalid_script() {
204 let engine = PluginEngine::new();
205 let result = engine.compile("fn hello( { }");
206 assert!(result.is_err());
207 }
208
209 #[test]
210 fn test_has_function() {
211 let engine = PluginEngine::new();
212 let ast = engine.compile("fn describe() { } fn parse(x, y) { }").unwrap();
213
214 assert!(engine.has_function(&ast, "describe", 0));
215 assert!(engine.has_function(&ast, "parse", 2));
216 assert!(!engine.has_function(&ast, "describe", 1));
217 assert!(!engine.has_function(&ast, "missing", 0));
218 }
219
220 #[test]
221 fn test_call_function() {
222 let engine = PluginEngine::new();
223 let ast = engine.compile("fn answer() { 42 }").unwrap();
224
225 let result: i64 = engine.call_fn(&ast, "test", "answer").unwrap();
226 assert_eq!(result, 42);
227 }
228
229 #[test]
230 fn test_operation_limit() {
231 let limits = SafetyLimits {
232 max_operations: 100,
233 ..Default::default()
234 };
235 let engine = PluginEngine::with_limits(limits);
236 let ast = engine
237 .compile("fn infinite() { let x = 0; loop { x += 1; } }")
238 .unwrap();
239
240 let result: Result<i64> = engine.call_fn(&ast, "test", "infinite");
241 assert!(result.is_err());
242
243 if let Err(PluginError::SafetyLimitExceeded { .. }) = result {
244 // Expected
245 } else {
246 panic!("Expected SafetyLimitExceeded error");
247 }
248 }
249
250 // ============ Plugin Lifecycle: Error and Recovery ============
251
252 #[test]
253 fn script_error_returns_plugin_error_not_panic() {
254 let engine = PluginEngine::new();
255 let ast = engine
256 .compile(
257 r#"
258 fn hook_on_task_created(task) {
259 throw "something went wrong in hook";
260 }
261 "#,
262 )
263 .unwrap();
264
265 let result: Result<Dynamic> =
266 engine.call_fn_1(&ast, "my-hook-plugin", "hook_on_task_created", "task-1");
267 assert!(result.is_err());
268
269 match result.unwrap_err() {
270 PluginError::ScriptError { plugin, message } => {
271 assert_eq!(plugin, "my-hook-plugin");
272 assert!(
273 message.contains("something went wrong in hook"),
274 "Unexpected message: {}",
275 message
276 );
277 }
278 other => panic!("Expected ScriptError, got {:?}", other),
279 }
280 }
281
282 #[test]
283 fn plugin_recoverable_after_hook_error() {
284 let engine = PluginEngine::new();
285 let ast = engine
286 .compile(
287 r#"
288 fn on_task_created(task_id) {
289 if task_id == "bad" {
290 throw "invalid task";
291 }
292 42
293 }
294 "#,
295 )
296 .unwrap();
297
298 // First call errors
299 let err_result: Result<Dynamic> =
300 engine.call_fn_1(&ast, "recoverable", "on_task_created", "bad".to_string());
301 assert!(err_result.is_err());
302 match err_result.unwrap_err() {
303 PluginError::ScriptError { .. } => {}
304 other => panic!("Expected ScriptError, got {:?}", other),
305 }
306
307 // Second call with valid input succeeds -- engine did not poison itself
308 let ok_result: i64 =
309 engine.call_fn_1(&ast, "recoverable", "on_task_created", "good".to_string()).unwrap();
310 assert_eq!(ok_result, 42);
311 }
312
313 #[test]
314 fn operation_limit_returns_safety_error_not_panic() {
315 let limits = SafetyLimits {
316 max_operations: 50,
317 ..Default::default()
318 };
319 let engine = PluginEngine::with_limits(limits);
320 let ast = engine
321 .compile(
322 r#"
323 fn expensive() {
324 let x = 0;
325 while x < 999999 { x += 1; }
326 x
327 }
328 "#,
329 )
330 .unwrap();
331
332 let result: Result<i64> = engine.call_fn(&ast, "expensive-plugin", "expensive");
333 assert!(result.is_err());
334
335 match result.unwrap_err() {
336 PluginError::SafetyLimitExceeded { plugin, message } => {
337 assert_eq!(plugin, "expensive-plugin");
338 assert!(
339 message.contains("operations"),
340 "Unexpected message: {}",
341 message
342 );
343 }
344 other => panic!("Expected SafetyLimitExceeded, got {:?}", other),
345 }
346 }
347
348 #[test]
349 fn recoverable_after_operation_limit() {
350 let limits = SafetyLimits {
351 max_operations: 50,
352 ..Default::default()
353 };
354 let engine = PluginEngine::with_limits(limits);
355 let ast = engine
356 .compile(
357 r#"
358 fn expensive() {
359 let x = 0;
360 while x < 999999 { x += 1; }
361 x
362 }
363 fn cheap() { 1 }
364 "#,
365 )
366 .unwrap();
367
368 // expensive() blows the ops limit
369 let err_result: Result<i64> = engine.call_fn(&ast, "ops-test", "expensive");
370 assert!(matches!(
371 err_result,
372 Err(PluginError::SafetyLimitExceeded { .. })
373 ));
374
375 // cheap() should still work -- the engine resets its operation counter per call
376 let ok_result: i64 = engine.call_fn(&ast, "ops-test", "cheap").unwrap();
377 assert_eq!(ok_result, 1);
378 }
379
380 #[test]
381 fn compile_execute_multiple_functions_lifecycle() {
382 let engine = PluginEngine::new();
383
384 // Compile a plugin-like script with describe + parse
385 let ast = engine
386 .compile(
387 r#"
388 fn describe() {
389 #{
390 name: "lifecycle-test",
391 file_extensions: ["csv"]
392 }
393 }
394
395 fn parse(file_path, options) {
396 let items = [];
397 items.push(#{ description: "parsed from " + file_path });
398 goingson::task_result(items)
399 }
400 "#,
401 )
402 .unwrap();
403
404 // Validate function signatures exist
405 assert!(engine.has_function(&ast, "describe", 0));
406 assert!(engine.has_function(&ast, "parse", 2));
407 assert!(!engine.has_function(&ast, "execute", 1));
408
409 // Execute describe()
410 let desc: Dynamic = engine.call_fn(&ast, "lifecycle", "describe").unwrap();
411 let map = desc.try_cast::<rhai::Map>().unwrap();
412 assert_eq!(
413 map.get("name").unwrap().clone().into_string().unwrap(),
414 "lifecycle-test"
415 );
416
417 // Execute parse()
418 let options = rhai::Map::new();
419 let result: Dynamic = engine
420 .call_fn_2(&ast, "lifecycle", "parse", "/tmp/test.csv".to_string(), options)
421 .unwrap();
422 let result_map = result.try_cast::<rhai::Map>().unwrap();
423 assert_eq!(
424 result_map
425 .get("entity_type")
426 .unwrap()
427 .clone()
428 .into_string()
429 .unwrap(),
430 "task"
431 );
432 }
433
434 #[test]
435 fn eval_is_disabled() {
436 let engine = PluginEngine::new();
437 let result = engine.compile(r#"fn sneaky() { eval("1 + 1") }"#);
438 // eval is disabled at the symbol level, so compilation should fail
439 assert!(result.is_err());
440 }
441
442 #[test]
443 fn call_fn_with_runtime_error_returns_script_error() {
444 let engine = PluginEngine::new();
445 let ast = engine
446 .compile("fn divide(a, b) { a / b }")
447 .unwrap();
448
449 // Division by zero should produce a ScriptError
450 let result: Result<Dynamic> = engine.call_fn_2(
451 &ast,
452 "type-test",
453 "divide",
454 42_i64,
455 0_i64,
456 );
457 assert!(result.is_err());
458 match result.unwrap_err() {
459 PluginError::ScriptError { plugin, .. } => {
460 assert_eq!(plugin, "type-test");
461 }
462 other => panic!("Expected ScriptError, got {:?}", other),
463 }
464 }
465 }
466