Skip to content

Commit 42cea95

Browse files
committed
Refactor throttling and add it to Managed Identity
1 parent 5926b11 commit 42cea95

File tree

2 files changed

+66
-24
lines changed

2 files changed

+66
-24
lines changed

msal/managed_identity.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
except:
1717
UserDict = dict # The real UserDict is an old-style class which fails super()
1818
from .token_cache import TokenCache
19-
from .throttled_http_client import ThrottledHttpClient
19+
from .individual_cache import _IndividualCache as IndividualCache
20+
from .throttled_http_client import ThrottledHttpClientBase, _parse_http_429_5xx_retry_after
2021

2122

2223
logger = logging.getLogger(__name__)
@@ -107,6 +108,22 @@ def __init__(self, client_id=None, resource_id=None, object_id=None):
107108
"client_id, resource_id, object_id")
108109

109110

111+
class _ThrottledHttpClient(ThrottledHttpClientBase):
112+
def __init__(self, http_client, http_cache):
113+
super(_ThrottledHttpClient, self).__init__(http_client, http_cache)
114+
self.get = IndividualCache( # All MIs (except Cloud Shell) use GETs
115+
mapping=self._expiring_mapping,
116+
key_maker=lambda func, args, kwargs: "POST {} hash={} 429/5xx/Retry-After".format(
117+
args[0], # It is the endpoint, typically a constant per MI type
118+
_hash(
119+
# Managed Identity flavors have inconsistent parameters.
120+
# We simply choose to hash them all.
121+
str(kwargs.get("params")) + str(kwargs.get("data"))),
122+
),
123+
expires_in=_parse_http_429_5xx_retry_after,
124+
)(http_client.get)
125+
126+
110127
class ManagedIdentityClient(object):
111128
"""This API encapulates multiple managed identity backends:
112129
VM, App Service, Azure Automation (Runbooks), Azure Function, Service Fabric,
@@ -116,7 +133,8 @@ class ManagedIdentityClient(object):
116133
"""
117134
_instance, _tenant = socket.getfqdn(), "managed_identity" # Placeholders
118135

119-
def __init__(self, managed_identity, *, http_client, token_cache=None):
136+
def __init__(
137+
self, managed_identity, *, http_client, token_cache=None, http_cache=None):
120138
"""Create a managed identity client.
121139
122140
:param dict managed_identity:
@@ -142,6 +160,10 @@ def __init__(self, managed_identity, *, http_client, token_cache=None):
142160
Optional. It accepts a :class:`msal.TokenCache` instance to store tokens.
143161
It will use an in-memory token cache by default.
144162
163+
:param http_cache:
164+
Optional. It has the same characteristics as the
165+
:paramref:`msal.ClientApplication.http_cache`.
166+
145167
Recipe 1: Hard code a managed identity for your app::
146168
147169
import msal, requests
@@ -169,12 +191,21 @@ def __init__(self, managed_identity, *, http_client, token_cache=None):
169191
token = client.acquire_token_for_client("resource")
170192
"""
171193
self._managed_identity = managed_identity
172-
if isinstance(http_client, ThrottledHttpClient):
173-
raise ValueError(
174-
# It is a precaution to reject application.py's throttled http_client,
175-
# whose cache life on HTTP GET 200 is too long for Managed Identity.
176-
"This class does not currently accept a ThrottledHttpClient.")
177-
self._http_client = http_client
194+
self._http_client = _ThrottledHttpClient(
195+
# This class only throttles excess token acquisition requests.
196+
# It does not provide retry.
197+
# Retry is the http_client or caller's responsibility, not MSAL's.
198+
#
199+
# FWIW, here is the inconsistent retry recommendation.
200+
# 1. Only MI on VM defines exotic 404 and 410 retry recommendations
201+
# ( https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#error-handling )
202+
# (especially for 410 which was supposed to be a permanent failure).
203+
# 2. MI on Service Fabric specifically suggests to not retry on 404.
204+
# ( https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-cluster-managed-identity-service-fabric-app-code#error-handling )
205+
http_client.http_client # Patch the raw (unpatched) http client
206+
if isinstance(http_client, ThrottledHttpClientBase) else http_client,
207+
{} if http_cache is None else http_cache, # Default to an in-memory dict
208+
)
178209
self._token_cache = token_cache or TokenCache()
179210

180211
def acquire_token_for_client(self, resource=None):

msal/throttled_http_client.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,42 @@ def _extract_data(kwargs, key, default=None):
4545
return data.get(key) if isinstance(data, dict) else default
4646

4747

48-
class ThrottledHttpClient(object):
49-
def __init__(self, http_client, http_cache):
50-
"""Throttle the given http_client by storing and retrieving data from cache.
48+
class ThrottledHttpClientBase(object):
49+
"""Throttle the given http_client by storing and retrieving data from cache.
5150
52-
This wrapper exists so that our patching post() and get() would prevent
53-
re-patching side effect when/if same http_client being reused.
54-
"""
55-
expiring_mapping = ExpiringMapping( # It will automatically clean up
51+
This wrapper exists so that our patching post() and get() would prevent
52+
re-patching side effect when/if same http_client being reused.
53+
54+
The subclass should implement post() and/or get()
55+
"""
56+
def __init__(self, http_client, http_cache):
57+
self.http_client = http_client
58+
self._expiring_mapping = ExpiringMapping( # It will automatically clean up
5659
mapping=http_cache if http_cache is not None else {},
5760
capacity=1024, # To prevent cache blowing up especially for CCA
5861
lock=Lock(), # TODO: This should ideally also allow customization
5962
)
6063

64+
def post(self, *args, **kwargs):
65+
return self.http_client.post(*args, **kwargs)
66+
67+
def get(self, *args, **kwargs):
68+
return self.http_client.get(*args, **kwargs)
69+
70+
def close(self):
71+
return self.http_client.close()
72+
73+
74+
class ThrottledHttpClient(ThrottledHttpClientBase):
75+
def __init__(self, http_client, http_cache):
76+
super(ThrottledHttpClient, self).__init__(http_client, http_cache)
77+
6178
_post = http_client.post # We'll patch _post, and keep original post() intact
6279

6380
_post = IndividualCache(
6481
# Internal specs requires throttling on at least token endpoint,
6582
# here we have a generic patch for POST on all endpoints.
66-
mapping=expiring_mapping,
83+
mapping=self._expiring_mapping,
6784
key_maker=lambda func, args, kwargs:
6885
"POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format(
6986
args[0], # It is the url, typically containing authority and tenant
@@ -81,7 +98,7 @@ def __init__(self, http_client, http_cache):
8198
)(_post)
8299

83100
_post = IndividualCache( # It covers the "UI required cache"
84-
mapping=expiring_mapping,
101+
mapping=self._expiring_mapping,
85102
key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
86103
args[0], # It is the url, typically containing authority and tenant
87104
_hash(
@@ -120,7 +137,7 @@ def __init__(self, http_client, http_cache):
120137
self.post = _post
121138

122139
self.get = IndividualCache( # Typically those discovery GETs
123-
mapping=expiring_mapping,
140+
mapping=self._expiring_mapping,
124141
key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format(
125142
args[0], # It is the url, sometimes containing inline params
126143
_hash(kwargs.get("params", "")),
@@ -129,13 +146,7 @@ def __init__(self, http_client, http_cache):
129146
3600*24 if 200 <= result.status_code < 300 else 0,
130147
)(http_client.get)
131148

132-
self._http_client = http_client
133-
134149
# The following 2 methods have been defined dynamically by __init__()
135150
#def post(self, *args, **kwargs): pass
136151
#def get(self, *args, **kwargs): pass
137152

138-
def close(self):
139-
"""MSAL won't need this. But we allow throttled_http_client.close() anyway"""
140-
return self._http_client.close()
141-

0 commit comments

Comments
 (0)