From 41b308611fe337d3b0565ec3827eb3ba47f7b624 Mon Sep 17 00:00:00 2001 From: JM Ibanez Date: Wed, 6 Aug 2014 11:34:51 +0800 Subject: [PATCH 1/9] Add Six as a dependency. We use this to shim up some of the modules that changed locations between Python 2 and Python 3. --- requirements.txt | 1 + setup.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a79a5d9f..37a92814 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ Django>=1.4 shortuuid>=0.3 +six>=1.7 diff --git a/setup.py b/setup.py index f9f55b85..884227a9 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,8 @@ 'Framework :: Django', ], install_requires=[ - "shortuuid>=0.3" + "shortuuid>=0.3", + "six>=1.7", ], include_package_data=True, zip_safe=False, From 3eb9191dc2d7bda33e33e0f367624ae201f648cf Mon Sep 17 00:00:00 2001 From: JM Ibanez Date: Wed, 6 Aug 2014 11:37:10 +0800 Subject: [PATCH 2/9] Use proper import syntax for relative imports. --- provider/oauth2/__init__.py | 14 ++++++++------ provider/views.py | 4 +++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/provider/oauth2/__init__.py b/provider/oauth2/__init__.py index 34796220..57369dc2 100644 --- a/provider/oauth2/__init__.py +++ b/provider/oauth2/__init__.py @@ -1,6 +1,8 @@ -import backends -import forms -import managers -import models -import urls -import views \ No newline at end of file +from __future__ import absolute_import + +from . import backends +from . import forms +from . import managers +from . import models +from . import urls +from . import views diff --git a/provider/views.py b/provider/views.py index dd1200df..d8239f78 100644 --- a/provider/views.py +++ b/provider/views.py @@ -1,3 +1,5 @@ +from __future__ import absolute_import + import json import urlparse from django.http import HttpResponse @@ -5,7 +7,7 @@ from django.utils.translation import ugettext as _ from django.views.generic.base import TemplateView from django.core.exceptions import ObjectDoesNotExist -from oauth2.models import Client +from .oauth2.models import Client from . import constants, scope From 546637033233f4f0736e03f2b450c01026dcd053 Mon Sep 17 00:00:00 2001 From: JM Ibanez Date: Wed, 6 Aug 2014 11:38:41 +0800 Subject: [PATCH 3/9] Use Python 3 syntax for exception handling. Breakage warning: this won't work on Python versions < 2.6. --- provider/forms.py | 4 ++-- provider/views.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/provider/forms.py b/provider/forms.py index 9b7fd609..f7c68d8a 100644 --- a/provider/forms.py +++ b/provider/forms.py @@ -51,7 +51,7 @@ def _clean_fields(self): """ try: super(OAuthForm, self)._clean_fields() - except OAuthValidationError, e: + except OAuthValidationError as e: self._errors.update(e.args[0]) def _clean_form(self): @@ -60,5 +60,5 @@ def _clean_form(self): """ try: super(OAuthForm, self)._clean_form() - except OAuthValidationError, e: + except OAuthValidationError as e: self._errors.update(e.args[0]) diff --git a/provider/views.py b/provider/views.py index d8239f78..c23e3401 100644 --- a/provider/views.py +++ b/provider/views.py @@ -257,7 +257,7 @@ def handle(self, request, post_data=None): try: client, data = self._validate_client(request, data) - except OAuthError, e: + except OAuthError as e: return self.error_response(request, e.args[0], status=400) authorization_form = self.get_authorization_form(request, client, @@ -598,5 +598,5 @@ def post(self, request): try: return handler(request, request.POST, client) - except OAuthError, e: + except OAuthError as e: return self.error_response(e.args[0]) From a8714c2c87c4f795408fd8eafa2dbe1c3295ac61 Mon Sep 17 00:00:00 2001 From: JM Ibanez Date: Wed, 6 Aug 2014 11:42:18 +0800 Subject: [PATCH 4/9] Handle module name changes for Django >=1.5 To allow Python 3 compatibility, Django 1.5 renamed smart_unicode in django.utils.encoding as smart_text (as Python 3's str type is unicode by default). To accomodate running on earlier Django versions, we do an import of the old name; if we fail, we import smart_text as smart_unicode instead (with the assumption that Django is actually in the path). --- provider/oauth2/forms.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/provider/oauth2/forms.py b/provider/oauth2/forms.py index bb4dcb89..d474be09 100644 --- a/provider/oauth2/forms.py +++ b/provider/oauth2/forms.py @@ -1,6 +1,12 @@ from django import forms from django.contrib.auth import authenticate -from django.utils.encoding import smart_unicode + +try: + from django.utils.encoding import smart_unicode +except ImportError: + # Django >=1.5 + from django.utils.encoding import smart_text as smart_unicode + from django.utils.translation import ugettext as _ from .. import scope from ..constants import RESPONSE_TYPE_CHOICES, SCOPES From 39717bc595a751acc3a9be46f09b588d56e7245a Mon Sep 17 00:00:00 2001 From: JM Ibanez Date: Wed, 6 Aug 2014 11:46:14 +0800 Subject: [PATCH 5/9] Fix up usage of urlparse to import via Six. urlparse in Python 3 has been subsumed into urllib; to support both Python 2 and Python 3, we import urlparse via Six. This commit also fixes a bug in the tests. There is an incorrect assumption that parse_qs parses just the query string of a full URL, when given a full URL; it doesn't, and instead parse_qs assumes the caller passes in just the query string part of the URL. When given a full URL, the rest of the URL including the first attribute in the query string is assigned as the first key of the resulting dictionary returned by parse_qs. Because the 'code' key incidentally ends up being the second attribute in the query string, the test passes. However, because there was a change in the ordering of the query string, no 'code' key could be found when running in Python 3. --- provider/oauth2/tests.py | 6 ++++-- provider/views.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/provider/oauth2/tests.py b/provider/oauth2/tests.py index 53de8af2..456a2f3a 100644 --- a/provider/oauth2/tests.py +++ b/provider/oauth2/tests.py @@ -1,5 +1,6 @@ import json -import urlparse +from six.moves.urllib import parse as urlparse +import six import datetime from django.http import QueryDict from django.conf import settings @@ -578,7 +579,8 @@ def test_clear_expired(self): self.assertTrue('code' in location) # verify that Grant with code exists - code = urlparse.parse_qs(location)['code'][0] + location_qs = urlparse.urlparse(location)[4] + code = urlparse.parse_qs(location_qs)['code'][0] self.assertTrue(Grant.objects.filter(code=code).exists()) # use the code/grant diff --git a/provider/views.py b/provider/views.py index c23e3401..b99ffd3c 100644 --- a/provider/views.py +++ b/provider/views.py @@ -1,7 +1,7 @@ from __future__ import absolute_import import json -import urlparse +from six.moves.urllib import parse as urlparse from django.http import HttpResponse from django.http import HttpResponseRedirect, QueryDict from django.utils.translation import ugettext as _ From b9f9f31519a834e69b275db5e7ef8d5b6a5e8c30 Mon Sep 17 00:00:00 2001 From: JM Ibanez Date: Wed, 6 Aug 2014 11:52:17 +0800 Subject: [PATCH 6/9] Fix bytes vs str usage. Because Python 3 is strict when it comes to byte strings vs. text strings, and because Python 3 does not do any automatic coercion between the two types, we need to be more explicit. In particular, when using response.content (which is in really a byte string), we need to convert it to a particular character encoding before treating it as a text string. We assume UTF-8 for the tests, which *will* break in other encodings, but since these are the tests, we don't mind assuming UTF-8. As well, because encode() in the str type in Python 3 no longer supports non-character encodings such as Base64, we need to use codecs to do the conversion. We also fix the assumption that shortuuid.uuid() returns a byte string (which it doesn't) when passing to hashlib. --- provider/oauth2/backends.py | 5 +++- provider/oauth2/tests.py | 58 ++++++++++++++++++++----------------- provider/utils.py | 9 +++--- 3 files changed, 41 insertions(+), 31 deletions(-) diff --git a/provider/oauth2/backends.py b/provider/oauth2/backends.py index db0fb853..00477327 100644 --- a/provider/oauth2/backends.py +++ b/provider/oauth2/backends.py @@ -1,3 +1,5 @@ +import codecs + from ..utils import now from .forms import ClientAuthForm, PublicPasswordGrantForm from .models import AccessToken @@ -29,7 +31,8 @@ def authenticate(self, request=None): try: basic, base64 = auth.split(' ') - client_id, client_secret = base64.decode('base64').split(':') + base64_base64 = codecs.decode(base64.encode('utf-8'), 'base64').decode('utf-8') + client_id, client_secret = base64_base64.split(':') form = ClientAuthForm({ 'client_id': client_id, diff --git a/provider/oauth2/tests.py b/provider/oauth2/tests.py index 456a2f3a..08d46b61 100644 --- a/provider/oauth2/tests.py +++ b/provider/oauth2/tests.py @@ -1,6 +1,7 @@ import json from six.moves.urllib import parse as urlparse import six +import codecs import datetime from django.http import QueryDict from django.conf import settings @@ -17,6 +18,10 @@ from .backends import BasicClientBackend, RequestParamsClientBackend from .backends import AccessTokenBackend +def _py3_base64_str_shim(val): + val_b = val.encode('utf-8') + return codecs.encode(val_b, 'base64').decode('utf-8') + @skipIfCustomUser class BaseOAuth2TestCase(TestCase): @@ -90,7 +95,7 @@ def test_authorization_requires_client_id(self): response = self.client.get(self.auth_url2()) self.assertEqual(400, response.status_code) - self.assertTrue("An unauthorized client tried to access your resources." in response.content) + self.assertTrue("An unauthorized client tried to access your resources." in response.content.decode('utf-8')) def test_authorization_rejects_invalid_client_id(self): self.login() @@ -98,7 +103,7 @@ def test_authorization_rejects_invalid_client_id(self): response = self.client.get(self.auth_url2()) self.assertEqual(400, response.status_code) - self.assertTrue("An unauthorized client tried to access your resources." in response.content) + self.assertTrue("An unauthorized client tried to access your resources." in response.content.decode('utf-8')) def test_authorization_requires_response_type(self): self.login() @@ -106,7 +111,7 @@ def test_authorization_requires_response_type(self): response = self.client.get(self.auth_url2()) self.assertEqual(400, response.status_code) - self.assertTrue(escape(u"No 'response_type' supplied.") in response.content) + self.assertTrue(escape(u"No 'response_type' supplied.") in response.content.decode('utf-8')) def test_authorization_requires_supported_response_type(self): self.login() @@ -114,7 +119,7 @@ def test_authorization_requires_supported_response_type(self): response = self.client.get(self.auth_url2()) self.assertEqual(400, response.status_code) - self.assertTrue(escape(u"'unsupported' is not a supported response type.") in response.content) + self.assertTrue(escape(u"'unsupported' is not a supported response type.") in response.content.decode('utf-8')) response = self.client.get(self.auth_url() + '?client_id=%s&response_type=code' % self.get_client().client_id) response = self.client.get(self.auth_url2()) @@ -133,7 +138,7 @@ def test_authorization_requires_a_valid_redirect_uri(self): response = self.client.get(self.auth_url2()) self.assertEqual(400, response.status_code) - self.assertTrue(escape(u"The requested redirect didn't match the client settings.") in response.content) + self.assertTrue(escape(u"The requested redirect didn't match the client settings.") in response.content.decode('utf-8')) response = self.client.get(self.auth_url() + '?client_id=%s&response_type=code&redirect_uri=%s' % ( self.get_client().client_id, @@ -149,7 +154,7 @@ def test_authorization_requires_a_valid_scope(self): response = self.client.get(self.auth_url2()) self.assertEqual(400, response.status_code) - self.assertTrue(escape(u"'invalid' is not a valid scope.") in response.content) + self.assertTrue(escape(u"'invalid' is not a valid scope.") in response.content.decode('utf-8')) response = self.client.get(self.auth_url() + '?client_id=%s&response_type=code&scope=%s' % ( self.get_client().client_id, @@ -224,7 +229,7 @@ def test_fetching_access_token_with_invalid_client(self): 'client_secret': self.get_client().client_secret, }) self.assertEqual(400, response.status_code, response.content) - self.assertEqual('invalid_client', json.loads(response.content)['error']) + self.assertEqual('invalid_client', json.loads(response.content.decode('utf-8'))['error']) def test_fetching_access_token_with_invalid_grant(self): self.login() @@ -237,7 +242,7 @@ def test_fetching_access_token_with_invalid_grant(self): 'code': '123'}) self.assertEqual(400, response.status_code, response.content) - self.assertEqual('invalid_grant', json.loads(response.content)['error']) + self.assertEqual('invalid_grant', json.loads(response.content.decode('utf-8'))['error']) def _login_authorize_get_token(self): required_props = ['access_token', 'token_type'] @@ -257,7 +262,7 @@ def _login_authorize_get_token(self): self.assertEqual(200, response.status_code, response.content) - token = json.loads(response.content) + token = json.loads(response.content.decode('utf-8')) for prop in required_props: self.assertIn(prop, token, "Access token response missing " @@ -284,7 +289,7 @@ def test_fetching_access_token_with_invalid_grant_type(self): }) self.assertEqual(400, response.status_code) - self.assertEqual('unsupported_grant_type', json.loads(response.content)['error'], + self.assertEqual('unsupported_grant_type', json.loads(response.content.decode('utf-8'))['error'], response.content) def test_fetching_single_access_token(self): @@ -325,7 +330,7 @@ def test_fetching_access_token_multiple_times(self): 'code': code}) self.assertEqual(400, response.status_code) - self.assertEqual('invalid_grant', json.loads(response.content)['error']) + self.assertEqual('invalid_grant', json.loads(response.content.decode('utf-8'))['error']) def test_escalating_the_scope(self): self.login() @@ -340,7 +345,7 @@ def test_escalating_the_scope(self): 'scope': 'read write'}) self.assertEqual(400, response.status_code) - self.assertEqual('invalid_scope', json.loads(response.content)['error']) + self.assertEqual('invalid_scope', json.loads(response.content.decode('utf-8'))['error']) def test_refreshing_an_access_token(self): token = self._login_authorize_get_token() @@ -362,7 +367,7 @@ def test_refreshing_an_access_token(self): }) self.assertEqual(400, response.status_code) - self.assertEqual('invalid_grant', json.loads(response.content)['error'], + self.assertEqual('invalid_grant', json.loads(response.content.decode('utf-8'))['error'], response.content) def test_password_grant_public(self): @@ -379,8 +384,8 @@ def test_password_grant_public(self): }) self.assertEqual(200, response.status_code, response.content) - self.assertNotIn('refresh_token', json.loads(response.content)) - expires_in = json.loads(response.content)['expires_in'] + self.assertNotIn('refresh_token', json.loads(response.content.decode('utf-8'))) + expires_in = json.loads(response.content.decode('utf-8'))['expires_in'] expires_in_days = round(expires_in / (60.0 * 60.0 * 24.0)) self.assertEqual(expires_in_days, constants.EXPIRE_DELTA_PUBLIC.days) @@ -398,7 +403,7 @@ def test_password_grant_confidential(self): }) self.assertEqual(200, response.status_code, response.content) - self.assertTrue(json.loads(response.content)['refresh_token']) + self.assertTrue(json.loads(response.content.decode('utf-8'))['refresh_token']) def test_password_grant_confidential_no_secret(self): c = self.get_client() @@ -412,7 +417,7 @@ def test_password_grant_confidential_no_secret(self): 'password': self.get_password(), }) - self.assertEqual('invalid_client', json.loads(response.content)['error']) + self.assertEqual('invalid_client', json.loads(response.content.decode('utf-8'))['error']) def test_password_grant_invalid_password_public(self): c = self.get_client() @@ -427,7 +432,7 @@ def test_password_grant_invalid_password_public(self): }) self.assertEqual(400, response.status_code, response.content) - self.assertEqual('invalid_client', json.loads(response.content)['error']) + self.assertEqual('invalid_client', json.loads(response.content.decode('utf-8'))['error']) def test_password_grant_invalid_password_confidential(self): c = self.get_client() @@ -443,7 +448,7 @@ def test_password_grant_invalid_password_confidential(self): }) self.assertEqual(400, response.status_code, response.content) - self.assertEqual('invalid_grant', json.loads(response.content)['error']) + self.assertEqual('invalid_grant', json.loads(response.content.decode('utf-8'))['error']) def test_access_token_response_valid_token_type(self): token = self._login_authorize_get_token() @@ -455,9 +460,10 @@ class AuthBackendTest(BaseOAuth2TestCase): def test_basic_client_backend(self): request = type('Request', (object,), {'META': {}})() - request.META['HTTP_AUTHORIZATION'] = "Basic " + "{0}:{1}".format( - self.get_client().client_id, - self.get_client().client_secret).encode('base64') + request.META['HTTP_AUTHORIZATION'] = "Basic " + _py3_base64_str_shim( + "{0}:{1}".format( + self.get_client().client_id, + self.get_client().client_secret)) self.assertEqual(BasicClientBackend().authenticate(request).id, 2, "Didn't return the right client.") @@ -497,13 +503,13 @@ def test_authorization_enforces_SSL(self): response = self.client.get(self.auth_url()) self.assertEqual(400, response.status_code) - self.assertTrue("A secure connection is required." in response.content) + self.assertTrue("A secure connection is required." in response.content.decode('utf-8')) def test_access_token_enforces_SSL(self): response = self.client.post(self.access_token_url(), {}) self.assertEqual(400, response.status_code) - self.assertTrue("A secure connection is required." in response.content) + self.assertTrue("A secure connection is required." in response.content.decode('utf-8')) class ClientFormTest(TestCase): @@ -590,7 +596,7 @@ def test_clear_expired(self): 'client_secret': self.get_client().client_secret, 'code': code}) self.assertEquals(200, response.status_code) - token = json.loads(response.content) + token = json.loads(response.content.decode('utf-8')) self.assertTrue('access_token' in token) access_token = token['access_token'] self.assertTrue('refresh_token' in token) @@ -612,7 +618,7 @@ def test_clear_expired(self): 'client_secret': self.get_client().client_secret, }) self.assertEqual(200, response.status_code) - token = json.loads(response.content) + token = json.loads(response.content.decode('utf-8')) self.assertTrue('access_token' in token) self.assertNotEquals(access_token, token['access_token']) self.assertTrue('refresh_token' in token) diff --git a/provider/utils.py b/provider/utils.py index 957a5c75..1cacc737 100644 --- a/provider/utils.py +++ b/provider/utils.py @@ -1,5 +1,6 @@ import hashlib import shortuuid +import six from datetime import datetime, tzinfo from django.conf import settings from django.utils import dateparse @@ -31,8 +32,8 @@ def short_token(): """ Generate a hash that can be used as an application identifier """ - hash = hashlib.sha1(shortuuid.uuid()) - hash.update(settings.SECRET_KEY) + hash = hashlib.sha1(shortuuid.uuid().encode('utf-8')) + hash.update(settings.SECRET_KEY.encode('utf-8')) return hash.hexdigest()[::2] @@ -40,8 +41,8 @@ def long_token(): """ Generate a hash that can be used as an application secret """ - hash = hashlib.sha1(shortuuid.uuid()) - hash.update(settings.SECRET_KEY) + hash = hashlib.sha1(shortuuid.uuid().encode('utf-8')) + hash.update(settings.SECRET_KEY.encode('utf-8')) return hash.hexdigest() From 12f802e167b75e2add0cc26bb1903fb759ba7765 Mon Sep 17 00:00:00 2001 From: JM Ibanez Date: Wed, 6 Aug 2014 11:59:35 +0800 Subject: [PATCH 7/9] Fix reduce and iteritems usage. --- provider/scope.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/provider/scope.py b/provider/scope.py index 46c9d4b7..83ce0ebb 100644 --- a/provider/scope.py +++ b/provider/scope.py @@ -9,6 +9,8 @@ """ from .constants import SCOPES +from functools import reduce +from six import iteritems SCOPE_NAMES = [(name, name) for (value, name) in SCOPES] SCOPE_NAME_DICT = dict([(name, value) for (value, name) in SCOPES]) @@ -73,7 +75,7 @@ def to_names(scope): """ return [ name - for (name, value) in SCOPE_NAME_DICT.iteritems() + for (name, value) in iteritems(SCOPE_NAME_DICT) if check(value, scope) ] From ce933d9a4cdb2d0d63c9e4c40e6430af25feb90d Mon Sep 17 00:00:00 2001 From: JM Ibanez Date: Wed, 6 Aug 2014 12:00:06 +0800 Subject: [PATCH 8/9] Don't iterate through the session keys while modifying the session. This breaks in Python 3, as modifying a dictionary while iterating through its keys raises an exception, because dict.keys() now returns a view instead of a materialized list. We instead pull out the session keys as a list and iterate through that instead, which works for both Python 3 and Python 2 (although in Python 2 this creates an additional list). --- provider/views.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/provider/views.py b/provider/views.py index b99ffd3c..0a8b4abc 100644 --- a/provider/views.py +++ b/provider/views.py @@ -73,7 +73,8 @@ def clear_data(self, request): """ Clear all OAuth related data from the session store. """ - for key in request.session.keys(): + session_keys = list(request.session.keys()) + for key in session_keys: if key.startswith(constants.SESSION_KEY): del request.session[key] From a94f575e59e1e460ca1b33dea36ec5b9307fcf64 Mon Sep 17 00:00:00 2001 From: JM Ibanez Date: Wed, 6 Aug 2014 12:02:56 +0800 Subject: [PATCH 9/9] Remove use of deprecated assert{Not}Equals. --- provider/oauth2/tests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/provider/oauth2/tests.py b/provider/oauth2/tests.py index 08d46b61..3b5966ac 100644 --- a/provider/oauth2/tests.py +++ b/provider/oauth2/tests.py @@ -595,7 +595,7 @@ def test_clear_expired(self): 'client_id': self.get_client().client_id, 'client_secret': self.get_client().client_secret, 'code': code}) - self.assertEquals(200, response.status_code) + self.assertEqual(200, response.status_code) token = json.loads(response.content.decode('utf-8')) self.assertTrue('access_token' in token) access_token = token['access_token'] @@ -620,9 +620,9 @@ def test_clear_expired(self): self.assertEqual(200, response.status_code) token = json.loads(response.content.decode('utf-8')) self.assertTrue('access_token' in token) - self.assertNotEquals(access_token, token['access_token']) + self.assertNotEqual(access_token, token['access_token']) self.assertTrue('refresh_token' in token) - self.assertNotEquals(refresh_token, token['refresh_token']) + self.assertNotEqual(refresh_token, token['refresh_token']) # make sure the orig AccessToken and RefreshToken are gone self.assertFalse(AccessToken.objects.filter(token=access_token)