Skip to content

Commit 73375e3

Browse files
committed
Proof-of-Concept: MI via CCA
1 parent 673d5c3 commit 73375e3

File tree

2 files changed

+22
-16
lines changed

2 files changed

+22
-16
lines changed

msal/application.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .region import _detect_region
2323
from .throttled_http_client import ThrottledHttpClient
2424
from .cloudshell import _is_running_in_cloud_shell
25+
from .imds import ManagedIdentityClient, ManagedIdentity, _scope_to_resource
2526

2627

2728
# The __init__.py will import this. Not the other way around.
@@ -2021,6 +2022,14 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs):
20212022
- an error response would contain "error" and usually "error_description".
20222023
"""
20232024
# TBD: force_refresh behavior
2025+
if ManagedIdentity.is_managed_identity(self.client_id):
2026+
if len(scopes) != 1:
2027+
raise ValueError("Managed Identity supports only one scope/resource")
2028+
if claims_challenge:
2029+
raise ValueError("Managed Identity does not support claims_challenge")
2030+
return ManagedIdentityClient(
2031+
self.http_client, self.client_id, self.token_cache
2032+
).acquire_token(_scope_to_resource(scopes[0]))
20242033
if self.authority.tenant.lower() in ["common", "organizations"]:
20252034
warnings.warn(
20262035
"Using /common or /organizations authority "

tests/test_mi.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111

1212
from tests.http_client import MinimalResponse
1313
from msal import (
14+
ConfidentialClientApplication,
1415
SystemAssignedManagedIdentity, UserAssignedManagedIdentity,
15-
ManagedIdentityClient)
16+
)
1617

1718

1819
class ManagedIdentityTestCase(unittest.TestCase):
@@ -39,26 +40,22 @@ class ClientTestCase(unittest.TestCase):
3940
maxDiff = None
4041

4142
def setUp(self):
42-
self.app = ManagedIdentityClient(
43-
{ # Here we test it with the raw dict form, to test that
44-
# the client has no hard dependency on ManagedIdentity object
45-
"ManagedIdentityIdType": "SystemAssigned", "Id": None,
46-
},
47-
requests.Session(),
48-
)
43+
system_assigned = {"ManagedIdentityIdType": "SystemAssigned", "Id": None}
44+
self.app = ConfidentialClientApplication(client_id=system_assigned)
4945

5046
def _test_token_cache(self, app):
51-
cache = app._token_cache._cache
47+
cache = app.token_cache._cache
5248
self.assertEqual(1, len(cache.get("AccessToken", [])), "Should have 1 AT")
5349
at = list(cache["AccessToken"].values())[0]
5450
self.assertEqual(
55-
app._managed_identity.get("Id", "SYSTEM_ASSIGNED_MANAGED_IDENTITY"),
51+
app.client_id.get("Id", "SYSTEM_ASSIGNED_MANAGED_IDENTITY"),
5652
at["client_id"],
5753
"Should have expected client_id")
5854
self.assertEqual("managed_identity", at["realm"], "Should have expected realm")
5955

6056
def _test_happy_path(self, app, mocked_http):
61-
result = app.acquire_token_for_client(resource="R")
57+
#result = app.acquire_token_for_client(resource="R")
58+
result = app.acquire_token_for_client(["R"])
6259
mocked_http.assert_called()
6360
self.assertEqual({
6461
"access_token": "AT",
@@ -68,29 +65,29 @@ def _test_happy_path(self, app, mocked_http):
6865
}, result, "Should obtain a token response")
6966
self.assertEqual(
7067
result["access_token"],
71-
app.acquire_token_for_client(resource="R").get("access_token"),
68+
app.acquire_token_for_client(["R"]).get("access_token"),
7269
"Should hit the same token from cache")
7370
self._test_token_cache(app)
7471

7572

7673
class VmTestCase(ClientTestCase):
7774

7875
def test_happy_path(self):
79-
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
76+
with patch.object(self.app.http_client, "get", return_value=MinimalResponse(
8077
status_code=200,
8178
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
8279
)) as mocked_method:
8380
self._test_happy_path(self.app, mocked_method)
8481

8582
def test_vm_error_should_be_returned_as_is(self):
8683
raw_error = '{"raw": "error format is undefined"}'
87-
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
84+
with patch.object(self.app.http_client, "get", return_value=MinimalResponse(
8885
status_code=400,
8986
text=raw_error,
9087
)) as mocked_method:
9188
self.assertEqual(
92-
json.loads(raw_error), self.app.acquire_token_for_client(resource="R"))
93-
self.assertEqual({}, self.app._token_cache._cache)
89+
json.loads(raw_error), self.app.acquire_token_for_client(["R"]))
90+
self.assertEqual({}, self.app.token_cache._cache)
9491

9592

9693
@patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"})

0 commit comments

Comments
 (0)