|
| 1 | +import json |
| 2 | +import base64 |
| 3 | +import time |
| 4 | + |
| 5 | +from . import oauth2 |
| 6 | + |
| 7 | + |
| 8 | +def base64decode(raw): |
| 9 | + """A helper can handle a padding-less raw input""" |
| 10 | + raw += '=' * (-len(raw) % 4) # https://stackoverflow.com/a/32517907/728675 |
| 11 | + return base64.b64decode(raw).decode("utf-8") |
| 12 | + |
| 13 | + |
| 14 | +def decode_id_token(id_token, client_id=None, issuer=None, nonce=None, now=None): |
| 15 | + """Decodes and validates an id_token and returns its claims as a dictionary. |
| 16 | +
|
| 17 | + ID token claims would at least contain: "iss", "sub", "aud", "exp", "iat", |
| 18 | + per `specs <https://openid.net/specs/openid-connect-core-1_0.html#IDToken>`_ |
| 19 | + and it may contain other optional content such as "preferred_username", |
| 20 | + `maybe more <https://openid.net/specs/openid-connect-core-1_0.html#Claims>`_ |
| 21 | + """ |
| 22 | + decoded = json.loads(base64decode(id_token.split('.')[1])) |
| 23 | + err = None # https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation |
| 24 | + if issuer and issuer != decoded["iss"]: |
| 25 | + # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationResponse |
| 26 | + err = ('2. The Issuer Identifier for the OpenID Provider, "%s", ' |
| 27 | + "(which is typically obtained during Discovery), " |
| 28 | + "MUST exactly match the value of the iss (issuer) Claim.") % issuer |
| 29 | + if client_id: |
| 30 | + valid_aud = client_id in decoded["aud"] if isinstance( |
| 31 | + decoded["aud"], list) else client_id == decoded["aud"] |
| 32 | + if not valid_aud: |
| 33 | + err = "3. The aud (audience) Claim must contain this client's client_id." |
| 34 | + # Per specs: |
| 35 | + # 6. If the ID Token is received via direct communication between |
| 36 | + # the Client and the Token Endpoint (which it is in this flow), |
| 37 | + # the TLS server validation MAY be used to validate the issuer |
| 38 | + # in place of checking the token signature. |
| 39 | + if (now or time.time()) > decoded["exp"]: |
| 40 | + err = "9. The current time MUST be before the time represented by the exp Claim." |
| 41 | + if nonce and nonce != decoded.get("nonce"): |
| 42 | + err = ("11. Nonce must be the same value " |
| 43 | + "as the one that was sent in the Authentication Request") |
| 44 | + if err: |
| 45 | + raise RuntimeError("%s id_token was: %s" % ( |
| 46 | + err, json.dumps(decoded, indent=2))) |
| 47 | + return decoded |
| 48 | + |
| 49 | + |
| 50 | +class Client(oauth2.Client): |
| 51 | + """OpenID Connect is a layer on top of the OAuth2. |
| 52 | +
|
| 53 | + See its specs at https://openid.net/connect/ |
| 54 | + """ |
| 55 | + |
| 56 | + def decode_id_token(self, id_token, nonce=None): |
| 57 | + """See :func:`~decode_id_token`.""" |
| 58 | + return decode_id_token( |
| 59 | + id_token, nonce=nonce, |
| 60 | + client_id=self.client_id, issuer=self.configuration.get("issuer")) |
| 61 | + |
| 62 | + def _obtain_token(self, grant_type, *args, **kwargs): |
| 63 | + """The result will also contain one more key "id_token_claims", |
| 64 | + whose value will be a dictionary returned by :func:`~decode_id_token`. |
| 65 | + """ |
| 66 | + ret = super(Client, self)._obtain_token(grant_type, *args, **kwargs) |
| 67 | + if "id_token" in ret: |
| 68 | + ret["id_token_claims"] = self.decode_id_token(ret["id_token"]) |
| 69 | + return ret |
| 70 | + |
0 commit comments