Skip to content

Commit ae7f794

Browse files
authored
Merge pull request #109 from AzureAD/release-0.8.0
Release 0.8.0
2 parents c319ea3 + f387d41 commit ae7f794

File tree

8 files changed

+277
-89
lines changed

8 files changed

+277
-89
lines changed

.travis.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ matrix:
1111
- python: 3.7
1212
dist: xenial
1313
sudo: true
14+
- python: 3.8
15+
dist: xenial
16+
sudo: true
1417

1518
install:
1619
- pip install -r requirements.txt

msal/application.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
# The __init__.py will import this. Not the other way around.
21-
__version__ = "0.7.0"
21+
__version__ = "0.8.0"
2222

2323
logger = logging.getLogger(__name__)
2424

@@ -194,8 +194,6 @@ def get_authorization_request_url(
194194
login_hint=None, # type: Optional[str]
195195
state=None, # Recommended by OAuth2 for CSRF protection
196196
redirect_uri=None,
197-
authority=None, # By default, it will use self.authority;
198-
# Multi-tenant app can use new authority on demand
199197
response_type="code", # Can be "token" if you use Implicit Grant
200198
**kwargs):
201199
"""Constructs a URL for you to start a Authorization Code Grant.
@@ -207,6 +205,9 @@ def get_authorization_request_url(
207205
Identifier of the user. Generally a User Principal Name (UPN).
208206
:param str redirect_uri:
209207
Address to return to upon receiving a response from the authority.
208+
:param str response_type:
209+
Default value is "code" for an OAuth2 Authorization Code grant.
210+
You can use other content such as "id_token".
210211
:return: The authorization url as a string.
211212
"""
212213
""" # TBD: this would only be meaningful in a new acquire_token_interactive()
@@ -217,15 +218,22 @@ def get_authorization_request_url(
217218
(Under the hood, we simply merge scope and additional_scope before
218219
sending them on the wire.)
219220
"""
221+
authority = kwargs.pop("authority", None) # Historically we support this
222+
if authority:
223+
warnings.warn(
224+
"We haven't decided if this method will accept authority parameter")
225+
# The previous implementation is, it will use self.authority by default.
226+
# Multi-tenant app can use new authority on demand
220227
the_authority = Authority(
221228
authority,
222229
verify=self.verify, proxies=self.proxies, timeout=self.timeout,
223230
) if authority else self.authority
231+
224232
client = Client(
225233
{"authorization_endpoint": the_authority.authorization_endpoint},
226234
self.client_id)
227235
return client.build_auth_request_uri(
228-
response_type="code", # Using Authorization Code grant
236+
response_type=response_type,
229237
redirect_uri=redirect_uri, state=state, login_hint=login_hint,
230238
scope=decorate_scope(scopes, self.client_id),
231239
)
@@ -269,6 +277,7 @@ def acquire_token_by_authorization_code(
269277
# one scope. But, MSAL decorates your scope anyway, so they are never
270278
# really empty.
271279
assert isinstance(scopes, list), "Invalid parameter type"
280+
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
272281
return self.client.obtain_token_by_authorization_code(
273282
code, redirect_uri=redirect_uri,
274283
data=dict(
@@ -396,6 +405,7 @@ def acquire_token_silent(
396405
- None when cache lookup does not yield anything.
397406
"""
398407
assert isinstance(scopes, list), "Invalid parameter type"
408+
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
399409
if authority:
400410
warnings.warn("We haven't decided how/if this method will accept authority parameter")
401411
# the_authority = Authority(
@@ -412,7 +422,7 @@ def acquire_token_silent(
412422
validate_authority=False,
413423
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
414424
result = self._acquire_token_silent_from_cache_and_possibly_refresh_it(
415-
scopes, account, the_authority, **kwargs)
425+
scopes, account, the_authority, force_refresh=force_refresh, **kwargs)
416426
if result:
417427
return result
418428

@@ -424,15 +434,19 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
424434
force_refresh=False, # type: Optional[boolean]
425435
**kwargs):
426436
if not force_refresh:
427-
matches = self.token_cache.find(
428-
self.token_cache.CredentialType.ACCESS_TOKEN,
429-
target=scopes,
430-
query={
437+
query={
431438
"client_id": self.client_id,
432439
"environment": authority.instance,
433440
"realm": authority.tenant,
434441
"home_account_id": (account or {}).get("home_account_id"),
435-
})
442+
}
443+
key_id = kwargs.get("data", {}).get("key_id")
444+
if key_id: # Some token types (SSH-certs, POP) are bound to a key
445+
query["key_id"] = key_id
446+
matches = self.token_cache.find(
447+
self.token_cache.CredentialType.ACCESS_TOKEN,
448+
target=scopes,
449+
query=query)
436450
now = time.time()
437451
for entry in matches:
438452
expires_in = int(entry["expires_on"]) - now
@@ -513,6 +527,20 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
513527
if break_condition(response):
514528
break
515529

530+
def _validate_ssh_cert_input_data(self, data):
531+
if data.get("token_type") == "ssh-cert":
532+
if not data.get("req_cnf"):
533+
raise ValueError(
534+
"When requesting an SSH certificate, "
535+
"you must include a string parameter named 'req_cnf' "
536+
"containing the public key in JWK format "
537+
"(https://tools.ietf.org/html/rfc7517).")
538+
if not data.get("key_id"):
539+
raise ValueError(
540+
"When requesting an SSH certificate, "
541+
"you must include a string parameter named 'key_id' "
542+
"which identifies the key in the 'req_cnf' argument.")
543+
516544

517545
class PublicClientApplication(ClientApplication): # browser app or mobile app
518546

msal/authority.py

Lines changed: 59 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
import re
1+
try:
2+
from urllib.parse import urlparse
3+
except ImportError: # Fall back to Python 2
4+
from urlparse import urlparse
25
import logging
36

47
import requests
@@ -15,14 +18,21 @@
1518
'login.microsoftonline.us',
1619
'login.microsoftonline.de',
1720
])
18-
21+
WELL_KNOWN_B2C_HOSTS = [
22+
"b2clogin.com",
23+
"b2clogin.cn",
24+
"b2clogin.us",
25+
"b2clogin.de",
26+
]
1927

2028
class Authority(object):
2129
"""This class represents an (already-validated) authority.
2230
2331
Once constructed, it contains members named "*_endpoint" for this instance.
2432
TODO: It will also cache the previously-validated authority instances.
2533
"""
34+
_domains_without_user_realm_discovery = set([])
35+
2636
def __init__(self, authority_url, validate_authority=True,
2737
verify=True, proxies=None, timeout=None,
2838
):
@@ -37,18 +47,30 @@ def __init__(self, authority_url, validate_authority=True,
3747
self.verify = verify
3848
self.proxies = proxies
3949
self.timeout = timeout
40-
canonicalized, self.instance, tenant = canonicalize(authority_url)
41-
tenant_discovery_endpoint = (
42-
'https://{}/{}{}/.well-known/openid-configuration'.format(
43-
self.instance,
44-
tenant,
45-
"" if tenant == "adfs" else "/v2.0" # the AAD v2 endpoint
46-
))
47-
if (tenant != "adfs" and validate_authority
50+
authority, self.instance, tenant = canonicalize(authority_url)
51+
is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS)
52+
if (tenant != "adfs" and (not is_b2c) and validate_authority
4853
and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS):
49-
tenant_discovery_endpoint = instance_discovery(
50-
canonicalized + "/oauth2/v2.0/authorize",
54+
payload = instance_discovery(
55+
"https://{}{}/oauth2/v2.0/authorize".format(
56+
self.instance, authority.path),
5157
verify=verify, proxies=proxies, timeout=timeout)
58+
if payload.get("error") == "invalid_instance":
59+
raise ValueError(
60+
"invalid_instance: "
61+
"The authority you provided, %s, is not whitelisted. "
62+
"If it is indeed your legit customized domain name, "
63+
"you can turn off this check by passing in "
64+
"validate_authority=False"
65+
% authority_url)
66+
tenant_discovery_endpoint = payload['tenant_discovery_endpoint']
67+
else:
68+
tenant_discovery_endpoint = (
69+
'https://{}{}{}/.well-known/openid-configuration'.format(
70+
self.instance,
71+
authority.path, # In B2C scenario, it is "/tenant/policy"
72+
"" if tenant == "adfs" else "/v2.0" # the AAD v2 endpoint
73+
))
5274
openid_config = tenant_discovery(
5375
tenant_discovery_endpoint,
5476
verify=verify, proxies=proxies, timeout=timeout)
@@ -58,42 +80,44 @@ def __init__(self, authority_url, validate_authority=True,
5880
_, _, self.tenant = canonicalize(self.token_endpoint) # Usually a GUID
5981
self.is_adfs = self.tenant.lower() == 'adfs'
6082

61-
def user_realm_discovery(self, username):
62-
resp = requests.get(
63-
"https://{netloc}/common/userrealm/{username}?api-version=1.0".format(
64-
netloc=self.instance, username=username),
65-
headers={'Accept':'application/json'},
66-
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
67-
resp.raise_for_status()
68-
return resp.json()
69-
# It will typically contain "ver", "account_type",
83+
def user_realm_discovery(self, username, response=None):
84+
# It will typically return a dict containing "ver", "account_type",
7085
# "federation_protocol", "cloud_audience_urn",
7186
# "federation_metadata_url", "federation_active_auth_url", etc.
87+
if self.instance not in self.__class__._domains_without_user_realm_discovery:
88+
resp = response or requests.get(
89+
"https://{netloc}/common/userrealm/{username}?api-version=1.0".format(
90+
netloc=self.instance, username=username),
91+
headers={'Accept':'application/json'},
92+
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
93+
if resp.status_code != 404:
94+
resp.raise_for_status()
95+
return resp.json()
96+
self.__class__._domains_without_user_realm_discovery.add(self.instance)
97+
return {} # This can guide the caller to fall back normal ROPC flow
98+
7299

73-
def canonicalize(url):
74-
# Returns (canonicalized_url, netloc, tenant). Raises ValueError on errors.
75-
match_object = re.match(r'https://([^/]+)/([^/?#]+)', url.lower())
76-
if not match_object:
100+
def canonicalize(authority_url):
101+
authority = urlparse(authority_url)
102+
parts = authority.path.split("/")
103+
if authority.scheme != "https" or len(parts) < 2 or not parts[1]:
77104
raise ValueError(
78105
"Your given address (%s) should consist of "
79106
"an https url with a minimum of one segment in a path: e.g. "
80-
"https://login.microsoftonline.com/<tenant_name>" % url)
81-
return match_object.group(0), match_object.group(1), match_object.group(2)
107+
"https://login.microsoftonline.com/<tenant> "
108+
"or https://<tenant_name>.b2clogin.com/<tenant_name>.onmicrosoft.com/policy"
109+
% authority_url)
110+
return authority, authority.netloc, parts[1]
82111

83-
def instance_discovery(url, response=None, **kwargs):
84-
# Returns tenant discovery endpoint
85-
resp = requests.get( # Note: This URL seemingly returns V1 endpoint only
112+
def instance_discovery(url, **kwargs):
113+
return requests.get( # Note: This URL seemingly returns V1 endpoint only
86114
'https://{}/common/discovery/instance'.format(
87115
WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too
88116
# See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103
89117
# and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33
90118
),
91119
params={'authorization_endpoint': url, 'api-version': '1.0'},
92-
**kwargs)
93-
payload = response or resp.json()
94-
if 'tenant_discovery_endpoint' not in payload:
95-
raise MsalServiceError(status_code=resp.status_code, **payload)
96-
return payload['tenant_discovery_endpoint']
120+
**kwargs).json()
97121

98122
def tenant_discovery(tenant_discovery_endpoint, **kwargs):
99123
# Returns Openid Configuration

msal/token_cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def __add(self, event, now=None):
127127
if "token_endpoint" in event:
128128
_, environment, realm = canonicalize(event["token_endpoint"])
129129
response = event.get("response", {})
130+
data = event.get("data", {})
130131
access_token = response.get("access_token")
131132
refresh_token = response.get("refresh_token")
132133
id_token = response.get("id_token")
@@ -165,6 +166,8 @@ def __add(self, event, now=None):
165166
"expires_on": str(now + expires_in), # Same here
166167
"extended_expires_on": str(now + ext_expires_in) # Same here
167168
}
169+
if data.get("key_id"): # It happens in SSH-cert or POP scenario
170+
at["key_id"] = data.get("key_id")
168171
self.modify(self.CredentialType.ACCESS_TOKEN, at, at)
169172

170173
if client_info:

sample/device_flow_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
The configuration file would look like this:
33
44
{
5-
"authority": "https://login.microsoftonline.com/organizations",
5+
"authority": "https://login.microsoftonline.com/common",
66
"client_id": "your_client_id",
77
"scope": ["User.Read"]
88
}

tests/test_authority.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import os
2+
13
from msal.authority import *
24
from msal.exceptions import MsalServiceError
35
from tests import unittest
46

57

8+
@unittest.skipIf(os.getenv("TRAVIS_TAG"), "Skip network io during tagged release")
69
class TestAuthority(unittest.TestCase):
710

811
def test_wellknown_host_and_tenant(self):
@@ -26,7 +29,7 @@ def test_lessknown_host_will_return_a_set_of_v1_endpoints(self):
2629
self.assertNotIn('v2.0', a.token_endpoint)
2730

2831
def test_unknown_host_wont_pass_instance_discovery(self):
29-
with self.assertRaisesRegexp(MsalServiceError, "invalid_instance"):
32+
with self.assertRaisesRegexp(ValueError, "invalid_instance"):
3033
Authority('https://unknown.host/tenant_doesnt_matter_in_this_case')
3134

3235
def test_invalid_host_skipping_validation_meets_connection_error_down_the_road(self):
@@ -37,19 +40,19 @@ def test_invalid_host_skipping_validation_meets_connection_error_down_the_road(s
3740
class TestAuthorityInternalHelperCanonicalize(unittest.TestCase):
3841

3942
def test_canonicalize_tenant_followed_by_extra_paths(self):
40-
self.assertEqual(
41-
canonicalize("https://example.com/tenant/subpath?foo=bar#fragment"),
42-
("https://example.com/tenant", "example.com", "tenant"))
43+
_, i, t = canonicalize("https://example.com/tenant/subpath?foo=bar#fragment")
44+
self.assertEqual("example.com", i)
45+
self.assertEqual("tenant", t)
4346

4447
def test_canonicalize_tenant_followed_by_extra_query(self):
45-
self.assertEqual(
46-
canonicalize("https://example.com/tenant?foo=bar#fragment"),
47-
("https://example.com/tenant", "example.com", "tenant"))
48+
_, i, t = canonicalize("https://example.com/tenant?foo=bar#fragment")
49+
self.assertEqual("example.com", i)
50+
self.assertEqual("tenant", t)
4851

4952
def test_canonicalize_tenant_followed_by_extra_fragment(self):
50-
self.assertEqual(
51-
canonicalize("https://example.com/tenant#fragment"),
52-
("https://example.com/tenant", "example.com", "tenant"))
53+
_, i, t = canonicalize("https://example.com/tenant#fragment")
54+
self.assertEqual("example.com", i)
55+
self.assertEqual("tenant", t)
5356

5457
def test_canonicalize_rejects_non_https(self):
5558
with self.assertRaises(ValueError):
@@ -64,20 +67,22 @@ def test_canonicalize_rejects_tenantless_host_with_trailing_slash(self):
6467
canonicalize("https://no.tenant.example.com/")
6568

6669

67-
class TestAuthorityInternalHelperInstanceDiscovery(unittest.TestCase):
68-
69-
def test_instance_discovery_happy_case(self):
70-
self.assertEqual(
71-
instance_discovery("https://login.windows.net/tenant"),
72-
"https://login.windows.net/tenant/.well-known/openid-configuration")
73-
74-
def test_instance_discovery_with_unknown_instance(self):
75-
with self.assertRaisesRegexp(MsalServiceError, "invalid_instance"):
76-
instance_discovery('https://unknown.host/tenant_doesnt_matter_here')
77-
78-
def test_instance_discovery_with_mocked_response(self):
79-
mock_response = {'tenant_discovery_endpoint': 'http://a.com/t/openid'}
80-
endpoint = instance_discovery(
81-
"https://login.microsoftonline.in/tenant.com", response=mock_response)
82-
self.assertEqual(endpoint, mock_response['tenant_discovery_endpoint'])
70+
@unittest.skipIf(os.getenv("TRAVIS_TAG"), "Skip network io during tagged release")
71+
class TestAuthorityInternalHelperUserRealmDiscovery(unittest.TestCase):
72+
def test_memorize(self):
73+
# We use a real authority so the constructor can finish tenant discovery
74+
authority = "https://login.microsoftonline.com/common"
75+
self.assertNotIn(authority, Authority._domains_without_user_realm_discovery)
76+
a = Authority(authority, validate_authority=False)
77+
78+
# We now pretend this authority supports no User Realm Discovery
79+
class MockResponse(object):
80+
status_code = 404
81+
a.user_realm_discovery("john.doe@example.com", response=MockResponse())
82+
self.assertIn(
83+
"login.microsoftonline.com",
84+
Authority._domains_without_user_realm_discovery,
85+
"user_realm_discovery() should memorize domains not supporting URD")
86+
a.user_realm_discovery("john.doe@example.com",
87+
response="This would cause exception if memorization did not work")
8388

0 commit comments

Comments
 (0)