Skip to content

Commit bb80636

Browse files
authored
Merge pull request #82 from AzureAD/release-0.6.0
MSAL Python 0.6.0
2 parents 171d67a + 7d8c029 commit bb80636

File tree

6 files changed

+144
-42
lines changed

6 files changed

+144
-42
lines changed

msal/application.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
# The __init__.py will import this. Not the other way around.
21-
__version__ = "0.5.1"
21+
__version__ = "0.6.0"
2222

2323
logger = logging.getLogger(__name__)
2424

@@ -400,20 +400,21 @@ def acquire_token_silent(
400400
# authority,
401401
# verify=self.verify, proxies=self.proxies, timeout=self.timeout,
402402
# ) if authority else self.authority
403-
result = self._acquire_token_silent(scopes, account, self.authority, **kwargs)
403+
result = self._acquire_token_silent_from_cache_and_possibly_refresh_it(
404+
scopes, account, self.authority, **kwargs)
404405
if result:
405406
return result
406407
for alias in self._get_authority_aliases(self.authority.instance):
407408
the_authority = Authority(
408409
"https://" + alias + "/" + self.authority.tenant,
409410
validate_authority=False,
410411
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
411-
result = self._acquire_token_silent(
412+
result = self._acquire_token_silent_from_cache_and_possibly_refresh_it(
412413
scopes, account, the_authority, **kwargs)
413414
if result:
414415
return result
415416

416-
def _acquire_token_silent(
417+
def _acquire_token_silent_from_cache_and_possibly_refresh_it(
417418
self,
418419
scopes, # type: List[str]
419420
account, # type: Optional[Account]

msal/authority.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,14 @@ def __init__(self, authority_url, validate_authority=True,
3838
self.proxies = proxies
3939
self.timeout = timeout
4040
canonicalized, self.instance, tenant = canonicalize(authority_url)
41-
tenant_discovery_endpoint = ( # Hard code a V2 pattern as default value
42-
'https://{}/{}/v2.0/.well-known/openid-configuration'
43-
.format(self.instance, tenant))
44-
if validate_authority and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS:
41+
tenant_discovery_endpoint = (
42+
'https://{}/{}{}/.well-known/openid-configuration'.format(
43+
self.instance,
44+
tenant,
45+
"" if tenant == "adfs" else "/v2.0" # the AAD v2 endpoint
46+
))
47+
if (tenant != "adfs" and validate_authority
48+
and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS):
4549
tenant_discovery_endpoint = instance_discovery(
4650
canonicalized + "/oauth2/v2.0/authorize",
4751
verify=verify, proxies=proxies, timeout=timeout)

msal/oauth2cli/oauth2.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -380,26 +380,34 @@ def __init__(self,
380380
self.on_removing_rt = on_removing_rt
381381
self.on_updating_rt = on_updating_rt
382382

383-
def _obtain_token(self, grant_type, params=None, data=None, *args, **kwargs):
383+
def _obtain_token(self, grant_type, params=None, data=None,
384+
rt_getter=lambda token_item: token_item["refresh_token"],
385+
*args, **kwargs):
386+
RT = "refresh_token"
387+
_data = data.copy() # to prevent side effect
388+
refresh_token = _data.get(RT)
389+
if grant_type == RT and isinstance(refresh_token, dict):
390+
_data[RT] = rt_getter(refresh_token) # Put raw RT in _data
384391
resp = super(Client, self)._obtain_token(
385-
grant_type, params, data, *args, **kwargs)
392+
grant_type, params, _data, *args, **kwargs)
386393
if "error" not in resp:
387394
_resp = resp.copy()
388-
if grant_type == "refresh_token" and "refresh_token" in _resp:
389-
_resp.pop("refresh_token") # We'll handle this in its own method
395+
if grant_type == RT and RT in _resp and isinstance(refresh_token, dict):
396+
_resp.pop(RT) # So we skip it in on_obtaining_tokens(); it will
397+
# be handled in self.obtain_token_by_refresh_token()
390398
if "scope" in _resp:
391399
scope = _resp["scope"].split() # It is conceptually a set,
392400
# but we represent it as a list which can be persisted to JSON
393401
else:
394402
# TODO: Deal with absent scope in authorization grant
395-
scope = data.get("scope")
403+
scope = _data.get("scope")
396404
self.on_obtaining_tokens({
397405
"client_id": self.client_id,
398406
"scope": scope,
399407
"token_endpoint": self.configuration["token_endpoint"],
400408
"grant_type": grant_type, # can be used to know an IdToken-less
401409
# response is for an app or for a user
402-
"response": _resp, "params": params, "data": data,
410+
"response": _resp, "params": params, "data": _data,
403411
})
404412
return resp
405413

@@ -411,26 +419,29 @@ def obtain_token_by_refresh_token(self, token_item, scope=None,
411419
"""This is an "overload" which accepts a refresh token item as a dict,
412420
therefore this method can relay refresh_token item to event listeners.
413421
414-
:param token_item: A refresh token item came from storage
422+
:param token_item:
423+
A refresh token item as a dict, came from the cache managed by this lib.
424+
425+
Alternatively, you can still use a refresh token (RT) as a string,
426+
supposedly came from a token cache managed by a different library,
427+
then this library will store the new RT (if Authority Server issued one)
428+
into this lib's cache. This is a way to migrate from other lib to us.
415429
:param scope: If omitted, is treated as equal to the scope originally
416430
granted by the resource ownser,
417431
according to https://tools.ietf.org/html/rfc6749#section-6
418432
:param rt_getter: A callable used to extract the RT from token_item
419433
:param on_removing_rt: If absent, fall back to the one defined in initialization
420434
"""
421-
if isinstance(token_item, str):
422-
# Satisfy the L of SOLID, although we expect caller uses a dict
423-
return super(Client, self).obtain_token_by_refresh_token(
424-
token_item, scope=scope, **kwargs)
435+
resp = super(Client, self).obtain_token_by_refresh_token(
436+
token_item, scope=scope,
437+
rt_getter=rt_getter, # Wire up this for _obtain_token()
438+
**kwargs)
425439
if isinstance(token_item, dict):
426-
resp = super(Client, self).obtain_token_by_refresh_token(
427-
rt_getter(token_item), scope=scope, **kwargs)
428440
if resp.get('error') == 'invalid_grant':
429441
(on_removing_rt or self.on_removing_rt)(token_item) # Discard old RT
430442
if 'refresh_token' in resp:
431443
self.on_updating_rt(token_item, resp['refresh_token'])
432-
return resp
433-
raise ValueError("token_item should not be a type %s" % type(token_item))
444+
return resp
434445

435446
def obtain_token_by_assertion(
436447
self, assertion, grant_type, scope=None, **kwargs):

msal/token_cache.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,18 +111,25 @@ def add(self, event, now=None):
111111
event, indent=4, sort_keys=True,
112112
default=str, # A workaround when assertion is in bytes in Python 3
113113
))
114+
environment = realm = None
115+
if "token_endpoint" in event:
116+
_, environment, realm = canonicalize(event["token_endpoint"])
114117
response = event.get("response", {})
115118
access_token = response.get("access_token")
116119
refresh_token = response.get("refresh_token")
117120
id_token = response.get("id_token")
121+
id_token_claims = (
122+
decode_id_token(id_token, client_id=event["client_id"])
123+
if id_token else {})
118124
client_info = {}
119-
home_account_id = None
120-
if "client_info" in response:
125+
home_account_id = None # It would remain None in client_credentials flow
126+
if "client_info" in response: # We asked for it, and AAD will provide it
121127
client_info = json.loads(base64decode(response["client_info"]))
122128
home_account_id = "{uid}.{utid}".format(**client_info)
123-
environment = realm = None
124-
if "token_endpoint" in event:
125-
_, environment, realm = canonicalize(event["token_endpoint"])
129+
elif id_token_claims: # This would be an end user on ADFS-direct scenario
130+
client_info["uid"] = id_token_claims.get("sub")
131+
home_account_id = id_token_claims.get("sub")
132+
126133
target = ' '.join(event.get("scope", [])) # Per schema, we don't sort it
127134

128135
with self._lock:
@@ -148,15 +155,15 @@ def add(self, event, now=None):
148155
self.modify(self.CredentialType.ACCESS_TOKEN, at, at)
149156

150157
if client_info:
151-
decoded_id_token = decode_id_token(
152-
id_token, client_id=event["client_id"]) if id_token else {}
153158
account = {
154159
"home_account_id": home_account_id,
155160
"environment": environment,
156161
"realm": realm,
157-
"local_account_id": decoded_id_token.get(
158-
"oid", decoded_id_token.get("sub")),
159-
"username": decoded_id_token.get("preferred_username"),
162+
"local_account_id": id_token_claims.get(
163+
"oid", id_token_claims.get("sub")),
164+
"username": id_token_claims.get("preferred_username") # AAD
165+
or id_token_claims.get("upn") # ADFS 2019
166+
or "", # The schema does not like null
160167
"authority_type":
161168
self.AuthorityType.ADFS if realm == "adfs"
162169
else self.AuthorityType.MSSTS,

sample/device_flow_sample.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,18 @@
5151

5252
if not result:
5353
logging.info("No suitable token exists in cache. Let's get a new one from AAD.")
54+
5455
flow = app.initiate_device_flow(scopes=config["scope"])
56+
if "user_code" not in flow:
57+
raise ValueError(
58+
"Fail to create device flow. Err: %s" % json.dumps(flow, indent=4))
59+
5560
print(flow["message"])
61+
sys.stdout.flush() # Some terminal needs this to ensure the message is shown
62+
5663
# Ideally you should wait here, in order to save some unnecessary polling
57-
# input("Press Enter after you successfully login from another device...")
64+
# input("Press Enter after signing in from another device to proceed, CTRL+C to abort.")
65+
5866
result = app.acquire_token_by_device_flow(flow) # By default it will block
5967
# You can follow this instruction to shorten the block time
6068
# https://msal-python.readthedocs.io/en/latest/#msal.PublicClientApplication.acquire_token_by_device_flow

tests/test_token_cache.py

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,29 @@ class TokenCacheTestCase(unittest.TestCase):
1616
@staticmethod
1717
def build_id_token(
1818
iss="issuer", sub="subject", aud="my_client_id", exp=None, iat=None,
19-
preferred_username="me", **claims):
19+
**claims): # AAD issues "preferred_username", ADFS issues "upn"
2020
return "header.%s.signature" % base64.b64encode(json.dumps(dict({
2121
"iss": iss,
2222
"sub": sub,
2323
"aud": aud,
2424
"exp": exp or (time.time() + 100),
2525
"iat": iat or time.time(),
26-
"preferred_username": preferred_username,
2726
}, **claims)).encode()).decode('utf-8')
2827

2928
@staticmethod
3029
def build_response( # simulate a response from AAD
31-
uid="uid", utid="utid", # They will form client_info
30+
uid=None, utid=None, # If present, they will form client_info
3231
access_token=None, expires_in=3600, token_type="some type",
3332
refresh_token=None,
3433
foci=None,
3534
id_token=None, # or something generated by build_id_token()
3635
error=None,
3736
):
38-
response = {
39-
"client_info": base64.b64encode(json.dumps({
37+
response = {}
38+
if uid and utid: # Mimic the AAD behavior for "client_info=1" request
39+
response["client_info"] = base64.b64encode(json.dumps({
4040
"uid": uid, "utid": utid,
41-
}).encode()).decode('utf-8'),
42-
}
41+
}).encode()).decode('utf-8')
4342
if error:
4443
response["error"] = error
4544
if access_token:
@@ -59,7 +58,7 @@ def build_response( # simulate a response from AAD
5958
def setUp(self):
6059
self.cache = TokenCache()
6160

62-
def testAdd(self):
61+
def testAddByAad(self):
6362
client_id = "my_client_id"
6463
id_token = self.build_id_token(
6564
oid="object1234", preferred_username="John Doe", aud=client_id)
@@ -132,6 +131,78 @@ def testAdd(self):
132131
"appmetadata-login.example.com-my_client_id")
133132
)
134133

134+
def testAddByAdfs(self):
135+
client_id = "my_client_id"
136+
id_token = self.build_id_token(aud=client_id, upn="JaneDoe@example.com")
137+
self.cache.add({
138+
"client_id": client_id,
139+
"scope": ["s2", "s1", "s3"], # Not in particular order
140+
"token_endpoint": "https://fs.msidlab8.com/adfs/oauth2/token",
141+
"response": self.build_response(
142+
uid=None, utid=None, # ADFS will provide no client_info
143+
expires_in=3600, access_token="an access token",
144+
id_token=id_token, refresh_token="a refresh token"),
145+
}, now=1000)
146+
self.assertEqual(
147+
{
148+
'cached_at': "1000",
149+
'client_id': 'my_client_id',
150+
'credential_type': 'AccessToken',
151+
'environment': 'fs.msidlab8.com',
152+
'expires_on': "4600",
153+
'extended_expires_on': "4600",
154+
'home_account_id': "subject",
155+
'realm': 'adfs',
156+
'secret': 'an access token',
157+
'target': 's2 s1 s3',
158+
},
159+
self.cache._cache["AccessToken"].get(
160+
'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s2 s1 s3')
161+
)
162+
self.assertEqual(
163+
{
164+
'client_id': 'my_client_id',
165+
'credential_type': 'RefreshToken',
166+
'environment': 'fs.msidlab8.com',
167+
'home_account_id': "subject",
168+
'secret': 'a refresh token',
169+
'target': 's2 s1 s3',
170+
},
171+
self.cache._cache["RefreshToken"].get(
172+
'subject-fs.msidlab8.com-refreshtoken-my_client_id--s2 s1 s3')
173+
)
174+
self.assertEqual(
175+
{
176+
'home_account_id': "subject",
177+
'environment': 'fs.msidlab8.com',
178+
'realm': 'adfs',
179+
'local_account_id': "subject",
180+
'username': "JaneDoe@example.com",
181+
'authority_type': "ADFS",
182+
},
183+
self.cache._cache["Account"].get('subject-fs.msidlab8.com-adfs')
184+
)
185+
self.assertEqual(
186+
{
187+
'credential_type': 'IdToken',
188+
'secret': id_token,
189+
'home_account_id': "subject",
190+
'environment': 'fs.msidlab8.com',
191+
'realm': 'adfs',
192+
'client_id': 'my_client_id',
193+
},
194+
self.cache._cache["IdToken"].get(
195+
'subject-fs.msidlab8.com-idtoken-my_client_id-adfs-')
196+
)
197+
self.assertEqual(
198+
{
199+
"client_id": "my_client_id",
200+
'environment': 'fs.msidlab8.com',
201+
},
202+
self.cache._cache.get("AppMetadata", {}).get(
203+
"appmetadata-fs.msidlab8.com-my_client_id")
204+
)
205+
135206

136207
class SerializableTokenCacheTestCase(TokenCacheTestCase):
137208
# Run all inherited test methods, and have extra check in tearDown()

0 commit comments

Comments
 (0)