Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions provider/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _clean_fields(self):
"""
try:
super(OAuthForm, self)._clean_fields()
except OAuthValidationError, e:
except OAuthValidationError as e:
self._errors.update(e.args[0])

def _clean_form(self):
Expand All @@ -60,5 +60,5 @@ def _clean_form(self):
"""
try:
super(OAuthForm, self)._clean_form()
except OAuthValidationError, e:
except OAuthValidationError as e:
self._errors.update(e.args[0])
14 changes: 8 additions & 6 deletions provider/oauth2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import backends
import forms
import managers
import models
import urls
import views
from __future__ import absolute_import

from . import backends
from . import forms
from . import managers
from . import models
from . import urls
from . import views
5 changes: 4 additions & 1 deletion provider/oauth2/backends.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import codecs

from ..utils import now
from .forms import ClientAuthForm, PublicPasswordGrantForm
from .models import AccessToken
Expand Down Expand Up @@ -29,7 +31,8 @@ def authenticate(self, request=None):

try:
basic, base64 = auth.split(' ')
client_id, client_secret = base64.decode('base64').split(':')
base64_base64 = codecs.decode(base64.encode('utf-8'), 'base64').decode('utf-8')
client_id, client_secret = base64_base64.split(':')

form = ClientAuthForm({
'client_id': client_id,
Expand Down
8 changes: 7 additions & 1 deletion provider/oauth2/forms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from django import forms
from django.contrib.auth import authenticate
from django.utils.encoding import smart_unicode

try:
from django.utils.encoding import smart_unicode
except ImportError:
# Django >=1.5
from django.utils.encoding import smart_text as smart_unicode

from django.utils.translation import ugettext as _
from .. import scope
from ..constants import RESPONSE_TYPE_CHOICES, SCOPES
Expand Down
70 changes: 39 additions & 31 deletions provider/oauth2/tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import urlparse
from six.moves.urllib import parse as urlparse
import six
import codecs
import datetime
from django.http import QueryDict
from django.conf import settings
Expand All @@ -16,6 +18,10 @@
from .backends import BasicClientBackend, RequestParamsClientBackend
from .backends import AccessTokenBackend

def _py3_base64_str_shim(val):
val_b = val.encode('utf-8')
return codecs.encode(val_b, 'base64').decode('utf-8')


@skipIfCustomUser
class BaseOAuth2TestCase(TestCase):
Expand Down Expand Up @@ -89,31 +95,31 @@ def test_authorization_requires_client_id(self):
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
self.assertTrue("An unauthorized client tried to access your resources." in response.content)
self.assertTrue("An unauthorized client tried to access your resources." in response.content.decode('utf-8'))

def test_authorization_rejects_invalid_client_id(self):
self.login()
response = self.client.get(self.auth_url() + '?client_id=123')
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
self.assertTrue("An unauthorized client tried to access your resources." in response.content)
self.assertTrue("An unauthorized client tried to access your resources." in response.content.decode('utf-8'))

def test_authorization_requires_response_type(self):
self.login()
response = self.client.get(self.auth_url() + '?client_id=%s' % self.get_client().client_id)
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
self.assertTrue(escape(u"No 'response_type' supplied.") in response.content)
self.assertTrue(escape(u"No 'response_type' supplied.") in response.content.decode('utf-8'))

def test_authorization_requires_supported_response_type(self):
self.login()
response = self.client.get(self.auth_url() + '?client_id=%s&response_type=unsupported' % self.get_client().client_id)
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
self.assertTrue(escape(u"'unsupported' is not a supported response type.") in response.content)
self.assertTrue(escape(u"'unsupported' is not a supported response type.") in response.content.decode('utf-8'))

response = self.client.get(self.auth_url() + '?client_id=%s&response_type=code' % self.get_client().client_id)
response = self.client.get(self.auth_url2())
Expand All @@ -132,7 +138,7 @@ def test_authorization_requires_a_valid_redirect_uri(self):
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
self.assertTrue(escape(u"The requested redirect didn't match the client settings.") in response.content)
self.assertTrue(escape(u"The requested redirect didn't match the client settings.") in response.content.decode('utf-8'))

response = self.client.get(self.auth_url() + '?client_id=%s&response_type=code&redirect_uri=%s' % (
self.get_client().client_id,
Expand All @@ -148,7 +154,7 @@ def test_authorization_requires_a_valid_scope(self):
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
self.assertTrue(escape(u"'invalid' is not a valid scope.") in response.content)
self.assertTrue(escape(u"'invalid' is not a valid scope.") in response.content.decode('utf-8'))

response = self.client.get(self.auth_url() + '?client_id=%s&response_type=code&scope=%s' % (
self.get_client().client_id,
Expand Down Expand Up @@ -223,7 +229,7 @@ def test_fetching_access_token_with_invalid_client(self):
'client_secret': self.get_client().client_secret, })

self.assertEqual(400, response.status_code, response.content)
self.assertEqual('invalid_client', json.loads(response.content)['error'])
self.assertEqual('invalid_client', json.loads(response.content.decode('utf-8'))['error'])

def test_fetching_access_token_with_invalid_grant(self):
self.login()
Expand All @@ -236,7 +242,7 @@ def test_fetching_access_token_with_invalid_grant(self):
'code': '123'})

self.assertEqual(400, response.status_code, response.content)
self.assertEqual('invalid_grant', json.loads(response.content)['error'])
self.assertEqual('invalid_grant', json.loads(response.content.decode('utf-8'))['error'])

def _login_authorize_get_token(self):
required_props = ['access_token', 'token_type']
Expand All @@ -256,7 +262,7 @@ def _login_authorize_get_token(self):

self.assertEqual(200, response.status_code, response.content)

token = json.loads(response.content)
token = json.loads(response.content.decode('utf-8'))

for prop in required_props:
self.assertIn(prop, token, "Access token response missing "
Expand All @@ -283,7 +289,7 @@ def test_fetching_access_token_with_invalid_grant_type(self):
})

self.assertEqual(400, response.status_code)
self.assertEqual('unsupported_grant_type', json.loads(response.content)['error'],
self.assertEqual('unsupported_grant_type', json.loads(response.content.decode('utf-8'))['error'],
response.content)

def test_fetching_single_access_token(self):
Expand Down Expand Up @@ -324,7 +330,7 @@ def test_fetching_access_token_multiple_times(self):
'code': code})

self.assertEqual(400, response.status_code)
self.assertEqual('invalid_grant', json.loads(response.content)['error'])
self.assertEqual('invalid_grant', json.loads(response.content.decode('utf-8'))['error'])

def test_escalating_the_scope(self):
self.login()
Expand All @@ -339,7 +345,7 @@ def test_escalating_the_scope(self):
'scope': 'read write'})

self.assertEqual(400, response.status_code)
self.assertEqual('invalid_scope', json.loads(response.content)['error'])
self.assertEqual('invalid_scope', json.loads(response.content.decode('utf-8'))['error'])

def test_refreshing_an_access_token(self):
token = self._login_authorize_get_token()
Expand All @@ -361,7 +367,7 @@ def test_refreshing_an_access_token(self):
})

self.assertEqual(400, response.status_code)
self.assertEqual('invalid_grant', json.loads(response.content)['error'],
self.assertEqual('invalid_grant', json.loads(response.content.decode('utf-8'))['error'],
response.content)

def test_password_grant_public(self):
Expand All @@ -378,8 +384,8 @@ def test_password_grant_public(self):
})

self.assertEqual(200, response.status_code, response.content)
self.assertNotIn('refresh_token', json.loads(response.content))
expires_in = json.loads(response.content)['expires_in']
self.assertNotIn('refresh_token', json.loads(response.content.decode('utf-8')))
expires_in = json.loads(response.content.decode('utf-8'))['expires_in']
expires_in_days = round(expires_in / (60.0 * 60.0 * 24.0))
self.assertEqual(expires_in_days, constants.EXPIRE_DELTA_PUBLIC.days)

Expand All @@ -397,7 +403,7 @@ def test_password_grant_confidential(self):
})

self.assertEqual(200, response.status_code, response.content)
self.assertTrue(json.loads(response.content)['refresh_token'])
self.assertTrue(json.loads(response.content.decode('utf-8'))['refresh_token'])

def test_password_grant_confidential_no_secret(self):
c = self.get_client()
Expand All @@ -411,7 +417,7 @@ def test_password_grant_confidential_no_secret(self):
'password': self.get_password(),
})

self.assertEqual('invalid_client', json.loads(response.content)['error'])
self.assertEqual('invalid_client', json.loads(response.content.decode('utf-8'))['error'])

def test_password_grant_invalid_password_public(self):
c = self.get_client()
Expand All @@ -426,7 +432,7 @@ def test_password_grant_invalid_password_public(self):
})

self.assertEqual(400, response.status_code, response.content)
self.assertEqual('invalid_client', json.loads(response.content)['error'])
self.assertEqual('invalid_client', json.loads(response.content.decode('utf-8'))['error'])

def test_password_grant_invalid_password_confidential(self):
c = self.get_client()
Expand All @@ -442,7 +448,7 @@ def test_password_grant_invalid_password_confidential(self):
})

self.assertEqual(400, response.status_code, response.content)
self.assertEqual('invalid_grant', json.loads(response.content)['error'])
self.assertEqual('invalid_grant', json.loads(response.content.decode('utf-8'))['error'])

def test_access_token_response_valid_token_type(self):
token = self._login_authorize_get_token()
Expand All @@ -454,9 +460,10 @@ class AuthBackendTest(BaseOAuth2TestCase):

def test_basic_client_backend(self):
request = type('Request', (object,), {'META': {}})()
request.META['HTTP_AUTHORIZATION'] = "Basic " + "{0}:{1}".format(
self.get_client().client_id,
self.get_client().client_secret).encode('base64')
request.META['HTTP_AUTHORIZATION'] = "Basic " + _py3_base64_str_shim(
"{0}:{1}".format(
self.get_client().client_id,
self.get_client().client_secret))

self.assertEqual(BasicClientBackend().authenticate(request).id,
2, "Didn't return the right client.")
Expand Down Expand Up @@ -496,13 +503,13 @@ def test_authorization_enforces_SSL(self):
response = self.client.get(self.auth_url())

self.assertEqual(400, response.status_code)
self.assertTrue("A secure connection is required." in response.content)
self.assertTrue("A secure connection is required." in response.content.decode('utf-8'))

def test_access_token_enforces_SSL(self):
response = self.client.post(self.access_token_url(), {})

self.assertEqual(400, response.status_code)
self.assertTrue("A secure connection is required." in response.content)
self.assertTrue("A secure connection is required." in response.content.decode('utf-8'))


class ClientFormTest(TestCase):
Expand Down Expand Up @@ -578,7 +585,8 @@ def test_clear_expired(self):
self.assertTrue('code' in location)

# verify that Grant with code exists
code = urlparse.parse_qs(location)['code'][0]
location_qs = urlparse.urlparse(location)[4]
code = urlparse.parse_qs(location_qs)['code'][0]
self.assertTrue(Grant.objects.filter(code=code).exists())

# use the code/grant
Expand All @@ -587,8 +595,8 @@ def test_clear_expired(self):
'client_id': self.get_client().client_id,
'client_secret': self.get_client().client_secret,
'code': code})
self.assertEquals(200, response.status_code)
token = json.loads(response.content)
self.assertEqual(200, response.status_code)
token = json.loads(response.content.decode('utf-8'))
self.assertTrue('access_token' in token)
access_token = token['access_token']
self.assertTrue('refresh_token' in token)
Expand All @@ -610,11 +618,11 @@ def test_clear_expired(self):
'client_secret': self.get_client().client_secret,
})
self.assertEqual(200, response.status_code)
token = json.loads(response.content)
token = json.loads(response.content.decode('utf-8'))
self.assertTrue('access_token' in token)
self.assertNotEquals(access_token, token['access_token'])
self.assertNotEqual(access_token, token['access_token'])
self.assertTrue('refresh_token' in token)
self.assertNotEquals(refresh_token, token['refresh_token'])
self.assertNotEqual(refresh_token, token['refresh_token'])

# make sure the orig AccessToken and RefreshToken are gone
self.assertFalse(AccessToken.objects.filter(token=access_token)
Expand Down
4 changes: 3 additions & 1 deletion provider/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"""

from .constants import SCOPES
from functools import reduce
from six import iteritems

SCOPE_NAMES = [(name, name) for (value, name) in SCOPES]
SCOPE_NAME_DICT = dict([(name, value) for (value, name) in SCOPES])
Expand Down Expand Up @@ -73,7 +75,7 @@ def to_names(scope):
"""
return [
name
for (name, value) in SCOPE_NAME_DICT.iteritems()
for (name, value) in iteritems(SCOPE_NAME_DICT)
if check(value, scope)
]

Expand Down
9 changes: 5 additions & 4 deletions provider/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hashlib
import shortuuid
import six
from datetime import datetime, tzinfo
from django.conf import settings
from django.utils import dateparse
Expand Down Expand Up @@ -31,17 +32,17 @@ def short_token():
"""
Generate a hash that can be used as an application identifier
"""
hash = hashlib.sha1(shortuuid.uuid())
hash.update(settings.SECRET_KEY)
hash = hashlib.sha1(shortuuid.uuid().encode('utf-8'))
hash.update(settings.SECRET_KEY.encode('utf-8'))
return hash.hexdigest()[::2]


def long_token():
"""
Generate a hash that can be used as an application secret
"""
hash = hashlib.sha1(shortuuid.uuid())
hash.update(settings.SECRET_KEY)
hash = hashlib.sha1(shortuuid.uuid().encode('utf-8'))
hash.update(settings.SECRET_KEY.encode('utf-8'))
return hash.hexdigest()


Expand Down
Loading