11#![ allow( dead_code) ]
22use std:: { collections:: HashMap , time:: Duration } ;
33
4- use futures:: future:: try_join_all ;
4+ use futures:: future:: join_all ;
55use ginepro:: LoadBalancedChannel ;
6+ use reqwest:: StatusCode ;
67use url:: Url ;
78
89use crate :: config:: { ServiceConfig , Tls } ;
@@ -26,16 +27,44 @@ pub const DEFAULT_DETECTOR_PORT: u16 = 8080;
2627const DEFAULT_CONNECT_TIMEOUT : Duration = Duration :: from_secs ( 5 ) ;
2728const DEFAULT_REQUEST_TIMEOUT : Duration = Duration :: from_secs ( 10 ) ;
2829
30+ /// Client errors.
2931#[ derive( Debug , thiserror:: Error ) ]
3032pub enum Error {
31- #[ error( "model not found: {0}" ) ]
32- ModelNotFound ( String ) ,
33- #[ error( transparent) ]
34- ReqwestError ( #[ from] reqwest:: Error ) ,
35- #[ error( transparent) ]
36- TonicError ( #[ from] tonic:: Status ) ,
37- #[ error( transparent) ]
38- IoError ( #[ from] std:: io:: Error ) ,
33+ #[ error( "{}" , . 0 . message( ) ) ]
34+ Grpc ( #[ from] tonic:: Status ) ,
35+ #[ error( "{0}" ) ]
36+ Http ( #[ from] reqwest:: Error ) ,
37+ #[ error( "model not found: {model_id}" ) ]
38+ ModelNotFound { model_id : String } ,
39+ }
40+
41+ impl Error {
42+ /// Returns status code.
43+ pub fn status_code ( & self ) -> StatusCode {
44+ use tonic:: Code :: * ;
45+ match self {
46+ // Return equivalent http status code for grpc status code
47+ Error :: Grpc ( error) => match error. code ( ) {
48+ InvalidArgument => StatusCode :: BAD_REQUEST ,
49+ Internal => StatusCode :: INTERNAL_SERVER_ERROR ,
50+ NotFound => StatusCode :: NOT_FOUND ,
51+ DeadlineExceeded => StatusCode :: REQUEST_TIMEOUT ,
52+ Unimplemented => StatusCode :: NOT_IMPLEMENTED ,
53+ Unauthenticated => StatusCode :: UNAUTHORIZED ,
54+ PermissionDenied => StatusCode :: FORBIDDEN ,
55+ Ok => StatusCode :: OK ,
56+ _ => StatusCode :: INTERNAL_SERVER_ERROR ,
57+ } ,
58+ // Return http status code for error responses
59+ // and 500 for other errors
60+ Error :: Http ( error) => match error. status ( ) {
61+ Some ( code) => code,
62+ None => StatusCode :: INTERNAL_SERVER_ERROR ,
63+ } ,
64+ // Return 404 for model not found
65+ Error :: ModelNotFound { .. } => StatusCode :: NOT_FOUND ,
66+ }
67+ }
3968}
4069
4170#[ derive( Clone ) ]
@@ -71,7 +100,7 @@ impl std::ops::Deref for HttpClient {
71100pub async fn create_http_clients (
72101 default_port : u16 ,
73102 config : & [ ( String , ServiceConfig ) ] ,
74- ) -> Result < HashMap < String , HttpClient > , Error > {
103+ ) -> HashMap < String , HttpClient > {
75104 let clients = config
76105 . iter ( )
77106 . map ( |( name, service_config) | async move {
@@ -86,22 +115,25 @@ pub async fn create_http_clients(
86115 let cert_pem = tokio:: fs:: read ( cert_path) . await . unwrap_or_else ( |error| {
87116 panic ! ( "error reading cert from {cert_path:?}: {error}" )
88117 } ) ;
89- let identity = reqwest:: Identity :: from_pem ( & cert_pem) ?;
118+ let identity = reqwest:: Identity :: from_pem ( & cert_pem)
119+ . unwrap_or_else ( |error| panic ! ( "error parsing cert: {error}" ) ) ;
90120 builder = builder. use_rustls_tls ( ) . identity ( identity) ;
91121 }
92- let client = builder. build ( ) ?;
122+ let client = builder
123+ . build ( )
124+ . unwrap_or_else ( |error| panic ! ( "error creating http client for {name}: {error}" ) ) ;
93125 let client = HttpClient :: new ( base_url, client) ;
94- Ok ( ( name. clone ( ) , client) ) as Result < ( String , HttpClient ) , Error >
126+ ( name. clone ( ) , client)
95127 } )
96128 . collect :: < Vec < _ > > ( ) ;
97- Ok ( try_join_all ( clients) . await ? . into_iter ( ) . collect ( ) )
129+ join_all ( clients) . await . into_iter ( ) . collect ( )
98130}
99131
100132async fn create_grpc_clients < C > (
101133 default_port : u16 ,
102134 config : & [ ( String , ServiceConfig ) ] ,
103135 new : fn ( LoadBalancedChannel ) -> C ,
104- ) -> Result < HashMap < String , C > , Error > {
136+ ) -> HashMap < String , C > {
105137 let clients = config
106138 . iter ( )
107139 . map ( |( name, service_config) | async move {
@@ -140,9 +172,9 @@ async fn create_grpc_clients<C>(
140172 if let Some ( client_tls_config) = client_tls_config {
141173 builder = builder. with_tls ( client_tls_config) ;
142174 }
143- let channel = builder. channel ( ) . await . unwrap ( ) ; // TODO: handle error
144- Ok ( ( name. clone ( ) , new ( channel) ) ) as Result < ( String , C ) , Error >
175+ let channel = builder. channel ( ) . await . unwrap_or_else ( |error| panic ! ( "error creating grpc client for {name}: { error}" ) ) ;
176+ ( name. clone ( ) , new ( channel) )
145177 } )
146178 . collect :: < Vec < _ > > ( ) ;
147- Ok ( try_join_all ( clients) . await ? . into_iter ( ) . collect ( ) )
179+ join_all ( clients) . await . into_iter ( ) . collect ( )
148180}
0 commit comments