Skip to content

Commit bcc788a

Browse files
authored
refactor main (#102)
* replace joinset with tasktracker * restructure `main`
1 parent 8ffca0b commit bcc788a

File tree

3 files changed

+66
-74
lines changed

3 files changed

+66
-74
lines changed

Cargo.lock

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ hyper-rustls = { version = "0.27", default-features = false, features = ["webpki
1313
hyper-hickory = { version = "0.8", features = ["tokio"] }
1414
hyper-util = { version = "0.1", default-features = false, features = ["tokio", "server", "http1", "http2"] }
1515
tokio = { version = "1.43", features = ["rt-multi-thread", "macros", "signal"] }
16-
tokio-util = { version = "0.7.8", default-features = false, features = ["time"] }
16+
tokio-util = { version = "0.7.8", features = ["rt", "time"] }
1717
tracing = "0.1"
1818
tracing-subscriber = "0.3"
1919
twilight-http-ratelimiting = "0.17"

src/main.rs

Lines changed: 52 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@ use std::{
2222
convert::Infallible,
2323
env,
2424
error::Error,
25-
net::{Ipv4Addr, SocketAddr},
25+
net::{Ipv4Addr, SocketAddrV4},
2626
pin::pin,
2727
str::FromStr,
2828
sync::Arc,
2929
time::{Duration, Instant},
3030
};
31-
use tokio::{net::TcpListener, task::JoinSet};
31+
use tokio::net::TcpListener;
32+
use tokio_util::task::TaskTracker;
3233
use tracing::{error, info, trace};
3334
use twilight_http_ratelimiting::{Endpoint, Method, RateLimitHeaders, RateLimiter};
3435

@@ -55,10 +56,7 @@ static METRIC_KEY: LazyLock<Cow<str>> = LazyLock::new(|| {
5556
async fn main() -> Result<(), Box<dyn Error>> {
5657
tracing_subscriber::fmt::init();
5758

58-
let host = parse_env("HOST")?.unwrap_or(Ipv4Addr::UNSPECIFIED);
59-
let port = parse_env("PORT")?.unwrap_or(80);
60-
61-
let https_connector = {
59+
let client = {
6260
let mut http_connector = TokioHickoryResolver::default().into_http_connector();
6361
http_connector.enforce_http(false);
6462

@@ -67,45 +65,42 @@ async fn main() -> Result<(), Box<dyn Error>> {
6765
.https_only()
6866
.enable_http1();
6967

70-
if env::var_os("DISABLE_HTTP2").is_some() {
68+
let https_connector = if env::var_os("DISABLE_HTTP2").is_some() {
7169
builder.wrap_connector(http_connector)
7270
} else {
7371
builder.enable_http2().wrap_connector(http_connector)
74-
}
72+
};
73+
74+
Client::builder(TokioExecutor::new()).build(https_connector)
7575
};
7676

77-
let client: Client<_, Incoming> = Client::builder(TokioExecutor::new()).build(https_connector);
77+
#[cfg(feature = "metrics")]
78+
let handle = PrometheusBuilder::new()
79+
.idle_timeout(
80+
MetricKindMask::COUNTER | MetricKindMask::HISTOGRAM,
81+
Some(Duration::from_secs(
82+
parse_env("METRIC_TIMEOUT")?.unwrap_or(300),
83+
)),
84+
)
85+
.install_recorder()
86+
.expect("installed once");
87+
7888
let ratelimiter_map = Arc::new(RatelimiterMap::new(
7989
env::var("DISCORD_TOKEN")?,
8090
Duration::from_secs(parse_env("CLIENT_DECAY_TIMEOUT")?.unwrap_or(3600)),
8191
parse_env("CLIENT_CACHE_MAX_SIZE")?,
8292
));
8393

84-
let address = SocketAddr::from((host, port));
85-
86-
#[cfg(feature = "metrics")]
87-
let handle: Arc<PrometheusHandle>;
88-
89-
#[cfg(feature = "metrics")]
90-
{
91-
let timeout = parse_env("METRIC_TIMEOUT")?.unwrap_or(300);
92-
let recorder = PrometheusBuilder::new()
93-
.idle_timeout(
94-
MetricKindMask::COUNTER | MetricKindMask::HISTOGRAM,
95-
Some(Duration::from_secs(timeout)),
96-
)
97-
.build_recorder();
98-
handle = Arc::new(recorder.handle());
99-
metrics::set_global_recorder(Box::new(recorder))
100-
.expect("Failed to create metrics receiver!");
101-
}
94+
let host = parse_env("HOST")?.unwrap_or(Ipv4Addr::UNSPECIFIED);
95+
let port = parse_env("PORT")?.unwrap_or(80);
96+
let address = SocketAddrV4::new(host, port);
10297

10398
let listener = TcpListener::bind(&address).await?;
10499
let mut shutdown_signal = pin!(shutdown_signal());
105100

106101
info!("Listening on http://{}", address);
107102

108-
let mut tasks = JoinSet::new();
103+
let tracker = TaskTracker::new();
109104

110105
loop {
111106
tokio::select! {
@@ -115,53 +110,35 @@ async fn main() -> Result<(), Box<dyn Error>> {
115110
continue;
116111
};
117112

118-
119-
let ratelimiter_map = ratelimiter_map.clone();
120-
// Cloning a hyper client is fairly cheap by design
113+
let ratelimiter_map = Arc::clone(&ratelimiter_map);
121114
let client = client.clone();
122-
123115
#[cfg(feature = "metrics")]
124116
let handle = handle.clone();
125117

126-
tasks.spawn(async move {
127-
trace!("Connection from: {:?}", addr);
128-
129-
let service_fn = service::service_fn(move |incoming: Request<Incoming>| {
130-
let token = incoming
131-
.headers()
132-
.get("authorization")
133-
.and_then(|value| value.to_str().ok());
134-
let (ratelimiter, token) = ratelimiter_map.get_or_insert(token);
135-
let client = client.clone();
136-
118+
let service_fn = service::service_fn(move |request| {
119+
let token = request
120+
.headers()
121+
.get(header::AUTHORIZATION)
122+
.and_then(|value| value.to_str().ok());
123+
let (ratelimiter, token) = ratelimiter_map.get_or_insert(token);
124+
let client = client.clone();
125+
#[cfg(feature = "metrics")]
126+
let handle = handle.clone();
127+
128+
async move {
137129
#[cfg(feature = "metrics")]
138-
{
139-
let handle = handle.clone();
140-
141-
async move {
142-
Ok::<_, Infallible>({
143-
if incoming.uri().path() == "/metrics" {
144-
handle_metrics(handle)
145-
} else {
146-
handle_request(client, ratelimiter, token, incoming)
147-
.await
148-
.unwrap_or_else(|err| err.as_response())
149-
}
150-
})
151-
}
130+
if request.uri().path() == "/metrics" {
131+
return Ok::<_, Infallible>(handle_metrics(handle));
152132
}
153133

154-
#[cfg(not(feature = "metrics"))]
155-
{
156-
async move {
157-
Ok::<_, Infallible>(
158-
handle_request(client, ratelimiter, token, incoming)
159-
.await
160-
.unwrap_or_else(|err| err.as_response()),
161-
)
162-
}
163-
}
164-
});
134+
Ok::<_, Infallible>(handle_request(client, ratelimiter, token, request)
135+
.await
136+
.unwrap_or_else(|err| err.as_response()))
137+
}
138+
});
139+
140+
tracker.spawn(async move {
141+
trace!("Connection from: {:?}", addr);
165142

166143
let result = Builder::new(TokioExecutor::new())
167144
.serve_connection(TokioIo::new(stream), service_fn)
@@ -180,7 +157,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
180157
}
181158
}
182159

183-
while tasks.join_next().await.is_some() {}
160+
tracker.close();
161+
info!("waiting for {} task(s) to finish", tracker.len());
162+
tracker.wait().await;
184163

185164
Ok(())
186165
}
@@ -308,7 +287,7 @@ async fn handle_request(
308287
}
309288

310289
#[cfg(feature = "metrics")]
311-
fn handle_metrics(handle: Arc<PrometheusHandle>) -> Response<BoxBody<Bytes, hyper::Error>> {
290+
fn handle_metrics(handle: PrometheusHandle) -> Response<BoxBody<Bytes, hyper::Error>> {
312291
Response::builder()
313292
.header(
314293
header::CONTENT_TYPE,
@@ -379,10 +358,10 @@ fn parse_headers(
379358
}
380359
}
381360

382-
fn parse_env<T>(key: &str) -> Result<Option<T>, Box<dyn Error>>
361+
fn parse_env<F>(key: &str) -> Result<Option<F>, Box<dyn Error>>
383362
where
384-
T: FromStr,
385-
<T as FromStr>::Err: Error + 'static,
363+
F: FromStr,
364+
<F as FromStr>::Err: Error + 'static,
386365
{
387366
match env::var(key) {
388367
Ok(s) => match s.parse() {

0 commit comments

Comments
 (0)