Skip to content

Commit af57f66

Browse files
authored
🥅🏷️ Use more defined struct for threshold (#66)
* 🥅🏷️ Use more defined struct for threshold Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * 🏷️ Update threshold to f64 Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> * 📝 Update TODOs Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --------- Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com>
1 parent e82dfbf commit af57f66

File tree

4 files changed

+9
-17
lines changed

4 files changed

+9
-17
lines changed

TODOs.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
# TODOs
2-
- [ ] Design detector map config and implement loading this config file
32
- [ ] Add TLS support for orchestrator server
43
- [ ] Implement Health probes
5-
- [ ] Add request models (objects for request and response)
64
- [ ] Add unit tests
7-
- [ ] Add request validation for classification with text generation endpoint
85
- [ ] Add request validation for streaming classification with text generation endpoint
9-
- [ ] Host API in swagger pages in the repo
106
- [ ] Tokenization REST API and client will need to be updated for bidirectional streaming if/when available for REST use
117
- [ ] There is currently NO WAY for us to know the prefix ID required by TGIS for inferencing on tuned prompts
128
- [ ] Add configurable request timeouts for all clients
13-
- [ ] [OSS public] Copyright headers
14-
- [ ] [OSS public] CI - github actions

src/config.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ pub struct DetectorConfig {
8080
/// ID of chunker that this detector will use
8181
pub chunker_id: String,
8282
/// Default threshold with which to filter detector results by score
83-
pub default_threshold: f32,
83+
pub default_threshold: f64,
8484
}
8585

8686
/// Overall orchestrator server configuration

src/models.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
use crate::pb;
44
use std::collections::HashMap;
55

6-
// TODO: When detector API is updated, consider if fields
7-
// like 'threshold' can be named options instead of the
8-
// use a generic HashMap with Values here
9-
// ref. https://github.com/foundation-model-stack/fms-guardrails-orchestrator/issues/37
10-
pub type DetectorParams = HashMap<String, serde_json::Value>;
6+
/// Parameters relevant to each detector
7+
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
8+
pub struct DetectorParams {
9+
/// Threshold with which to filter detector results by score
10+
pub threshold: Option<f64>,
11+
}
1112

1213
/// User request to orchestrator
1314
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]

src/orchestrator.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,15 +323,12 @@ async fn handle_chunk_task(
323323
async fn handle_detection_task(
324324
ctx: Arc<Context>,
325325
detector_id: String,
326-
default_threshold: f32,
326+
default_threshold: f64,
327327
detector_params: DetectorParams,
328328
chunks: Vec<Chunk>,
329329
) -> Result<Vec<TokenClassificationResult>, Error> {
330330
let detector_id = detector_id.clone();
331-
let threshold = detector_params
332-
.get("threshold")
333-
.and_then(|v| v.as_f64())
334-
.unwrap_or(default_threshold as f64);
331+
let threshold = detector_params.threshold.unwrap_or(default_threshold);
335332
let contents = chunks.iter().map(|chunk| chunk.text.clone()).collect();
336333
let request = ContentAnalysisRequest::new(contents);
337334
debug!(

0 commit comments

Comments
 (0)