diff --git a/provider/forms.py b/provider/forms.py index 9b7fd609..f7c68d8a 100644 --- a/provider/forms.py +++ b/provider/forms.py @@ -51,7 +51,7 @@ def _clean_fields(self): """ try: super(OAuthForm, self)._clean_fields() - except OAuthValidationError, e: + except OAuthValidationError as e: self._errors.update(e.args[0]) def _clean_form(self): @@ -60,5 +60,5 @@ def _clean_form(self): """ try: super(OAuthForm, self)._clean_form() - except OAuthValidationError, e: + except OAuthValidationError as e: self._errors.update(e.args[0]) diff --git a/provider/oauth2/__init__.py b/provider/oauth2/__init__.py index 34796220..57369dc2 100644 --- a/provider/oauth2/__init__.py +++ b/provider/oauth2/__init__.py @@ -1,6 +1,8 @@ -import backends -import forms -import managers -import models -import urls -import views \ No newline at end of file +from __future__ import absolute_import + +from . import backends +from . import forms +from . import managers +from . import models +from . import urls +from . import views diff --git a/provider/oauth2/backends.py b/provider/oauth2/backends.py index db0fb853..00477327 100644 --- a/provider/oauth2/backends.py +++ b/provider/oauth2/backends.py @@ -1,3 +1,5 @@ +import codecs + from ..utils import now from .forms import ClientAuthForm, PublicPasswordGrantForm from .models import AccessToken @@ -29,7 +31,8 @@ def authenticate(self, request=None): try: basic, base64 = auth.split(' ') - client_id, client_secret = base64.decode('base64').split(':') + base64_base64 = codecs.decode(base64.encode('utf-8'), 'base64').decode('utf-8') + client_id, client_secret = base64_base64.split(':') form = ClientAuthForm({ 'client_id': client_id, diff --git a/provider/oauth2/forms.py b/provider/oauth2/forms.py index bb4dcb89..d474be09 100644 --- a/provider/oauth2/forms.py +++ b/provider/oauth2/forms.py @@ -1,6 +1,12 @@ from django import forms from django.contrib.auth import authenticate -from django.utils.encoding import smart_unicode + +try: + from django.utils.encoding import smart_unicode +except ImportError: + # Django >=1.5 + from django.utils.encoding import smart_text as smart_unicode + from django.utils.translation import ugettext as _ from .. import scope from ..constants import RESPONSE_TYPE_CHOICES, SCOPES diff --git a/provider/oauth2/tests.py b/provider/oauth2/tests.py index 53de8af2..3b5966ac 100644 --- a/provider/oauth2/tests.py +++ b/provider/oauth2/tests.py @@ -1,5 +1,7 @@ import json -import urlparse +from six.moves.urllib import parse as urlparse +import six +import codecs import datetime from django.http import QueryDict from django.conf import settings @@ -16,6 +18,10 @@ from .backends import BasicClientBackend, RequestParamsClientBackend from .backends import AccessTokenBackend +def _py3_base64_str_shim(val): + val_b = val.encode('utf-8') + return codecs.encode(val_b, 'base64').decode('utf-8') + @skipIfCustomUser class BaseOAuth2TestCase(TestCase): @@ -89,7 +95,7 @@ def test_authorization_requires_client_id(self): response = self.client.get(self.auth_url2()) self.assertEqual(400, response.status_code) - self.assertTrue("An unauthorized client tried to access your resources." in response.content) + self.assertTrue("An unauthorized client tried to access your resources." in response.content.decode('utf-8')) def test_authorization_rejects_invalid_client_id(self): self.login() @@ -97,7 +103,7 @@ def test_authorization_rejects_invalid_client_id(self): response = self.client.get(self.auth_url2()) self.assertEqual(400, response.status_code) - self.assertTrue("An unauthorized client tried to access your resources." in response.content) + self.assertTrue("An unauthorized client tried to access your resources." in response.content.decode('utf-8')) def test_authorization_requires_response_type(self): self.login() @@ -105,7 +111,7 @@ def test_authorization_requires_response_type(self): response = self.client.get(self.auth_url2()) self.assertEqual(400, response.status_code) - self.assertTrue(escape(u"No 'response_type' supplied.") in response.content) + self.assertTrue(escape(u"No 'response_type' supplied.") in response.content.decode('utf-8')) def test_authorization_requires_supported_response_type(self): self.login() @@ -113,7 +119,7 @@ def test_authorization_requires_supported_response_type(self): response = self.client.get(self.auth_url2()) self.assertEqual(400, response.status_code) - self.assertTrue(escape(u"'unsupported' is not a supported response type.") in response.content) + self.assertTrue(escape(u"'unsupported' is not a supported response type.") in response.content.decode('utf-8')) response = self.client.get(self.auth_url() + '?client_id=%s&response_type=code' % self.get_client().client_id) response = self.client.get(self.auth_url2()) @@ -132,7 +138,7 @@ def test_authorization_requires_a_valid_redirect_uri(self): response = self.client.get(self.auth_url2()) self.assertEqual(400, response.status_code) - self.assertTrue(escape(u"The requested redirect didn't match the client settings.") in response.content) + self.assertTrue(escape(u"The requested redirect didn't match the client settings.") in response.content.decode('utf-8')) response = self.client.get(self.auth_url() + '?client_id=%s&response_type=code&redirect_uri=%s' % ( self.get_client().client_id, @@ -148,7 +154,7 @@ def test_authorization_requires_a_valid_scope(self): response = self.client.get(self.auth_url2()) self.assertEqual(400, response.status_code) - self.assertTrue(escape(u"'invalid' is not a valid scope.") in response.content) + self.assertTrue(escape(u"'invalid' is not a valid scope.") in response.content.decode('utf-8')) response = self.client.get(self.auth_url() + '?client_id=%s&response_type=code&scope=%s' % ( self.get_client().client_id, @@ -223,7 +229,7 @@ def test_fetching_access_token_with_invalid_client(self): 'client_secret': self.get_client().client_secret, }) self.assertEqual(400, response.status_code, response.content) - self.assertEqual('invalid_client', json.loads(response.content)['error']) + self.assertEqual('invalid_client', json.loads(response.content.decode('utf-8'))['error']) def test_fetching_access_token_with_invalid_grant(self): self.login() @@ -236,7 +242,7 @@ def test_fetching_access_token_with_invalid_grant(self): 'code': '123'}) self.assertEqual(400, response.status_code, response.content) - self.assertEqual('invalid_grant', json.loads(response.content)['error']) + self.assertEqual('invalid_grant', json.loads(response.content.decode('utf-8'))['error']) def _login_authorize_get_token(self): required_props = ['access_token', 'token_type'] @@ -256,7 +262,7 @@ def _login_authorize_get_token(self): self.assertEqual(200, response.status_code, response.content) - token = json.loads(response.content) + token = json.loads(response.content.decode('utf-8')) for prop in required_props: self.assertIn(prop, token, "Access token response missing " @@ -283,7 +289,7 @@ def test_fetching_access_token_with_invalid_grant_type(self): }) self.assertEqual(400, response.status_code) - self.assertEqual('unsupported_grant_type', json.loads(response.content)['error'], + self.assertEqual('unsupported_grant_type', json.loads(response.content.decode('utf-8'))['error'], response.content) def test_fetching_single_access_token(self): @@ -324,7 +330,7 @@ def test_fetching_access_token_multiple_times(self): 'code': code}) self.assertEqual(400, response.status_code) - self.assertEqual('invalid_grant', json.loads(response.content)['error']) + self.assertEqual('invalid_grant', json.loads(response.content.decode('utf-8'))['error']) def test_escalating_the_scope(self): self.login() @@ -339,7 +345,7 @@ def test_escalating_the_scope(self): 'scope': 'read write'}) self.assertEqual(400, response.status_code) - self.assertEqual('invalid_scope', json.loads(response.content)['error']) + self.assertEqual('invalid_scope', json.loads(response.content.decode('utf-8'))['error']) def test_refreshing_an_access_token(self): token = self._login_authorize_get_token() @@ -361,7 +367,7 @@ def test_refreshing_an_access_token(self): }) self.assertEqual(400, response.status_code) - self.assertEqual('invalid_grant', json.loads(response.content)['error'], + self.assertEqual('invalid_grant', json.loads(response.content.decode('utf-8'))['error'], response.content) def test_password_grant_public(self): @@ -378,8 +384,8 @@ def test_password_grant_public(self): }) self.assertEqual(200, response.status_code, response.content) - self.assertNotIn('refresh_token', json.loads(response.content)) - expires_in = json.loads(response.content)['expires_in'] + self.assertNotIn('refresh_token', json.loads(response.content.decode('utf-8'))) + expires_in = json.loads(response.content.decode('utf-8'))['expires_in'] expires_in_days = round(expires_in / (60.0 * 60.0 * 24.0)) self.assertEqual(expires_in_days, constants.EXPIRE_DELTA_PUBLIC.days) @@ -397,7 +403,7 @@ def test_password_grant_confidential(self): }) self.assertEqual(200, response.status_code, response.content) - self.assertTrue(json.loads(response.content)['refresh_token']) + self.assertTrue(json.loads(response.content.decode('utf-8'))['refresh_token']) def test_password_grant_confidential_no_secret(self): c = self.get_client() @@ -411,7 +417,7 @@ def test_password_grant_confidential_no_secret(self): 'password': self.get_password(), }) - self.assertEqual('invalid_client', json.loads(response.content)['error']) + self.assertEqual('invalid_client', json.loads(response.content.decode('utf-8'))['error']) def test_password_grant_invalid_password_public(self): c = self.get_client() @@ -426,7 +432,7 @@ def test_password_grant_invalid_password_public(self): }) self.assertEqual(400, response.status_code, response.content) - self.assertEqual('invalid_client', json.loads(response.content)['error']) + self.assertEqual('invalid_client', json.loads(response.content.decode('utf-8'))['error']) def test_password_grant_invalid_password_confidential(self): c = self.get_client() @@ -442,7 +448,7 @@ def test_password_grant_invalid_password_confidential(self): }) self.assertEqual(400, response.status_code, response.content) - self.assertEqual('invalid_grant', json.loads(response.content)['error']) + self.assertEqual('invalid_grant', json.loads(response.content.decode('utf-8'))['error']) def test_access_token_response_valid_token_type(self): token = self._login_authorize_get_token() @@ -454,9 +460,10 @@ class AuthBackendTest(BaseOAuth2TestCase): def test_basic_client_backend(self): request = type('Request', (object,), {'META': {}})() - request.META['HTTP_AUTHORIZATION'] = "Basic " + "{0}:{1}".format( - self.get_client().client_id, - self.get_client().client_secret).encode('base64') + request.META['HTTP_AUTHORIZATION'] = "Basic " + _py3_base64_str_shim( + "{0}:{1}".format( + self.get_client().client_id, + self.get_client().client_secret)) self.assertEqual(BasicClientBackend().authenticate(request).id, 2, "Didn't return the right client.") @@ -496,13 +503,13 @@ def test_authorization_enforces_SSL(self): response = self.client.get(self.auth_url()) self.assertEqual(400, response.status_code) - self.assertTrue("A secure connection is required." in response.content) + self.assertTrue("A secure connection is required." in response.content.decode('utf-8')) def test_access_token_enforces_SSL(self): response = self.client.post(self.access_token_url(), {}) self.assertEqual(400, response.status_code) - self.assertTrue("A secure connection is required." in response.content) + self.assertTrue("A secure connection is required." in response.content.decode('utf-8')) class ClientFormTest(TestCase): @@ -578,7 +585,8 @@ def test_clear_expired(self): self.assertTrue('code' in location) # verify that Grant with code exists - code = urlparse.parse_qs(location)['code'][0] + location_qs = urlparse.urlparse(location)[4] + code = urlparse.parse_qs(location_qs)['code'][0] self.assertTrue(Grant.objects.filter(code=code).exists()) # use the code/grant @@ -587,8 +595,8 @@ def test_clear_expired(self): 'client_id': self.get_client().client_id, 'client_secret': self.get_client().client_secret, 'code': code}) - self.assertEquals(200, response.status_code) - token = json.loads(response.content) + self.assertEqual(200, response.status_code) + token = json.loads(response.content.decode('utf-8')) self.assertTrue('access_token' in token) access_token = token['access_token'] self.assertTrue('refresh_token' in token) @@ -610,11 +618,11 @@ def test_clear_expired(self): 'client_secret': self.get_client().client_secret, }) self.assertEqual(200, response.status_code) - token = json.loads(response.content) + token = json.loads(response.content.decode('utf-8')) self.assertTrue('access_token' in token) - self.assertNotEquals(access_token, token['access_token']) + self.assertNotEqual(access_token, token['access_token']) self.assertTrue('refresh_token' in token) - self.assertNotEquals(refresh_token, token['refresh_token']) + self.assertNotEqual(refresh_token, token['refresh_token']) # make sure the orig AccessToken and RefreshToken are gone self.assertFalse(AccessToken.objects.filter(token=access_token) diff --git a/provider/scope.py b/provider/scope.py index 46c9d4b7..83ce0ebb 100644 --- a/provider/scope.py +++ b/provider/scope.py @@ -9,6 +9,8 @@ """ from .constants import SCOPES +from functools import reduce +from six import iteritems SCOPE_NAMES = [(name, name) for (value, name) in SCOPES] SCOPE_NAME_DICT = dict([(name, value) for (value, name) in SCOPES]) @@ -73,7 +75,7 @@ def to_names(scope): """ return [ name - for (name, value) in SCOPE_NAME_DICT.iteritems() + for (name, value) in iteritems(SCOPE_NAME_DICT) if check(value, scope) ] diff --git a/provider/utils.py b/provider/utils.py index 957a5c75..1cacc737 100644 --- a/provider/utils.py +++ b/provider/utils.py @@ -1,5 +1,6 @@ import hashlib import shortuuid +import six from datetime import datetime, tzinfo from django.conf import settings from django.utils import dateparse @@ -31,8 +32,8 @@ def short_token(): """ Generate a hash that can be used as an application identifier """ - hash = hashlib.sha1(shortuuid.uuid()) - hash.update(settings.SECRET_KEY) + hash = hashlib.sha1(shortuuid.uuid().encode('utf-8')) + hash.update(settings.SECRET_KEY.encode('utf-8')) return hash.hexdigest()[::2] @@ -40,8 +41,8 @@ def long_token(): """ Generate a hash that can be used as an application secret """ - hash = hashlib.sha1(shortuuid.uuid()) - hash.update(settings.SECRET_KEY) + hash = hashlib.sha1(shortuuid.uuid().encode('utf-8')) + hash.update(settings.SECRET_KEY.encode('utf-8')) return hash.hexdigest() diff --git a/provider/views.py b/provider/views.py index dd1200df..0a8b4abc 100644 --- a/provider/views.py +++ b/provider/views.py @@ -1,11 +1,13 @@ +from __future__ import absolute_import + import json -import urlparse +from six.moves.urllib import parse as urlparse from django.http import HttpResponse from django.http import HttpResponseRedirect, QueryDict from django.utils.translation import ugettext as _ from django.views.generic.base import TemplateView from django.core.exceptions import ObjectDoesNotExist -from oauth2.models import Client +from .oauth2.models import Client from . import constants, scope @@ -71,7 +73,8 @@ def clear_data(self, request): """ Clear all OAuth related data from the session store. """ - for key in request.session.keys(): + session_keys = list(request.session.keys()) + for key in session_keys: if key.startswith(constants.SESSION_KEY): del request.session[key] @@ -255,7 +258,7 @@ def handle(self, request, post_data=None): try: client, data = self._validate_client(request, data) - except OAuthError, e: + except OAuthError as e: return self.error_response(request, e.args[0], status=400) authorization_form = self.get_authorization_form(request, client, @@ -596,5 +599,5 @@ def post(self, request): try: return handler(request, request.POST, client) - except OAuthError, e: + except OAuthError as e: return self.error_response(e.args[0]) diff --git a/requirements.txt b/requirements.txt index a79a5d9f..37a92814 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ Django>=1.4 shortuuid>=0.3 +six>=1.7 diff --git a/setup.py b/setup.py index f9f55b85..884227a9 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,8 @@ 'Framework :: Django', ], install_requires=[ - "shortuuid>=0.3" + "shortuuid>=0.3", + "six>=1.7", ], include_package_data=True, zip_safe=False,