Skip to content

Commit 44e4401

Browse files
committed
refresh_in defaults to half of expires_in if expires_in >= 7200
1 parent f1676d2 commit 44e4401

File tree

2 files changed

+38
-20
lines changed

2 files changed

+38
-20
lines changed

msal/imds.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from collections import UserDict
1717
except:
1818
UserDict = dict # The real UserDict is an old-style class which fails super()
19+
from .token_cache import TokenCache
1920

2021

2122
logger = logging.getLogger(__name__)
@@ -104,6 +105,7 @@ def _scope_to_resource(scope): # This is an experimental reasonable-effort appr
104105

105106

106107
def _obtain_token(http_client, managed_identity, resource):
108+
# A unified low-level API that talks to different Managed Identity
107109
if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ
108110
and "IDENTITY_SERVER_THUMBPRINT" in os.environ
109111
):
@@ -303,6 +305,12 @@ def _obtain_token_on_arc(http_client, endpoint, resource):
303305

304306

305307
class ManagedIdentityClient(object):
308+
"""A low level API that encapulate multiple managed identity backends:
309+
VM, App Service, Azure Automation (Runbooks), Azure Function, Service Fabric,
310+
and Azure Arc.
311+
312+
It also provides token cache support.
313+
"""
306314
_instance, _tenant = socket.getfqdn(), "managed_identity" # Placeholders
307315

308316
def __init__(self, http_client, managed_identity, token_cache=None):
@@ -319,16 +327,17 @@ def __init__(self, http_client, managed_identity, token_cache=None):
319327
320328
:param token_cache:
321329
Optional. It accepts a :class:`msal.TokenCache` instance to store tokens.
330+
It will use an in-memory token cache by default.
322331
323-
Example: Hard code a managed identity for your app::
332+
Recipe 1: Hard code a managed identity for your app::
324333
325334
import msal, requests
326335
client = msal.ManagedIdentityClient(
327336
requests.Session(),
328337
msal.UserAssignedManagedIdentity(client_id="foo"),
329338
)
330339
331-
Recipe: Write once, run everywhere.
340+
Recipe 2: Write once, run everywhere.
332341
If you use different managed identity on different deployment,
333342
you may use an environment variable (such as AZURE_MANAGED_IDENTITY)
334343
to store a json blob like
@@ -346,7 +355,7 @@ def __init__(self, http_client, managed_identity, token_cache=None):
346355
"""
347356
self._http_client = http_client
348357
self._managed_identity = managed_identity
349-
self._token_cache = token_cache
358+
self._token_cache = token_cache or TokenCache()
350359

351360
def acquire_token(self, resource=None):
352361
"""Acquire token for the managed identity.
@@ -361,7 +370,8 @@ def acquire_token(self, resource=None):
361370
access_token_from_cache = None
362371
client_id_in_cache = self._managed_identity.get(
363372
ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY")
364-
if self._token_cache:
373+
if True: # Does not offer an "if not force_refresh" option, because
374+
# there would be built-in token cache in the service side anyway
365375
matches = self._token_cache.find(
366376
self._token_cache.CredentialType.ACCESS_TOKEN,
367377
target=[resource],
@@ -386,17 +396,26 @@ def acquire_token(self, resource=None):
386396
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
387397
break # With a fallback in hand, we break here to go refresh
388398
return access_token_from_cache # It is still good as new
389-
result = _obtain_token(self._http_client, self._managed_identity, resource)
390-
if self._token_cache and "access_token" in result:
391-
self._token_cache.add(dict(
392-
client_id=client_id_in_cache,
393-
scope=[resource],
394-
token_endpoint="https://{}/{}".format(self._instance, self._tenant),
395-
response=result,
396-
params={},
397-
data={},
398-
#grant_type="placeholder",
399-
))
400-
return result
401-
return access_token_from_cache or result
399+
try:
400+
result = _obtain_token(self._http_client, self._managed_identity, resource)
401+
if "access_token" in result:
402+
expires_in = result.get("expires_in", 3600)
403+
if "refresh_in" not in result and expires_in >= 7200:
404+
result["refresh_in"] = int(expires_in / 2)
405+
self._token_cache.add(dict(
406+
client_id=client_id_in_cache,
407+
scope=[resource],
408+
token_endpoint="https://{}/{}".format(self._instance, self._tenant),
409+
response=result,
410+
params={},
411+
data={},
412+
#grant_type="placeholder",
413+
))
414+
if (result and "error" not in result) or (not access_token_from_cache):
415+
return result
416+
except: # The exact HTTP exception is transportation-layer dependent
417+
# Typically network error. Potential AAD outage?
418+
if not access_token_from_cache: # It means there is no fall back option
419+
raise # We choose to bubble up the exception
420+
return access_token_from_cache
402421

tests/test_mi.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from tests.http_client import MinimalResponse
1313
from msal import (
14-
TokenCache,
1514
SystemAssignedManagedIdentity, UserAssignedManagedIdentity,
1615
ManagedIdentityClient)
1716

@@ -46,7 +45,7 @@ def setUp(self):
4645
# the client has no hard dependency on ManagedIdentity object
4746
"ManagedIdentityIdType": "SystemAssigned", "Id": None,
4847
},
49-
token_cache=TokenCache())
48+
)
5049

5150
def _test_token_cache(self, app):
5251
cache = app._token_cache._cache
@@ -140,7 +139,7 @@ def test_unified_api_service_should_ignore_unnecessary_client_id(self):
140139
self._test_happy_path(ManagedIdentityClient(
141140
requests.Session(),
142141
{"ManagedIdentityIdType": "ClientId", "Id": "foo"},
143-
token_cache=TokenCache()))
142+
))
144143

145144
def test_sf_error_should_be_normalized(self):
146145
raw_error = '''

0 commit comments

Comments
 (0)