Skip to content

Commit e3985d5

Browse files
authored
🐛 Allow input detection on whole input for streaming text generation endpoint (#388)
* 🐛🔧 Allow input detection on text generation Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * ✅ Valid whole doc input detection test Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --------- Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>
1 parent d473b00 commit e3985d5

File tree

2 files changed

+96
-37
lines changed

2 files changed

+96
-37
lines changed

src/orchestrator/handlers/streaming_classification_with_gen.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,25 @@ impl Handle<StreamingClassificationWithGenTask> for Orchestrator {
7474
let input_detectors = task.guardrails_config.input_detectors();
7575
let output_detectors = task.guardrails_config.output_detectors();
7676

77-
// input detectors validation
77+
// Input detectors validation
78+
// Allow `whole_doc_chunker` detectors on input detection
79+
// because the input detection call is unary
7880
if let Err(error) = validate_detectors(
7981
&input_detectors,
8082
&ctx.config.detectors,
8183
&[DetectorType::TextContents],
82-
false,
84+
true,
8385
) {
8486
let _ = response_tx.send(Err(error)).await;
8587
return;
8688
}
8789

88-
// output detectors validation
90+
// Output detectors validation
91+
// Disallow `whole_doc_chunker` detectors on output detection
92+
// for now until results of these detectors are handled as
93+
// planned for chat completions, with detection results
94+
// provided separately at the end but not blocking other
95+
// detection results that may be provided on smaller chunks
8996
if let Err(error) = validate_detectors(
9097
&output_detectors,
9198
&ctx.config.detectors,

tests/streaming_classification_with_gen.rs

Lines changed: 86 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,40 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> {
404404
then.pb(mock_tokenization_response.clone());
405405
});
406406

407+
// Detector on whole doc / entire input for multi-detector scenario
408+
let whole_doc_mock_detection_response = ContentAnalysisResponse {
409+
start: 0,
410+
end: 61,
411+
text: "This sentence does not have a detection. But <this one does>.".into(),
412+
detection: "has_angle_brackets_1".into(),
413+
detection_type: "angle_brackets_1".into(),
414+
detector_id: Some(DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC.into()),
415+
score: 1.0,
416+
evidence: None,
417+
metadata: Metadata::new(),
418+
};
419+
let mut whole_doc_detection_mocks = MockSet::new();
420+
whole_doc_detection_mocks.mock(|when, then| {
421+
when.path(TEXT_CONTENTS_DETECTOR_ENDPOINT)
422+
.json(ContentAnalysisRequest {
423+
contents: vec![
424+
"This sentence does not have a detection. But <this one does>.".into(),
425+
],
426+
detector_params: DetectorParams::new(),
427+
});
428+
then.json([vec![&whole_doc_mock_detection_response]]);
429+
});
430+
let mock_whole_doc_detector_server = MockServer::new(DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC)
431+
.with_mocks(whole_doc_detection_mocks);
432+
407433
// Start orchestrator server and its dependencies
408434
let mock_chunker_server = MockServer::new(chunker_id).grpc().with_mocks(chunker_mocks);
409435
let mock_detector_server = MockServer::new(detector_name).with_mocks(detection_mocks);
410436
let generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks);
411437
let orchestrator_server = TestOrchestratorServer::builder()
412438
.config_path(ORCHESTRATOR_CONFIG_FILE_PATH)
413439
.generation_server(&generation_server)
414-
.detector_servers([&mock_detector_server])
440+
.detector_servers([&mock_detector_server, &mock_whole_doc_detector_server])
415441
.chunker_servers([&mock_chunker_server])
416442
.build()
417443
.await?;
@@ -471,6 +497,65 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> {
471497
}])
472498
);
473499

500+
// Multi-detector scenario with detector that uses content from entire input
501+
let response = orchestrator_server
502+
.post(ORCHESTRATOR_STREAMING_ENDPOINT)
503+
.json(&GuardrailsHttpRequest {
504+
model_id: model_id.into(),
505+
inputs: "This sentence does not have a detection. But <this one does>.".into(),
506+
guardrail_config: Some(GuardrailsConfig {
507+
input: Some(GuardrailsConfigInput {
508+
models: HashMap::from([
509+
(detector_name.into(), DetectorParams::new()),
510+
(
511+
DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC.into(),
512+
DetectorParams::new(),
513+
),
514+
]),
515+
masks: None,
516+
}),
517+
output: None,
518+
}),
519+
text_gen_parameters: None,
520+
})
521+
.send()
522+
.await?;
523+
let sse_stream: SseStream<ClassifiedGeneratedTextStreamResult> =
524+
SseStream::new(response.bytes_stream());
525+
let messages = sse_stream.try_collect::<Vec<_>>().await?;
526+
debug!("{messages:#?}");
527+
528+
assert_eq!(messages.len(), 1);
529+
assert!(messages[0].generated_text.is_none());
530+
assert_eq!(
531+
messages[0].token_classification_results,
532+
TextGenTokenClassificationResults {
533+
input: Some(vec![
534+
TokenClassificationResult {
535+
start: 0,
536+
end: 61,
537+
word: whole_doc_mock_detection_response.text,
538+
entity: whole_doc_mock_detection_response.detection,
539+
entity_group: whole_doc_mock_detection_response.detection_type,
540+
detector_id: whole_doc_mock_detection_response.detector_id,
541+
score: whole_doc_mock_detection_response.score,
542+
token_count: None
543+
},
544+
TokenClassificationResult {
545+
start: 46, // index of first token of detected text, relative to the `inputs` string sent in the orchestrator request.
546+
end: 59, // index of last token (+1) of detected text, relative to the `inputs` string sent in the orchestrator request.
547+
word: "this one does".into(),
548+
entity: "has_angle_brackets".into(),
549+
entity_group: "angle_brackets".into(),
550+
detector_id: Some(detector_name.to_string()),
551+
score: mock_detection_response.score,
552+
token_count: None
553+
}
554+
]),
555+
output: None
556+
}
557+
);
558+
474559
Ok(())
475560
}
476561

@@ -727,39 +812,6 @@ async fn orchestrator_validation_error() -> Result<(), anyhow::Error> {
727812
"failed at invalid input detector scenario"
728813
);
729814

730-
// Invalid chunker on input detector scenario
731-
let response = orchestrator_server
732-
.post(ORCHESTRATOR_STREAMING_ENDPOINT)
733-
.json(&GuardrailsHttpRequest {
734-
model_id: model_id.into(),
735-
inputs: "This request contains a detector with an invalid chunker".into(),
736-
guardrail_config: Some(GuardrailsConfig {
737-
input: Some(GuardrailsConfigInput {
738-
models: HashMap::from([(
739-
DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC.into(),
740-
DetectorParams::new(),
741-
)]),
742-
masks: None,
743-
}),
744-
output: None,
745-
}),
746-
text_gen_parameters: None,
747-
})
748-
.send()
749-
.await?;
750-
debug!("{response:#?}");
751-
752-
assert_eq!(response.status(), 200);
753-
let sse_stream: SseStream<OrchestratorError> = SseStream::new(response.bytes_stream());
754-
let messages = sse_stream.try_collect::<Vec<_>>().await?;
755-
debug!("{messages:#?}");
756-
assert_eq!(messages.len(), 1);
757-
assert_eq!(
758-
messages[0],
759-
OrchestratorError::chunker_not_supported(DETECTOR_NAME_ANGLE_BRACKETS_WHOLE_DOC),
760-
"failed on input detector with invalid chunker scenario"
761-
);
762-
763815
// Non-existing input detector scenario
764816
let response = orchestrator_server
765817
.post(ORCHESTRATOR_STREAMING_ENDPOINT)

0 commit comments

Comments
 (0)