diff --git a/discord_webhook/__init__.py b/discord_webhook/__init__.py index b6a966f..08ffb46 100644 --- a/discord_webhook/__init__.py +++ b/discord_webhook/__init__.py @@ -1,5 +1,12 @@ -__all__ = ["DiscordWebhook", "DiscordEmbed", "AsyncDiscordWebhook"] +__all__ = [ + "AsyncDiscordWebhook", + "DiscordEmbed", + "DiscordWebhook", + "DiscordComponentActionRow", + "DiscordComponentButton", +] -from .webhook import DiscordWebhook, DiscordEmbed -from .async_webhook import AsyncDiscordWebhook +from .components import DiscordComponentButton, DiscordComponentActionRow +from .webhook import DiscordEmbed, DiscordWebhook # isort:skip +from .async_webhook import AsyncDiscordWebhook # isort:skip diff --git a/discord_webhook/components.py b/discord_webhook/components.py new file mode 100644 index 0000000..2113415 --- /dev/null +++ b/discord_webhook/components.py @@ -0,0 +1,136 @@ +from typing import Optional, List, Union + +from . import constants +from .webhook_exceptions import ComponentException + + +class BaseDiscordComponent: + """ + A base class for discord components. + """ + + custom_id: str + label: str + type: int + + def __init__(self, **kwargs): + self.custom_id = kwargs.get("custom_id") + self.label = kwargs.get("label") + + if ( + type(self.type) is not int + or self.type not in constants.DISCORD_COMPONENT_TYPES + ): + raise ComponentException( + "The provided component type is invalid. A valid component type is an" + " integer between 1 and 8." + ) + if self.custom_id and len(self.custom_id) > 100: + raise ComponentException("custom_id can be a maximum of 100 characters.") + + +class DiscordComponentButton(BaseDiscordComponent): + """ + Represent a button that can be used in a message. + """ + + disabled: Optional[bool] + emoji = None + label: Optional[str] + style: int + type: int + url: Optional[str] + + def __init__( + self, style: int = constants.DISCORD_COMPONENT_BUTTON_STYLE_PRIMARY, **kwargs + ): + """ + :param style: button style (int 1 - 5) + :keyword disabled: Whether the button is disabled (defaults to false) + :keyword custom_id: developer-defined identifier for the button + :keyword label: Text that appears on the button + :keyword url: URL for DISCORD_COMPONENT_BUTTON_STYLE_LINK (int 5) buttons + """ + self.type = constants.DISCORD_COMPONENT_TYPE_BUTTON + self.style = style + self.disabled = kwargs.get("disabled", False) + self.custom_id = kwargs.get("custom_id") + self.emoji = kwargs.get("emoji") + self.label = kwargs.get("label") + self.url = kwargs.get("url") + + if ( + type(self.style) is not int + or self.style not in constants.DISCORD_COMPONENT_BUTTON_STYLES + ): + raise ComponentException( + "The provided button style is invalid. A valid button style is an" + " integer between 1 and 5." + ) + if ( + constants.DISCORD_COMPONENT_BUTTON_STYLE_PRIMARY + <= self.style + <= constants.DISCORD_COMPONENT_BUTTON_STYLE_DANGER + and not self.custom_id + ): + raise ComponentException("custom_id needs to be provided as a kwarg.") + if self.style == constants.DISCORD_COMPONENT_BUTTON_STYLE_LINK and not self.url: + raise ComponentException("url needs to be provided as a kwarg.") + if self.label and len(self.label) > 80: + raise ComponentException( + "The label can be a maximum of 80 characters long." + ) + + super().__init__(**kwargs) + + +class DiscordComponentActionRow: + """ + Represent an action row that can be used in a message. + """ + + components: list + type: int = constants.DISCORD_COMPONENT_TYPE_ACTION_ROW + + def __init__( + self, + components: Optional[List[Union[dict, BaseDiscordComponent]]] = None, + **kwargs, + ): + """ + :keyword components: displayed components in an action row + """ + if components is None: + components = [] + + self.components = components + self.type = constants.DISCORD_COMPONENT_TYPE_ACTION_ROW + + super().__init__(**kwargs) + + def add_component(self, component): + """ + Add a component to the row + :param component: discord component instance + """ + if isinstance(component, DiscordComponentActionRow): + raise ComponentException("An action row can't contain another action row.") + if ( + isinstance(component, DiscordComponentButton) + and sum( + 1 + for comp in self.components + if comp.get("type") == constants.DISCORD_COMPONENT_TYPE_BUTTON + ) + >= 5 + ): + raise ComponentException("An Action Row can contain up to 5 buttons.") + + if not isinstance(component, dict): + component = { + key: value + for key, value in component.__dict__.items() + if value is not None + } + + self.components.append(component) diff --git a/discord_webhook/constants.py b/discord_webhook/constants.py new file mode 100644 index 0000000..8da5673 --- /dev/null +++ b/discord_webhook/constants.py @@ -0,0 +1,33 @@ +DISCORD_COMPONENT_TYPE_ACTION_ROW = 1 +DISCORD_COMPONENT_TYPE_BUTTON = 2 +DISCORD_COMPONENT_TYPE_STRING_SELECT = 3 +DISCORD_COMPONENT_TYPE_TEXT_INPUT = 4 +DISCORD_COMPONENT_TYPE_USER_SELECT = 5 +DISCORD_COMPONENT_TYPE_ROLE_SELECT = 6 +DISCORD_COMPONENT_TYPE_MENTIONABLE_SELECT = 7 +DISCORD_COMPONENT_TYPE_CHANNEL_SELECT = 8 + +DISCORD_COMPONENT_TYPES = [ + DISCORD_COMPONENT_TYPE_ACTION_ROW, + DISCORD_COMPONENT_TYPE_BUTTON, + DISCORD_COMPONENT_TYPE_STRING_SELECT, + DISCORD_COMPONENT_TYPE_TEXT_INPUT, + DISCORD_COMPONENT_TYPE_USER_SELECT, + DISCORD_COMPONENT_TYPE_ROLE_SELECT, + DISCORD_COMPONENT_TYPE_MENTIONABLE_SELECT, + DISCORD_COMPONENT_TYPE_CHANNEL_SELECT, +] + +DISCORD_COMPONENT_BUTTON_STYLE_PRIMARY = 1 +DISCORD_COMPONENT_BUTTON_STYLE_SECONDARY = 2 +DISCORD_COMPONENT_BUTTON_STYLE_SUCCESS = 3 +DISCORD_COMPONENT_BUTTON_STYLE_DANGER = 4 +DISCORD_COMPONENT_BUTTON_STYLE_LINK = 5 + +DISCORD_COMPONENT_BUTTON_STYLES = [ + DISCORD_COMPONENT_BUTTON_STYLE_PRIMARY, + DISCORD_COMPONENT_BUTTON_STYLE_SECONDARY, + DISCORD_COMPONENT_BUTTON_STYLE_SUCCESS, + DISCORD_COMPONENT_BUTTON_STYLE_DANGER, + DISCORD_COMPONENT_BUTTON_STYLE_LINK, +] diff --git a/discord_webhook/webhook.py b/discord_webhook/webhook.py index 591e07e..e74d6d1 100644 --- a/discord_webhook/webhook.py +++ b/discord_webhook/webhook.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import requests +from . import DiscordComponentActionRow from .webhook_exceptions import ColorNotInRangeException logger = logging.getLogger(__name__) @@ -242,6 +243,7 @@ def __init__(self, url: str, **kwargs) -> None: :keyword list allowed_mentions: allowed mentions for the message :keyword dict attachments: attachments that should be included :keyword str avatar_url: override the default avatar of the webhook + :keyword list components: list of components :keyword str content: the message contents :keyword list embeds: list of embedded rich content :keyword dict files: to apply file(s) with message @@ -255,6 +257,7 @@ def __init__(self, url: str, **kwargs) -> None: self.allowed_mentions = kwargs.get("allowed_mentions", []) self.attachments = kwargs.get("attachments", []) self.avatar_url = kwargs.get("avatar_url") + self.components = kwargs.get("components", []) self.content = kwargs.get("content") self.embeds = kwargs.get("embeds", []) self.files = kwargs.get("files", {}) @@ -266,6 +269,19 @@ def __init__(self, url: str, **kwargs) -> None: self.url = url self.username = kwargs.get("username", False) + def add_component_row( + self, action_row: Union[DiscordComponentActionRow, Dict[str, Any]] + ) -> None: + """ + Add a component row to the webhook. + :param action_row: action row instance + """ + self.components.append( + action_row.__dict__ + if isinstance(action_row, DiscordComponentActionRow) + else action_row + ) + def add_embed(self, embed: Union[DiscordEmbed, Dict[str, Any]]) -> None: """ Add an embedded rich content. diff --git a/discord_webhook/webhook_exceptions.py b/discord_webhook/webhook_exceptions.py index 6342cf1..b1c3036 100644 --- a/discord_webhook/webhook_exceptions.py +++ b/discord_webhook/webhook_exceptions.py @@ -16,3 +16,11 @@ def __init__(self, color: Union[str, int], message=None) -> None: " (HEXADECIMAL)." ) super().__init__(message) + + +class ComponentException(Exception): + """ + This Exception will be raised for components. + """ + + pass diff --git a/tests/components/__init__.py b/tests/components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/components/test_action_row.py b/tests/components/test_action_row.py new file mode 100644 index 0000000..4244f6d --- /dev/null +++ b/tests/components/test_action_row.py @@ -0,0 +1,28 @@ +import pytest + +from discord_webhook import DiscordComponentButton, DiscordComponentActionRow +from discord_webhook.webhook_exceptions import ComponentException + + +def test__action_row_in_action_row(): + action_row_1 = DiscordComponentActionRow() + action_row_2 = DiscordComponentActionRow() + + with pytest.raises(ComponentException) as excinfo: + action_row_1.add_component(action_row_2) + + assert str(excinfo.value) == "An action row can't contain another action row." + assert len(action_row_1.components) == 0 + + +def test__max_buttons(): + action_row = DiscordComponentActionRow() + button = DiscordComponentButton(custom_id="test") + for _ in range(0, 5): + action_row.add_component(button) + + with pytest.raises(ComponentException) as excinfo: + action_row.add_component(button) + + assert str(excinfo.value) == "An Action Row can contain up to 5 buttons." + assert len(action_row.components) == 5 diff --git a/tests/components/test_button.py b/tests/components/test_button.py new file mode 100644 index 0000000..acf98d4 --- /dev/null +++ b/tests/components/test_button.py @@ -0,0 +1,58 @@ +import pytest +from pytest import mark + +from discord_webhook import DiscordComponentButton, constants +from discord_webhook.webhook_exceptions import ComponentException + + +@mark.parametrize( + "style, field, error_message", + [ + ( + constants.DISCORD_COMPONENT_BUTTON_STYLE_PRIMARY, + "custom_id", + "custom_id needs to be provided as a kwarg.", + ), + ( + constants.DISCORD_COMPONENT_BUTTON_STYLE_SECONDARY, + "custom_id", + "custom_id needs to be provided as a kwarg.", + ), + ( + constants.DISCORD_COMPONENT_BUTTON_STYLE_SUCCESS, + "custom_id", + "custom_id needs to be provided as a kwarg.", + ), + ( + constants.DISCORD_COMPONENT_BUTTON_STYLE_DANGER, + "custom_id", + "custom_id needs to be provided as a kwarg.", + ), + ( + constants.DISCORD_COMPONENT_BUTTON_STYLE_LINK, + "url", + "url needs to be provided as a kwarg.", + ), + ], +) +def test__styles__required_fields(style, field, error_message): + # valid button + DiscordComponentButton(**{"style": style, field: "test_string"}) + + # required field is missing + with pytest.raises(ComponentException) as excinfo: + DiscordComponentButton(style=style) + + assert str(excinfo.value) == error_message + + +@mark.parametrize("invalid_style", [0, 6, "a", True, None]) +def test__styles__invalid(invalid_style): + with pytest.raises(ComponentException) as excinfo: + DiscordComponentButton(style=invalid_style) + + assert ( + str(excinfo.value) + == "The provided button style is invalid. A valid button style is an integer" + " between 1 and 5." + ) diff --git a/tests/components/test_component.py b/tests/components/test_component.py new file mode 100644 index 0000000..9193a89 --- /dev/null +++ b/tests/components/test_component.py @@ -0,0 +1,43 @@ +import pytest + +from discord_webhook.components import BaseDiscordComponent +from discord_webhook import constants +from discord_webhook.webhook_exceptions import ComponentException + + +def test__component__types(): + for component_type in constants.DISCORD_COMPONENT_TYPES: + + class TestDiscordComponent(BaseDiscordComponent): + type = component_type + + TestDiscordComponent() + + for component_type in [0, 9, "a", True]: + + class TestDiscordComponent(BaseDiscordComponent): + type = component_type + + with pytest.raises(ComponentException) as excinfo: + TestDiscordComponent() + + assert ( + str(excinfo.value) + == "The provided component type is invalid. A valid component type is an" + " integer between 1 and 8." + ) + + +def test__component__custom_id_max_length(): + class TestDiscordComponent(BaseDiscordComponent): + type = constants.DISCORD_COMPONENT_TYPE_BUTTON + + custom_id = "".join("a" for i in range(100)) + + TestDiscordComponent(custom_id=custom_id) + + # total length of 101 chars + with pytest.raises(ComponentException) as excinfo: + TestDiscordComponent(custom_id=f"{custom_id}a") + + assert str(excinfo.value) == "custom_id can be a maximum of 100 characters."