Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/1463.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Class properties of :class:`BaseFlag` now return instances of :class:`BaseFlag` rather than :class:`flag_value`. If you are accessing :attr:`flag_value.flag`, change to :attr:`BaseFlag.value`.
1 change: 1 addition & 0 deletions changelog/1463.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Class properties of :class:`BaseFlag` now return instances of :class:`BaseFlag`, allowing you to pass them directly where a flag instance is expected.
151 changes: 43 additions & 108 deletions disnake/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@

import functools
import operator
from collections.abc import Iterator, Sequence
from collections.abc import Iterator, Mapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Generic,
NoReturn,
Optional,
TypeVar,
Expand All @@ -19,7 +18,7 @@
)

from .enums import UserFlags
from .utils import MISSING, _generated
from .utils import _generated, deprecated

if TYPE_CHECKING:
from typing_extensions import Self
Expand All @@ -46,52 +45,21 @@
"InteractionContextTypes",
)

BF = TypeVar("BF", bound="BaseFlags")
T = TypeVar("T", bound="BaseFlags")


class flag_value(Generic[T]):
class flag_value:
def __init__(self, func: Callable[[Any], int]) -> None:
self.flag = func(None)
self.flag: int = func(None)
self.__doc__ = func.__doc__
self._parent: type[T] = MISSING

def __eq__(self, other: Any) -> bool:
if isinstance(other, flag_value):
return self.flag == other.flag
if isinstance(other, BaseFlags):
return self._parent is other.__class__ and self.flag == other.value
return False

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)

def __or__(self, other: Union[flag_value[T], T]) -> T:
if isinstance(other, BaseFlags):
if self._parent is not other.__class__:
msg = f"unsupported operand type(s) for |: flags of '{self._parent.__name__}' and flags of '{other.__class__.__name__}'"
raise TypeError(msg)
return other._from_value(self.flag | other.value)
if not isinstance(other, flag_value):
msg = f"unsupported operand type(s) for |: flags of '{self._parent.__name__}' and {other.__class__}"
raise TypeError(msg)
if self._parent is not other._parent:
msg = f"unsupported operand type(s) for |: flags of '{self._parent.__name__}' and flags of '{other._parent.__name__}'"
raise TypeError(msg)
return self._parent._from_value(self.flag | other.flag)

def __invert__(self: flag_value[T]) -> T:
return ~self._parent._from_value(self.flag)

@overload
def __get__(self, instance: None, owner: type[BF]) -> flag_value[BF]: ...

def __get__(self, instance: None, owner: type[T]) -> T: ...
@overload
def __get__(self, instance: BF, owner: type[BF]) -> bool: ...

def __get__(self, instance: Optional[BF], owner: type[BF]) -> Any:
def __get__(self, instance: T, owner: type[T]) -> bool: ...
def __get__(self, instance: Optional[T], owner: type[T]) -> Union[bool, T]:
if instance is None:
return self
return owner._from_value(self.flag)
return instance._has_flag(self.flag)

def __set__(self, instance: BaseFlags, value: bool) -> None:
Expand All @@ -101,22 +69,29 @@ def __repr__(self) -> str:
return f"<flag_value flag={self.flag!r}>"


class alias_flag_value(flag_value[T]):
class alias_flag_value(flag_value):
pass


def all_flags_value(flags: dict[str, int]) -> int:
def all_flags_value(flags: Mapping[str, int]) -> int:
return functools.reduce(operator.or_, flags.values())


class BaseFlags:
VALID_FLAGS: ClassVar[dict[str, int]]
DEFAULT_VALUE: ClassVar[int]
VALID_FLAGS: ClassVar[Mapping[str, int]] = {}
DEFAULT_VALUE: ClassVar[int] = 0

value: int

__slots__ = ("value",)

if not TYPE_CHECKING:

@property
@deprecated("BaseFlags.value")
def flag(self) -> int:
return self.value

def __init__(self, **kwargs: bool) -> None:
self.value = self.DEFAULT_VALUE
for key, value in kwargs.items():
Expand All @@ -126,17 +101,16 @@ def __init__(self, **kwargs: bool) -> None:
setattr(self, key, value)

@classmethod
def __init_subclass__(cls, inverted: bool = False, no_fill_flags: bool = False) -> type[Self]:
def __init_subclass__(cls, inverted: bool = False, no_fill_flags: bool = False) -> None:
# add a way to bypass filling flags, eg for ListBaseFlags.
if no_fill_flags:
return cls
return

# use the parent's current flags as a base if they exist
cls.VALID_FLAGS = getattr(cls, "VALID_FLAGS", {}).copy()
# use a copy of the parent's current flags as a base if they exist
cls.VALID_FLAGS = dict(cls.VALID_FLAGS)

for name, value in cls.__dict__.items():
if isinstance(value, flag_value):
value._parent = cls
cls.VALID_FLAGS[name] = value.flag

if not cls.VALID_FLAGS:
Expand All @@ -145,107 +119,68 @@ def __init_subclass__(cls, inverted: bool = False, no_fill_flags: bool = False)

cls.DEFAULT_VALUE = all_flags_value(cls.VALID_FLAGS) if inverted else 0

return cls

@classmethod
def _from_value(cls, value: int) -> Self:
self = cls.__new__(cls)
self.value = value
return self

def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
return self.value == other.value
if isinstance(other, flag_value):
return self.__class__ is other._parent and self.value == other.flag
return False

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return self.value == other.value

def __and__(self, other: Self) -> Self:
if not isinstance(other, self.__class__):
msg = f"unsupported operand type(s) for &: '{self.__class__.__name__}' and '{other.__class__.__name__}'"
raise TypeError(msg)
return NotImplemented
return self._from_value(self.value & other.value)

def __iand__(self, other: Self) -> Self:
if not isinstance(other, self.__class__):
msg = f"unsupported operand type(s) for &=: '{self.__class__.__name__}' and '{other.__class__.__name__}'"
raise TypeError(msg)
return NotImplemented
self.value &= other.value
return self

def __or__(self, other: Union[Self, flag_value[Self]]) -> Self:
if isinstance(other, flag_value):
if self.__class__ is not other._parent:
msg = f"unsupported operand type(s) for |: flags of '{self.__class__.__name__}' and flags of '{other._parent.__name__}'"
raise TypeError(msg)
return self._from_value(self.value | other.flag)
def __or__(self, other: Self) -> Self:
if not isinstance(other, self.__class__):
msg = f"unsupported operand type(s) for |: '{self.__class__.__name__}' and '{other.__class__.__name__}'"
raise TypeError(msg)
return NotImplemented
return self._from_value(self.value | other.value)

def __ior__(self, other: Union[Self, flag_value[Self]]) -> Self:
if isinstance(other, flag_value):
if self.__class__ is not other._parent:
msg = f"unsupported operand type(s) for |=: flags of '{self.__class__.__name__}' and flags of '{other._parent.__name__}'"
raise TypeError(msg)
self.value |= other.flag
return self
def __ior__(self, other: Self) -> Self:
if not isinstance(other, self.__class__):
msg = f"unsupported operand type(s) for |=: '{self.__class__.__name__}' and '{other.__class__.__name__}'"
raise TypeError(msg)
return NotImplemented
self.value |= other.value
return self

def __xor__(self, other: Union[Self, flag_value[Self]]) -> Self:
if isinstance(other, flag_value):
if self.__class__ is not other._parent:
msg = f"unsupported operand type(s) for ^: flags of '{self.__class__.__name__}' and flags of '{other._parent.__name__}'"
raise TypeError(msg)
return self._from_value(self.value ^ other.flag)
def __xor__(self, other: Self) -> Self:
if not isinstance(other, self.__class__):
msg = f"unsupported operand type(s) for ^: '{self.__class__.__name__}' and '{other.__class__.__name__}'"
raise TypeError(msg)
return NotImplemented
return self._from_value(self.value ^ other.value)

def __ixor__(self, other: Union[Self, flag_value[Self]]) -> Self:
if isinstance(other, flag_value):
if self.__class__ is not other._parent:
msg = f"unsupported operand type(s) for ^=: flags of '{self.__class__.__name__}' and flags of '{other._parent.__name__}'"
raise TypeError(msg)
self.value ^= other.flag
return self
def __ixor__(self, other: Self) -> Self:
if not isinstance(other, self.__class__):
msg = f"unsupported operand type(s) for ^=: '{self.__class__.__name__}' and '{other.__class__.__name__}'"
raise TypeError(msg)
return NotImplemented
self.value ^= other.value
return self

def __le__(self, other: Self) -> bool:
if not isinstance(other, self.__class__):
msg = f"'<=' not supported between instances of '{self.__class__.__name__}' and '{other.__class__.__name__}'"
raise TypeError(msg)
return NotImplemented
return (self.value & other.value) == self.value

def __ge__(self, other: Self) -> bool:
if not isinstance(other, self.__class__):
msg = f"'>=' not supported between instances of '{self.__class__.__name__}' and '{other.__class__.__name__}'"
raise TypeError(msg)
return NotImplemented
return (self.value | other.value) == self.value

def __lt__(self, other: Self) -> bool:
if not isinstance(other, self.__class__):
msg = f"'<' not supported between instances of '{self.__class__.__name__}' and '{other.__class__.__name__}'"
raise TypeError(msg)
return NotImplemented
return (self.value & other.value) == self.value and self.value != other.value

def __gt__(self, other: Self) -> bool:
if not isinstance(other, self.__class__):
msg = f"'>' not supported between instances of '{self.__class__.__name__}' and '{other.__class__.__name__}'"
raise TypeError(msg)
return NotImplemented
return (self.value | other.value) == self.value and self.value != other.value

def __invert__(self) -> Self:
Expand Down Expand Up @@ -278,7 +213,7 @@ def _set_flag(self, o: int, toggle: bool) -> None:
elif toggle is False:
self.value &= ~o
else:
msg = f"Value to set for {self.__class__.__name__} must be a bool."
msg = f"Value to set for {self.__class__.__name__} must be a bool, got {toggle!r}."
raise TypeError(msg)


Expand Down Expand Up @@ -1045,7 +980,7 @@ class Intents(BaseFlags):

.. versionchanged:: 2.6

This can be now be provided on initialisation.
This can be now be provided on initialization.
"""

__slots__ = ()
Expand Down
4 changes: 1 addition & 3 deletions disnake/interactions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,9 +934,7 @@ async def defer(

if defer_type is InteractionResponseType.deferred_channel_message:
# we only want to set flags if we are sending a message
data["flags"] = 0
if ephemeral:
data["flags"] |= MessageFlags.ephemeral.flag
data["flags"] = MessageFlags(ephemeral=ephemeral is True).value

adapter = async_context.get()
await adapter.create_interaction_response(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_flag_value_or(self) -> None:
assert ins.value == 5
assert (TestFlags.two | ins).value == 7

assert not ins.value & TestFlags.sixteen.flag
assert not ins.value & TestFlags.sixteen.value
ins |= TestFlags.sixteen
assert ins.value == 21

Expand Down Expand Up @@ -421,7 +421,7 @@ def test_set_and_get_flag(self) -> None:

ins.two = True
assert ins.two is True
assert ins.value == TestFlags.two.flag == 1 << 1
assert ins.value == TestFlags.two.value == 1 << 1

def test_alias_flag_value(self) -> None:
ins = TestFlags(three=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_init_permissions_keyword_arguments(self) -> None:
assert perms.manage_messages is True

# check we only have the manage message permission
assert perms.value == Permissions.manage_messages.flag
assert perms.value == Permissions.manage_messages.value

def test_init_permissions_keyword_arguments_with_aliases(self) -> None:
assert Permissions(read_messages=True, view_channel=False).value == 0
Expand Down
Loading