@@ -3,6 +3,7 @@ use std::io::Cursor;
33use std:: io:: Result as IoResult ;
44use std:: pin:: Pin ;
55use std:: sync:: Arc ;
6+ use std:: sync:: Mutex ;
67use std:: task:: Context ;
78use std:: task:: Poll ;
89use std:: time:: Duration ;
@@ -233,8 +234,8 @@ impl FaultInjector for OpenAiInjector {
233234 return Ok ( ( status, headers, body) ) ;
234235 }
235236
236- if let Some ( instruction ) = & self . settings . instruction {
237- if rand :: random :: < f64 > ( ) < self . settings . probability {
237+ if rand :: random :: < f64 > ( ) < self . settings . probability {
238+ if let Some ( instruction ) = & self . settings . instruction {
238239 let _ = event. with_fault ( FaultEvent :: Llm {
239240 direction : Direction :: Ingress ,
240241 side : StreamSide :: Client ,
@@ -263,7 +264,7 @@ impl FaultInjector for OpenAiInjector {
263264 let payload = json ! ( {
264265 "id" : id,
265266 "model" : model,
266- "choices" : [ { "delta" : { "role" : "system " , "content" : instruction } , "index" : 0 , "finish_reason" : null } ]
267+ "choices" : [ { "delta" : { "role" : "assistant " , "content" : instruction } , "index" : 0 , "finish_reason" : null } ]
267268 } )
268269 . to_string ( ) ;
269270 let mut framed = String :: new ( ) ;
@@ -287,6 +288,28 @@ impl FaultInjector for OpenAiInjector {
287288 case : self . settings . case ,
288289 } ) ;
289290
291+ return Ok ( ( status, headers, boxed) ) ;
292+ } else if let ( Some ( regex) , Some ( replacement) ) =
293+ ( & self . regex , & self . settings . replacement )
294+ {
295+ let _ = event. with_fault ( FaultEvent :: Llm {
296+ direction : Direction :: Ingress ,
297+ side : StreamSide :: Client ,
298+ case : self . settings . case ,
299+ } ) ;
300+
301+ let regex = regex. clone ( ) ;
302+ let rep = replacement. clone ( ) ;
303+
304+ let boxed: BoxChunkStream =
305+ rewrite_sse_stream ( body, regex, rep) ;
306+
307+ let _ = event. on_applied ( FaultEvent :: Llm {
308+ direction : Direction :: Ingress ,
309+ side : StreamSide :: Client ,
310+ case : self . settings . case ,
311+ } ) ;
312+
290313 return Ok ( ( status, headers, boxed) ) ;
291314 }
292315 }
@@ -491,7 +514,7 @@ fn mutate_request(
491514 }
492515 } ;
493516
494- tracing:: debug!( "LLM request {}" , doc) ;
517+ // tracing::debug!("LLM request {}", doc);
495518
496519 if path. starts_with ( "/v1/chat/completions" )
497520 || path. starts_with ( "/api/v1/chat/completions" )
@@ -578,8 +601,6 @@ fn scramble_response(
578601) -> Result < Vec < u8 > , ProxyError > {
579602 match serde_json:: from_slice :: < Value > ( & body) {
580603 Ok ( mut doc) => {
581- tracing:: debug!( "{:?}" , doc) ;
582-
583604 if let Some ( object) = doc. get ( "object" ) . and_then ( Value :: as_str) {
584605 if object == "chat.completion" {
585606 if let Some ( choices) =
@@ -652,3 +673,133 @@ fn get_delay(rng: &mut SmallRng, mean: &f64, stddev: &f64) -> Duration {
652673 let nanos = ( ( sample - millis as f64 ) * 1_000_000.0 ) . round ( ) as u32 ;
653674 Duration :: from_millis ( millis) + Duration :: from_nanos ( nanos as u64 )
654675}
676+
677+ fn rewrite_sse_stream (
678+ body : BoxChunkStream ,
679+ regex : Regex ,
680+ replacement : String ,
681+ ) -> BoxChunkStream {
682+ // we need an accumulator to handle reading of chunks that span
683+ // multiple stream read
684+ let acc = Arc :: new ( Mutex :: new ( String :: new ( ) ) ) ;
685+ let rep = replacement;
686+
687+ let stream = body. map ( move |chunk_res| {
688+ let regex = regex. clone ( ) ;
689+ let acc = acc. clone ( ) ;
690+ chunk_res. map ( |bytes| {
691+ // Decode chunk (SSE is text)
692+ let mut piece = match std:: str:: from_utf8 ( & bytes) {
693+ Ok ( s) => s. replace ( "\r \n " , "\n " ) , // normalize EOL
694+ Err ( _) => return bytes, // pass through if not UTF-8
695+ } ;
696+
697+ // Append to accumulator
698+ {
699+ let mut a = acc. lock ( ) . unwrap ( ) ;
700+ a. push_str ( & piece) ;
701+ // We consume complete events from `a` into `out`,
702+ // leaving any partial event in `a` for the next chunk.
703+ }
704+
705+ // Extract complete events
706+ let mut out = String :: new ( ) ;
707+ {
708+ let mut a = acc. lock ( ) . unwrap ( ) ;
709+
710+ loop {
711+ // Find end of next complete event (blank line)
712+ let Some ( pos) = a. find ( "\n \n " ) else { break } ;
713+
714+ // Split out one full event (without the trailing blank
715+ // line)
716+ let event = a[ ..pos] . to_string ( ) ;
717+
718+ // Drain consumed part + the blank line
719+ a. drain ( ..pos + 2 ) ;
720+
721+ // Parse the event lines
722+ let mut header_lines: Vec < & str > = Vec :: new ( ) ;
723+ let mut data_lines: Vec < & str > = Vec :: new ( ) ;
724+
725+ for line in event. lines ( ) {
726+ if let Some ( rest) = line. strip_prefix ( "data:" ) {
727+ data_lines
728+ . push ( rest. strip_prefix ( ' ' ) . unwrap_or ( rest) ) ;
729+ } else {
730+ header_lines. push ( line) ;
731+ }
732+ }
733+
734+ // Let's re-emit events with no data unchanged
735+ if data_lines. is_empty ( ) {
736+ out. push_str ( & event) ;
737+ out. push_str ( "\n \n " ) ;
738+ continue ;
739+ }
740+
741+ let data_payload = data_lines. join ( "\n " ) ;
742+
743+ // Special case: [DONE]
744+ // https://platform.openai.com/docs/api-reference/runs/createRun#runs_createrun-stream
745+ if data_payload. trim ( ) == "[DONE]" {
746+ if !header_lines. is_empty ( ) {
747+ out. push_str ( & header_lines. join ( "\n " ) ) ;
748+ out. push ( '\n' ) ;
749+ }
750+ out. push_str ( "data: [DONE]\n \n " ) ;
751+ continue ;
752+ }
753+
754+ // Let's try rewriting choices[0].delta.content
755+ match serde_json:: from_str :: < Value > ( & data_payload) {
756+ Ok ( mut json) => {
757+ if let Some ( content_val) = json
758+ . get_mut ( "choices" )
759+ . and_then ( |c| c. get_mut ( 0 ) )
760+ . and_then ( |c| c. get_mut ( "delta" ) )
761+ . and_then ( |d| d. get_mut ( "content" ) )
762+ {
763+ if let Some ( content_str) = content_val. as_str ( )
764+ {
765+ let replaced = regex
766+ . replace_all ( content_str, rep. as_str ( ) )
767+ . to_string ( ) ;
768+ * content_val = Value :: String ( replaced) ;
769+ }
770+ }
771+
772+ // Re-serialize + emit, preserving headers
773+ if !header_lines. is_empty ( ) {
774+ out. push_str ( & header_lines. join ( "\n " ) ) ;
775+ out. push ( '\n' ) ;
776+ }
777+ match serde_json:: to_string ( & json) {
778+ Ok ( s) => {
779+ out. push_str ( "data: " ) ;
780+ out. push_str ( & s) ;
781+ out. push_str ( "\n \n " ) ;
782+ }
783+ Err ( _) => {
784+ // fallback: original event untouched
785+ out. push_str ( & event) ;
786+ out. push_str ( "\n \n " ) ;
787+ }
788+ }
789+ }
790+ Err ( _) => {
791+ // Not JSON? Emit original event
792+ out. push_str ( & event) ;
793+ out. push_str ( "\n \n " ) ;
794+ }
795+ }
796+ }
797+ }
798+
799+ // Emit only processed complete events for this chunk
800+ Bytes :: from ( out)
801+ } )
802+ } ) ;
803+
804+ Box :: pin ( stream)
805+ }
0 commit comments