Skip to content

feat(commands): add support for typing.Literal[...] as command choices #2782

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ These changes are available on the `master` branch, but have not yet been releas
([#2714](https://github.com/Pycord-Development/pycord/pull/2714))
- Added the ability to pass a `datetime.time` object to `format_dt`.
([#2747](https://github.com/Pycord-Development/pycord/pull/2747))
- Added support for type hinting slash command options with `typing.Annotated`.
([#2782](https://github.com/Pycord-Development/pycord/pull/2782))
- Added `discord.Interaction.created_at`.
([#2801](https://github.com/Pycord-Development/pycord/pull/2801))

Expand Down
25 changes: 23 additions & 2 deletions discord/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@
from .options import Option, OptionChoice

if sys.version_info >= (3, 11):
from typing import Annotated, get_args, get_origin
from typing import Annotated, Literal, get_args, get_origin
else:
from typing_extensions import Annotated, get_args, get_origin
from typing_extensions import Annotated, Literal, get_args, get_origin

__all__ = (
"_BaseCommand",
Expand Down Expand Up @@ -806,6 +806,24 @@ def _parse_options(self, params, *, check_params: bool = True) -> list[Option]:
if option == inspect.Parameter.empty:
option = str

if self._is_typing_literal(option):
literal_values = get_args(option)
if not all(isinstance(v, (str, int, float)) for v in literal_values):
raise TypeError(
"Literal values must be str, int, or float for Discord choices."
)

value_type = type(literal_values[0])
if not all(isinstance(v, value_type) for v in literal_values):
raise TypeError("All Literal values must be of the same type.")

option = Option(
value_type,
choices=[
OptionChoice(name=str(v), value=v) for v in literal_values
],
)

if self._is_typing_annotated(option):
type_hint = get_args(option)[0]
metadata = option.__metadata__
Expand Down Expand Up @@ -908,6 +926,9 @@ def _is_typing_union(self, annotation):
def _is_typing_optional(self, annotation):
return self._is_typing_union(annotation) and type(None) in annotation.__args__ # type: ignore

def _is_typing_literal(self, annotation):
return get_origin(annotation) is Literal

def _is_typing_annotated(self, annotation):
return get_origin(annotation) is Annotated

Expand Down