Skip to content

Commit b6f5030

Browse files
authored
Small changes: code reuse, simplify, doc comments, client timeouts (#25)
Signed-off-by: declark1 <daniel.clark@ibm.com>
1 parent 22d4e3e commit b6f5030

File tree

6 files changed

+127
-76
lines changed

6 files changed

+127
-76
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ tracing = "0.1.40"
3333
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
3434
url = "2.5.0"
3535
validator = { version = "0.18.1", features = ["derive"] } # For API validation
36+
uuid = { version = "1.8.0", features = ["v4", "fast-rng"] }
3637

3738
[build-dependencies]
3839
tonic-build = "0.11.0"

src/clients.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#![allow(dead_code)]
2-
use std::collections::HashMap;
2+
use std::{collections::HashMap, time::Duration};
33

44
use futures::future::try_join_all;
55
use ginepro::LoadBalancedChannel;
@@ -19,6 +19,8 @@ pub use nlp::NlpClient;
1919
pub const DEFAULT_TGIS_PORT: u16 = 8033;
2020
pub const DEFAULT_CAIKIT_NLP_PORT: u16 = 8085;
2121
pub const DEFAULT_DETECTOR_PORT: u16 = 8080;
22+
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
23+
const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
2224

2325
#[derive(Debug, thiserror::Error)]
2426
pub enum Error {
@@ -72,7 +74,9 @@ pub async fn create_http_clients(
7274
let port = service_config.port.unwrap_or(default_port);
7375
let mut base_url = Url::parse(&service_config.hostname).unwrap();
7476
base_url.set_port(Some(port)).unwrap();
75-
let mut builder = reqwest::ClientBuilder::new();
77+
let mut builder = reqwest::ClientBuilder::new()
78+
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
79+
.timeout(DEFAULT_REQUEST_TIMEOUT);
7680
if let Some(Tls::Config(tls_config)) = &service_config.tls {
7781
let cert_path = tls_config.cert_path.as_ref().unwrap().as_path();
7882
let cert_pem = tokio::fs::read(cert_path).await.unwrap_or_else(|error| {
@@ -100,7 +104,9 @@ async fn create_grpc_clients<C>(
100104
let mut builder = LoadBalancedChannel::builder((
101105
service_config.hostname.clone(),
102106
service_config.port.unwrap_or(default_port),
103-
));
107+
))
108+
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
109+
.timeout(DEFAULT_REQUEST_TIMEOUT);
104110
let client_tls_config = if let Some(Tls::Config(tls_config)) = &service_config.tls {
105111
let cert_path = tls_config.cert_path.as_ref().unwrap().as_path();
106112
let key_path = tls_config.key_path.as_ref().unwrap().as_path();

src/config.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ detectors:
155155
port: 9000
156156
chunker_id: sentence-en
157157
config: {}
158+
tls: {}
158159
"#;
159160
let config: OrchestratorConfig = serde_yml::from_str(s)?;
160161
assert!(config.chunkers.len() == 2 && config.detectors.len() == 1);

src/models.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,22 @@ pub struct GuardrailsConfig {
4646
pub output: Option<GuardrailsConfigOutput>,
4747
}
4848

49+
impl GuardrailsConfig {
50+
pub fn input_masks(&self) -> Option<&[(usize, usize)]> {
51+
self.input.as_ref().and_then(|input| input.masks.as_deref())
52+
}
53+
54+
pub fn input_detectors(&self) -> Option<&HashMap<String, DetectorParams>> {
55+
self.input.as_ref().and_then(|input| input.models.as_ref())
56+
}
57+
58+
pub fn output_detectors(&self) -> Option<&HashMap<String, DetectorParams>> {
59+
self.output
60+
.as_ref()
61+
.and_then(|output| output.models.as_ref())
62+
}
63+
}
64+
4965
/// Configuration for detection on input to a text generation model (e.g. user prompt)
5066
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)]
5167
pub struct GuardrailsConfigInput {

0 commit comments

Comments
 (0)