1
+ use std:: sync:: Arc ;
2
+
1
3
// Import the base64 crate Engine trait anonymously so we can
2
4
// call its methods without adding to the namespace.
3
5
use base64:: engine:: general_purpose:: STANDARD as BASE64 ;
4
6
use base64:: engine:: Engine as _;
7
+ use tracing:: Instrument ;
8
+
9
+ // Auths in those groups are independent of each other.
10
+ // This lets us reduce mutex contention
11
+ #[ derive( Hash , Eq , PartialEq , Clone ) ]
12
+ struct AuthTimersGroupKey {
13
+ url : String ,
14
+ username : String ,
15
+ }
5
16
6
- lazy_static ! {
7
- static ref AUTH : std:: sync:: Mutex <std:: collections:: HashMap <Handle , Header >> =
8
- std:: sync:: Mutex :: new( std:: collections:: HashMap :: new( ) ) ;
9
- static ref AUTH_TIMERS : std:: sync:: Mutex <AuthTimers > =
10
- std:: sync:: Mutex :: new( std:: collections:: HashMap :: new( ) ) ;
17
+ impl AuthTimersGroupKey {
18
+ fn new ( url : & str , handle : & Handle ) -> Self {
19
+ let ( username, _) = handle. parse ( ) . unwrap_or_default ( ) ;
20
+
21
+ Self {
22
+ url : url. to_string ( ) ,
23
+ username,
24
+ }
25
+ }
11
26
}
12
27
13
- type AuthTimers = std:: collections:: HashMap < ( String , Handle ) , std:: time:: Instant > ;
28
+ // Within a group, we can hold the lock for longer to verify the auth with upstream
29
+ type AuthTimersGroup = std:: collections:: HashMap < Handle , std:: time:: Instant > ;
30
+ type AuthTimers =
31
+ std:: collections:: HashMap < AuthTimersGroupKey , Arc < tokio:: sync:: Mutex < AuthTimersGroup > > > ;
32
+
33
+ lazy_static ! {
34
+ // Note the use of std::sync::Mutex: access to those structures should only be performed
35
+ // shortly, without blocking the async runtime for long time and without holding the
36
+ // lock across an await point.
37
+ static ref AUTH : std:: sync:: Mutex <std:: collections:: HashMap <Handle , Header >> = Default :: default ( ) ;
38
+ static ref AUTH_TIMERS : std:: sync:: Mutex <AuthTimers > = Default :: default ( ) ;
39
+ }
14
40
15
41
// Wrapper struct for storing passwords to avoid having
16
42
// them output to traces by accident
17
- #[ derive( Clone ) ]
43
+ #[ derive( Clone , Default ) ]
18
44
struct Header {
19
45
pub header : Option < hyper:: header:: HeaderValue > ,
20
46
}
21
47
22
- #[ derive( Clone , PartialEq , Eq , Hash , serde:: Serialize , serde:: Deserialize ) ]
48
+ #[ derive( Clone , PartialEq , Eq , Hash , PartialOrd , Ord , serde:: Serialize , serde:: Deserialize ) ]
23
49
pub struct Handle {
24
- pub hash : String ,
50
+ pub hash : Option < String > ,
25
51
}
26
52
27
53
impl std:: fmt:: Debug for Handle {
@@ -32,39 +58,50 @@ impl std::fmt::Debug for Handle {
32
58
33
59
impl Handle {
34
60
// Returns a pair: (username, password)
35
- pub fn parse ( & self ) -> josh:: JoshResult < ( String , String ) > {
36
- let line = josh:: some_or!(
37
- AUTH . lock( )
61
+ pub fn parse ( & self ) -> Option < ( String , String ) > {
62
+ let get_result = || -> josh:: JoshResult < ( String , String ) > {
63
+ let line = AUTH
64
+ . lock ( )
38
65
. unwrap ( )
39
66
. get ( self )
40
67
. and_then ( |h| h. header . as_ref ( ) )
41
- . map( |h| h. as_bytes( ) . to_owned( ) ) ,
42
- {
43
- return Ok ( ( "" . to_string( ) , "" . to_string( ) ) ) ;
44
- }
45
- ) ;
68
+ . map ( |h| h. as_bytes ( ) . to_owned ( ) )
69
+ . ok_or_else ( || josh:: josh_error ( "no auth found" ) ) ?;
46
70
47
- let u = josh:: ok_or!( String :: from_utf8( line[ 6 ..] . to_vec( ) ) , {
48
- return Ok ( ( "" . to_string( ) , "" . to_string( ) ) ) ;
49
- } ) ;
50
- let decoded = josh:: ok_or!( BASE64 . decode( u) , {
51
- return Ok ( ( "" . to_string( ) , "" . to_string( ) ) ) ;
52
- } ) ;
53
- let s = josh:: ok_or!( String :: from_utf8( decoded) , {
54
- return Ok ( ( "" . to_string( ) , "" . to_string( ) ) ) ;
55
- } ) ;
56
- let ( username, password) = s. as_str ( ) . split_once ( ':' ) . unwrap_or ( ( "" , "" ) ) ;
57
- Ok ( ( username. to_string ( ) , password. to_string ( ) ) )
71
+ let line = String :: from_utf8 ( line) ?;
72
+ let ( _, token) = line
73
+ . split_once ( ' ' )
74
+ . ok_or_else ( || josh:: josh_error ( "Unsupported auth type" ) ) ?;
75
+
76
+ let decoded = BASE64 . decode ( token) ?;
77
+ let decoded = String :: from_utf8 ( decoded) ?;
78
+
79
+ let ( username, password) = decoded
80
+ . split_once ( ':' )
81
+ . ok_or_else ( || josh:: josh_error ( "No password found" ) ) ?;
82
+
83
+ Ok ( ( username. to_string ( ) , password. to_string ( ) ) )
84
+ } ;
85
+
86
+ match get_result ( ) {
87
+ Ok ( pair) => Some ( pair) ,
88
+ Err ( e) => {
89
+ tracing:: trace!(
90
+ handle = ?self ,
91
+ "Falling back to default auth: {:?}" ,
92
+ e
93
+ ) ;
94
+
95
+ None
96
+ }
97
+ }
58
98
}
59
99
}
60
100
61
101
pub fn add_auth ( token : & str ) -> josh:: JoshResult < Handle > {
62
102
let header = hyper:: header:: HeaderValue :: from_str ( & format ! ( "Basic {}" , BASE64 . encode( token) ) ) ?;
63
103
let hp = Handle {
64
- hash : format ! (
65
- "{:?}" ,
66
- git2:: Oid :: hash_object( git2:: ObjectType :: Blob , header. as_bytes( ) ) ?
67
- ) ,
104
+ hash : Some ( git2:: Oid :: hash_object ( git2:: ObjectType :: Blob , header. as_bytes ( ) ) ?. to_string ( ) ) ,
68
105
} ;
69
106
let p = Header {
70
107
header : Some ( header) ,
@@ -73,65 +110,122 @@ pub fn add_auth(token: &str) -> josh::JoshResult<Handle> {
73
110
Ok ( hp)
74
111
}
75
112
76
- pub async fn check_auth ( url : & str , auth : & Handle , required : bool ) -> josh:: JoshResult < bool > {
77
- if required && auth. hash . is_empty ( ) {
78
- return Ok ( false ) ;
79
- }
113
+ #[ tracing:: instrument( ) ]
114
+ pub async fn check_http_auth ( url : & str , auth : & Handle , required : bool ) -> josh:: JoshResult < bool > {
115
+ use opentelemetry_semantic_conventions:: trace:: HTTP_RESPONSE_STATUS_CODE ;
80
116
81
- if let Some ( last) = AUTH_TIMERS . lock ( ) ?. get ( & ( url. to_string ( ) , auth. clone ( ) ) ) {
82
- let since = std:: time:: Instant :: now ( ) . duration_since ( * last) ;
83
- tracing:: trace!( "last: {:?}, since: {:?}" , last, since) ;
84
- if since < std:: time:: Duration :: from_secs ( 60 * 30 ) {
85
- tracing:: trace!( "cached auth" ) ;
86
- return Ok ( true ) ;
87
- }
117
+ if required && auth. hash . is_none ( ) {
118
+ return Ok ( false ) ;
88
119
}
89
120
90
- tracing:: trace!( "no cached auth {:?}" , * AUTH_TIMERS . lock( ) ?) ;
121
+ let group_key = AuthTimersGroupKey :: new ( url, & auth) ;
122
+ let auth_timers = AUTH_TIMERS
123
+ . lock ( )
124
+ . unwrap ( )
125
+ . entry ( group_key. clone ( ) )
126
+ . or_default ( )
127
+ . clone ( ) ;
91
128
92
- let https = hyper_tls:: HttpsConnector :: new ( ) ;
93
- let client = hyper:: Client :: builder ( ) . build :: < _ , hyper:: Body > ( https) ;
129
+ let auth_header = AUTH . lock ( ) . unwrap ( ) . get ( auth) . cloned ( ) . unwrap_or_default ( ) ;
94
130
95
- let password = AUTH
96
- . lock ( ) ?
97
- . get ( auth)
98
- . unwrap_or ( & Header { header : None } )
99
- . to_owned ( ) ;
100
131
let refs_url = format ! ( "{}/info/refs?service=git-upload-pack" , url) ;
132
+ let do_request = || {
133
+ let refs_url = refs_url. clone ( ) ;
134
+ let do_request_span = tracing:: info_span!( "check_http_auth: make request" ) ;
101
135
102
- let builder = hyper :: Request :: builder ( )
103
- . method ( hyper :: Method :: GET )
104
- . uri ( & refs_url ) ;
136
+ async move {
137
+ let https = hyper_tls :: HttpsConnector :: new ( ) ;
138
+ let client = hyper :: Client :: builder ( ) . build :: < _ , hyper :: Body > ( https ) ;
105
139
106
- let builder = if let Some ( value) = password. header {
107
- builder. header ( hyper:: header:: AUTHORIZATION , value)
108
- } else {
109
- builder
140
+ let builder = hyper:: Request :: builder ( )
141
+ . method ( hyper:: Method :: GET )
142
+ . uri ( & refs_url) ;
143
+
144
+ let builder = if let Some ( value) = auth_header. header {
145
+ builder. header ( hyper:: header:: AUTHORIZATION , value)
146
+ } else {
147
+ builder
148
+ } ;
149
+
150
+ let request = builder. body ( hyper:: Body :: empty ( ) ) ?;
151
+ let resp = client. request ( request) . await ?;
152
+
153
+ Ok :: < _ , josh:: JoshError > ( resp)
154
+ }
155
+ . instrument ( do_request_span)
110
156
} ;
111
157
112
- let request = builder. body ( hyper:: Body :: empty ( ) ) ?;
113
- let resp = client. request ( request) . await ?;
158
+ // Only lock the mutex if auth handle is not empty, because otherwise
159
+ // for remotes that require auth, we could run into situation where
160
+ // multiple requests are executed essentially sequentially because
161
+ // remote always returns 401 for authenticated requests and we never
162
+ // populate the auth_timers map
163
+ let resp = if auth. hash . is_some ( ) {
164
+ let mut auth_timers = auth_timers. lock ( ) . await ;
165
+
166
+ if let Some ( last) = auth_timers. get ( auth) {
167
+ let since = std:: time:: Instant :: now ( ) . duration_since ( * last) ;
168
+ let expired = since > std:: time:: Duration :: from_secs ( 60 * 30 ) ;
169
+
170
+ tracing:: info!(
171
+ last = ?last,
172
+ since = ?since,
173
+ expired = %expired,
174
+ "check_http_auth: found auth entry"
175
+ ) ;
176
+
177
+ if !expired {
178
+ return Ok ( true ) ;
179
+ }
180
+ }
114
181
115
- let status = resp. status ( ) ;
182
+ tracing:: info!(
183
+ auth_timers = ?auth_timers,
184
+ "check_http_auth: no valid cached auth"
185
+ ) ;
116
186
117
- tracing:: trace!( "http resp.status {:?}" , resp. status( ) ) ;
187
+ let resp = do_request ( ) . await ?;
188
+ if resp. status ( ) . is_success ( ) {
189
+ auth_timers. insert ( auth. clone ( ) , std:: time:: Instant :: now ( ) ) ;
190
+ }
191
+
192
+ resp
193
+ } else {
194
+ do_request ( ) . await ?
195
+ } ;
196
+
197
+ let status = resp. status ( ) ;
118
198
119
- let err_msg = format ! ( "got http response: {} {:?}" , refs_url, resp) ;
199
+ tracing:: event!(
200
+ tracing:: Level :: INFO ,
201
+ { HTTP_RESPONSE_STATUS_CODE } = status. as_u16( ) ,
202
+ "check_http_auth: response"
203
+ ) ;
120
204
121
205
if status == hyper:: StatusCode :: OK {
122
- AUTH_TIMERS
123
- . lock ( ) ?
124
- . insert ( ( url. to_string ( ) , auth. clone ( ) ) , std:: time:: Instant :: now ( ) ) ;
125
206
Ok ( true )
126
207
} else if status == hyper:: StatusCode :: UNAUTHORIZED {
127
- tracing:: warn! ( "resp.status == 401: {:?}" , & err_msg ) ;
128
- tracing:: trace! (
129
- "body: {:?}" ,
130
- std :: str :: from_utf8 ( & hyper :: body :: to_bytes ( resp . into_body ( ) ) . await ? )
208
+ tracing:: event! (
209
+ tracing:: Level :: WARN ,
210
+ { HTTP_RESPONSE_STATUS_CODE } = status . as_u16 ( ) ,
211
+ "check_http_auth: unauthorized"
131
212
) ;
213
+
214
+ let response = hyper:: body:: to_bytes ( resp. into_body ( ) ) . await ?;
215
+ let response = String :: from_utf8_lossy ( & response) ;
216
+
217
+ tracing:: event!(
218
+ tracing:: Level :: TRACE ,
219
+ "http.response.body" = %response,
220
+ "check_http_auth: unauthorized" ,
221
+ ) ;
222
+
132
223
Ok ( false )
133
224
} else {
134
- return Err ( josh:: josh_error ( & err_msg) ) ;
225
+ return Err ( josh:: josh_error ( & format ! (
226
+ "check_http_auth: got http response: {} {:?}" ,
227
+ refs_url, resp
228
+ ) ) ) ;
135
229
}
136
230
}
137
231
@@ -144,9 +238,8 @@ pub fn strip_auth(
144
238
145
239
if let Some ( header) = header {
146
240
let hp = Handle {
147
- hash : format ! (
148
- "{:?}" ,
149
- git2:: Oid :: hash_object( git2:: ObjectType :: Blob , header. as_bytes( ) ) ?
241
+ hash : Some (
242
+ git2:: Oid :: hash_object ( git2:: ObjectType :: Blob , header. as_bytes ( ) ) ?. to_string ( ) ,
150
243
) ,
151
244
} ;
152
245
let p = Header {
@@ -156,10 +249,5 @@ pub fn strip_auth(
156
249
return Ok ( ( hp, req) ) ;
157
250
}
158
251
159
- Ok ( (
160
- Handle {
161
- hash : "" . to_owned ( ) ,
162
- } ,
163
- req,
164
- ) )
252
+ Ok ( ( Handle { hash : None } , req) )
165
253
}
0 commit comments