@@ -22,13 +22,14 @@ use std::{
2222 convert:: Infallible ,
2323 env,
2424 error:: Error ,
25- net:: { Ipv4Addr , SocketAddr } ,
25+ net:: { Ipv4Addr , SocketAddrV4 } ,
2626 pin:: pin,
2727 str:: FromStr ,
2828 sync:: Arc ,
2929 time:: { Duration , Instant } ,
3030} ;
31- use tokio:: { net:: TcpListener , task:: JoinSet } ;
31+ use tokio:: net:: TcpListener ;
32+ use tokio_util:: task:: TaskTracker ;
3233use tracing:: { error, info, trace} ;
3334use twilight_http_ratelimiting:: { Endpoint , Method , RateLimitHeaders , RateLimiter } ;
3435
@@ -55,10 +56,7 @@ static METRIC_KEY: LazyLock<Cow<str>> = LazyLock::new(|| {
5556async fn main ( ) -> Result < ( ) , Box < dyn Error > > {
5657 tracing_subscriber:: fmt:: init ( ) ;
5758
58- let host = parse_env ( "HOST" ) ?. unwrap_or ( Ipv4Addr :: UNSPECIFIED ) ;
59- let port = parse_env ( "PORT" ) ?. unwrap_or ( 80 ) ;
60-
61- let https_connector = {
59+ let client = {
6260 let mut http_connector = TokioHickoryResolver :: default ( ) . into_http_connector ( ) ;
6361 http_connector. enforce_http ( false ) ;
6462
@@ -67,45 +65,42 @@ async fn main() -> Result<(), Box<dyn Error>> {
6765 . https_only ( )
6866 . enable_http1 ( ) ;
6967
70- if env:: var_os ( "DISABLE_HTTP2" ) . is_some ( ) {
68+ let https_connector = if env:: var_os ( "DISABLE_HTTP2" ) . is_some ( ) {
7169 builder. wrap_connector ( http_connector)
7270 } else {
7371 builder. enable_http2 ( ) . wrap_connector ( http_connector)
74- }
72+ } ;
73+
74+ Client :: builder ( TokioExecutor :: new ( ) ) . build ( https_connector)
7575 } ;
7676
77- let client: Client < _ , Incoming > = Client :: builder ( TokioExecutor :: new ( ) ) . build ( https_connector) ;
77+ #[ cfg( feature = "metrics" ) ]
78+ let handle = PrometheusBuilder :: new ( )
79+ . idle_timeout (
80+ MetricKindMask :: COUNTER | MetricKindMask :: HISTOGRAM ,
81+ Some ( Duration :: from_secs (
82+ parse_env ( "METRIC_TIMEOUT" ) ?. unwrap_or ( 300 ) ,
83+ ) ) ,
84+ )
85+ . install_recorder ( )
86+ . expect ( "installed once" ) ;
87+
7888 let ratelimiter_map = Arc :: new ( RatelimiterMap :: new (
7989 env:: var ( "DISCORD_TOKEN" ) ?,
8090 Duration :: from_secs ( parse_env ( "CLIENT_DECAY_TIMEOUT" ) ?. unwrap_or ( 3600 ) ) ,
8191 parse_env ( "CLIENT_CACHE_MAX_SIZE" ) ?,
8292 ) ) ;
8393
84- let address = SocketAddr :: from ( ( host, port) ) ;
85-
86- #[ cfg( feature = "metrics" ) ]
87- let handle: Arc < PrometheusHandle > ;
88-
89- #[ cfg( feature = "metrics" ) ]
90- {
91- let timeout = parse_env ( "METRIC_TIMEOUT" ) ?. unwrap_or ( 300 ) ;
92- let recorder = PrometheusBuilder :: new ( )
93- . idle_timeout (
94- MetricKindMask :: COUNTER | MetricKindMask :: HISTOGRAM ,
95- Some ( Duration :: from_secs ( timeout) ) ,
96- )
97- . build_recorder ( ) ;
98- handle = Arc :: new ( recorder. handle ( ) ) ;
99- metrics:: set_global_recorder ( Box :: new ( recorder) )
100- . expect ( "Failed to create metrics receiver!" ) ;
101- }
94+ let host = parse_env ( "HOST" ) ?. unwrap_or ( Ipv4Addr :: UNSPECIFIED ) ;
95+ let port = parse_env ( "PORT" ) ?. unwrap_or ( 80 ) ;
96+ let address = SocketAddrV4 :: new ( host, port) ;
10297
10398 let listener = TcpListener :: bind ( & address) . await ?;
10499 let mut shutdown_signal = pin ! ( shutdown_signal( ) ) ;
105100
106101 info ! ( "Listening on http://{}" , address) ;
107102
108- let mut tasks = JoinSet :: new ( ) ;
103+ let tracker = TaskTracker :: new ( ) ;
109104
110105 loop {
111106 tokio:: select! {
@@ -115,53 +110,35 @@ async fn main() -> Result<(), Box<dyn Error>> {
115110 continue ;
116111 } ;
117112
118-
119- let ratelimiter_map = ratelimiter_map. clone( ) ;
120- // Cloning a hyper client is fairly cheap by design
113+ let ratelimiter_map = Arc :: clone( & ratelimiter_map) ;
121114 let client = client. clone( ) ;
122-
123115 #[ cfg( feature = "metrics" ) ]
124116 let handle = handle. clone( ) ;
125117
126- tasks . spawn ( async move {
127- trace! ( "Connection from: {:?}" , addr ) ;
128-
129- let service_fn = service :: service_fn ( move |incoming : Request < Incoming >| {
130- let token = incoming
131- . headers ( )
132- . get ( "authorization" )
133- . and_then ( |value| value . to_str ( ) . ok ( ) ) ;
134- let ( ratelimiter , token ) = ratelimiter_map . get_or_insert ( token ) ;
135- let client = client . clone ( ) ;
136-
118+ let service_fn = service :: service_fn ( move |request| {
119+ let token = request
120+ . headers ( )
121+ . get ( header :: AUTHORIZATION )
122+ . and_then ( |value| value . to_str ( ) . ok ( ) ) ;
123+ let ( ratelimiter , token ) = ratelimiter_map . get_or_insert ( token ) ;
124+ let client = client . clone ( ) ;
125+ # [ cfg ( feature = "metrics" ) ]
126+ let handle = handle . clone ( ) ;
127+
128+ async move {
137129 #[ cfg( feature = "metrics" ) ]
138- {
139- let handle = handle. clone( ) ;
140-
141- async move {
142- Ok :: <_, Infallible >( {
143- if incoming. uri( ) . path( ) == "/metrics" {
144- handle_metrics( handle)
145- } else {
146- handle_request( client, ratelimiter, token, incoming)
147- . await
148- . unwrap_or_else( |err| err. as_response( ) )
149- }
150- } )
151- }
130+ if request. uri( ) . path( ) == "/metrics" {
131+ return Ok :: <_, Infallible >( handle_metrics( handle) ) ;
152132 }
153133
154- #[ cfg( not( feature = "metrics" ) ) ]
155- {
156- async move {
157- Ok :: <_, Infallible >(
158- handle_request( client, ratelimiter, token, incoming)
159- . await
160- . unwrap_or_else( |err| err. as_response( ) ) ,
161- )
162- }
163- }
164- } ) ;
134+ Ok :: <_, Infallible >( handle_request( client, ratelimiter, token, request)
135+ . await
136+ . unwrap_or_else( |err| err. as_response( ) ) )
137+ }
138+ } ) ;
139+
140+ tracker. spawn( async move {
141+ trace!( "Connection from: {:?}" , addr) ;
165142
166143 let result = Builder :: new( TokioExecutor :: new( ) )
167144 . serve_connection( TokioIo :: new( stream) , service_fn)
@@ -180,7 +157,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
180157 }
181158 }
182159
183- while tasks. join_next ( ) . await . is_some ( ) { }
160+ tracker. close ( ) ;
161+ info ! ( "waiting for {} task(s) to finish" , tracker. len( ) ) ;
162+ tracker. wait ( ) . await ;
184163
185164 Ok ( ( ) )
186165}
@@ -308,7 +287,7 @@ async fn handle_request(
308287}
309288
310289#[ cfg( feature = "metrics" ) ]
311- fn handle_metrics ( handle : Arc < PrometheusHandle > ) -> Response < BoxBody < Bytes , hyper:: Error > > {
290+ fn handle_metrics ( handle : PrometheusHandle ) -> Response < BoxBody < Bytes , hyper:: Error > > {
312291 Response :: builder ( )
313292 . header (
314293 header:: CONTENT_TYPE ,
@@ -379,10 +358,10 @@ fn parse_headers(
379358 }
380359}
381360
382- fn parse_env < T > ( key : & str ) -> Result < Option < T > , Box < dyn Error > >
361+ fn parse_env < F > ( key : & str ) -> Result < Option < F > , Box < dyn Error > >
383362where
384- T : FromStr ,
385- <T as FromStr >:: Err : Error + ' static ,
363+ F : FromStr ,
364+ <F as FromStr >:: Err : Error + ' static ,
386365{
387366 match env:: var ( key) {
388367 Ok ( s) => match s. parse ( ) {
0 commit comments