diff --git a/changelog/1463.breaking.rst b/changelog/1463.breaking.rst new file mode 100644 index 0000000000..230de8867f --- /dev/null +++ b/changelog/1463.breaking.rst @@ -0,0 +1 @@ +Class properties of :class:`BaseFlag` now return instances of :class:`BaseFlag` rather than :class:`flag_value`. If you are accessing :attr:`flag_value.flag`, change to :attr:`BaseFlag.value`. diff --git a/changelog/1463.feature.rst b/changelog/1463.feature.rst new file mode 100644 index 0000000000..9b9daba3bf --- /dev/null +++ b/changelog/1463.feature.rst @@ -0,0 +1 @@ +Class properties of :class:`BaseFlag` now return instances of :class:`BaseFlag`, allowing you to pass them directly where a flag instance is expected. diff --git a/disnake/flags.py b/disnake/flags.py index 5491641c4c..d9f37ebfb1 100644 --- a/disnake/flags.py +++ b/disnake/flags.py @@ -4,13 +4,12 @@ import functools import operator -from collections.abc import Iterator, Sequence +from collections.abc import Iterator, Mapping, Sequence from typing import ( TYPE_CHECKING, Any, Callable, ClassVar, - Generic, NoReturn, Optional, TypeVar, @@ -19,7 +18,7 @@ ) from .enums import UserFlags -from .utils import MISSING, _generated +from .utils import _generated, deprecated if TYPE_CHECKING: from typing_extensions import Self @@ -46,52 +45,21 @@ "InteractionContextTypes", ) -BF = TypeVar("BF", bound="BaseFlags") T = TypeVar("T", bound="BaseFlags") -class flag_value(Generic[T]): +class flag_value: def __init__(self, func: Callable[[Any], int]) -> None: - self.flag = func(None) + self.flag: int = func(None) self.__doc__ = func.__doc__ - self._parent: type[T] = MISSING - - def __eq__(self, other: Any) -> bool: - if isinstance(other, flag_value): - return self.flag == other.flag - if isinstance(other, BaseFlags): - return self._parent is other.__class__ and self.flag == other.value - return False - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) - - def __or__(self, other: Union[flag_value[T], T]) -> T: - if isinstance(other, BaseFlags): - if self._parent is not other.__class__: - msg = f"unsupported operand type(s) for |: flags of '{self._parent.__name__}' and flags of '{other.__class__.__name__}'" - raise TypeError(msg) - return other._from_value(self.flag | other.value) - if not isinstance(other, flag_value): - msg = f"unsupported operand type(s) for |: flags of '{self._parent.__name__}' and {other.__class__}" - raise TypeError(msg) - if self._parent is not other._parent: - msg = f"unsupported operand type(s) for |: flags of '{self._parent.__name__}' and flags of '{other._parent.__name__}'" - raise TypeError(msg) - return self._parent._from_value(self.flag | other.flag) - - def __invert__(self: flag_value[T]) -> T: - return ~self._parent._from_value(self.flag) @overload - def __get__(self, instance: None, owner: type[BF]) -> flag_value[BF]: ... - + def __get__(self, instance: None, owner: type[T]) -> T: ... @overload - def __get__(self, instance: BF, owner: type[BF]) -> bool: ... - - def __get__(self, instance: Optional[BF], owner: type[BF]) -> Any: + def __get__(self, instance: T, owner: type[T]) -> bool: ... + def __get__(self, instance: Optional[T], owner: type[T]) -> Union[bool, T]: if instance is None: - return self + return owner._from_value(self.flag) return instance._has_flag(self.flag) def __set__(self, instance: BaseFlags, value: bool) -> None: @@ -101,22 +69,29 @@ def __repr__(self) -> str: return f"" -class alias_flag_value(flag_value[T]): +class alias_flag_value(flag_value): pass -def all_flags_value(flags: dict[str, int]) -> int: +def all_flags_value(flags: Mapping[str, int]) -> int: return functools.reduce(operator.or_, flags.values()) class BaseFlags: - VALID_FLAGS: ClassVar[dict[str, int]] - DEFAULT_VALUE: ClassVar[int] + VALID_FLAGS: ClassVar[Mapping[str, int]] = {} + DEFAULT_VALUE: ClassVar[int] = 0 value: int __slots__ = ("value",) + if not TYPE_CHECKING: + + @property + @deprecated("BaseFlags.value") + def flag(self) -> int: + return self.value + def __init__(self, **kwargs: bool) -> None: self.value = self.DEFAULT_VALUE for key, value in kwargs.items(): @@ -126,17 +101,16 @@ def __init__(self, **kwargs: bool) -> None: setattr(self, key, value) @classmethod - def __init_subclass__(cls, inverted: bool = False, no_fill_flags: bool = False) -> type[Self]: + def __init_subclass__(cls, inverted: bool = False, no_fill_flags: bool = False) -> None: # add a way to bypass filling flags, eg for ListBaseFlags. if no_fill_flags: - return cls + return - # use the parent's current flags as a base if they exist - cls.VALID_FLAGS = getattr(cls, "VALID_FLAGS", {}).copy() + # use a copy of the parent's current flags as a base if they exist + cls.VALID_FLAGS = dict(cls.VALID_FLAGS) for name, value in cls.__dict__.items(): if isinstance(value, flag_value): - value._parent = cls cls.VALID_FLAGS[name] = value.flag if not cls.VALID_FLAGS: @@ -145,107 +119,68 @@ def __init_subclass__(cls, inverted: bool = False, no_fill_flags: bool = False) cls.DEFAULT_VALUE = all_flags_value(cls.VALID_FLAGS) if inverted else 0 - return cls - @classmethod def _from_value(cls, value: int) -> Self: self = cls.__new__(cls) self.value = value return self - def __eq__(self, other: Any) -> bool: - if isinstance(other, self.__class__): - return self.value == other.value - if isinstance(other, flag_value): - return self.__class__ is other._parent and self.value == other.flag - return False - - def __ne__(self, other: Any) -> bool: - return not self.__eq__(other) + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + return self.value == other.value def __and__(self, other: Self) -> Self: if not isinstance(other, self.__class__): - msg = f"unsupported operand type(s) for &: '{self.__class__.__name__}' and '{other.__class__.__name__}'" - raise TypeError(msg) + return NotImplemented return self._from_value(self.value & other.value) def __iand__(self, other: Self) -> Self: if not isinstance(other, self.__class__): - msg = f"unsupported operand type(s) for &=: '{self.__class__.__name__}' and '{other.__class__.__name__}'" - raise TypeError(msg) + return NotImplemented self.value &= other.value return self - def __or__(self, other: Union[Self, flag_value[Self]]) -> Self: - if isinstance(other, flag_value): - if self.__class__ is not other._parent: - msg = f"unsupported operand type(s) for |: flags of '{self.__class__.__name__}' and flags of '{other._parent.__name__}'" - raise TypeError(msg) - return self._from_value(self.value | other.flag) + def __or__(self, other: Self) -> Self: if not isinstance(other, self.__class__): - msg = f"unsupported operand type(s) for |: '{self.__class__.__name__}' and '{other.__class__.__name__}'" - raise TypeError(msg) + return NotImplemented return self._from_value(self.value | other.value) - def __ior__(self, other: Union[Self, flag_value[Self]]) -> Self: - if isinstance(other, flag_value): - if self.__class__ is not other._parent: - msg = f"unsupported operand type(s) for |=: flags of '{self.__class__.__name__}' and flags of '{other._parent.__name__}'" - raise TypeError(msg) - self.value |= other.flag - return self + def __ior__(self, other: Self) -> Self: if not isinstance(other, self.__class__): - msg = f"unsupported operand type(s) for |=: '{self.__class__.__name__}' and '{other.__class__.__name__}'" - raise TypeError(msg) + return NotImplemented self.value |= other.value return self - def __xor__(self, other: Union[Self, flag_value[Self]]) -> Self: - if isinstance(other, flag_value): - if self.__class__ is not other._parent: - msg = f"unsupported operand type(s) for ^: flags of '{self.__class__.__name__}' and flags of '{other._parent.__name__}'" - raise TypeError(msg) - return self._from_value(self.value ^ other.flag) + def __xor__(self, other: Self) -> Self: if not isinstance(other, self.__class__): - msg = f"unsupported operand type(s) for ^: '{self.__class__.__name__}' and '{other.__class__.__name__}'" - raise TypeError(msg) + return NotImplemented return self._from_value(self.value ^ other.value) - def __ixor__(self, other: Union[Self, flag_value[Self]]) -> Self: - if isinstance(other, flag_value): - if self.__class__ is not other._parent: - msg = f"unsupported operand type(s) for ^=: flags of '{self.__class__.__name__}' and flags of '{other._parent.__name__}'" - raise TypeError(msg) - self.value ^= other.flag - return self + def __ixor__(self, other: Self) -> Self: if not isinstance(other, self.__class__): - msg = f"unsupported operand type(s) for ^=: '{self.__class__.__name__}' and '{other.__class__.__name__}'" - raise TypeError(msg) + return NotImplemented self.value ^= other.value return self def __le__(self, other: Self) -> bool: if not isinstance(other, self.__class__): - msg = f"'<=' not supported between instances of '{self.__class__.__name__}' and '{other.__class__.__name__}'" - raise TypeError(msg) + return NotImplemented return (self.value & other.value) == self.value def __ge__(self, other: Self) -> bool: if not isinstance(other, self.__class__): - msg = f"'>=' not supported between instances of '{self.__class__.__name__}' and '{other.__class__.__name__}'" - raise TypeError(msg) + return NotImplemented return (self.value | other.value) == self.value def __lt__(self, other: Self) -> bool: if not isinstance(other, self.__class__): - msg = f"'<' not supported between instances of '{self.__class__.__name__}' and '{other.__class__.__name__}'" - raise TypeError(msg) + return NotImplemented return (self.value & other.value) == self.value and self.value != other.value def __gt__(self, other: Self) -> bool: if not isinstance(other, self.__class__): - msg = f"'>' not supported between instances of '{self.__class__.__name__}' and '{other.__class__.__name__}'" - raise TypeError(msg) + return NotImplemented return (self.value | other.value) == self.value and self.value != other.value def __invert__(self) -> Self: @@ -278,7 +213,7 @@ def _set_flag(self, o: int, toggle: bool) -> None: elif toggle is False: self.value &= ~o else: - msg = f"Value to set for {self.__class__.__name__} must be a bool." + msg = f"Value to set for {self.__class__.__name__} must be a bool, got {toggle!r}." raise TypeError(msg) @@ -1045,7 +980,7 @@ class Intents(BaseFlags): .. versionchanged:: 2.6 - This can be now be provided on initialisation. + This can be now be provided on initialization. """ __slots__ = () diff --git a/disnake/interactions/base.py b/disnake/interactions/base.py index 4a7012c430..da219d54cc 100644 --- a/disnake/interactions/base.py +++ b/disnake/interactions/base.py @@ -934,9 +934,7 @@ async def defer( if defer_type is InteractionResponseType.deferred_channel_message: # we only want to set flags if we are sending a message - data["flags"] = 0 - if ephemeral: - data["flags"] |= MessageFlags.ephemeral.flag + data["flags"] = MessageFlags(ephemeral=ephemeral is True).value adapter = async_context.get() await adapter.create_interaction_response( diff --git a/tests/test_flags.py b/tests/test_flags.py index 5deb7eb8cb..0b206a3eeb 100644 --- a/tests/test_flags.py +++ b/tests/test_flags.py @@ -111,7 +111,7 @@ def test_flag_value_or(self) -> None: assert ins.value == 5 assert (TestFlags.two | ins).value == 7 - assert not ins.value & TestFlags.sixteen.flag + assert not ins.value & TestFlags.sixteen.value ins |= TestFlags.sixteen assert ins.value == 21 @@ -421,7 +421,7 @@ def test_set_and_get_flag(self) -> None: ins.two = True assert ins.two is True - assert ins.value == TestFlags.two.flag == 1 << 1 + assert ins.value == TestFlags.two.value == 1 << 1 def test_alias_flag_value(self) -> None: ins = TestFlags(three=True) diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 4350350805..05400f9b82 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -14,7 +14,7 @@ def test_init_permissions_keyword_arguments(self) -> None: assert perms.manage_messages is True # check we only have the manage message permission - assert perms.value == Permissions.manage_messages.flag + assert perms.value == Permissions.manage_messages.value def test_init_permissions_keyword_arguments_with_aliases(self) -> None: assert Permissions(read_messages=True, view_channel=False).value == 0