1
-
2
- use std:: { result:: Result , thread, time, fmt} ;
3
1
use std:: collections:: HashMap ;
2
+ use std:: { fmt, result:: Result , thread, time} ;
4
3
5
- use chrono:: { DateTime , Duration } ;
6
4
use chrono:: offset:: Utc ;
5
+ use chrono:: { DateTime , Duration } ;
7
6
8
7
mod util;
9
8
10
- #[ derive( Debug , Default , Clone , serde_derive:: Serialize ) ]
9
+ #[ derive( Debug , Default , Clone , serde_derive:: Serialize , serde_derive :: Deserialize ) ]
11
10
pub struct Credential {
12
11
pub token : String ,
13
12
pub expiry : String ,
@@ -26,25 +25,24 @@ impl Credential {
26
25
pub fn is_expired ( & self ) -> bool {
27
26
let exp = match DateTime :: parse_from_rfc3339 ( self . expiry . as_str ( ) ) {
28
27
Ok ( time) => time,
29
- Err ( _) => return false
28
+ Err ( _) => return false ,
30
29
} ;
31
30
let now = Utc :: now ( ) ;
32
31
now > exp
33
32
}
34
33
}
35
34
36
-
37
35
#[ derive( Debug , Clone ) ]
38
36
pub enum DeviceFlowError {
39
- HttpError ( String ) ,
40
- GitHubError ( String ) ,
37
+ HttpError ( String ) ,
38
+ GitHubError ( String ) ,
41
39
}
42
40
43
41
impl fmt:: Display for DeviceFlowError {
44
42
fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
45
43
match self {
46
44
DeviceFlowError :: HttpError ( string) => write ! ( f, "DeviceFlowError: {}" , string) ,
47
- DeviceFlowError :: GitHubError ( string) => write ! ( f, "DeviceFlowError: {}" , string)
45
+ DeviceFlowError :: GitHubError ( string) => write ! ( f, "DeviceFlowError: {}" , string) ,
48
46
}
49
47
}
50
48
}
@@ -57,52 +55,83 @@ impl From<reqwest::Error> for DeviceFlowError {
57
55
}
58
56
}
59
57
60
- pub fn authorize ( client_id : String , host : Option < String > ) -> Result < Credential , DeviceFlowError > {
58
+ pub fn authorize (
59
+ client_id : String ,
60
+ host : Option < String > ,
61
+ scope : Option < String > ,
62
+ ) -> Result < Credential , DeviceFlowError > {
61
63
let my_string: String ;
62
64
let thost = match host {
63
65
Some ( string) => {
64
66
my_string = string;
65
67
Some ( my_string. as_str ( ) )
66
- } ,
67
- None => None
68
+ }
69
+ None => None ,
68
70
} ;
69
71
70
- let mut flow = DeviceFlow :: start ( client_id. as_str ( ) , thost) ?;
72
+ let binding: String ;
73
+ let tscope = match scope {
74
+ Some ( string) => {
75
+ binding = string;
76
+ Some ( binding. as_str ( ) )
77
+ }
78
+ None => None ,
79
+ } ;
80
+
81
+ let mut flow = DeviceFlow :: start ( client_id. as_str ( ) , thost, tscope) ?;
71
82
72
83
// eprintln!("res is {:?}", res);
73
- eprintln ! ( "Please visit {} in your browser" , flow. verification_uri. clone( ) . unwrap( ) ) ;
84
+ eprintln ! (
85
+ "Please visit {} in your browser" ,
86
+ flow. verification_uri. clone( ) . unwrap( )
87
+ ) ;
74
88
eprintln ! ( "And enter code: {}" , flow. user_code. clone( ) . unwrap( ) ) ;
75
89
76
90
thread:: sleep ( FIVE_SECONDS ) ;
77
91
78
92
flow. poll ( 20 )
79
93
}
80
94
81
- pub fn refresh ( client_id : & str , refresh_token : & str , host : Option < String > ) -> Result < Credential , DeviceFlowError > {
95
+ pub fn refresh (
96
+ client_id : & str ,
97
+ refresh_token : & str ,
98
+ host : Option < String > ,
99
+ scope : Option < String > ,
100
+ ) -> Result < Credential , DeviceFlowError > {
82
101
let my_string: String ;
83
102
let thost = match host {
84
103
Some ( string) => {
85
104
my_string = string;
86
105
Some ( my_string. as_str ( ) )
87
- } ,
88
- None => None
106
+ }
107
+ None => None ,
89
108
} ;
90
109
91
- refresh_access_token ( client_id, refresh_token, thost)
110
+ let scope_binding;
111
+ let tscope = match scope {
112
+ Some ( string) => {
113
+ scope_binding = string;
114
+ Some ( scope_binding. as_str ( ) )
115
+ }
116
+ None => None ,
117
+ } ;
118
+
119
+ refresh_access_token ( client_id, refresh_token, thost, tscope)
92
120
}
93
121
94
122
#[ derive( Debug , Clone ) ]
95
123
pub enum DeviceFlowState {
96
124
Pending ,
97
125
Processing ( time:: Duration ) ,
98
126
Success ( Credential ) ,
99
- Failure ( DeviceFlowError )
127
+ Failure ( DeviceFlowError ) ,
100
128
}
101
129
102
130
#[ derive( Clone ) ]
103
131
pub struct DeviceFlow {
104
132
pub host : String ,
105
133
pub client_id : String ,
134
+ pub scope : String ,
106
135
pub user_code : Option < String > ,
107
136
pub device_code : Option < String > ,
108
137
pub verification_uri : Option < String > ,
@@ -112,12 +141,16 @@ pub struct DeviceFlow {
112
141
const FIVE_SECONDS : time:: Duration = time:: Duration :: new ( 5 , 0 ) ;
113
142
114
143
impl DeviceFlow {
115
- pub fn new ( client_id : & str , maybe_host : Option < & str > ) -> Self {
116
- Self {
144
+ pub fn new ( client_id : & str , maybe_host : Option < & str > , scope : Option < & str > ) -> Self {
145
+ Self {
117
146
client_id : String :: from ( client_id) ,
147
+ scope : match scope {
148
+ Some ( string) => String :: from ( string) ,
149
+ None => String :: new ( ) ,
150
+ } ,
118
151
host : match maybe_host {
119
152
Some ( string) => String :: from ( string) ,
120
- None => String :: from ( "github.com" )
153
+ None => String :: from ( "github.com" ) ,
121
154
} ,
122
155
user_code : None ,
123
156
device_code : None ,
@@ -126,31 +159,43 @@ impl DeviceFlow {
126
159
}
127
160
}
128
161
129
- pub fn start ( client_id : & str , maybe_host : Option < & str > ) -> Result < DeviceFlow , DeviceFlowError > {
130
- let mut flow = DeviceFlow :: new ( client_id, maybe_host) ;
162
+ pub fn start (
163
+ client_id : & str ,
164
+ maybe_host : Option < & str > ,
165
+ scope : Option < & str > ,
166
+ ) -> Result < DeviceFlow , DeviceFlowError > {
167
+ let mut flow = DeviceFlow :: new ( client_id, maybe_host, scope) ;
131
168
132
169
flow. setup ( ) ;
133
170
134
171
match flow. state {
135
172
DeviceFlowState :: Processing ( _) => Ok ( flow. to_owned ( ) ) ,
136
173
DeviceFlowState :: Failure ( err) => Err ( err) ,
137
- _ => Err ( util:: credential_error ( "Something truly unexpected happened" . into ( ) ) )
174
+ _ => Err ( util:: credential_error (
175
+ "Something truly unexpected happened" . into ( ) ,
176
+ ) ) ,
138
177
}
139
178
}
140
179
141
180
pub fn setup ( & mut self ) {
142
- let body = format ! ( "client_id={}" , & self . client_id) ;
181
+ let body = format ! ( "client_id={}&scope={} " , & self . client_id, & self . scope ) ;
143
182
let entry_url = format ! ( "https://{}/login/device/code" , & self . host) ;
144
183
145
184
if let Some ( res) = util:: send_request ( self , entry_url, body) {
146
- if res. contains_key ( "error" ) && res. contains_key ( "error_description" ) {
147
- self . state = DeviceFlowState :: Failure ( util:: credential_error ( res[ "error_description" ] . as_str ( ) . unwrap ( ) . into ( ) ) )
185
+ if res. contains_key ( "error" ) && res. contains_key ( "error_description" ) {
186
+ self . state = DeviceFlowState :: Failure ( util:: credential_error (
187
+ res[ "error_description" ] . as_str ( ) . unwrap ( ) . into ( ) ,
188
+ ) )
148
189
} else if res. contains_key ( "error" ) {
149
- self . state = DeviceFlowState :: Failure ( util:: credential_error ( format ! ( "Error response: {:?}" , res[ "error" ] . as_str( ) . unwrap( ) ) ) )
190
+ self . state = DeviceFlowState :: Failure ( util:: credential_error ( format ! (
191
+ "Error response: {:?}" ,
192
+ res[ "error" ] . as_str( ) . unwrap( )
193
+ ) ) )
150
194
} else {
151
195
self . user_code = Some ( String :: from ( res[ "user_code" ] . as_str ( ) . unwrap ( ) ) ) ;
152
196
self . device_code = Some ( String :: from ( res[ "device_code" ] . as_str ( ) . unwrap ( ) ) ) ;
153
- self . verification_uri = Some ( String :: from ( res[ "verification_uri" ] . as_str ( ) . unwrap ( ) ) ) ;
197
+ self . verification_uri =
198
+ Some ( String :: from ( res[ "verification_uri" ] . as_str ( ) . unwrap ( ) ) ) ;
154
199
self . state = DeviceFlowState :: Processing ( FIVE_SECONDS ) ;
155
200
}
156
201
} ;
@@ -162,51 +207,57 @@ impl DeviceFlow {
162
207
163
208
if let DeviceFlowState :: Processing ( interval) = self . state {
164
209
if count == iterations {
165
- return Err ( util:: credential_error ( "Max poll iterations reached" . into ( ) ) )
210
+ return Err ( util:: credential_error ( "Max poll iterations reached" . into ( ) ) ) ;
166
211
}
167
212
168
213
thread:: sleep ( interval) ;
169
214
} else {
170
- break
215
+ break ;
171
216
}
172
- } ;
217
+ }
173
218
174
219
match & self . state {
175
220
DeviceFlowState :: Success ( cred) => Ok ( cred. to_owned ( ) ) ,
176
221
DeviceFlowState :: Failure ( err) => Err ( err. to_owned ( ) ) ,
177
- _ => Err ( util:: credential_error ( "Unable to fetch credential, sorry :/" . into ( ) ) )
222
+ _ => Err ( util:: credential_error (
223
+ "Unable to fetch credential, sorry :/" . into ( ) ,
224
+ ) ) ,
178
225
}
179
226
}
180
227
181
228
pub fn update ( & mut self ) {
182
229
let poll_url = format ! ( "https://{}/login/oauth/access_token" , self . host) ;
183
- let poll_payload = format ! ( "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code" ,
230
+ let poll_payload = format ! (
231
+ "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code" ,
184
232
self . client_id,
185
233
& self . device_code. clone( ) . unwrap( )
186
234
) ;
187
235
188
236
if let Some ( res) = util:: send_request ( self , poll_url, poll_payload) {
189
237
if res. contains_key ( "error" ) {
190
238
match res[ "error" ] . as_str ( ) . unwrap ( ) {
191
- "authorization_pending" => { } ,
239
+ "authorization_pending" => { }
192
240
"slow_down" => {
193
241
if let DeviceFlowState :: Processing ( current_interval) = self . state {
194
- self . state = DeviceFlowState :: Processing ( current_interval + FIVE_SECONDS ) ;
242
+ self . state =
243
+ DeviceFlowState :: Processing ( current_interval + FIVE_SECONDS ) ;
195
244
} ;
196
- } ,
245
+ }
197
246
other_reason => {
198
- self . state = DeviceFlowState :: Failure (
199
- util:: credential_error ( format ! ( "Error checking for token: {}" , other_reason) )
200
- ) ;
201
- } ,
247
+ self . state = DeviceFlowState :: Failure ( util:: credential_error ( format ! (
248
+ "Error checking for token: {}" ,
249
+ other_reason
250
+ ) ) ) ;
251
+ }
202
252
}
203
253
} else {
204
254
let mut this_credential = Credential :: empty ( ) ;
205
255
this_credential. token = res[ "access_token" ] . as_str ( ) . unwrap ( ) . to_string ( ) ;
206
256
207
257
if let Some ( expires_in) = res. get ( "expires_in" ) {
208
258
this_credential. expiry = calculate_expiry ( expires_in. as_i64 ( ) . unwrap ( ) ) ;
209
- this_credential. refresh_token = res[ "refresh_token" ] . as_str ( ) . unwrap ( ) . to_string ( ) ;
259
+ this_credential. refresh_token =
260
+ res[ "refresh_token" ] . as_str ( ) . unwrap ( ) . to_string ( ) ;
210
261
}
211
262
212
263
self . state = DeviceFlowState :: Success ( this_credential) ;
@@ -222,25 +273,40 @@ fn calculate_expiry(expires_in: i64) -> String {
222
273
expiry. to_rfc3339 ( )
223
274
}
224
275
225
- fn refresh_access_token ( client_id : & str , refresh_token : & str , maybe_host : Option < & str > ) -> Result < Credential , DeviceFlowError > {
276
+ fn refresh_access_token (
277
+ client_id : & str ,
278
+ refresh_token : & str ,
279
+ maybe_host : Option < & str > ,
280
+ maybe_scope : Option < & str > ,
281
+ ) -> Result < Credential , DeviceFlowError > {
226
282
let host = match maybe_host {
227
283
Some ( string) => string,
228
- None => "github.com"
284
+ None => "github.com" ,
285
+ } ;
286
+
287
+ let scope = match maybe_scope {
288
+ Some ( string) => string,
289
+ None => "" ,
229
290
} ;
230
291
231
292
let client = reqwest:: blocking:: Client :: new ( ) ;
232
293
let entry_url = format ! ( "https://{}/login/oauth/access_token" , & host) ;
233
- let request_body = format ! ( "client_id={}&refresh_token={}&client_secret=&grant_type=refresh_token" ,
234
- & client_id, & refresh_token) ;
294
+ let request_body = format ! (
295
+ "client_id={}&refresh_token={}&client_secret=&grant_type=refresh_token&scope={}" ,
296
+ & client_id, & refresh_token, & scope
297
+ ) ;
235
298
236
- let res = client. post ( & entry_url)
299
+ let res = client
300
+ . post ( & entry_url)
237
301
. header ( "Accept" , "application/json" )
238
302
. body ( request_body)
239
303
. send ( ) ?
240
304
. json :: < HashMap < String , serde_json:: Value > > ( ) ?;
241
305
242
306
if res. contains_key ( "error" ) {
243
- Err ( util:: credential_error ( res[ "error" ] . as_str ( ) . unwrap ( ) . into ( ) ) )
307
+ Err ( util:: credential_error (
308
+ res[ "error" ] . as_str ( ) . unwrap ( ) . into ( ) ,
309
+ ) )
244
310
} else {
245
311
let mut credential = Credential :: empty ( ) ;
246
312
// eprintln!("res: {:?}", &res);
0 commit comments