Skip to content

Commit f2a4aa5

Browse files
committed
Managed Identity for Machine Learning
1 parent 42cea95 commit f2a4aa5

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

msal/managed_identity.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,15 @@ def _obtain_token(http_client, managed_identity, resource):
313313
managed_identity,
314314
resource,
315315
)
316+
if "MSI_ENDPOINT" in os.environ and "MSI_SECRET" in os.environ:
317+
# Back ported from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.15.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py
318+
return _obtain_token_on_machine_learning(
319+
http_client,
320+
os.environ["MSI_ENDPOINT"],
321+
os.environ["MSI_SECRET"],
322+
managed_identity,
323+
resource,
324+
)
316325
if "IDENTITY_ENDPOINT" in os.environ and "IMDS_ENDPOINT" in os.environ:
317326
if ManagedIdentity.is_user_assigned(managed_identity):
318327
raise ValueError( # Note: Azure Identity for Python raised exception too
@@ -329,6 +338,7 @@ def _obtain_token(http_client, managed_identity, resource):
329338

330339

331340
def _adjust_param(params, managed_identity):
341+
# Modify the params dict in place
332342
id_name = ManagedIdentity._types_mapping.get(
333343
managed_identity.get(ManagedIdentity.ID_TYPE))
334344
if id_name:
@@ -405,6 +415,36 @@ def _obtain_token_on_app_service(
405415
logger.debug("IMDS emits unexpected payload: %s", resp.text)
406416
raise
407417

418+
def _obtain_token_on_machine_learning(
419+
http_client, endpoint, secret, managed_identity, resource,
420+
):
421+
# Could not find protocol docs from https://docs.microsoft.com/en-us/azure/machine-learning
422+
# The following implementation is back ported from Azure Identity 1.15.0
423+
logger.debug("Obtaining token via managed identity on Azure Machine Learning")
424+
params = {"api-version": "2017-09-01", "resource": resource}
425+
_adjust_param(params, managed_identity)
426+
resp = http_client.get(
427+
endpoint,
428+
params=params,
429+
headers={"secret": secret},
430+
)
431+
try:
432+
payload = json.loads(resp.text)
433+
if payload.get("access_token") and payload.get("expires_on"):
434+
return { # Normalizing the payload into OAuth2 format
435+
"access_token": payload["access_token"],
436+
"expires_in": int(payload["expires_on"]) - int(time.time()),
437+
"resource": payload.get("resource"),
438+
"token_type": payload.get("token_type", "Bearer"),
439+
}
440+
return {
441+
"error": "invalid_scope", # TODO: To be tested
442+
"error_description": "{}".format(payload),
443+
}
444+
except ValueError:
445+
logger.debug("IMDS emits unexpected payload: %s", resp.text)
446+
raise
447+
408448

409449
def _obtain_token_on_service_fabric(
410450
http_client, endpoint, identity_header, server_thumbprint, resource,

tests/test_mi.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,30 @@ def test_app_service_error_should_be_normalized(self):
117117
self.assertEqual({}, self.app._token_cache._cache)
118118

119119

120+
@patch.dict(os.environ, {"MSI_ENDPOINT": "http://localhost", "MSI_SECRET": "foo"})
121+
class MachineLearningTestCase(ClientTestCase):
122+
123+
def test_happy_path(self):
124+
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
125+
status_code=200,
126+
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
127+
int(time.time()) + 1234),
128+
)) as mocked_method:
129+
self._test_happy_path(self.app, mocked_method)
130+
131+
def test_machine_learning_error_should_be_normalized(self):
132+
raw_error = '{"error": "placeholder", "message": "placeholder"}'
133+
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
134+
status_code=500,
135+
text=raw_error,
136+
)) as mocked_method:
137+
self.assertEqual({
138+
"error": "invalid_scope",
139+
"error_description": "{'error': 'placeholder', 'message': 'placeholder'}",
140+
}, self.app.acquire_token_for_client(resource="R"))
141+
self.assertEqual({}, self.app._token_cache._cache)
142+
143+
120144
@patch.dict(os.environ, {
121145
"IDENTITY_ENDPOINT": "http://localhost",
122146
"IDENTITY_HEADER": "foo",

0 commit comments

Comments
 (0)