Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ninja_jwt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Django Ninja JWT - JSON Web Token for Django-Ninja"""

__version__ = "5.3.7"
__version__ = "5.3.9"
61 changes: 38 additions & 23 deletions ninja_jwt/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from ninja.schema import DjangoGetter
from ninja_extra import service_resolver
from ninja_extra.context import RouteContext
from pydantic import ConfigDict, model_validator
from pydantic import ConfigDict, ValidationInfo, model_validator
from pydantic.main import BaseModel

import ninja_jwt.exceptions as exceptions
from ninja_jwt.utils import token_error
Expand All @@ -28,14 +29,21 @@


class SchemaInputService:
def __init__(self, values: SCHEMA_INPUT, model_config: ConfigDict) -> None:
def __init__(
self,
values: SCHEMA_INPUT,
model_config: ConfigDict,
request: Optional[HttpRequest] = None,
) -> None:
self.model_config = model_config
self.values = values

self._request: Optional[HttpRequest] = request

def get_request(self) -> HttpRequest:
if self.model_config.get("extra") == "forbid":
if self.model_config.get("extra") == "forbid" and self._request is None:
return service_resolver(RouteContext).request
return self.values._context.get("request")
return self._request

def get_values(self) -> Dict:
if self.model_config.get("extra") == "forbid":
Expand Down Expand Up @@ -75,7 +83,7 @@ def check_user_authentication_rule(self) -> None:
)

@classmethod
def validate_values(cls, request: HttpRequest, values: Dict) -> Dict:
def validate_values(cls, values: Dict) -> Dict:
if user_name_field not in values and "password" not in values:
raise exceptions.ValidationError(
{
Expand All @@ -92,16 +100,16 @@ def validate_values(cls, request: HttpRequest, values: Dict) -> Dict:
if not values.get("password"):
raise exceptions.ValidationError({"password": "password is required"})

_user = authenticate(request, **values)
cls._user = _user
return values

def authenticate(self, request: HttpRequest, credentials: Dict) -> None:
self._user = authenticate(request, **credentials)

if not (_user is not None and _user.is_active):
if not (self._user is not None and self._user.is_active):
raise exceptions.AuthenticationFailed(
cls._default_error_messages["no_active_account"]
self._default_error_messages["no_active_account"]
)

return values

def output_schema(self) -> Schema:
warnings.warn(
"output_schema() is deprecated in favor of to_response_schema()",
Expand All @@ -119,36 +127,45 @@ def get_token(cls, user: AbstractUser) -> Dict:

class TokenObtainInputSchemaBase(ModelSchema, TokenInputSchemaMixin):
class Config:
# extra = "allow"
# extra = "forbid"
model = get_user_model()
model_fields = ["password", user_name_field]
extra = "forbid"

@model_validator(mode="before")
def validate_inputs(cls, values: SCHEMA_INPUT) -> dict:
schema_input = SchemaInputService(values, cls.model_config)
input_values = schema_input.get_values()
request = schema_input.get_request()

if isinstance(input_values, dict):
values.update(cls.validate_values(request=request, values=input_values))
return values
cls.validate_values(values=input_values)
return values

@model_validator(mode="after")
def post_validate(cls, values: Dict) -> dict:
return cls.post_validate_schema(values)
def post_validate(
cls, values: "TokenObtainInputSchemaBase", info: ValidationInfo
) -> BaseModel:
schema_input = SchemaInputService(
values.model_dump(), cls.model_config, info.context.get("request")
)

credentials = schema_input.get_values()
request = schema_input.get_request()

values.authenticate(request, credentials)
cls.post_validate_schema(values)

return values

@classmethod
def post_validate_schema(cls, values: Dict) -> dict:
def post_validate_schema(cls, values: "TokenObtainInputSchemaBase") -> None:
"""
This is a post validate process which is common for any token generating schema.
:param values:
:return:
"""
# get_token can return values that wants to apply to `OutputSchema`

data = cls.get_token(cls._user)
data = cls.get_token(values._user)

if not isinstance(data, dict):
raise Exception("`get_token` must return a `typing.Dict` type.")
Expand All @@ -158,9 +175,7 @@ def post_validate_schema(cls, values: Dict) -> dict:
values.__dict__.update(token_data=data)

if api_settings.UPDATE_LAST_LOGIN:
update_last_login(None, cls._user)

return values
update_last_login(None, values._user)

def get_response_schema_init_kwargs(self) -> dict:
return dict(
Expand Down