5
5
from urllib .parse import urljoin
6
6
import logging
7
7
import sys
8
+ import warnings
9
+
10
+ import requests
8
11
9
12
from .oauth2cli import Client , JwtSigner
10
13
from .authority import Authority
15
18
16
19
17
20
# The __init__.py will import this. Not the other way around.
18
- __version__ = "0.2 .0"
21
+ __version__ = "0.3 .0"
19
22
20
23
logger = logging .getLogger (__name__ )
21
24
@@ -101,6 +104,14 @@ def __init__(
101
104
# Here the self.authority is not the same type as authority in input
102
105
self .token_cache = token_cache or TokenCache ()
103
106
self .client = self ._build_client (client_credential , self .authority )
107
+ self .authority_groups = self ._get_authority_aliases ()
108
+
109
+ def _get_authority_aliases (self ):
110
+ resp = requests .get (
111
+ "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize" ,
112
+ headers = {'Accept' : 'application/json' })
113
+ resp .raise_for_status ()
114
+ return [set (group ['aliases' ]) for group in resp .json ()['metadata' ]]
104
115
105
116
def _build_client (self , client_credential , authority ):
106
117
client_assertion = None
@@ -236,19 +247,33 @@ def get_accounts(self, username=None):
236
247
Your app can choose to display those information to end user,
237
248
and allow user to choose one of his/her accounts to proceed.
238
249
"""
239
- # The following implementation finds accounts only from saved accounts,
240
- # but does NOT correlate them with saved RTs. It probably won't matter,
241
- # because in MSAL universe, there are always Accounts and RTs together.
242
- accounts = self .token_cache .find (
243
- self .token_cache .CredentialType .ACCOUNT ,
244
- query = {"environment" : self .authority .instance })
250
+ accounts = self ._find_msal_accounts (environment = self .authority .instance )
251
+ if not accounts : # Now try other aliases of this authority instance
252
+ for group in self .authority_groups :
253
+ if self .authority .instance in group :
254
+ for alias in group :
255
+ if alias != self .authority .instance :
256
+ accounts = self ._find_msal_accounts (environment = alias )
257
+ if accounts :
258
+ break
245
259
if username :
246
260
# Federated account["username"] from AAD could contain mixed case
247
261
lowercase_username = username .lower ()
248
262
accounts = [a for a in accounts
249
263
if a ["username" ].lower () == lowercase_username ]
264
+ # Does not further filter by existing RTs here. It probably won't matter.
265
+ # Because in most cases Accounts and RTs co-exist.
266
+ # Even in the rare case when an RT is revoked and then removed,
267
+ # acquire_token_silent() would then yield no result,
268
+ # apps would fall back to other acquire methods. This is the standard pattern.
250
269
return accounts
251
270
271
+ def _find_msal_accounts (self , environment ):
272
+ return [a for a in self .token_cache .find (
273
+ TokenCache .CredentialType .ACCOUNT , query = {"environment" : environment })
274
+ if a ["authority_type" ] in (
275
+ TokenCache .AuthorityType .ADFS , TokenCache .AuthorityType .MSSTS )]
276
+
252
277
def acquire_token_silent (
253
278
self ,
254
279
scopes , # type: List[str]
@@ -275,19 +300,44 @@ def acquire_token_silent(
275
300
- None when cache lookup does not yield anything.
276
301
"""
277
302
assert isinstance (scopes , list ), "Invalid parameter type"
278
- the_authority = Authority (
279
- authority ,
280
- verify = self .verify , proxies = self .proxies , timeout = self .timeout ,
281
- ) if authority else self .authority
282
-
303
+ if authority :
304
+ warnings .warn ("We haven't decided how/if this method will accept authority parameter" )
305
+ # the_authority = Authority(
306
+ # authority,
307
+ # verify=self.verify, proxies=self.proxies, timeout=self.timeout,
308
+ # ) if authority else self.authority
309
+ result = self ._acquire_token_silent (scopes , account , self .authority , ** kwargs )
310
+ if result :
311
+ return result
312
+ for group in self .authority_groups :
313
+ if self .authority .instance in group :
314
+ for alias in group :
315
+ if alias != self .authority .instance :
316
+ the_authority = Authority (
317
+ "https://" + alias + "/" + self .authority .tenant ,
318
+ validate_authority = False ,
319
+ verify = self .verify , proxies = self .proxies ,
320
+ timeout = self .timeout ,)
321
+ result = self ._acquire_token_silent (
322
+ scopes , account , the_authority , ** kwargs )
323
+ if result :
324
+ return result
325
+
326
+ def _acquire_token_silent (
327
+ self ,
328
+ scopes , # type: List[str]
329
+ account , # type: Optional[Account]
330
+ authority , # This can be different than self.authority
331
+ force_refresh = False , # type: Optional[boolean]
332
+ ** kwargs ):
283
333
if not force_refresh :
284
334
matches = self .token_cache .find (
285
335
self .token_cache .CredentialType .ACCESS_TOKEN ,
286
336
target = scopes ,
287
337
query = {
288
338
"client_id" : self .client_id ,
289
- "environment" : the_authority .instance ,
290
- "realm" : the_authority .tenant ,
339
+ "environment" : authority .instance ,
340
+ "realm" : authority .tenant ,
291
341
"home_account_id" : (account or {}).get ("home_account_id" ),
292
342
})
293
343
now = time .time ()
@@ -301,26 +351,71 @@ def acquire_token_silent(
301
351
"token_type" : "Bearer" ,
302
352
"expires_in" : int (expires_in ), # OAuth2 specs defines it as int
303
353
}
354
+ return self ._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family (
355
+ authority , decorate_scope (scopes , self .client_id ), account ,
356
+ ** kwargs )
304
357
358
+ def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family (
359
+ self , authority , scopes , account , ** kwargs ):
360
+ query = {
361
+ "environment" : authority .instance ,
362
+ "home_account_id" : (account or {}).get ("home_account_id" ),
363
+ # "realm": authority.tenant, # AAD RTs are tenant-independent
364
+ }
365
+ apps = self .token_cache .find ( # Use find(), rather than token_cache.get(...)
366
+ TokenCache .CredentialType .APP_METADATA , query = {
367
+ "environment" : authority .instance , "client_id" : self .client_id })
368
+ app_metadata = apps [0 ] if apps else {}
369
+ if not app_metadata : # Meaning this app is now used for the first time.
370
+ # When/if we have a way to directly detect current app's family,
371
+ # we'll rewrite this block, to support multiple families.
372
+ # For now, we try existing RTs (*). If it works, we are in that family.
373
+ # (*) RTs of a different app/family are not supposed to be
374
+ # shared with or accessible by us in the first place.
375
+ at = self ._acquire_token_silent_by_finding_specific_refresh_token (
376
+ authority , scopes ,
377
+ dict (query , family_id = "1" ), # A hack, we have only 1 family for now
378
+ rt_remover = lambda rt_item : None , # NO-OP b/c RTs are likely not mine
379
+ break_condition = lambda response : # Break loop when app not in family
380
+ # Based on an AAD-only behavior mentioned in internal doc here
381
+ # https://msazure.visualstudio.com/One/_git/ESTS-Docs/pullrequest/1138595
382
+ "client_mismatch" in response .get ("error_additional_info" , []),
383
+ ** kwargs )
384
+ if at :
385
+ return at
386
+ if app_metadata .get ("family_id" ): # Meaning this app belongs to this family
387
+ at = self ._acquire_token_silent_by_finding_specific_refresh_token (
388
+ authority , scopes , dict (query , family_id = app_metadata ["family_id" ]),
389
+ ** kwargs )
390
+ if at :
391
+ return at
392
+ # Either this app is an orphan, so we will naturally use its own RT;
393
+ # or all attempts above have failed, so we fall back to non-foci behavior.
394
+ return self ._acquire_token_silent_by_finding_specific_refresh_token (
395
+ authority , scopes , dict (query , client_id = self .client_id ), ** kwargs )
396
+
397
+ def _acquire_token_silent_by_finding_specific_refresh_token (
398
+ self , authority , scopes , query ,
399
+ rt_remover = None , break_condition = lambda response : False , ** kwargs ):
305
400
matches = self .token_cache .find (
306
401
self .token_cache .CredentialType .REFRESH_TOKEN ,
307
402
# target=scopes, # AAD RTs are scope-independent
308
- query = {
309
- "client_id" : self .client_id ,
310
- "environment" : the_authority .instance ,
311
- "home_account_id" : (account or {}).get ("home_account_id" ),
312
- # "realm": the_authority.tenant, # AAD RTs are tenant-independent
313
- })
314
- client = self ._build_client (self .client_credential , the_authority )
403
+ query = query )
404
+ logger .debug ("Found %d RTs matching %s" , len (matches ), query )
405
+ client = self ._build_client (self .client_credential , authority )
315
406
for entry in matches :
316
- logger .debug ("Cache hit an RT" )
407
+ logger .debug ("Cache attempts an RT" )
317
408
response = client .obtain_token_by_refresh_token (
318
409
entry , rt_getter = lambda token_item : token_item ["secret" ],
319
- scope = decorate_scope (scopes , self .client_id ))
410
+ on_removing_rt = rt_remover or self .token_cache .remove_rt ,
411
+ scope = scopes ,
412
+ ** kwargs )
320
413
if "error" not in response :
321
414
return response
322
415
logger .debug (
323
416
"Refresh failed. {error}: {error_description}" .format (** response ))
417
+ if break_condition (response ):
418
+ break
324
419
325
420
326
421
class PublicClientApplication (ClientApplication ): # browser app or mobile app
0 commit comments