Skip to content

Commit 45e99d3

Browse files
authored
Implement error handling (#31)
Signed-off-by: declark1 <daniel.clark@ibm.com>
1 parent 3a01d11 commit 45e99d3

File tree

11 files changed

+321
-131
lines changed

11 files changed

+321
-131
lines changed

Cargo.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ name = "fms-guardrails-orchestr8"
1414
path = "src/main.rs"
1515

1616
[dependencies]
17+
anyhow = "1.0.83"
1718
axum = { version = "0.7.5", features = ["json"] }
1819
clap = { version = "4.5.3", features = ["derive", "env"] }
1920
futures = "0.3.30"
@@ -25,15 +26,15 @@ rustls-webpki = "0.102.2"
2526
serde = { version = "1.0.200", features = ["derive"] }
2627
serde_json = "1.0.116"
2728
serde_yml = "0.0.5"
28-
thiserror = "1.0.59"
29+
thiserror = "1.0.60"
2930
tokio = { version = "1.37.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "fs"] }
3031
tokio-stream = "0.1.14"
3132
tonic = { version = "0.11.0", features = ["tls"] }
3233
tracing = "0.1.40"
3334
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
3435
url = "2.5.0"
35-
validator = { version = "0.18.1", features = ["derive"] } # For API validation
3636
uuid = { version = "1.8.0", features = ["v4", "fast-rng"] }
37+
validator = { version = "0.18.1", features = ["derive"] } # For API validation
3738

3839
[build-dependencies]
3940
tonic-build = "0.11.0"

src/clients.rs

+50-18
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#![allow(dead_code)]
22
use std::{collections::HashMap, time::Duration};
33

4-
use futures::future::try_join_all;
4+
use futures::future::join_all;
55
use ginepro::LoadBalancedChannel;
6+
use reqwest::StatusCode;
67
use url::Url;
78

89
use crate::config::{ServiceConfig, Tls};
@@ -26,16 +27,44 @@ pub const DEFAULT_DETECTOR_PORT: u16 = 8080;
2627
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
2728
const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
2829

30+
/// Client errors.
2931
#[derive(Debug, thiserror::Error)]
3032
pub enum Error {
31-
#[error("model not found: {0}")]
32-
ModelNotFound(String),
33-
#[error(transparent)]
34-
ReqwestError(#[from] reqwest::Error),
35-
#[error(transparent)]
36-
TonicError(#[from] tonic::Status),
37-
#[error(transparent)]
38-
IoError(#[from] std::io::Error),
33+
#[error("{}", .0.message())]
34+
Grpc(#[from] tonic::Status),
35+
#[error("{0}")]
36+
Http(#[from] reqwest::Error),
37+
#[error("model not found: {model_id}")]
38+
ModelNotFound { model_id: String },
39+
}
40+
41+
impl Error {
42+
/// Returns status code.
43+
pub fn status_code(&self) -> StatusCode {
44+
use tonic::Code::*;
45+
match self {
46+
// Return equivalent http status code for grpc status code
47+
Error::Grpc(error) => match error.code() {
48+
InvalidArgument => StatusCode::BAD_REQUEST,
49+
Internal => StatusCode::INTERNAL_SERVER_ERROR,
50+
NotFound => StatusCode::NOT_FOUND,
51+
DeadlineExceeded => StatusCode::REQUEST_TIMEOUT,
52+
Unimplemented => StatusCode::NOT_IMPLEMENTED,
53+
Unauthenticated => StatusCode::UNAUTHORIZED,
54+
PermissionDenied => StatusCode::FORBIDDEN,
55+
Ok => StatusCode::OK,
56+
_ => StatusCode::INTERNAL_SERVER_ERROR,
57+
},
58+
// Return http status code for error responses
59+
// and 500 for other errors
60+
Error::Http(error) => match error.status() {
61+
Some(code) => code,
62+
None => StatusCode::INTERNAL_SERVER_ERROR,
63+
},
64+
// Return 404 for model not found
65+
Error::ModelNotFound { .. } => StatusCode::NOT_FOUND,
66+
}
67+
}
3968
}
4069

4170
#[derive(Clone)]
@@ -71,7 +100,7 @@ impl std::ops::Deref for HttpClient {
71100
pub async fn create_http_clients(
72101
default_port: u16,
73102
config: &[(String, ServiceConfig)],
74-
) -> Result<HashMap<String, HttpClient>, Error> {
103+
) -> HashMap<String, HttpClient> {
75104
let clients = config
76105
.iter()
77106
.map(|(name, service_config)| async move {
@@ -86,22 +115,25 @@ pub async fn create_http_clients(
86115
let cert_pem = tokio::fs::read(cert_path).await.unwrap_or_else(|error| {
87116
panic!("error reading cert from {cert_path:?}: {error}")
88117
});
89-
let identity = reqwest::Identity::from_pem(&cert_pem)?;
118+
let identity = reqwest::Identity::from_pem(&cert_pem)
119+
.unwrap_or_else(|error| panic!("error parsing cert: {error}"));
90120
builder = builder.use_rustls_tls().identity(identity);
91121
}
92-
let client = builder.build()?;
122+
let client = builder
123+
.build()
124+
.unwrap_or_else(|error| panic!("error creating http client for {name}: {error}"));
93125
let client = HttpClient::new(base_url, client);
94-
Ok((name.clone(), client)) as Result<(String, HttpClient), Error>
126+
(name.clone(), client)
95127
})
96128
.collect::<Vec<_>>();
97-
Ok(try_join_all(clients).await?.into_iter().collect())
129+
join_all(clients).await.into_iter().collect()
98130
}
99131

100132
async fn create_grpc_clients<C>(
101133
default_port: u16,
102134
config: &[(String, ServiceConfig)],
103135
new: fn(LoadBalancedChannel) -> C,
104-
) -> Result<HashMap<String, C>, Error> {
136+
) -> HashMap<String, C> {
105137
let clients = config
106138
.iter()
107139
.map(|(name, service_config)| async move {
@@ -140,9 +172,9 @@ async fn create_grpc_clients<C>(
140172
if let Some(client_tls_config) = client_tls_config {
141173
builder = builder.with_tls(client_tls_config);
142174
}
143-
let channel = builder.channel().await.unwrap(); // TODO: handle error
144-
Ok((name.clone(), new(channel))) as Result<(String, C), Error>
175+
let channel = builder.channel().await.unwrap_or_else(|error| panic!("error creating grpc client for {name}: {error}"));
176+
(name.clone(), new(channel))
145177
})
146178
.collect::<Vec<_>>();
147-
Ok(try_join_all(clients).await?.into_iter().collect())
179+
join_all(clients).await.into_iter().collect()
148180
}

src/clients/chunker.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,18 @@ pub struct ChunkerClient {
2626
}
2727

2828
impl ChunkerClient {
29-
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
30-
let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await?;
31-
Ok(Self { clients })
29+
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
30+
let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await;
31+
Self { clients }
3232
}
3333

3434
fn client(&self, model_id: &str) -> Result<ChunkersServiceClient<LoadBalancedChannel>, Error> {
3535
Ok(self
3636
.clients
3737
.get(model_id)
38-
.ok_or_else(|| Error::ModelNotFound(model_id.into()))?
38+
.ok_or_else(|| Error::ModelNotFound {
39+
model_id: model_id.to_string(),
40+
})?
3941
.clone())
4042
}
4143

src/clients/detector.rs

+6-5
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@ pub struct DetectorClient {
1313
}
1414

1515
impl DetectorClient {
16-
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
17-
let clients: HashMap<String, HttpClient> =
18-
create_http_clients(default_port, config).await?;
19-
Ok(Self { clients })
16+
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
17+
let clients: HashMap<String, HttpClient> = create_http_clients(default_port, config).await;
18+
Self { clients }
2019
}
2120

2221
fn client(&self, model_id: &str) -> Result<HttpClient, Error> {
2322
Ok(self
2423
.clients
2524
.get(model_id)
26-
.ok_or_else(|| Error::ModelNotFound(model_id.into()))?
25+
.ok_or_else(|| Error::ModelNotFound {
26+
model_id: model_id.to_string(),
27+
})?
2728
.clone())
2829
}
2930

src/clients/nlp.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,18 @@ pub struct NlpClient {
2929
}
3030

3131
impl NlpClient {
32-
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
33-
let clients = create_grpc_clients(default_port, config, NlpServiceClient::new).await?;
34-
Ok(Self { clients })
32+
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
33+
let clients = create_grpc_clients(default_port, config, NlpServiceClient::new).await;
34+
Self { clients }
3535
}
3636

3737
fn client(&self, model_id: &str) -> Result<NlpServiceClient<LoadBalancedChannel>, Error> {
3838
Ok(self
3939
.clients
4040
.get(model_id)
41-
.ok_or_else(|| Error::ModelNotFound(model_id.into()))?
41+
.ok_or_else(|| Error::ModelNotFound {
42+
model_id: model_id.to_string(),
43+
})?
4244
.clone())
4345
}
4446

src/clients/tgis.rs

+6-5
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@ pub struct TgisClient {
2121
}
2222

2323
impl TgisClient {
24-
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
25-
let clients =
26-
create_grpc_clients(default_port, config, GenerationServiceClient::new).await?;
27-
Ok(Self { clients })
24+
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
25+
let clients = create_grpc_clients(default_port, config, GenerationServiceClient::new).await;
26+
Self { clients }
2827
}
2928

3029
fn client(
@@ -36,7 +35,9 @@ impl TgisClient {
3635
Ok(self
3736
.clients
3837
.get(model_id)
39-
.ok_or_else(|| Error::ModelNotFound(model_id.into()))?
38+
.ok_or_else(|| Error::ModelNotFound {
39+
model_id: model_id.to_string(),
40+
})?
4041
.clone())
4142
}
4243

src/config.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,10 @@ impl OrchestratorConfig {
108108
todo!()
109109
}
110110

111-
pub fn get_chunker_id(&self, detector_id: &str) -> String {
112-
self.detectors.get(detector_id).unwrap().chunker_id.clone()
111+
pub fn get_chunker_id(&self, detector_id: &str) -> Option<String> {
112+
self.detectors
113+
.get(detector_id)
114+
.map(|detector_config| detector_config.chunker_id.clone())
113115
}
114116
}
115117

@@ -126,8 +128,9 @@ fn service_tls_name_to_config(
126128

127129
#[cfg(test)]
128130
mod tests {
131+
use anyhow::Error;
132+
129133
use super::*;
130-
use crate::Error;
131134

132135
#[test]
133136
fn test_deserialize_config() -> Result<(), Error> {

src/lib.rs

-35
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,8 @@
11
#![allow(clippy::iter_kv_map, clippy::enum_variant_names)]
22

3-
use axum::{http::StatusCode, Json};
4-
53
mod clients;
64
mod config;
75
mod models;
86
mod orchestrator;
97
mod pb;
108
pub mod server;
11-
12-
#[derive(Debug, thiserror::Error)]
13-
pub enum Error {
14-
#[error(transparent)]
15-
ClientError(#[from] crate::clients::Error),
16-
#[error(transparent)]
17-
IoError(#[from] std::io::Error),
18-
#[error(transparent)]
19-
YamlError(#[from] serde_yml::Error),
20-
}
21-
22-
// TODO: create better errors and properly convert
23-
impl From<Error> for (StatusCode, Json<String>) {
24-
fn from(value: Error) -> Self {
25-
use Error::*;
26-
match value {
27-
ClientError(error) => match error {
28-
clients::Error::ModelNotFound(message) => {
29-
(StatusCode::UNPROCESSABLE_ENTITY, Json(message))
30-
}
31-
clients::Error::ReqwestError(error) => {
32-
(StatusCode::INTERNAL_SERVER_ERROR, Json(error.to_string()))
33-
}
34-
clients::Error::TonicError(error) => {
35-
(StatusCode::INTERNAL_SERVER_ERROR, Json(error.to_string()))
36-
}
37-
clients::Error::IoError(_) => todo!(),
38-
},
39-
IoError(_) => todo!(),
40-
YamlError(_) => todo!(),
41-
}
42-
}
43-
}

0 commit comments

Comments
 (0)