Skip to content

Commit 7acfdc5

Browse files
authored
Merge pull request #28 from AzureAD/release-0.3.0
MSAL Python 0.3.0
2 parents 08b20d4 + 48ba43b commit 7acfdc5

File tree

8 files changed

+342
-54
lines changed

8 files changed

+342
-54
lines changed

msal/application.py

Lines changed: 118 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from urllib.parse import urljoin
66
import logging
77
import sys
8+
import warnings
9+
10+
import requests
811

912
from .oauth2cli import Client, JwtSigner
1013
from .authority import Authority
@@ -15,7 +18,7 @@
1518

1619

1720
# The __init__.py will import this. Not the other way around.
18-
__version__ = "0.2.0"
21+
__version__ = "0.3.0"
1922

2023
logger = logging.getLogger(__name__)
2124

@@ -101,6 +104,14 @@ def __init__(
101104
# Here the self.authority is not the same type as authority in input
102105
self.token_cache = token_cache or TokenCache()
103106
self.client = self._build_client(client_credential, self.authority)
107+
self.authority_groups = self._get_authority_aliases()
108+
109+
def _get_authority_aliases(self):
110+
resp = requests.get(
111+
"https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize",
112+
headers={'Accept': 'application/json'})
113+
resp.raise_for_status()
114+
return [set(group['aliases']) for group in resp.json()['metadata']]
104115

105116
def _build_client(self, client_credential, authority):
106117
client_assertion = None
@@ -236,19 +247,33 @@ def get_accounts(self, username=None):
236247
Your app can choose to display those information to end user,
237248
and allow user to choose one of his/her accounts to proceed.
238249
"""
239-
# The following implementation finds accounts only from saved accounts,
240-
# but does NOT correlate them with saved RTs. It probably won't matter,
241-
# because in MSAL universe, there are always Accounts and RTs together.
242-
accounts = self.token_cache.find(
243-
self.token_cache.CredentialType.ACCOUNT,
244-
query={"environment": self.authority.instance})
250+
accounts = self._find_msal_accounts(environment=self.authority.instance)
251+
if not accounts: # Now try other aliases of this authority instance
252+
for group in self.authority_groups:
253+
if self.authority.instance in group:
254+
for alias in group:
255+
if alias != self.authority.instance:
256+
accounts = self._find_msal_accounts(environment=alias)
257+
if accounts:
258+
break
245259
if username:
246260
# Federated account["username"] from AAD could contain mixed case
247261
lowercase_username = username.lower()
248262
accounts = [a for a in accounts
249263
if a["username"].lower() == lowercase_username]
264+
# Does not further filter by existing RTs here. It probably won't matter.
265+
# Because in most cases Accounts and RTs co-exist.
266+
# Even in the rare case when an RT is revoked and then removed,
267+
# acquire_token_silent() would then yield no result,
268+
# apps would fall back to other acquire methods. This is the standard pattern.
250269
return accounts
251270

271+
def _find_msal_accounts(self, environment):
272+
return [a for a in self.token_cache.find(
273+
TokenCache.CredentialType.ACCOUNT, query={"environment": environment})
274+
if a["authority_type"] in (
275+
TokenCache.AuthorityType.ADFS, TokenCache.AuthorityType.MSSTS)]
276+
252277
def acquire_token_silent(
253278
self,
254279
scopes, # type: List[str]
@@ -275,19 +300,44 @@ def acquire_token_silent(
275300
- None when cache lookup does not yield anything.
276301
"""
277302
assert isinstance(scopes, list), "Invalid parameter type"
278-
the_authority = Authority(
279-
authority,
280-
verify=self.verify, proxies=self.proxies, timeout=self.timeout,
281-
) if authority else self.authority
282-
303+
if authority:
304+
warnings.warn("We haven't decided how/if this method will accept authority parameter")
305+
# the_authority = Authority(
306+
# authority,
307+
# verify=self.verify, proxies=self.proxies, timeout=self.timeout,
308+
# ) if authority else self.authority
309+
result = self._acquire_token_silent(scopes, account, self.authority, **kwargs)
310+
if result:
311+
return result
312+
for group in self.authority_groups:
313+
if self.authority.instance in group:
314+
for alias in group:
315+
if alias != self.authority.instance:
316+
the_authority = Authority(
317+
"https://" + alias + "/" + self.authority.tenant,
318+
validate_authority=False,
319+
verify=self.verify, proxies=self.proxies,
320+
timeout=self.timeout,)
321+
result = self._acquire_token_silent(
322+
scopes, account, the_authority, **kwargs)
323+
if result:
324+
return result
325+
326+
def _acquire_token_silent(
327+
self,
328+
scopes, # type: List[str]
329+
account, # type: Optional[Account]
330+
authority, # This can be different than self.authority
331+
force_refresh=False, # type: Optional[boolean]
332+
**kwargs):
283333
if not force_refresh:
284334
matches = self.token_cache.find(
285335
self.token_cache.CredentialType.ACCESS_TOKEN,
286336
target=scopes,
287337
query={
288338
"client_id": self.client_id,
289-
"environment": the_authority.instance,
290-
"realm": the_authority.tenant,
339+
"environment": authority.instance,
340+
"realm": authority.tenant,
291341
"home_account_id": (account or {}).get("home_account_id"),
292342
})
293343
now = time.time()
@@ -301,26 +351,71 @@ def acquire_token_silent(
301351
"token_type": "Bearer",
302352
"expires_in": int(expires_in), # OAuth2 specs defines it as int
303353
}
354+
return self._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
355+
authority, decorate_scope(scopes, self.client_id), account,
356+
**kwargs)
304357

358+
def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
359+
self, authority, scopes, account, **kwargs):
360+
query = {
361+
"environment": authority.instance,
362+
"home_account_id": (account or {}).get("home_account_id"),
363+
# "realm": authority.tenant, # AAD RTs are tenant-independent
364+
}
365+
apps = self.token_cache.find( # Use find(), rather than token_cache.get(...)
366+
TokenCache.CredentialType.APP_METADATA, query={
367+
"environment": authority.instance, "client_id": self.client_id})
368+
app_metadata = apps[0] if apps else {}
369+
if not app_metadata: # Meaning this app is now used for the first time.
370+
# When/if we have a way to directly detect current app's family,
371+
# we'll rewrite this block, to support multiple families.
372+
# For now, we try existing RTs (*). If it works, we are in that family.
373+
# (*) RTs of a different app/family are not supposed to be
374+
# shared with or accessible by us in the first place.
375+
at = self._acquire_token_silent_by_finding_specific_refresh_token(
376+
authority, scopes,
377+
dict(query, family_id="1"), # A hack, we have only 1 family for now
378+
rt_remover=lambda rt_item: None, # NO-OP b/c RTs are likely not mine
379+
break_condition=lambda response: # Break loop when app not in family
380+
# Based on an AAD-only behavior mentioned in internal doc here
381+
# https://msazure.visualstudio.com/One/_git/ESTS-Docs/pullrequest/1138595
382+
"client_mismatch" in response.get("error_additional_info", []),
383+
**kwargs)
384+
if at:
385+
return at
386+
if app_metadata.get("family_id"): # Meaning this app belongs to this family
387+
at = self._acquire_token_silent_by_finding_specific_refresh_token(
388+
authority, scopes, dict(query, family_id=app_metadata["family_id"]),
389+
**kwargs)
390+
if at:
391+
return at
392+
# Either this app is an orphan, so we will naturally use its own RT;
393+
# or all attempts above have failed, so we fall back to non-foci behavior.
394+
return self._acquire_token_silent_by_finding_specific_refresh_token(
395+
authority, scopes, dict(query, client_id=self.client_id), **kwargs)
396+
397+
def _acquire_token_silent_by_finding_specific_refresh_token(
398+
self, authority, scopes, query,
399+
rt_remover=None, break_condition=lambda response: False, **kwargs):
305400
matches = self.token_cache.find(
306401
self.token_cache.CredentialType.REFRESH_TOKEN,
307402
# target=scopes, # AAD RTs are scope-independent
308-
query={
309-
"client_id": self.client_id,
310-
"environment": the_authority.instance,
311-
"home_account_id": (account or {}).get("home_account_id"),
312-
# "realm": the_authority.tenant, # AAD RTs are tenant-independent
313-
})
314-
client = self._build_client(self.client_credential, the_authority)
403+
query=query)
404+
logger.debug("Found %d RTs matching %s", len(matches), query)
405+
client = self._build_client(self.client_credential, authority)
315406
for entry in matches:
316-
logger.debug("Cache hit an RT")
407+
logger.debug("Cache attempts an RT")
317408
response = client.obtain_token_by_refresh_token(
318409
entry, rt_getter=lambda token_item: token_item["secret"],
319-
scope=decorate_scope(scopes, self.client_id))
410+
on_removing_rt=rt_remover or self.token_cache.remove_rt,
411+
scope=scopes,
412+
**kwargs)
320413
if "error" not in response:
321414
return response
322415
logger.debug(
323416
"Refresh failed. {error}: {error_description}".format(**response))
417+
if break_condition(response):
418+
break
324419

325420

326421
class PublicClientApplication(ClientApplication): # browser app or mobile app

msal/oauth2cli/oauth2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
108108
data=None, # All relevant data, which will go into the http body
109109
headers=None, # a dict to be sent as request headers
110110
timeout=None,
111+
post=None, # A callable to replace requests.post(), for testing.
112+
# Such as: lambda url, **kwargs:
113+
# Mock(status_code=200, json=Mock(return_value={}))
111114
**kwargs # Relay all extra parameters to underlying requests
112115
): # Returns the json object came from the OAUTH2 response
113116
_data = {'client_id': self.client_id, 'grant_type': grant_type}
@@ -133,7 +136,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
133136
raise ValueError("token_endpoint not found in configuration")
134137
_headers = {'Accept': 'application/json'}
135138
_headers.update(headers or {})
136-
resp = self.session.post(
139+
resp = (post or self.session.post)(
137140
self.configuration["token_endpoint"],
138141
headers=_headers, params=params, data=_data, auth=auth,
139142
timeout=timeout or self.timeout,
@@ -393,16 +396,18 @@ def _obtain_token(self, grant_type, params=None, data=None, *args, **kwargs):
393396

394397
def obtain_token_by_refresh_token(self, token_item, scope=None,
395398
rt_getter=lambda token_item: token_item["refresh_token"],
399+
on_removing_rt=None,
396400
**kwargs):
397401
# type: (Union[str, dict], Union[str, list, set, tuple], Callable) -> dict
398402
"""This is an "overload" which accepts a refresh token item as a dict,
399403
therefore this method can relay refresh_token item to event listeners.
400404
401-
:param refresh_token_item: A refresh token item came from storage
405+
:param token_item: A refresh token item came from storage
402406
:param scope: If omitted, is treated as equal to the scope originally
403407
granted by the resource ownser,
404408
according to https://tools.ietf.org/html/rfc6749#section-6
405409
:param rt_getter: A callable used to extract the RT from token_item
410+
:param on_removing_rt: If absent, fall back to the one defined in initialization
406411
"""
407412
if isinstance(token_item, str):
408413
# Satisfy the L of SOLID, although we expect caller uses a dict
@@ -412,7 +417,7 @@ def obtain_token_by_refresh_token(self, token_item, scope=None,
412417
resp = super(Client, self).obtain_token_by_refresh_token(
413418
rt_getter(token_item), scope=scope, **kwargs)
414419
if resp.get('error') == 'invalid_grant':
415-
self.on_removing_rt(token_item) # Discard old RT
420+
(on_removing_rt or self.on_removing_rt)(token_item) # Discard old RT
416421
if 'refresh_token' in resp:
417422
self.on_updating_rt(token_item, resp['refresh_token'])
418423
return resp

msal/token_cache.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ class CredentialType:
3030
REFRESH_TOKEN = "RefreshToken"
3131
ACCOUNT = "Account" # Not exactly a credential type, but we put it here
3232
ID_TOKEN = "IdToken"
33+
APP_METADATA = "AppMetadata"
34+
35+
class AuthorityType:
36+
ADFS = "ADFS"
37+
MSSTS = "MSSTS" # MSSTS means AAD v2 for both AAD & MSA
3338

3439
def __init__(self):
3540
self._lock = threading.RLock()
@@ -118,8 +123,8 @@ def add(self, event, now=None):
118123
"oid", decoded_id_token.get("sub")),
119124
"username": decoded_id_token.get("preferred_username"),
120125
"authority_type":
121-
"ADFS" if realm == "adfs"
122-
else "MSSTS", # MSSTS means AAD v2 for both AAD & MSA
126+
self.AuthorityType.ADFS if realm == "adfs"
127+
else self.AuthorityType.MSSTS,
123128
# "client_info": response.get("client_info"), # Optional
124129
}
125130

@@ -158,6 +163,17 @@ def add(self, event, now=None):
158163
rt["family_id"] = response["foci"]
159164
self._cache.setdefault(self.CredentialType.REFRESH_TOKEN, {})[key] = rt
160165

166+
key = self._build_appmetadata_key(environment, event.get("client_id"))
167+
self._cache.setdefault(self.CredentialType.APP_METADATA, {})[key] = {
168+
"client_id": event.get("client_id"),
169+
"environment": environment,
170+
"family_id": response.get("foci"), # None is also valid
171+
}
172+
173+
@staticmethod
174+
def _build_appmetadata_key(environment, client_id):
175+
return "appmetadata-{}-{}".format(environment or "", client_id or "")
176+
161177
@classmethod
162178
def _build_rt_key(
163179
cls,
@@ -192,21 +208,24 @@ class SerializableTokenCache(TokenCache):
192208
Depending on your need,
193209
the following simple recipe for file-based persistence may be sufficient::
194210
195-
import atexit
196-
cache = SerializableTokenCache()
197-
cache.deserialize(open("my_cache.bin", "rb").read())
211+
import os, atexit, msal
212+
cache = msal.SerializableTokenCache()
213+
if os.path.exists("my_cache.bin"):
214+
cache.deserialize(open("my_cache.bin", "r").read())
198215
atexit.register(lambda:
199-
open("my_cache.bin", "wb").write(cache.serialize())
216+
open("my_cache.bin", "w").write(cache.serialize())
200217
# Hint: The following optional line persists only when state changed
201218
if cache.has_state_changed else None
202219
)
203-
app = ClientApplication(..., token_cache=cache)
220+
app = msal.ClientApplication(..., token_cache=cache)
204221
...
205222
206223
:var bool has_state_changed:
207224
Indicates whether the cache state has changed since last
208225
:func:`~serialize` or :func:`~deserialize` call.
209226
"""
227+
has_state_changed = False
228+
210229
def add(self, event, **kwargs):
211230
super(SerializableTokenCache, self).add(event, **kwargs)
212231
self.has_state_changed = True

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
.
2+
mock; python_version < '3.3'

sample/client_credential_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
{
55
"authority": "https://login.microsoftonline.com/organizations",
66
"client_id": "your_client_id",
7+
"scope": ["https://graph.microsoft.com/.default"],
78
"secret": "This is a sample only. You better NOT persist your password."
8-
"scope": ["https://graph.microsoft.com/.default"]
99
}
1010
1111
You can then run this sample with a JSON configuration file:

sample/username_password_sample.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,6 @@
5757
print(result.get("error"))
5858
print(result.get("error_description"))
5959
print(result.get("correlation_id")) # You may need this when reporting a bug
60-
60+
if 65001 in result.get("error_codes", []): # Not mean to be coded programatically, but...
61+
# AAD requires user consent for U/P flow
62+
print("Visit this to consent:", app.get_authorization_request_url(scope))

0 commit comments

Comments
 (0)