Skip to content

Commit a3e7bfc

Browse files
committed
fix llm faults
Signed-off-by: Sylvain Hellegouarch <[email protected]>
1 parent ac4bbe4 commit a3e7bfc

File tree

4 files changed

+163
-18
lines changed

4 files changed

+163
-18
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@
22

33
## Changed
44

5-
- Fix container registry
5+
- Fix a variety of minor bugs with LLM fault injection
6+
- Better SSE support
7+

fault-cli/src/cli.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,6 @@ pub struct LlmOptions {
323323
/// Regex pattern to scramble in prompt
324324
#[clap(
325325
long,
326-
group = "prompt-scramble",
327326
help_heading = "Prompt Scramble",
328327
env = "FAULT_LLM_SCRAMBLE_PATTERN"
329328
)]
@@ -332,7 +331,6 @@ pub struct LlmOptions {
332331
/// Substitute text for scramble
333332
#[clap(
334333
long,
335-
group = "prompt-scramble",
336334
help_heading = "Prompt Scramble",
337335
env = "FAULT_LLM_SCRAMBLE_WITH"
338336
)]
@@ -347,18 +345,12 @@ pub struct LlmOptions {
347345
pub instruction: Option<String>,
348346

349347
/// Regex pattern for bias
350-
#[clap(
351-
long,
352-
group = "inject-bias",
353-
help_heading = "Inject Bias",
354-
env = "FAULT_LLM_BIAS_PATTERN"
355-
)]
348+
#[clap(long, help_heading = "Inject Bias", env = "FAULT_LLM_BIAS_PATTERN")]
356349
pub bias_pattern: Option<String>,
357350

358351
/// Substitute text for bias
359352
#[clap(
360353
long,
361-
group = "inject-bias",
362354
help_heading = "Inject Bias",
363355
env = "FAULT_LLM_BIAS_REPLACEMENT"
364356
)]

fault-cli/src/config.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,8 @@ impl From<(LlmCase, &LlmOptions)> for FaultConfig {
437437
LlmCase::PromptScramble => {
438438
let s = OpenAiSettings {
439439
case: LlmCase::PromptScramble,
440-
pattern: None,
441-
replacement: None,
440+
pattern: options.scramble_pattern.clone(),
441+
replacement: options.scramble_with.clone(),
442442
instruction: options.instruction.clone(),
443443
probability: options.probability,
444444
kind: FaultKind::PromptScramble,

fault-cli/src/fault/llm/openai.rs

Lines changed: 157 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::io::Cursor;
33
use std::io::Result as IoResult;
44
use std::pin::Pin;
55
use std::sync::Arc;
6+
use std::sync::Mutex;
67
use std::task::Context;
78
use std::task::Poll;
89
use std::time::Duration;
@@ -233,8 +234,8 @@ impl FaultInjector for OpenAiInjector {
233234
return Ok((status, headers, body));
234235
}
235236

236-
if let Some(instruction) = &self.settings.instruction {
237-
if rand::random::<f64>() < self.settings.probability {
237+
if rand::random::<f64>() < self.settings.probability {
238+
if let Some(instruction) = &self.settings.instruction {
238239
let _ = event.with_fault(FaultEvent::Llm {
239240
direction: Direction::Ingress,
240241
side: StreamSide::Client,
@@ -263,7 +264,7 @@ impl FaultInjector for OpenAiInjector {
263264
let payload = json!({
264265
"id": id,
265266
"model": model,
266-
"choices": [{ "delta": { "role": "system", "content": instruction }, "index": 0, "finish_reason": null }]
267+
"choices": [{ "delta": { "role": "assistant", "content": instruction }, "index": 0, "finish_reason": null }]
267268
})
268269
.to_string();
269270
let mut framed = String::new();
@@ -287,6 +288,28 @@ impl FaultInjector for OpenAiInjector {
287288
case: self.settings.case,
288289
});
289290

291+
return Ok((status, headers, boxed));
292+
} else if let (Some(regex), Some(replacement)) =
293+
(&self.regex, &self.settings.replacement)
294+
{
295+
let _ = event.with_fault(FaultEvent::Llm {
296+
direction: Direction::Ingress,
297+
side: StreamSide::Client,
298+
case: self.settings.case,
299+
});
300+
301+
let regex = regex.clone();
302+
let rep = replacement.clone();
303+
304+
let boxed: BoxChunkStream =
305+
rewrite_sse_stream(body, regex, rep);
306+
307+
let _ = event.on_applied(FaultEvent::Llm {
308+
direction: Direction::Ingress,
309+
side: StreamSide::Client,
310+
case: self.settings.case,
311+
});
312+
290313
return Ok((status, headers, boxed));
291314
}
292315
}
@@ -491,7 +514,7 @@ fn mutate_request(
491514
}
492515
};
493516

494-
tracing::debug!("LLM request {}", doc);
517+
//tracing::debug!("LLM request {}", doc);
495518

496519
if path.starts_with("/v1/chat/completions")
497520
|| path.starts_with("/api/v1/chat/completions")
@@ -578,8 +601,6 @@ fn scramble_response(
578601
) -> Result<Vec<u8>, ProxyError> {
579602
match serde_json::from_slice::<Value>(&body) {
580603
Ok(mut doc) => {
581-
tracing::debug!("{:?}", doc);
582-
583604
if let Some(object) = doc.get("object").and_then(Value::as_str) {
584605
if object == "chat.completion" {
585606
if let Some(choices) =
@@ -652,3 +673,133 @@ fn get_delay(rng: &mut SmallRng, mean: &f64, stddev: &f64) -> Duration {
652673
let nanos = ((sample - millis as f64) * 1_000_000.0).round() as u32;
653674
Duration::from_millis(millis) + Duration::from_nanos(nanos as u64)
654675
}
676+
677+
fn rewrite_sse_stream(
678+
body: BoxChunkStream,
679+
regex: Regex,
680+
replacement: String,
681+
) -> BoxChunkStream {
682+
// we need an accumulator to handle reading of chunks that span
683+
// multiple stream read
684+
let acc = Arc::new(Mutex::new(String::new()));
685+
let rep = replacement;
686+
687+
let stream = body.map(move |chunk_res| {
688+
let regex = regex.clone();
689+
let acc = acc.clone();
690+
chunk_res.map(|bytes| {
691+
// Decode chunk (SSE is text)
692+
let mut piece = match std::str::from_utf8(&bytes) {
693+
Ok(s) => s.replace("\r\n", "\n"), // normalize EOL
694+
Err(_) => return bytes, // pass through if not UTF-8
695+
};
696+
697+
// Append to accumulator
698+
{
699+
let mut a = acc.lock().unwrap();
700+
a.push_str(&piece);
701+
// We consume complete events from `a` into `out`,
702+
// leaving any partial event in `a` for the next chunk.
703+
}
704+
705+
// Extract complete events
706+
let mut out = String::new();
707+
{
708+
let mut a = acc.lock().unwrap();
709+
710+
loop {
711+
// Find end of next complete event (blank line)
712+
let Some(pos) = a.find("\n\n") else { break };
713+
714+
// Split out one full event (without the trailing blank
715+
// line)
716+
let event = a[..pos].to_string();
717+
718+
// Drain consumed part + the blank line
719+
a.drain(..pos + 2);
720+
721+
// Parse the event lines
722+
let mut header_lines: Vec<&str> = Vec::new();
723+
let mut data_lines: Vec<&str> = Vec::new();
724+
725+
for line in event.lines() {
726+
if let Some(rest) = line.strip_prefix("data:") {
727+
data_lines
728+
.push(rest.strip_prefix(' ').unwrap_or(rest));
729+
} else {
730+
header_lines.push(line);
731+
}
732+
}
733+
734+
// Let's re-emit events with no data unchanged
735+
if data_lines.is_empty() {
736+
out.push_str(&event);
737+
out.push_str("\n\n");
738+
continue;
739+
}
740+
741+
let data_payload = data_lines.join("\n");
742+
743+
// Special case: [DONE]
744+
// https://platform.openai.com/docs/api-reference/runs/createRun#runs_createrun-stream
745+
if data_payload.trim() == "[DONE]" {
746+
if !header_lines.is_empty() {
747+
out.push_str(&header_lines.join("\n"));
748+
out.push('\n');
749+
}
750+
out.push_str("data: [DONE]\n\n");
751+
continue;
752+
}
753+
754+
// Let's try rewriting choices[0].delta.content
755+
match serde_json::from_str::<Value>(&data_payload) {
756+
Ok(mut json) => {
757+
if let Some(content_val) = json
758+
.get_mut("choices")
759+
.and_then(|c| c.get_mut(0))
760+
.and_then(|c| c.get_mut("delta"))
761+
.and_then(|d| d.get_mut("content"))
762+
{
763+
if let Some(content_str) = content_val.as_str()
764+
{
765+
let replaced = regex
766+
.replace_all(content_str, rep.as_str())
767+
.to_string();
768+
*content_val = Value::String(replaced);
769+
}
770+
}
771+
772+
// Re-serialize + emit, preserving headers
773+
if !header_lines.is_empty() {
774+
out.push_str(&header_lines.join("\n"));
775+
out.push('\n');
776+
}
777+
match serde_json::to_string(&json) {
778+
Ok(s) => {
779+
out.push_str("data: ");
780+
out.push_str(&s);
781+
out.push_str("\n\n");
782+
}
783+
Err(_) => {
784+
// fallback: original event untouched
785+
out.push_str(&event);
786+
out.push_str("\n\n");
787+
}
788+
}
789+
}
790+
Err(_) => {
791+
// Not JSON? Emit original event
792+
out.push_str(&event);
793+
out.push_str("\n\n");
794+
}
795+
}
796+
}
797+
}
798+
799+
// Emit only processed complete events for this chunk
800+
Bytes::from(out)
801+
})
802+
});
803+
804+
Box::pin(stream)
805+
}

0 commit comments

Comments
 (0)