16
16
from collections import UserDict
17
17
except :
18
18
UserDict = dict # The real UserDict is an old-style class which fails super()
19
+ from .token_cache import TokenCache
19
20
20
21
21
22
logger = logging .getLogger (__name__ )
@@ -104,6 +105,7 @@ def _scope_to_resource(scope): # This is an experimental reasonable-effort appr
104
105
105
106
106
107
def _obtain_token (http_client , managed_identity , resource ):
108
+ # A unified low-level API that talks to different Managed Identity
107
109
if ("IDENTITY_ENDPOINT" in os .environ and "IDENTITY_HEADER" in os .environ
108
110
and "IDENTITY_SERVER_THUMBPRINT" in os .environ
109
111
):
@@ -303,6 +305,12 @@ def _obtain_token_on_arc(http_client, endpoint, resource):
303
305
304
306
305
307
class ManagedIdentityClient (object ):
308
+ """A low level API that encapulate multiple managed identity backends:
309
+ VM, App Service, Azure Automation (Runbooks), Azure Function, Service Fabric,
310
+ and Azure Arc.
311
+
312
+ It also provides token cache support.
313
+ """
306
314
_instance , _tenant = socket .getfqdn (), "managed_identity" # Placeholders
307
315
308
316
def __init__ (self , http_client , managed_identity , token_cache = None ):
@@ -319,16 +327,17 @@ def __init__(self, http_client, managed_identity, token_cache=None):
319
327
320
328
:param token_cache:
321
329
Optional. It accepts a :class:`msal.TokenCache` instance to store tokens.
330
+ It will use an in-memory token cache by default.
322
331
323
- Example : Hard code a managed identity for your app::
332
+ Recipe 1 : Hard code a managed identity for your app::
324
333
325
334
import msal, requests
326
335
client = msal.ManagedIdentityClient(
327
336
requests.Session(),
328
337
msal.UserAssignedManagedIdentity(client_id="foo"),
329
338
)
330
339
331
- Recipe: Write once, run everywhere.
340
+ Recipe 2 : Write once, run everywhere.
332
341
If you use different managed identity on different deployment,
333
342
you may use an environment variable (such as AZURE_MANAGED_IDENTITY)
334
343
to store a json blob like
@@ -346,7 +355,7 @@ def __init__(self, http_client, managed_identity, token_cache=None):
346
355
"""
347
356
self ._http_client = http_client
348
357
self ._managed_identity = managed_identity
349
- self ._token_cache = token_cache
358
+ self ._token_cache = token_cache or TokenCache ()
350
359
351
360
def acquire_token (self , resource = None ):
352
361
"""Acquire token for the managed identity.
@@ -361,7 +370,8 @@ def acquire_token(self, resource=None):
361
370
access_token_from_cache = None
362
371
client_id_in_cache = self ._managed_identity .get (
363
372
ManagedIdentity .ID , "SYSTEM_ASSIGNED_MANAGED_IDENTITY" )
364
- if self ._token_cache :
373
+ if True : # Does not offer an "if not force_refresh" option, because
374
+ # there would be built-in token cache in the service side anyway
365
375
matches = self ._token_cache .find (
366
376
self ._token_cache .CredentialType .ACCESS_TOKEN ,
367
377
target = [resource ],
@@ -386,17 +396,26 @@ def acquire_token(self, resource=None):
386
396
if "refresh_on" in entry and int (entry ["refresh_on" ]) < now : # aging
387
397
break # With a fallback in hand, we break here to go refresh
388
398
return access_token_from_cache # It is still good as new
389
- result = _obtain_token (self ._http_client , self ._managed_identity , resource )
390
- if self ._token_cache and "access_token" in result :
391
- self ._token_cache .add (dict (
392
- client_id = client_id_in_cache ,
393
- scope = [resource ],
394
- token_endpoint = "https://{}/{}" .format (self ._instance , self ._tenant ),
395
- response = result ,
396
- params = {},
397
- data = {},
398
- #grant_type="placeholder",
399
- ))
400
- return result
401
- return access_token_from_cache or result
399
+ try :
400
+ result = _obtain_token (self ._http_client , self ._managed_identity , resource )
401
+ if "access_token" in result :
402
+ expires_in = result .get ("expires_in" , 3600 )
403
+ if "refresh_in" not in result and expires_in >= 7200 :
404
+ result ["refresh_in" ] = int (expires_in / 2 )
405
+ self ._token_cache .add (dict (
406
+ client_id = client_id_in_cache ,
407
+ scope = [resource ],
408
+ token_endpoint = "https://{}/{}" .format (self ._instance , self ._tenant ),
409
+ response = result ,
410
+ params = {},
411
+ data = {},
412
+ #grant_type="placeholder",
413
+ ))
414
+ if (result and "error" not in result ) or (not access_token_from_cache ):
415
+ return result
416
+ except : # The exact HTTP exception is transportation-layer dependent
417
+ # Typically network error. Potential AAD outage?
418
+ if not access_token_from_cache : # It means there is no fall back option
419
+ raise # We choose to bubble up the exception
420
+ return access_token_from_cache
402
421
0 commit comments