16
16
except :
17
17
UserDict = dict # The real UserDict is an old-style class which fails super()
18
18
from .token_cache import TokenCache
19
- from .throttled_http_client import ThrottledHttpClient
19
+ from .individual_cache import _IndividualCache as IndividualCache
20
+ from .throttled_http_client import ThrottledHttpClientBase , _parse_http_429_5xx_retry_after
20
21
21
22
22
23
logger = logging .getLogger (__name__ )
@@ -107,6 +108,22 @@ def __init__(self, client_id=None, resource_id=None, object_id=None):
107
108
"client_id, resource_id, object_id" )
108
109
109
110
111
+ class _ThrottledHttpClient (ThrottledHttpClientBase ):
112
+ def __init__ (self , http_client , http_cache ):
113
+ super (_ThrottledHttpClient , self ).__init__ (http_client , http_cache )
114
+ self .get = IndividualCache ( # All MIs (except Cloud Shell) use GETs
115
+ mapping = self ._expiring_mapping ,
116
+ key_maker = lambda func , args , kwargs : "POST {} hash={} 429/5xx/Retry-After" .format (
117
+ args [0 ], # It is the endpoint, typically a constant per MI type
118
+ _hash (
119
+ # Managed Identity flavors have inconsistent parameters.
120
+ # We simply choose to hash them all.
121
+ str (kwargs .get ("params" )) + str (kwargs .get ("data" ))),
122
+ ),
123
+ expires_in = _parse_http_429_5xx_retry_after ,
124
+ )(http_client .get )
125
+
126
+
110
127
class ManagedIdentityClient (object ):
111
128
"""This API encapulates multiple managed identity backends:
112
129
VM, App Service, Azure Automation (Runbooks), Azure Function, Service Fabric,
@@ -116,7 +133,8 @@ class ManagedIdentityClient(object):
116
133
"""
117
134
_instance , _tenant = socket .getfqdn (), "managed_identity" # Placeholders
118
135
119
- def __init__ (self , managed_identity , * , http_client , token_cache = None ):
136
+ def __init__ (
137
+ self , managed_identity , * , http_client , token_cache = None , http_cache = None ):
120
138
"""Create a managed identity client.
121
139
122
140
:param dict managed_identity:
@@ -142,6 +160,10 @@ def __init__(self, managed_identity, *, http_client, token_cache=None):
142
160
Optional. It accepts a :class:`msal.TokenCache` instance to store tokens.
143
161
It will use an in-memory token cache by default.
144
162
163
+ :param http_cache:
164
+ Optional. It has the same characteristics as the
165
+ :paramref:`msal.ClientApplication.http_cache`.
166
+
145
167
Recipe 1: Hard code a managed identity for your app::
146
168
147
169
import msal, requests
@@ -169,12 +191,21 @@ def __init__(self, managed_identity, *, http_client, token_cache=None):
169
191
token = client.acquire_token_for_client("resource")
170
192
"""
171
193
self ._managed_identity = managed_identity
172
- if isinstance (http_client , ThrottledHttpClient ):
173
- raise ValueError (
174
- # It is a precaution to reject application.py's throttled http_client,
175
- # whose cache life on HTTP GET 200 is too long for Managed Identity.
176
- "This class does not currently accept a ThrottledHttpClient." )
177
- self ._http_client = http_client
194
+ self ._http_client = _ThrottledHttpClient (
195
+ # This class only throttles excess token acquisition requests.
196
+ # It does not provide retry.
197
+ # Retry is the http_client or caller's responsibility, not MSAL's.
198
+ #
199
+ # FWIW, here is the inconsistent retry recommendation.
200
+ # 1. Only MI on VM defines exotic 404 and 410 retry recommendations
201
+ # ( https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#error-handling )
202
+ # (especially for 410 which was supposed to be a permanent failure).
203
+ # 2. MI on Service Fabric specifically suggests to not retry on 404.
204
+ # ( https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-cluster-managed-identity-service-fabric-app-code#error-handling )
205
+ http_client .http_client # Patch the raw (unpatched) http client
206
+ if isinstance (http_client , ThrottledHttpClientBase ) else http_client ,
207
+ {} if http_cache is None else http_cache , # Default to an in-memory dict
208
+ )
178
209
self ._token_cache = token_cache or TokenCache ()
179
210
180
211
def acquire_token_for_client (self , resource = None ):
0 commit comments