Skip to content

Commit 4e3652e

Browse files
evaline-judeclark1
andauthored
✨ Update to /text/contents API for detector (#62)
* 🔧 Use text contents endpoint Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * 🚧 WIP detector API update Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * 💡🏷️ Add comments for new contents API types Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * 🚧 Updated content analysis handling Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * 🚧 Remove clone on chunks Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * 💡 Add comments Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * 🚧 Update by reference Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * ♻️ Use from Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * 💡 Add TODO comment Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * Update handle_detection_task and index_codepoints, rename DetectorClient method --------- Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Co-authored-by: declark1 <daniel.clark@ibm.com>
1 parent cdbb37c commit 4e3652e

File tree

3 files changed

+111
-77
lines changed

3 files changed

+111
-77
lines changed

config/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ chunkers:
1313
detectors:
1414
hap-en:
1515
service:
16-
hostname: https://localhost/api/v1/detector
16+
hostname: https://localhost/api/v1/text/contents # full url / endpoint currently expected
1717
port: 8080
1818
tls: caikit
1919
chunker_id: en_regex

src/clients/detector.rs

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ impl DetectorClient {
2828
.clone())
2929
}
3030

31-
pub async fn classify(
31+
pub async fn text_contents(
3232
&self,
3333
model_id: &str,
34-
request: DetectorRequest,
35-
) -> Result<DetectorResponse, Error> {
34+
request: ContentAnalysisRequest,
35+
) -> Result<Vec<Vec<ContentAnalysisResponse>>, Error> {
3636
let client = self.client(model_id)?;
3737
let url = client.base_url().as_str();
3838
let response = client
@@ -47,39 +47,71 @@ impl DetectorClient {
4747
}
4848
}
4949

50-
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
51-
pub struct DetectorRequest {
52-
pub text: String,
53-
pub parameters: HashMap<String, serde_json::Value>,
50+
/// Request for text content analysis
51+
/// Results of this request will contain analysis / detection of each of the provided documents
52+
/// in the order they are present in the `contents` object.
53+
#[derive(Clone, Debug, Serialize, Deserialize)]
54+
pub struct ContentAnalysisRequest {
55+
/// Field allowing users to provide list of documents for analysis
56+
pub contents: Vec<String>,
5457
}
5558

56-
impl DetectorRequest {
57-
pub fn new(text: String, parameters: HashMap<String, serde_json::Value>) -> Self {
58-
Self { text, parameters }
59+
impl ContentAnalysisRequest {
60+
pub fn new(contents: Vec<String>) -> ContentAnalysisRequest {
61+
ContentAnalysisRequest { contents }
5962
}
6063
}
6164

62-
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
63-
pub struct Detection {
65+
/// Evidence type
66+
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
67+
#[serde(rename_all = "lowercase")]
68+
pub enum EvidenceType {
69+
Url,
70+
Title,
71+
}
72+
73+
/// Source of the evidence e.g. url
74+
#[derive(Clone, Debug, Serialize, Deserialize)]
75+
pub struct Evidence {
76+
/// Evidence source
77+
pub source: String,
78+
}
79+
80+
/// Evidence in response
81+
#[derive(Clone, Debug, Serialize, Deserialize)]
82+
pub struct EvidenceObj {
83+
/// Type field signifying the type of evidence provided
84+
#[serde(rename = "type")]
85+
pub r#type: EvidenceType,
86+
/// Evidence currently only containing source
87+
#[serde(skip_serializing_if = "Option::is_none")]
88+
pub evidence: Option<Evidence>,
89+
}
90+
91+
/// Response of text content analysis endpoint
92+
#[derive(Clone, Debug, Serialize, Deserialize)]
93+
pub struct ContentAnalysisResponse {
94+
/// Start index of detection
6495
pub start: usize,
96+
/// End index of detection
6597
pub end: usize,
66-
pub text: String,
98+
/// Relevant detection class
6799
pub detection: String,
100+
/// Detection type or aggregate detection label
68101
pub detection_type: String,
102+
/// Score of detection
69103
pub score: f64,
104+
/// Optional, any applicable evidences for detection
105+
#[serde(skip_serializing_if = "Option::is_none")]
106+
pub evidences: Option<Vec<EvidenceObj>>,
70107
}
71108

72-
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
73-
pub struct DetectorResponse {
74-
pub detections: Vec<Detection>,
75-
}
76-
77-
impl From<Detection> for crate::models::TokenClassificationResult {
78-
fn from(value: Detection) -> Self {
109+
impl From<ContentAnalysisResponse> for crate::models::TokenClassificationResult {
110+
fn from(value: ContentAnalysisResponse) -> Self {
79111
Self {
80112
start: value.start as u32,
81113
end: value.end as u32,
82-
word: value.text,
114+
word: "".to_string(), // TODO: fill in when provided in the detector API in the next iteration
83115
entity: value.detection,
84116
entity_group: value.detection_type,
85117
score: value.score,

src/orchestrator.rs

Lines changed: 57 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use uuid::Uuid;
1212

1313
use crate::{
1414
clients::{
15-
self, detector::DetectorRequest, ChunkerClient, DetectorClient, GenerationClient,
15+
self, detector::ContentAnalysisRequest, ChunkerClient, DetectorClient, GenerationClient,
1616
NlpClient, TgisClient,
1717
},
1818
config::{GenerationProvider, OrchestratorConfig},
@@ -198,7 +198,7 @@ async fn chunk_and_detect(
198198
Ok::<String, Error>(chunker_id)
199199
})
200200
.collect::<Result<Vec<_>, Error>>()?;
201-
// Spawn chunking tasks, returning a map of chunker_id->chunks.
201+
// Spawn chunking tasks, returning a map of chunker_id->chunks
202202
let chunks = chunk(ctx.clone(), chunker_ids, text_with_offsets).await?;
203203
// Spawn detection tasks
204204
let detections = detect(ctx.clone(), detectors, chunks).await?;
@@ -319,69 +319,57 @@ async fn handle_chunk_task(
319319
Ok((chunker_id, chunks))
320320
}
321321

322-
/// Sends a buffered, concurrent stream of requests to a detector service.
322+
/// Sends a request to a detector service.
323323
async fn handle_detection_task(
324324
ctx: Arc<Context>,
325325
detector_id: String,
326326
default_threshold: f32,
327327
detector_params: DetectorParams,
328328
chunks: Vec<Chunk>,
329329
) -> Result<Vec<TokenClassificationResult>, Error> {
330-
let detections = stream::iter(chunks)
331-
.map(|chunk| {
332-
let ctx = ctx.clone();
333-
let detector_id = detector_id.clone();
334-
let detector_params = detector_params.clone();
335-
async move {
336-
// NOTE: The detector request is expected to change and not actually
337-
// take parameters. Any parameters will be ignored for now
338-
// ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37
339-
let request = DetectorRequest::new(chunk.text.clone(), detector_params.clone());
340-
debug!(
341-
%detector_id,
342-
?request,
343-
"sending detector request"
344-
);
345-
let response = ctx
346-
.detector_client
347-
.classify(&detector_id, request)
348-
.await
349-
.map_err(|error| Error::DetectorRequestFailed {
350-
detector_id: detector_id.clone(),
351-
error,
352-
})?;
353-
debug!(
354-
%detector_id,
355-
?response,
356-
"received detector response"
357-
);
358-
// Filter results based on threshold (if applicable) here
359-
let results = response
360-
.detections
361-
.into_iter()
362-
.filter_map(|detection| {
363-
let mut result: TokenClassificationResult = detection.into();
364-
result.start += chunk.offset as u32;
365-
result.end += chunk.offset as u32;
366-
let threshold = detector_params
367-
.get("threshold")
368-
.and_then(|v| v.as_f64())
369-
.unwrap_or(default_threshold as f64);
370-
(result.score >= threshold).then_some(result)
371-
})
372-
.collect::<Vec<_>>();
373-
Ok::<Vec<TokenClassificationResult>, Error>(results)
374-
}
375-
})
376-
.buffered(5)
377-
.collect::<Vec<_>>()
330+
let detector_id = detector_id.clone();
331+
let threshold = detector_params
332+
.get("threshold")
333+
.and_then(|v| v.as_f64())
334+
.unwrap_or(default_threshold as f64);
335+
let contents = chunks.iter().map(|chunk| chunk.text.clone()).collect();
336+
let request = ContentAnalysisRequest::new(contents);
337+
debug!(
338+
%detector_id,
339+
?request,
340+
"sending detector request"
341+
);
342+
let response = ctx
343+
.detector_client
344+
.text_contents(&detector_id, request)
378345
.await
346+
.map_err(|error| Error::DetectorRequestFailed {
347+
detector_id: detector_id.clone(),
348+
error,
349+
})?;
350+
debug!(
351+
%detector_id,
352+
?response,
353+
"received detector response"
354+
);
355+
let results = chunks
379356
.into_iter()
380-
.collect::<Result<Vec<_>, Error>>()?
381-
.into_iter()
382-
.flatten()
357+
.zip(response)
358+
.flat_map(|(chunk, response)| {
359+
response
360+
.into_iter()
361+
.filter_map(|resp| {
362+
let mut result: TokenClassificationResult = resp.into();
363+
result.word =
364+
index_codepoints(&chunk.text, result.start as usize, result.end as usize);
365+
result.start += chunk.offset as u32;
366+
result.end += chunk.offset as u32;
367+
(result.score >= threshold).then_some(result)
368+
})
369+
.collect::<Vec<_>>()
370+
})
383371
.collect::<Vec<_>>();
384-
Ok(detections)
372+
Ok::<Vec<TokenClassificationResult>, Error>(results)
385373
}
386374

387375
/// Sends tokenize request to a generation service.
@@ -591,6 +579,12 @@ async fn generate(
591579
}
592580
}
593581

582+
/// Get codepoints of text between start and end indices
583+
fn index_codepoints(text: &str, start: usize, end: usize) -> String {
584+
let chars = text.chars().collect::<Vec<_>>();
585+
chars[start..end].iter().collect()
586+
}
587+
594588
/// Applies masks to input text, returning (offset, masked_text) pairs.
595589
fn apply_masks(text: &str, masks: &[(usize, usize)]) -> Vec<(usize, String)> {
596590
let chars = text.chars().collect::<Vec<_>>();
@@ -708,4 +702,12 @@ mod tests {
708702
];
709703
assert_eq!(text_with_offsets, expected_text_with_offsets)
710704
}
705+
706+
#[test]
707+
fn test_index_codepoints() {
708+
let s = "Hello world";
709+
assert_eq!(index_codepoints(s, 0, 5), "Hello");
710+
let s = "哈囉世界";
711+
assert_eq!(index_codepoints(s, 3, 4), "界");
712+
}
711713
}

0 commit comments

Comments
 (0)