@@ -164,6 +164,7 @@ impl HookExecutor {
164
164
& mut self ,
165
165
hooks : Vec < & Hook > ,
166
166
output : & mut impl Write ,
167
+ prompt : Option < & str >
167
168
) -> Result < Vec < ( Hook , String ) > , ChatError > {
168
169
let mut results = Vec :: with_capacity ( hooks. len ( ) ) ;
169
170
let mut futures = FuturesUnordered :: new ( ) ;
@@ -181,7 +182,7 @@ impl HookExecutor {
181
182
results. push ( ( index, ( hook. clone ( ) , cached. clone ( ) ) ) ) ;
182
183
continue ;
183
184
}
184
- let future = self . execute_hook ( hook) ;
185
+ let future = self . execute_hook ( hook, prompt ) ;
185
186
futures. push ( async move { ( index, future. await ) } ) ;
186
187
}
187
188
@@ -299,38 +300,47 @@ impl HookExecutor {
299
300
Ok ( results. into_iter ( ) . map ( |( _, r) | r) . collect ( ) )
300
301
}
301
302
302
- async fn execute_hook < ' a > ( & self , hook : & ' a Hook ) -> ( & ' a Hook , Result < String > , Duration ) {
303
+ async fn execute_hook < ' a > ( & self , hook : & ' a Hook , prompt : Option < & str > ) -> ( & ' a Hook , Result < String > , Duration ) {
303
304
let start_time = Instant :: now ( ) ;
304
305
let result = match hook. r#type {
305
- HookType :: Inline => self . execute_inline_hook ( hook) . await ,
306
+ HookType :: Inline => self . execute_inline_hook ( hook, prompt ) . await ,
306
307
} ;
307
308
308
309
( hook, result, start_time. elapsed ( ) )
309
310
}
310
311
311
- async fn execute_inline_hook ( & self , hook : & Hook ) -> Result < String > {
312
+ async fn execute_inline_hook ( & self , hook : & Hook , user_prompt : Option < & str > ) -> Result < String > {
312
313
let command = hook. command . as_ref ( ) . ok_or_else ( || eyre ! ( "no command specified" ) ) ?;
313
314
314
315
#[ cfg( unix) ]
315
- let command_future = tokio:: process:: Command :: new ( "bash" )
316
- . arg ( "-c" )
316
+ let mut cmd = tokio:: process:: Command :: new ( "bash" ) ;
317
+ #[ cfg( unix) ]
318
+ let cmd = cmd. arg ( "-c" )
317
319
. arg ( command)
318
320
. stdin ( Stdio :: piped ( ) )
319
321
. stdout ( Stdio :: piped ( ) )
320
- . stderr ( Stdio :: piped ( ) )
321
- . output ( ) ;
322
+ . stderr ( Stdio :: piped ( ) ) ;
322
323
323
324
#[ cfg( windows) ]
324
- let command_future = tokio:: process:: Command :: new ( "cmd" )
325
- . arg ( "/C" )
325
+ let mut cmd = tokio:: process:: Command :: new ( "cmd" ) ;
326
+ #[ cfg( windows) ]
327
+ let cmd = cmd. arg ( "/C" )
326
328
. arg ( command)
327
329
. stdin ( Stdio :: piped ( ) )
328
330
. stdout ( Stdio :: piped ( ) )
329
- . stderr ( Stdio :: piped ( ) )
330
- . output ( ) ;
331
+ . stderr ( Stdio :: piped ( ) ) ;
331
332
332
333
let timeout = Duration :: from_millis ( hook. timeout_ms ) ;
333
334
335
+ // Set USER_PROMPT environment variable if provided
336
+ if let Some ( prompt) = user_prompt {
337
+ // Sanitize the prompt to avoid issues with special characters
338
+ let sanitized_prompt = sanitize_user_prompt ( prompt) ;
339
+ cmd. env ( "USER_PROMPT" , sanitized_prompt) ;
340
+ }
341
+
342
+ let command_future = cmd. output ( ) ;
343
+
334
344
// Run with timeout
335
345
match tokio:: time:: timeout ( timeout, command_future) . await {
336
346
Ok ( result) => {
@@ -387,6 +397,19 @@ impl HookExecutor {
387
397
}
388
398
}
389
399
400
+ /// Sanitizes a string value to be used as an environment variable
401
+ fn sanitize_user_prompt ( input : & str ) -> String {
402
+ // Limit the size of input to first 4096 characters
403
+ let truncated = if input. len ( ) > 4096 {
404
+ & input[ 0 ..4096 ]
405
+ } else {
406
+ input
407
+ } ;
408
+
409
+ // Remove any potentially problematic characters
410
+ truncated. replace ( |c : char | c. is_control ( ) && c != '\n' && c != '\r' && c != '\t' , "" )
411
+ }
412
+
390
413
#[ deny( missing_docs) ]
391
414
#[ derive( Debug , PartialEq , Args ) ]
392
415
#[ command(
0 commit comments