diff --git a/rest_framework-stubs/serializers.pyi b/rest_framework-stubs/serializers.pyi index d17daccc3..f73920f8b 100644 --- a/rest_framework-stubs/serializers.pyi +++ b/rest_framework-stubs/serializers.pyi @@ -1,5 +1,5 @@ from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping, Sequence -from typing import Any, Generic, NoReturn, TypeVar +from typing import Any, Generic, NoReturn, TypeVar, overload from _typeshed import Self from django.core.exceptions import ValidationError as DjangoValidationError @@ -81,14 +81,33 @@ class BaseSerializer(Generic[_IN], Field[Any, Any, Any, _IN]): instance: _IN | None initial_data: Any _context: dict[str, Any] - def __new__(cls: type[Self], *args: Any, **kwargs: Any) -> Self: ... def __class_getitem__(cls, *args, **kwargs): ... - def __init__( - self, + # When both __init__ and __new__ are present, mypy will prefer __init__ + @overload + def __new__( + cls: type[Self], + instance: Iterable[_IN] | None, + many: Literal[True], + allow_empty: bool = ..., + context: dict[str, Any] = ..., + read_only: bool = ..., + write_only: bool = ..., + required: bool = ..., + default: Any = ..., + initial: Any = ..., + source: str = ..., + label: str = ..., + help_text: str = ..., + style: dict[str, Any] = ..., + error_messages: dict[str, str] = ..., + validators: Sequence[Validator[Any]] | None = ..., + allow_null: bool = ..., + ) -> ListSerializer[_IN]: ... + @overload + def __new__( + cls: type[Self], instance: _IN | None = ..., - data: Any = ..., - partial: bool = ..., - many: bool = ..., + many: Literal[False] = ..., allow_empty: bool = ..., context: dict[str, Any] = ..., read_only: bool = ..., @@ -103,7 +122,7 @@ class BaseSerializer(Generic[_IN], Field[Any, Any, Any, _IN]): error_messages: dict[str, str] = ..., validators: Sequence[Validator[Any]] | None = ..., allow_null: bool = ..., - ): ... + ) -> Self: ... @classmethod def many_init(cls, *args: Any, **kwargs: Any) -> BaseSerializer: ... def is_valid(self, raise_exception: bool = ...) -> bool: ... @@ -159,7 +178,7 @@ class ListSerializer( allow_empty: bool | None def __init__( self, - instance: _IN | None = ..., + instance: Iterable[_IN] | None = ..., data: Any = ..., partial: bool = ..., context: dict[str, Any] = ..., @@ -177,7 +196,7 @@ class ListSerializer( error_messages: dict[str, str] = ..., validators: Sequence[Validator[list[Any]]] | None = ..., allow_null: bool = ..., - ): ... + ) -> None: ... def get_initial(self) -> list[Mapping[Any, Any]]: ... def validate(self, attrs: Any) -> Any: ... @property @@ -203,27 +222,40 @@ class ModelSerializer(Serializer, BaseSerializer[_MT]): exclude: Sequence[str] | None depth: int | None extra_kwargs: dict[str, dict[str, Any]] # type: ignore[override] - def __init__( - self, + @overload + def __new__( # type: ignore[misc] + cls: type[Self], + instance: None | _MT | Sequence[_MT] | QuerySet[_MT] | Manager[_MT], + many: Literal[True], + ) -> ListSerializer[_MT]: ... + @overload + def __new__( + cls: type[Self], + instance: None | _MT | Sequence[_MT] | QuerySet[_MT] | Manager[_MT], + many: Literal[False], + ) -> Self: ... + @overload + def __new__( + cls: type[Self], instance: None | _MT | Sequence[_MT] | QuerySet[_MT] | Manager[_MT] = ..., data: Any = ..., partial: bool = ..., many: bool = ..., + allow_empty: bool = ..., context: dict[str, Any] = ..., read_only: bool = ..., write_only: bool = ..., required: bool = ..., - default: _MT | Sequence[_MT] | Callable[[], _MT | Sequence[_MT]] = ..., - initial: _MT | Sequence[_MT] | Callable[[], _MT | Sequence[_MT]] = ..., + default: Any = ..., + initial: Any = ..., source: str = ..., label: str = ..., help_text: str = ..., style: dict[str, Any] = ..., error_messages: dict[str, str] = ..., - validators: Sequence[Validator[_MT]] | None = ..., + validators: Sequence[Validator[Any]] | None = ..., allow_null: bool = ..., - allow_empty: bool = ..., - ): ... + ) -> Self: ... def update(self, instance: _MT, validated_data: Any) -> _MT: ... # type: ignore[override] def create(self, validated_data: Any) -> _MT: ... # type: ignore[override] def save(self, **kwargs: Any) -> _MT: ... # type: ignore[override] diff --git a/tests/typecheck/test_serializers.yml b/tests/typecheck/test_serializers.yml index 59e6d5179..f5fd11472 100644 --- a/tests/typecheck/test_serializers.yml +++ b/tests/typecheck/test_serializers.yml @@ -75,3 +75,21 @@ @cached_property def fields(self) -> BindingDict: return super().fields +- case: test_serializer_many_equals_false + main: | + from rest_framework import serializers + + class TestSerializer(serializers.Serializer[int]): + pass + + test_serializer = TestSerializer(1) + reveal_type(test_serializer) # N: Revealed type is "main.TestSerializer" +- case: test_serializer_many_equals_true + main: | + from rest_framework import serializers + + class TestSerializer(serializers.Serializer[int]): + pass + + test_serializer = TestSerializer(instance=[1, 2], many=True) + reveal_type(test_serializer) # N: Revealed type is "rest_framework.serializers.ListSerializer[builtins.int]"