Skip to content

Commit 65e6215

Browse files
committed
Refactor error handling
Signed-off-by: declark1 <daniel.clark@ibm.com>
1 parent 5c5346e commit 65e6215

File tree

9 files changed

+263
-98
lines changed

9 files changed

+263
-98
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ url = "2.5.0"
3434
validator = { version = "0.18.1", features = ["derive"] } # For API validation
3535
uuid = { version = "1.8.0", features = ["v4", "fast-rng"] }
3636
anyhow = "1.0.83"
37+
thiserror = "1.0.60"
3738

3839
[build-dependencies]
3940
tonic-build = "0.11.0"

src/clients.rs

+62-15
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#![allow(dead_code)]
22
use std::{collections::HashMap, time::Duration};
33

4-
use anyhow::{Context, Error};
5-
use futures::future::try_join_all;
4+
use futures::future::join_all;
65
use ginepro::LoadBalancedChannel;
6+
use reqwest::StatusCode;
77
use url::Url;
88

99
use crate::config::{ServiceConfig, Tls};
@@ -27,6 +27,54 @@ pub const DEFAULT_DETECTOR_PORT: u16 = 8080;
2727
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
2828
const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
2929

30+
/// Client errors.
31+
#[derive(Debug, thiserror::Error)]
32+
pub enum Error {
33+
#[error("{}", .0.message())]
34+
Grpc(#[from] tonic::Status),
35+
#[error("{0}")]
36+
Http(#[from] reqwest::Error),
37+
#[error("invalid model id: {model_id}")]
38+
InvalidModelId { model_id: String },
39+
}
40+
41+
impl Error {
42+
/// Returns status code.
43+
pub fn status_code(&self) -> StatusCode {
44+
use tonic::Code::*;
45+
match self {
46+
// Return equivalent http status code for grpc status code
47+
Error::Grpc(error) => match error.code() {
48+
InvalidArgument => StatusCode::BAD_REQUEST,
49+
Internal => StatusCode::INTERNAL_SERVER_ERROR,
50+
NotFound => StatusCode::NOT_FOUND,
51+
DeadlineExceeded => StatusCode::REQUEST_TIMEOUT,
52+
Unimplemented => StatusCode::NOT_IMPLEMENTED,
53+
Unauthenticated => StatusCode::UNAUTHORIZED,
54+
PermissionDenied => StatusCode::FORBIDDEN,
55+
Ok => StatusCode::OK,
56+
_ => StatusCode::INTERNAL_SERVER_ERROR,
57+
},
58+
// Return http status code for error responses
59+
// and 500 for other errors
60+
Error::Http(error) => match error.status() {
61+
Some(code) => code,
62+
None => StatusCode::INTERNAL_SERVER_ERROR,
63+
},
64+
// Return 422 for invalid model id
65+
Error::InvalidModelId { .. } => StatusCode::UNPROCESSABLE_ENTITY,
66+
}
67+
}
68+
69+
/// Returns true for validation-type errors (400/422) and false for other types.
70+
pub fn is_validation_error(&self) -> bool {
71+
matches!(
72+
self.status_code(),
73+
StatusCode::BAD_REQUEST | StatusCode::UNPROCESSABLE_ENTITY
74+
)
75+
}
76+
}
77+
3078
#[derive(Clone)]
3179
pub enum GenerationClient {
3280
Tgis(TgisClient),
@@ -60,7 +108,7 @@ impl std::ops::Deref for HttpClient {
60108
pub async fn create_http_clients(
61109
default_port: u16,
62110
config: &[(String, ServiceConfig)],
63-
) -> Result<HashMap<String, HttpClient>, Error> {
111+
) -> HashMap<String, HttpClient> {
64112
let clients = config
65113
.iter()
66114
.map(|(name, service_config)| async move {
@@ -75,26 +123,25 @@ pub async fn create_http_clients(
75123
let cert_pem = tokio::fs::read(cert_path).await.unwrap_or_else(|error| {
76124
panic!("error reading cert from {cert_path:?}: {error}")
77125
});
78-
let identity = reqwest::Identity::from_pem(&cert_pem)?;
126+
let identity = reqwest::Identity::from_pem(&cert_pem)
127+
.unwrap_or_else(|error| panic!("error creating identity: {error}"));
79128
builder = builder.use_rustls_tls().identity(identity);
80129
}
81-
let client = builder.build().with_context(|| {
82-
format!(
83-
"error creating http client, name={name}, service_config={service_config:?}"
84-
)
85-
})?;
130+
let client = builder
131+
.build()
132+
.unwrap_or_else(|error| panic!("error creating http client for {name}: {error}"));
86133
let client = HttpClient::new(base_url, client);
87-
Ok((name.clone(), client)) as Result<(String, HttpClient), Error>
134+
(name.clone(), client)
88135
})
89136
.collect::<Vec<_>>();
90-
Ok(try_join_all(clients).await?.into_iter().collect())
137+
join_all(clients).await.into_iter().collect()
91138
}
92139

93140
async fn create_grpc_clients<C>(
94141
default_port: u16,
95142
config: &[(String, ServiceConfig)],
96143
new: fn(LoadBalancedChannel) -> C,
97-
) -> Result<HashMap<String, C>, Error> {
144+
) -> HashMap<String, C> {
98145
let clients = config
99146
.iter()
100147
.map(|(name, service_config)| async move {
@@ -133,9 +180,9 @@ async fn create_grpc_clients<C>(
133180
if let Some(client_tls_config) = client_tls_config {
134181
builder = builder.with_tls(client_tls_config);
135182
}
136-
let channel = builder.channel().await.with_context(|| format!("error creating grpc client, name={name}, service_config={service_config:?}"))?;
137-
Ok((name.clone(), new(channel))) as Result<(String, C), Error>
183+
let channel = builder.channel().await.unwrap_or_else(|error| panic!("error creating grpc client for {name}: {error}"));
184+
(name.clone(), new(channel))
138185
})
139186
.collect::<Vec<_>>();
140-
Ok(try_join_all(clients).await?.into_iter().collect())
187+
join_all(clients).await.into_iter().collect()
141188
}

src/clients/chunker.rs

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
use std::{collections::HashMap, pin::Pin};
22

3-
use anyhow::{Context, Error};
43
use futures::{Stream, StreamExt};
54
use ginepro::LoadBalancedChannel;
65
use tokio::sync::mpsc;
76
use tokio_stream::wrappers::ReceiverStream;
87
use tonic::Request;
98

10-
use super::create_grpc_clients;
9+
use super::{create_grpc_clients, Error};
1110
use crate::{
1211
config::ServiceConfig,
1312
pb::{
@@ -27,16 +26,18 @@ pub struct ChunkerClient {
2726
}
2827

2928
impl ChunkerClient {
30-
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
31-
let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await?;
32-
Ok(Self { clients })
29+
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
30+
let clients = create_grpc_clients(default_port, config, ChunkersServiceClient::new).await;
31+
Self { clients }
3332
}
3433

3534
fn client(&self, model_id: &str) -> Result<ChunkersServiceClient<LoadBalancedChannel>, Error> {
3635
Ok(self
3736
.clients
3837
.get(model_id)
39-
.context(format!("model not found, model_id={model_id}"))?
38+
.ok_or_else(|| Error::InvalidModelId {
39+
model_id: model_id.to_string(),
40+
})?
4041
.clone())
4142
}
4243

src/clients/detector.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
use std::collections::HashMap;
22

3-
use anyhow::{Context, Error};
43
use serde::{Deserialize, Serialize};
54

6-
use super::{create_http_clients, HttpClient};
5+
use super::{create_http_clients, Error, HttpClient};
76
use crate::config::ServiceConfig;
87

98
const DETECTOR_ID_HEADER_NAME: &str = "detector-id";
@@ -14,17 +13,18 @@ pub struct DetectorClient {
1413
}
1514

1615
impl DetectorClient {
17-
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
18-
let clients: HashMap<String, HttpClient> =
19-
create_http_clients(default_port, config).await?;
20-
Ok(Self { clients })
16+
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
17+
let clients: HashMap<String, HttpClient> = create_http_clients(default_port, config).await;
18+
Self { clients }
2119
}
2220

2321
fn client(&self, model_id: &str) -> Result<HttpClient, Error> {
2422
Ok(self
2523
.clients
2624
.get(model_id)
27-
.context(format!("model not found, model_id={model_id}"))?
25+
.ok_or_else(|| Error::InvalidModelId {
26+
model_id: model_id.to_string(),
27+
})?
2828
.clone())
2929
}
3030

src/clients/nlp.rs

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
use std::collections::HashMap;
22

3-
use anyhow::{Context, Error};
43
use futures::StreamExt;
54
use ginepro::LoadBalancedChannel;
65
use tokio::sync::mpsc;
76
use tokio_stream::wrappers::ReceiverStream;
87
use tonic::Request;
98

10-
use super::create_grpc_clients;
9+
use super::{create_grpc_clients, Error};
1110
use crate::{
1211
config::ServiceConfig,
1312
pb::{
@@ -30,16 +29,18 @@ pub struct NlpClient {
3029
}
3130

3231
impl NlpClient {
33-
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
34-
let clients = create_grpc_clients(default_port, config, NlpServiceClient::new).await?;
35-
Ok(Self { clients })
32+
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
33+
let clients = create_grpc_clients(default_port, config, NlpServiceClient::new).await;
34+
Self { clients }
3635
}
3736

3837
fn client(&self, model_id: &str) -> Result<NlpServiceClient<LoadBalancedChannel>, Error> {
3938
Ok(self
4039
.clients
4140
.get(model_id)
42-
.context(format!("model not found, model_id={model_id}"))?
41+
.ok_or_else(|| Error::InvalidModelId {
42+
model_id: model_id.to_string(),
43+
})?
4344
.clone())
4445
}
4546

src/clients/tgis.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
use std::collections::HashMap;
22

3-
use anyhow::{Context, Error};
43
use futures::StreamExt;
54
use ginepro::LoadBalancedChannel;
65
use tokio::sync::mpsc;
76
use tokio_stream::wrappers::ReceiverStream;
87

9-
use super::create_grpc_clients;
8+
use super::{create_grpc_clients, Error};
109
use crate::{
1110
config::ServiceConfig,
1211
pb::fmaas::{
@@ -22,10 +21,9 @@ pub struct TgisClient {
2221
}
2322

2423
impl TgisClient {
25-
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Result<Self, Error> {
26-
let clients =
27-
create_grpc_clients(default_port, config, GenerationServiceClient::new).await?;
28-
Ok(Self { clients })
24+
pub async fn new(default_port: u16, config: &[(String, ServiceConfig)]) -> Self {
25+
let clients = create_grpc_clients(default_port, config, GenerationServiceClient::new).await;
26+
Self { clients }
2927
}
3028

3129
fn client(
@@ -37,7 +35,9 @@ impl TgisClient {
3735
Ok(self
3836
.clients
3937
.get(model_id)
40-
.context(format!("model not found, model_id={model_id}"))?
38+
.ok_or_else(|| Error::InvalidModelId {
39+
model_id: model_id.to_string(),
40+
})?
4141
.clone())
4242
}
4343

0 commit comments

Comments
 (0)