Skip to content

Commit 3043aeb

Browse files
Fix race condition in check_auth (#1331)
When a lot of requests arrive roughly at the same time, several requests can enter the critical section where an HTTP request to upstream is made to check the auth provided by the client. This means that potentially thousands of requests can get through to the remote, leading to rate limits and network errors with some remotes. * Introduce a much more granular lock scope * Fix the issue by extending the lock region * Switch mutex to async to avoid blocking runtime * Improve tracing commit-id:7d950008
1 parent 012f8dc commit 3043aeb

File tree

3 files changed

+177
-92
lines changed

3 files changed

+177
-92
lines changed

josh-proxy/src/auth.rs

+168-80
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,53 @@
1+
use std::sync::Arc;
2+
13
// Import the base64 crate Engine trait anonymously so we can
24
// call its methods without adding to the namespace.
35
use base64::engine::general_purpose::STANDARD as BASE64;
46
use base64::engine::Engine as _;
7+
use tracing::Instrument;
8+
9+
// Auths in those groups are independent of each other.
10+
// This lets us reduce mutex contention
11+
#[derive(Hash, Eq, PartialEq, Clone)]
12+
struct AuthTimersGroupKey {
13+
url: String,
14+
username: String,
15+
}
516

6-
lazy_static! {
7-
static ref AUTH: std::sync::Mutex<std::collections::HashMap<Handle, Header>> =
8-
std::sync::Mutex::new(std::collections::HashMap::new());
9-
static ref AUTH_TIMERS: std::sync::Mutex<AuthTimers> =
10-
std::sync::Mutex::new(std::collections::HashMap::new());
17+
impl AuthTimersGroupKey {
18+
fn new(url: &str, handle: &Handle) -> Self {
19+
let (username, _) = handle.parse().unwrap_or_default();
20+
21+
Self {
22+
url: url.to_string(),
23+
username,
24+
}
25+
}
1126
}
1227

13-
type AuthTimers = std::collections::HashMap<(String, Handle), std::time::Instant>;
28+
// Within a group, we can hold the lock for longer to verify the auth with upstream
29+
type AuthTimersGroup = std::collections::HashMap<Handle, std::time::Instant>;
30+
type AuthTimers =
31+
std::collections::HashMap<AuthTimersGroupKey, Arc<tokio::sync::Mutex<AuthTimersGroup>>>;
32+
33+
lazy_static! {
34+
// Note the use of std::sync::Mutex: access to those structures should only be performed
35+
// shortly, without blocking the async runtime for long time and without holding the
36+
// lock across an await point.
37+
static ref AUTH: std::sync::Mutex<std::collections::HashMap<Handle, Header>> = Default::default();
38+
static ref AUTH_TIMERS: std::sync::Mutex<AuthTimers> = Default::default();
39+
}
1440

1541
// Wrapper struct for storing passwords to avoid having
1642
// them output to traces by accident
17-
#[derive(Clone)]
43+
#[derive(Clone, Default)]
1844
struct Header {
1945
pub header: Option<hyper::header::HeaderValue>,
2046
}
2147

22-
#[derive(Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
48+
#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
2349
pub struct Handle {
24-
pub hash: String,
50+
pub hash: Option<String>,
2551
}
2652

2753
impl std::fmt::Debug for Handle {
@@ -32,39 +58,50 @@ impl std::fmt::Debug for Handle {
3258

3359
impl Handle {
3460
// Returns a pair: (username, password)
35-
pub fn parse(&self) -> josh::JoshResult<(String, String)> {
36-
let line = josh::some_or!(
37-
AUTH.lock()
61+
pub fn parse(&self) -> Option<(String, String)> {
62+
let get_result = || -> josh::JoshResult<(String, String)> {
63+
let line = AUTH
64+
.lock()
3865
.unwrap()
3966
.get(self)
4067
.and_then(|h| h.header.as_ref())
41-
.map(|h| h.as_bytes().to_owned()),
42-
{
43-
return Ok(("".to_string(), "".to_string()));
44-
}
45-
);
68+
.map(|h| h.as_bytes().to_owned())
69+
.ok_or_else(|| josh::josh_error("no auth found"))?;
4670

47-
let u = josh::ok_or!(String::from_utf8(line[6..].to_vec()), {
48-
return Ok(("".to_string(), "".to_string()));
49-
});
50-
let decoded = josh::ok_or!(BASE64.decode(u), {
51-
return Ok(("".to_string(), "".to_string()));
52-
});
53-
let s = josh::ok_or!(String::from_utf8(decoded), {
54-
return Ok(("".to_string(), "".to_string()));
55-
});
56-
let (username, password) = s.as_str().split_once(':').unwrap_or(("", ""));
57-
Ok((username.to_string(), password.to_string()))
71+
let line = String::from_utf8(line)?;
72+
let (_, token) = line
73+
.split_once(' ')
74+
.ok_or_else(|| josh::josh_error("Unsupported auth type"))?;
75+
76+
let decoded = BASE64.decode(token)?;
77+
let decoded = String::from_utf8(decoded)?;
78+
79+
let (username, password) = decoded
80+
.split_once(':')
81+
.ok_or_else(|| josh::josh_error("No password found"))?;
82+
83+
Ok((username.to_string(), password.to_string()))
84+
};
85+
86+
match get_result() {
87+
Ok(pair) => Some(pair),
88+
Err(e) => {
89+
tracing::trace!(
90+
handle = ?self,
91+
"Falling back to default auth: {:?}",
92+
e
93+
);
94+
95+
None
96+
}
97+
}
5898
}
5999
}
60100

61101
pub fn add_auth(token: &str) -> josh::JoshResult<Handle> {
62102
let header = hyper::header::HeaderValue::from_str(&format!("Basic {}", BASE64.encode(token)))?;
63103
let hp = Handle {
64-
hash: format!(
65-
"{:?}",
66-
git2::Oid::hash_object(git2::ObjectType::Blob, header.as_bytes())?
67-
),
104+
hash: Some(git2::Oid::hash_object(git2::ObjectType::Blob, header.as_bytes())?.to_string()),
68105
};
69106
let p = Header {
70107
header: Some(header),
@@ -73,65 +110,122 @@ pub fn add_auth(token: &str) -> josh::JoshResult<Handle> {
73110
Ok(hp)
74111
}
75112

76-
pub async fn check_auth(url: &str, auth: &Handle, required: bool) -> josh::JoshResult<bool> {
77-
if required && auth.hash.is_empty() {
78-
return Ok(false);
79-
}
113+
#[tracing::instrument()]
114+
pub async fn check_http_auth(url: &str, auth: &Handle, required: bool) -> josh::JoshResult<bool> {
115+
use opentelemetry_semantic_conventions::trace::HTTP_RESPONSE_STATUS_CODE;
80116

81-
if let Some(last) = AUTH_TIMERS.lock()?.get(&(url.to_string(), auth.clone())) {
82-
let since = std::time::Instant::now().duration_since(*last);
83-
tracing::trace!("last: {:?}, since: {:?}", last, since);
84-
if since < std::time::Duration::from_secs(60 * 30) {
85-
tracing::trace!("cached auth");
86-
return Ok(true);
87-
}
117+
if required && auth.hash.is_none() {
118+
return Ok(false);
88119
}
89120

90-
tracing::trace!("no cached auth {:?}", *AUTH_TIMERS.lock()?);
121+
let group_key = AuthTimersGroupKey::new(url, &auth);
122+
let auth_timers = AUTH_TIMERS
123+
.lock()
124+
.unwrap()
125+
.entry(group_key.clone())
126+
.or_default()
127+
.clone();
91128

92-
let https = hyper_tls::HttpsConnector::new();
93-
let client = hyper::Client::builder().build::<_, hyper::Body>(https);
129+
let auth_header = AUTH.lock().unwrap().get(auth).cloned().unwrap_or_default();
94130

95-
let password = AUTH
96-
.lock()?
97-
.get(auth)
98-
.unwrap_or(&Header { header: None })
99-
.to_owned();
100131
let refs_url = format!("{}/info/refs?service=git-upload-pack", url);
132+
let do_request = || {
133+
let refs_url = refs_url.clone();
134+
let do_request_span = tracing::info_span!("check_http_auth: make request");
101135

102-
let builder = hyper::Request::builder()
103-
.method(hyper::Method::GET)
104-
.uri(&refs_url);
136+
async move {
137+
let https = hyper_tls::HttpsConnector::new();
138+
let client = hyper::Client::builder().build::<_, hyper::Body>(https);
105139

106-
let builder = if let Some(value) = password.header {
107-
builder.header(hyper::header::AUTHORIZATION, value)
108-
} else {
109-
builder
140+
let builder = hyper::Request::builder()
141+
.method(hyper::Method::GET)
142+
.uri(&refs_url);
143+
144+
let builder = if let Some(value) = auth_header.header {
145+
builder.header(hyper::header::AUTHORIZATION, value)
146+
} else {
147+
builder
148+
};
149+
150+
let request = builder.body(hyper::Body::empty())?;
151+
let resp = client.request(request).await?;
152+
153+
Ok::<_, josh::JoshError>(resp)
154+
}
155+
.instrument(do_request_span)
110156
};
111157

112-
let request = builder.body(hyper::Body::empty())?;
113-
let resp = client.request(request).await?;
158+
// Only lock the mutex if auth handle is not empty, because otherwise
159+
// for remotes that require auth, we could run into situation where
160+
// multiple requests are executed essentially sequentially because
161+
// remote always returns 401 for authenticated requests and we never
162+
// populate the auth_timers map
163+
let resp = if auth.hash.is_some() {
164+
let mut auth_timers = auth_timers.lock().await;
165+
166+
if let Some(last) = auth_timers.get(auth) {
167+
let since = std::time::Instant::now().duration_since(*last);
168+
let expired = since > std::time::Duration::from_secs(60 * 30);
169+
170+
tracing::info!(
171+
last = ?last,
172+
since = ?since,
173+
expired = %expired,
174+
"check_http_auth: found auth entry"
175+
);
176+
177+
if !expired {
178+
return Ok(true);
179+
}
180+
}
114181

115-
let status = resp.status();
182+
tracing::info!(
183+
auth_timers = ?auth_timers,
184+
"check_http_auth: no valid cached auth"
185+
);
116186

117-
tracing::trace!("http resp.status {:?}", resp.status());
187+
let resp = do_request().await?;
188+
if resp.status().is_success() {
189+
auth_timers.insert(auth.clone(), std::time::Instant::now());
190+
}
191+
192+
resp
193+
} else {
194+
do_request().await?
195+
};
196+
197+
let status = resp.status();
118198

119-
let err_msg = format!("got http response: {} {:?}", refs_url, resp);
199+
tracing::event!(
200+
tracing::Level::INFO,
201+
{ HTTP_RESPONSE_STATUS_CODE } = status.as_u16(),
202+
"check_http_auth: response"
203+
);
120204

121205
if status == hyper::StatusCode::OK {
122-
AUTH_TIMERS
123-
.lock()?
124-
.insert((url.to_string(), auth.clone()), std::time::Instant::now());
125206
Ok(true)
126207
} else if status == hyper::StatusCode::UNAUTHORIZED {
127-
tracing::warn!("resp.status == 401: {:?}", &err_msg);
128-
tracing::trace!(
129-
"body: {:?}",
130-
std::str::from_utf8(&hyper::body::to_bytes(resp.into_body()).await?)
208+
tracing::event!(
209+
tracing::Level::WARN,
210+
{ HTTP_RESPONSE_STATUS_CODE } = status.as_u16(),
211+
"check_http_auth: unauthorized"
131212
);
213+
214+
let response = hyper::body::to_bytes(resp.into_body()).await?;
215+
let response = String::from_utf8_lossy(&response);
216+
217+
tracing::event!(
218+
tracing::Level::TRACE,
219+
"http.response.body" = %response,
220+
"check_http_auth: unauthorized",
221+
);
222+
132223
Ok(false)
133224
} else {
134-
return Err(josh::josh_error(&err_msg));
225+
return Err(josh::josh_error(&format!(
226+
"check_http_auth: got http response: {} {:?}",
227+
refs_url, resp
228+
)));
135229
}
136230
}
137231

@@ -144,9 +238,8 @@ pub fn strip_auth(
144238

145239
if let Some(header) = header {
146240
let hp = Handle {
147-
hash: format!(
148-
"{:?}",
149-
git2::Oid::hash_object(git2::ObjectType::Blob, header.as_bytes())?
241+
hash: Some(
242+
git2::Oid::hash_object(git2::ObjectType::Blob, header.as_bytes())?.to_string(),
150243
),
151244
};
152245
let p = Header {
@@ -156,10 +249,5 @@ pub fn strip_auth(
156249
return Ok((hp, req));
157250
}
158251

159-
Ok((
160-
Handle {
161-
hash: "".to_owned(),
162-
},
163-
req,
164-
))
252+
Ok((Handle { hash: None }, req))
165253
}

josh-proxy/src/bin/josh-proxy.rs

+8-11
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,13 @@ async fn fetch_upstream(
264264

265265
match (fetch_result, remote_auth) {
266266
(Ok(_), RemoteAuth::Http { auth }) => {
267-
let (auth_user, _) = auth.parse().map_err(FetchError::from_josh_error)?;
268-
269-
if matches!(&ARGS.poll_user, Some(user) if auth_user == user.as_str()) {
270-
service
271-
.poll
272-
.lock()?
273-
.insert((upstream_repo, auth.clone(), remote_url));
267+
if let Some((auth_user, _)) = auth.parse() {
268+
if matches!(&ARGS.poll_user, Some(user) if auth_user == user.as_str()) {
269+
service
270+
.poll
271+
.lock()?
272+
.insert((upstream_repo, auth.clone(), remote_url));
273+
}
274274
}
275275

276276
Ok(())
@@ -1275,10 +1275,7 @@ async fn call_service(
12751275

12761276
let http_auth_required = ARGS.require_auth && parsed_url.pathinfo == "/git-receive-pack";
12771277

1278-
if !josh_proxy::auth::check_auth(&remote_url, &auth, http_auth_required)
1279-
.in_current_span()
1280-
.await?
1281-
{
1278+
if !josh_proxy::auth::check_http_auth(&remote_url, &auth, http_auth_required).await? {
12821279
tracing::trace!("require-auth");
12831280
let builder = Response::builder()
12841281
.header(

josh-proxy/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ pub fn run_git_with_auth(
684684
Ok(shell.command_env(cmd, &env, &env_notrace))
685685
}
686686
RemoteAuth::Http { auth } => {
687-
let (username, password) = auth.parse()?;
687+
let (username, password) = auth.parse().unwrap_or_default();
688688
let env_notrace = [
689689
[
690690
("GIT_PASSWORD", password.as_str()),

0 commit comments

Comments
 (0)