Skip to content
18 changes: 13 additions & 5 deletions src/corbado_python_sdk/services/implementation/session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from jwt import (
ExpiredSignatureError,
ImmatureSignatureError,
InvalidAlgorithmError,
InvalidSignatureError,
decode,
)
Expand All @@ -16,6 +17,7 @@
)

DEFAULT_SESSION_TOKEN_LENGTH = 300
ALLOWED_ALGS = {"RS256"}


class SessionService(BaseModel):
Expand Down Expand Up @@ -90,7 +92,7 @@ def validate_token(self, session_token: StrictStr) -> UserEntity:

# decode short session (jwt) with signing key
try:
payload = decode(jwt=session_token, key=signing_key.key, algorithms=["RS256"])
payload = decode(jwt=session_token, key=signing_key.key, algorithms=list(ALLOWED_ALGS))

# extract information from decoded payload
token_issuer: str = payload.get("iss")
Expand All @@ -104,15 +106,21 @@ def validate_token(self, session_token: StrictStr) -> UserEntity:
)
except ExpiredSignatureError as error:
raise TokenValidationException(
error_type=ValidationErrorType.CODE_JWT_INVALID_SIGNATURE,
message=f"Error occured during token decode: {session_token}. {ValidationErrorType.CODE_JWT_INVALID_SIGNATURE.value}",
error_type=ValidationErrorType.CODE_JWT_EXPIRED,
message=f"Error occured during token decode: {session_token}. {ValidationErrorType.CODE_JWT_EXPIRED.value}",
original_exception=error,
)

except InvalidSignatureError as error:
raise TokenValidationException(
error_type=ValidationErrorType.CODE_JWT_EXPIRED,
message=f"Error occured during token decode: {session_token}. {ValidationErrorType.CODE_JWT_EXPIRED.value}",
error_type=ValidationErrorType.CODE_JWT_INVALID_SIGNATURE,
message=f"Error occured during token decode: {session_token}. {ValidationErrorType.CODE_JWT_INVALID_SIGNATURE.value}",
original_exception=error,
)
except InvalidAlgorithmError as error:
raise TokenValidationException(
error_type=ValidationErrorType.CODE_JWT_INVALID_SIGNATURE,
message="Algorithm not allowed",
original_exception=error,
)

Expand Down
37 changes: 33 additions & 4 deletions tests/unit/test_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ def _provide_jwts(self):
None,
None,
),
# Disallowed algorithm "none"
(
False,
self._generate_jwt(iss="https://auth.acme.com", exp=int(time()) + 100, nbf=int(time()) - 100, algorithm="none"),
PyJWKClientError,
'Unable to find a signing key that matches: "None"',
),
# Success with old Frontend API URL in config (2)
(
True,
Expand All @@ -196,7 +203,14 @@ def _provide_jwts(self):
]

@classmethod
def _generate_jwt(cls, iss: str, exp: int, nbf: int, valid_key: bool = True) -> str:
def _generate_jwt(
cls,
iss: str,
exp: int,
nbf: int,
valid_key: bool = True,
algorithm: str = "RS256",
) -> str:
payload = {
"iss": iss,
"iat": int(time()),
Expand All @@ -206,9 +220,24 @@ def _generate_jwt(cls, iss: str, exp: int, nbf: int, valid_key: bool = True) ->
"name": TEST_NAME,
}

if valid_key:
return encode(payload, key=cls.private_key, algorithm="RS256", headers={"kid": "kid123"})
return encode(payload, key=cls.invalid_private_key, algorithm="RS256", headers={"kid": "kid123"})
key_to_use = cls.private_key if valid_key else cls.invalid_private_key

# unsecured JWT (“none”)
if algorithm.lower() == "none":
# key must be None for alg=none
return encode(
payload,
key=None,
headers={"alg": "none", "typ": "JWT"},
)

# signed JWT (RS256 by default)
return encode(
payload,
key=key_to_use,
algorithm=algorithm,
headers={"kid": "kid123"},
)


class TestSessionService(TestBase):
Expand Down
Loading