1
1
#![ allow( dead_code) ]
2
2
use std:: { collections:: HashMap , time:: Duration } ;
3
3
4
- use anyhow:: { Context , Error } ;
5
- use futures:: future:: try_join_all;
4
+ use futures:: future:: join_all;
6
5
use ginepro:: LoadBalancedChannel ;
6
+ use reqwest:: StatusCode ;
7
7
use url:: Url ;
8
8
9
9
use crate :: config:: { ServiceConfig , Tls } ;
@@ -27,6 +27,54 @@ pub const DEFAULT_DETECTOR_PORT: u16 = 8080;
27
27
const DEFAULT_CONNECT_TIMEOUT : Duration = Duration :: from_secs ( 5 ) ;
28
28
const DEFAULT_REQUEST_TIMEOUT : Duration = Duration :: from_secs ( 10 ) ;
29
29
30
+ /// Client errors.
31
+ #[ derive( Debug , thiserror:: Error ) ]
32
+ pub enum Error {
33
+ #[ error( "{}" , . 0 . message( ) ) ]
34
+ Grpc ( #[ from] tonic:: Status ) ,
35
+ #[ error( "{0}" ) ]
36
+ Http ( #[ from] reqwest:: Error ) ,
37
+ #[ error( "invalid model id: {model_id}" ) ]
38
+ InvalidModelId { model_id : String } ,
39
+ }
40
+
41
+ impl Error {
42
+ /// Returns status code.
43
+ pub fn status_code ( & self ) -> StatusCode {
44
+ use tonic:: Code :: * ;
45
+ match self {
46
+ // Return equivalent http status code for grpc status code
47
+ Error :: Grpc ( error) => match error. code ( ) {
48
+ InvalidArgument => StatusCode :: BAD_REQUEST ,
49
+ Internal => StatusCode :: INTERNAL_SERVER_ERROR ,
50
+ NotFound => StatusCode :: NOT_FOUND ,
51
+ DeadlineExceeded => StatusCode :: REQUEST_TIMEOUT ,
52
+ Unimplemented => StatusCode :: NOT_IMPLEMENTED ,
53
+ Unauthenticated => StatusCode :: UNAUTHORIZED ,
54
+ PermissionDenied => StatusCode :: FORBIDDEN ,
55
+ Ok => StatusCode :: OK ,
56
+ _ => StatusCode :: INTERNAL_SERVER_ERROR ,
57
+ } ,
58
+ // Return http status code for error responses
59
+ // and 500 for other errors
60
+ Error :: Http ( error) => match error. status ( ) {
61
+ Some ( code) => code,
62
+ None => StatusCode :: INTERNAL_SERVER_ERROR ,
63
+ } ,
64
+ // Return 422 for invalid model id
65
+ Error :: InvalidModelId { .. } => StatusCode :: UNPROCESSABLE_ENTITY ,
66
+ }
67
+ }
68
+
69
+ /// Returns true for validation-type errors (400/422) and false for other types.
70
+ pub fn is_validation_error ( & self ) -> bool {
71
+ matches ! (
72
+ self . status_code( ) ,
73
+ StatusCode :: BAD_REQUEST | StatusCode :: UNPROCESSABLE_ENTITY
74
+ )
75
+ }
76
+ }
77
+
30
78
#[ derive( Clone ) ]
31
79
pub enum GenerationClient {
32
80
Tgis ( TgisClient ) ,
@@ -60,7 +108,7 @@ impl std::ops::Deref for HttpClient {
60
108
pub async fn create_http_clients (
61
109
default_port : u16 ,
62
110
config : & [ ( String , ServiceConfig ) ] ,
63
- ) -> Result < HashMap < String , HttpClient > , Error > {
111
+ ) -> HashMap < String , HttpClient > {
64
112
let clients = config
65
113
. iter ( )
66
114
. map ( |( name, service_config) | async move {
@@ -75,26 +123,25 @@ pub async fn create_http_clients(
75
123
let cert_pem = tokio:: fs:: read ( cert_path) . await . unwrap_or_else ( |error| {
76
124
panic ! ( "error reading cert from {cert_path:?}: {error}" )
77
125
} ) ;
78
- let identity = reqwest:: Identity :: from_pem ( & cert_pem) ?;
126
+ let identity = reqwest:: Identity :: from_pem ( & cert_pem)
127
+ . unwrap_or_else ( |error| panic ! ( "error creating identity: {error}" ) ) ;
79
128
builder = builder. use_rustls_tls ( ) . identity ( identity) ;
80
129
}
81
- let client = builder. build ( ) . with_context ( || {
82
- format ! (
83
- "error creating http client, name={name}, service_config={service_config:?}"
84
- )
85
- } ) ?;
130
+ let client = builder
131
+ . build ( )
132
+ . unwrap_or_else ( |error| panic ! ( "error creating http client for {name}: {error}" ) ) ;
86
133
let client = HttpClient :: new ( base_url, client) ;
87
- Ok ( ( name. clone ( ) , client) ) as Result < ( String , HttpClient ) , Error >
134
+ ( name. clone ( ) , client)
88
135
} )
89
136
. collect :: < Vec < _ > > ( ) ;
90
- Ok ( try_join_all ( clients) . await ? . into_iter ( ) . collect ( ) )
137
+ join_all ( clients) . await . into_iter ( ) . collect ( )
91
138
}
92
139
93
140
async fn create_grpc_clients < C > (
94
141
default_port : u16 ,
95
142
config : & [ ( String , ServiceConfig ) ] ,
96
143
new : fn ( LoadBalancedChannel ) -> C ,
97
- ) -> Result < HashMap < String , C > , Error > {
144
+ ) -> HashMap < String , C > {
98
145
let clients = config
99
146
. iter ( )
100
147
. map ( |( name, service_config) | async move {
@@ -133,9 +180,9 @@ async fn create_grpc_clients<C>(
133
180
if let Some ( client_tls_config) = client_tls_config {
134
181
builder = builder. with_tls ( client_tls_config) ;
135
182
}
136
- let channel = builder. channel ( ) . await . with_context ( || format ! ( "error creating grpc client, name= {name}, service_config={service_config:? }" ) ) ? ;
137
- Ok ( ( name. clone ( ) , new ( channel) ) ) as Result < ( String , C ) , Error >
183
+ let channel = builder. channel ( ) . await . unwrap_or_else ( |error| panic ! ( "error creating grpc client for {name}: {error }" ) ) ;
184
+ ( name. clone ( ) , new ( channel) )
138
185
} )
139
186
. collect :: < Vec < _ > > ( ) ;
140
- Ok ( try_join_all ( clients) . await ? . into_iter ( ) . collect ( ) )
187
+ join_all ( clients) . await . into_iter ( ) . collect ( )
141
188
}
0 commit comments