Skip to content

refactor: ♻️ Refactor weird checks in CogMeta and fix some typing and other qol things #2730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 143 additions & 107 deletions discord/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,17 @@
import pathlib
import sys
import types
from typing import Any, Callable, ClassVar, Generator, Mapping, TypeVar, overload
from collections.abc import Generator, Mapping
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
TypeVar,
overload,
)

from typing_extensions import TypeGuard

import discord.utils

Expand All @@ -43,6 +53,10 @@
_BaseCommand,
)

if TYPE_CHECKING:
from .ext.bridge import BridgeCommand


__all__ = (
"CogMeta",
"Cog",
Expand All @@ -59,6 +73,118 @@ def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(f"{parent}.")


def _is_bridge_command(command: Any) -> TypeGuard[BridgeCommand]:
return getattr(command, "__bridge__", False)


def _name_filter(c: Any) -> str:
return (
"app"
if isinstance(c, ApplicationCommand)
else ("bridge" if not _is_bridge_command(c) else "ext")
)


def _validate_name_prefix(base_class: type, name: str) -> None:
if name.startswith(("cog_", "bot_")):
raise TypeError(
f"Commands or listeners must not start with cog_ or bot_ (in method {base_class}.{name})"
)


def _process_attributes(
base: type,
) -> tuple[dict[str, Any], dict[str, Any]]: # pyright: ignore[reportExplicitAny]
commands: dict[str, _BaseCommand | BridgeCommand] = {}
listeners: dict[str, Callable[..., Any]] = {}

for attr_name, attr_value in base.__dict__.items():
if attr_name in commands:
del commands[attr_name]
if attr_name in listeners:
del listeners[attr_name]

if getattr(attr_value, "parent", None) and isinstance(
attr_value, ApplicationCommand
):
# Skip application commands if they are a part of a group
# Since they are already added when the group is added
continue

is_static_method = isinstance(attr_value, staticmethod)
if is_static_method:
attr_value = attr_value.__func__

if inspect.iscoroutinefunction(attr_value) and getattr(
attr_value, "__cog_listener__", False
):
_validate_name_prefix(base, attr_name)
listeners[attr_name] = attr_value
continue

if isinstance(attr_value, _BaseCommand) or _is_bridge_command(attr_value):
if is_static_method:
raise TypeError(
f"Command in method {base}.{attr_name!r} must not be staticmethod."
)
_validate_name_prefix(base, attr_name)

if isinstance(attr_value, _BaseCommand):
commands[attr_name] = attr_value

if _is_bridge_command(attr_value) and not attr_value.parent:
commands[f"ext_{attr_name}"] = attr_value.ext_variant
commands[f"app_{attr_name}"] = attr_value.slash_variant
commands[attr_name] = attr_value
for cmd in getattr(attr_value, "subcommands", []):
commands[f"ext_{cmd.ext_variant.qualified_name}"] = cmd.ext_variant

return commands, listeners


def _update_command(
command: _BaseCommand | BridgeCommand,
guild_ids: list[int],
lookup_table: dict[str, _BaseCommand | BridgeCommand],
new_cls: type[Cog],
) -> None:
if isinstance(command, ApplicationCommand) and not command.guild_ids and guild_ids:
command.guild_ids = guild_ids

if not isinstance(command, SlashCommandGroup) and not _is_bridge_command(command):
# ignore bridge commands
cmd: BridgeCommand | _BaseCommand | None = getattr(
new_cls,
command.callback.__name__,
None, # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType,reportAttributeAccessIssue]
)
if _is_bridge_command(cmd):
setattr(
cmd,
f"{_name_filter(command).replace('app', 'slash')}_variant",
command,
)
else:
setattr(
new_cls,
command.callback.__name__,
command, # pyright: ignore [reportAttributeAccessIssue, reportUnknownArgumentType, reportUnknownMemberType]
)

parent: (
BridgeCommand | _BaseCommand | None
) = ( # pyright: ignore [reportUnknownMemberType, reportUnknownVariableType]
command.parent # pyright: ignore [reportAttributeAccessIssue]
)
if parent is not None:
# Get the latest parent reference
parent = lookup_table[f"{_name_filter(command)}_{parent.qualified_name}"] # type: ignore # pyright: ignore[reportUnknownMemberType]

# Update the parent's reference to our self
parent.remove_command(command.name) # type: ignore # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType]
parent.add_command(command) # type: ignore # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType]


class CogMeta(type):
"""A metaclass for defining a cog.

Expand Down Expand Up @@ -127,7 +253,7 @@ async def bar(self, ctx):

__cog_name__: str
__cog_settings__: dict[str, Any]
__cog_commands__: list[ApplicationCommand]
__cog_commands__: list[_BaseCommand | BridgeCommand]
__cog_listeners__: list[tuple[str, str]]
__cog_guild_ids__: list[int]

Expand All @@ -142,128 +268,38 @@ def __new__(cls: type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
description = inspect.cleandoc(attrs.get("__doc__", ""))
attrs["__cog_description__"] = description

commands = {}
listeners = {}
no_bot_cog = (
"Commands or listeners must not start with cog_ or bot_ (in method"
" {0.__name__}.{1})"
)
commands: dict[str, _BaseCommand | BridgeCommand] = {}
listeners: dict[str, Callable[..., Any]] = {}

new_cls = super().__new__(cls, name, bases, attrs, **kwargs)

for base in reversed(new_cls.__mro__):
for elem, value in base.__dict__.items():
if elem in commands:
del commands[elem]
if elem in listeners:
del listeners[elem]

if getattr(value, "parent", None) and isinstance(
value, ApplicationCommand
):
# Skip commands if they are a part of a group
continue

is_static_method = isinstance(value, staticmethod)
if is_static_method:
value = value.__func__
if isinstance(value, _BaseCommand):
if is_static_method:
raise TypeError(
f"Command in method {base}.{elem!r} must not be"
" staticmethod."
)
if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem))
commands[elem] = value

# a test to see if this value is a BridgeCommand
if hasattr(value, "add_to") and not getattr(value, "parent", None):
if is_static_method:
raise TypeError(
f"Command in method {base}.{elem!r} must not be"
" staticmethod."
)
if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem))

commands[f"ext_{elem}"] = value.ext_variant
commands[f"app_{elem}"] = value.slash_variant
commands[elem] = value
for cmd in getattr(value, "subcommands", []):
commands[f"ext_{cmd.ext_variant.qualified_name}"] = (
cmd.ext_variant
)

if inspect.iscoroutinefunction(value):
try:
getattr(value, "__cog_listener__")
except AttributeError:
continue
else:
if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem))
listeners[elem] = value
new_commands, new_listeners = _process_attributes(base)
commands.update(new_commands)
listeners.update(new_listeners)

new_cls.__cog_commands__ = list(commands.values())

listeners_as_list = []
for listener in listeners.values():
for listener_name in listener.__cog_listener_names__:
# I use __name__ instead of just storing the value, so I can inject
# the self attribute when the time comes to add them to the bot
listeners_as_list.append((listener_name, listener.__name__))

new_cls.__cog_listeners__ = listeners_as_list
new_cls.__cog_listeners__ = [
(listener_name, listener.__name__)
for listener in listeners.values()
for listener_name in listener.__cog_listener_names__
]

cmd_attrs = new_cls.__cog_settings__

# Either update the command with the cog provided defaults or copy it.
# r.e type ignore, type-checker complains about overriding a ClassVar
new_cls.__cog_commands__ = tuple(c._update_copy(cmd_attrs) if not hasattr(c, "add_to") else c for c in new_cls.__cog_commands__) # type: ignore

name_filter = lambda c: (
"app"
if isinstance(c, ApplicationCommand)
else ("bridge" if not hasattr(c, "add_to") else "ext")
)
new_cls.__cog_commands__ = list(tuple(c._update_copy(cmd_attrs) if not _is_bridge_command(c) else c for c in new_cls.__cog_commands__)) # type: ignore

lookup = {
f"{name_filter(cmd)}_{cmd.qualified_name}": cmd
f"{_name_filter(cmd)}_{cmd.qualified_name}": cmd
for cmd in new_cls.__cog_commands__
}

# Update the Command instances dynamically as well
for command in new_cls.__cog_commands__:
if (
isinstance(command, ApplicationCommand)
and not command.guild_ids
and new_cls.__cog_guild_ids__
):
command.guild_ids = new_cls.__cog_guild_ids__

if not isinstance(command, SlashCommandGroup) and not hasattr(
command, "add_to"
):
# ignore bridge commands
cmd = getattr(new_cls, command.callback.__name__, None)
if hasattr(cmd, "add_to"):
setattr(
cmd,
f"{name_filter(command).replace('app', 'slash')}_variant",
command,
)
else:
setattr(new_cls, command.callback.__name__, command)

parent = command.parent
if parent is not None:
# Get the latest parent reference
parent = lookup[f"{name_filter(command)}_{parent.qualified_name}"] # type: ignore

# Update our parent's reference to our self
parent.remove_command(command.name) # type: ignore
parent.add_command(command) # type: ignore
_update_command(command, new_cls.__cog_guild_ids__, lookup, new_cls)

return new_cls

Expand Down Expand Up @@ -537,7 +573,7 @@ def _inject(self: CogT, bot) -> CogT:
# we've added so far for some form of atomic loading.

for index, command in enumerate(self.__cog_commands__):
if hasattr(command, "add_to"):
if _is_bridge_command(command):
bot.bridge_commands.append(command)
continue

Expand Down Expand Up @@ -582,7 +618,7 @@ def _eject(self, bot) -> None:

try:
for command in self.__cog_commands__:
if hasattr(command, "add_to"):
if _is_bridge_command(command):
bot.bridge_commands.remove(command)
continue
elif isinstance(command, ApplicationCommand):
Expand Down
2 changes: 2 additions & 0 deletions discord/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,8 @@ class SlashCommand(ApplicationCommand):

type = 1

parent: SlashCommandGroup | None

def __new__(cls, *args, **kwargs) -> SlashCommand:
self = super().__new__(cls)

Expand Down
2 changes: 2 additions & 0 deletions discord/ext/bridge/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ class BridgeCommand:
The prefix-based version of this bridge command.
"""

__bridge__: bool = True

__special_attrs__ = ["slash_variant", "ext_variant", "parent"]

def __init__(self, callback, **kwargs):
Expand Down