Skip to content

Commit 7ef8494

Browse files
committed
WIP: provide typing default for optional typevar
1 parent d0abf46 commit 7ef8494

File tree

8 files changed

+90
-40
lines changed

8 files changed

+90
-40
lines changed

disnake/ext/commands/bot_base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import traceback
1111
import warnings
1212
from collections.abc import Iterable
13-
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
13+
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, cast
1414

1515
import disnake
1616
from disnake.utils import iscoroutinefunction
@@ -26,6 +26,7 @@
2626
if TYPE_CHECKING:
2727
from typing_extensions import Self
2828

29+
from disnake.ext.commands.bot import AutoShardedBot, Bot
2930
from disnake.message import Message
3031

3132
from ._types import Check, CoroFunc, MaybeCoro
@@ -507,7 +508,9 @@ class be provided, it must be similar enough to :class:`.Context`\'s
507508
``cls`` parameter.
508509
"""
509510
view = StringView(message.content)
510-
ctx = cls(prefix=None, view=view, bot=self, message=message)
511+
ctx = cls(
512+
prefix=None, view=view, bot=cast("Union[Bot, AutoShardedBot]", self), message=message
513+
)
511514

512515
if message.author.id == self.user.id: # pyright: ignore[reportAttributeAccessIssue]
513516
return ctx

disnake/ext/commands/context.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,21 @@
3232

3333

3434
T = TypeVar("T")
35-
BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]")
3635
CogT = TypeVar("CogT", bound="Cog")
3736

3837
if TYPE_CHECKING:
38+
from typing_extensions import TypeVar # noqa: TC004
39+
3940
P = ParamSpec("P")
41+
BotT = TypeVar(
42+
"BotT",
43+
bound="Union[Bot, AutoShardedBot]",
44+
covariant=True,
45+
default=Union[Bot, AutoShardedBot],
46+
)
4047
else:
4148
P = TypeVar("P")
49+
BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]")
4250

4351

4452
class Context(disnake.abc.Messageable, Generic[BotT]):

disnake/ext/commands/core.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,25 +97,30 @@
9797

9898
MISSING: Any = disnake.utils.MISSING
9999

100-
T = TypeVar("T")
101100
VT = TypeVar("VT")
102-
CogT = TypeVar("CogT", bound="Optional[Cog]")
103101
CommandT = TypeVar("CommandT", bound="Command")
104-
ContextT = TypeVar("ContextT", bound="Context")
105102
GroupT = TypeVar("GroupT", bound="Group")
106103
HookT = TypeVar("HookT", bound="Hook")
107104
ErrorT = TypeVar("ErrorT", bound="Error")
108105

109106

110107
if TYPE_CHECKING:
111-
P = ParamSpec("P")
108+
from typing_extensions import TypeVar # noqa: TC004
112109

110+
P = ParamSpec("P", default=...)
111+
T = TypeVar("T", default=Any)
112+
113+
CogT = TypeVar("CogT", bound="Optional[Cog]", default="Optional[Cog]")
114+
ContextT = TypeVar("ContextT", bound="Context", default="Context")
113115
CommandCallback = Union[
114116
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
115117
Callable[Concatenate[ContextT, P], Coro[T]],
116118
]
117119
else:
120+
T = TypeVar("T")
118121
P = TypeVar("P")
122+
CogT = TypeVar("CogT", bound="Optional[Cog]")
123+
ContextT = TypeVar("ContextT", bound="Context")
119124

120125

121126
def wrap_callback(coro: Callable[..., Coro[T]]) -> Callable[..., Coro[Optional[T]]]:
@@ -1418,7 +1423,7 @@ def copy(self: GroupT) -> GroupT:
14181423
"""
14191424
ret = super().copy()
14201425
for cmd in self.commands:
1421-
ret.add_command(cmd.copy())
1426+
ret.add_command(cast("Command[CogT, Any, Any]", cmd.copy()))
14221427
return ret
14231428

14241429
async def invoke(self, ctx: Context) -> None:

disnake/ui/_types.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
77

88
if TYPE_CHECKING:
9-
from typing_extensions import TypeAlias
9+
from typing_extensions import (
10+
TypeAlias,
11+
TypeVar, # noqa: TC004
12+
)
1013

1114
from . import (
1215
ActionRow,
@@ -24,7 +27,9 @@
2427
from .select import ChannelSelect, MentionableSelect, RoleSelect, StringSelect, UserSelect
2528
from .view import View
2629

27-
V_co = TypeVar("V_co", bound="Optional[View]", covariant=True)
30+
V_co = TypeVar("V_co", bound="Optional[View]", covariant=True, default=Optional[View])
31+
else:
32+
V_co = TypeVar("V_co", bound="Optional[View]", covariant=True)
2833

2934
AnySelect = Union[
3035
"ChannelSelect[V_co]",

disnake/ui/button.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,23 @@
1717
)
1818

1919
if TYPE_CHECKING:
20-
from typing_extensions import ParamSpec, Self
20+
from typing_extensions import (
21+
ParamSpec,
22+
Self,
23+
TypeVar, # noqa: TC004
24+
)
2125

2226
from ..emoji import Emoji
2327
from .item import ItemCallbackType
2428
from .view import View
2529

30+
V_co = TypeVar("V_co", bound="Optional[View]", covariant=True, default=Optional[View])
2631
else:
2732
ParamSpec = TypeVar
33+
V_co = TypeVar("V_co", bound="Optional[View]", covariant=True)
2834

2935
B = TypeVar("B", bound="Button")
3036
B_co = TypeVar("B_co", bound="Button", covariant=True)
31-
V_co = TypeVar("V_co", bound="Optional[View]", covariant=True)
3237
P = ParamSpec("P")
3338

3439

disnake/ui/item.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
"Item",
2323
)
2424

25-
I = TypeVar("I", bound="Item[Any]") # noqa: E741
26-
V_co = TypeVar("V_co", bound="Optional[View]", covariant=True)
2725

2826
if TYPE_CHECKING:
29-
from typing_extensions import Self
27+
from typing_extensions import (
28+
Self,
29+
TypeVar, # noqa: TC004
30+
)
3031

3132
from ..client import Client
3233
from ..components import ActionRowChildComponent, Component
@@ -35,8 +36,18 @@
3536
from ..types.components import ActionRowChildComponent as ActionRowChildComponentPayload
3637
from .view import View
3738

39+
V_co = TypeVar("V_co", bound="Optional[View]", covariant=True, default=Optional[View])
40+
I = TypeVar("I", bound="Item[Any]", default="Item[Any]") # noqa: E741
3841
ItemCallbackType = Callable[[V_co, I, MessageInteraction], Coroutine[Any, Any, Any]]
3942

43+
SelfViewT = TypeVar("SelfViewT", bound="Optional[View]", default=Optional[View])
44+
else:
45+
I = TypeVar("I", bound="Item[Any]") # noqa: E741
46+
V_co = TypeVar("V_co", bound="Optional[View]", covariant=True)
47+
48+
SelfViewT = TypeVar("SelfViewT", bound="Optional[View]")
49+
50+
4051
ClientT = TypeVar("ClientT", bound="Client")
4152
UIComponentT = TypeVar("UIComponentT", bound="UIComponent")
4253

@@ -217,9 +228,6 @@ async def callback(self, interaction: MessageInteraction[ClientT], /) -> None:
217228
pass
218229

219230

220-
SelfViewT = TypeVar("SelfViewT", bound="Optional[View]")
221-
222-
223231
# While the decorators don't actually return a descriptor that matches this protocol,
224232
# this protocol ensures that type checkers don't complain about statements like `self.button.disabled = True`,
225233
# which work as `View.__init__` replaces the handler with the item.

disnake/ui/select/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,24 @@
2424
__all__ = ("BaseSelect",)
2525

2626
if TYPE_CHECKING:
27-
from typing_extensions import ParamSpec, Self
27+
from typing_extensions import (
28+
ParamSpec,
29+
Self,
30+
TypeVar, # noqa: TC004
31+
)
2832

2933
from ...abc import Snowflake
3034
from ...interactions import MessageInteraction
3135
from ..item import ItemCallbackType
3236
from ..view import View
3337

38+
V_co = TypeVar("V_co", bound="Optional[View]", covariant=True, default=Optional[View])
3439
else:
3540
ParamSpec = TypeVar
41+
V_co = TypeVar("V_co", bound="Optional[View]", covariant=True)
3642

3743

3844
S_co = TypeVar("S_co", bound="BaseSelect", covariant=True)
39-
V_co = TypeVar("V_co", bound="Optional[View]", covariant=True)
4045
SelectMenuT = TypeVar("SelectMenuT", bound=AnySelectMenu)
4146
SelectValueT = TypeVar("SelectValueT")
4247
P = ParamSpec("P")

disnake/ui/view.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111
from collections.abc import Sequence
1212
from functools import partial
1313
from itertools import groupby
14-
from typing import TYPE_CHECKING, Callable, ClassVar, Optional
14+
from typing import (
15+
TYPE_CHECKING,
16+
Callable,
17+
ClassVar,
18+
Generic,
19+
Optional,
20+
TypeVar,
21+
)
1522

1623
from ..components import (
1724
VALID_ACTION_ROW_MESSAGE_COMPONENT_TYPES,
@@ -38,37 +45,39 @@
3845
from .item import ItemCallbackType
3946

4047

48+
V_co = TypeVar("V_co", bound="View", covariant=True)
49+
4150
_log = logging.getLogger(__name__)
4251

4352

44-
def _component_to_item(component: ActionRowMessageComponent) -> Item:
53+
def _component_to_item(component: ActionRowMessageComponent) -> Item[V_co]:
4554
if item := _message_component_to_item(component):
4655
return item
4756
else:
48-
return Item.from_component(component)
57+
return Item[V_co].from_component(component)
4958

5059

51-
class _ViewWeights:
60+
class _ViewWeights(Generic[V_co]):
5261
__slots__ = ("weights",)
5362

54-
def __init__(self, children: list[Item]) -> None:
63+
def __init__(self, children: list[Item[V_co]]) -> None:
5564
self.weights: list[int] = [0, 0, 0, 0, 0]
5665

57-
key: Callable[[Item[View]], int] = lambda i: sys.maxsize if i.row is None else i.row
66+
key: Callable[[Item[V_co]], int] = lambda i: sys.maxsize if i.row is None else i.row
5867
children = sorted(children, key=key)
5968
for _, group in groupby(children, key=key):
6069
for item in group:
6170
self.add_item(item)
6271

63-
def find_open_space(self, item: Item) -> int:
72+
def find_open_space(self, item: Item[V_co]) -> int:
6473
for index, weight in enumerate(self.weights):
6574
if weight + item.width <= 5:
6675
return index
6776

6877
msg = "could not find open space for item"
6978
raise ValueError(msg)
7079

71-
def add_item(self, item: Item) -> None:
80+
def add_item(self, item: Item[V_co]) -> None:
7281
if item.row is not None:
7382
total = self.weights[item.row] + item.width
7483
if total > 5:
@@ -81,7 +90,7 @@ def add_item(self, item: Item) -> None:
8190
self.weights[index] += item.width
8291
item._rendered_row = index
8392

84-
def remove_item(self, item: Item) -> None:
93+
def remove_item(self, item: Item[V_co]) -> None:
8594
if item._rendered_row is not None:
8695
self.weights[item._rendered_row] -= item.width
8796
item._rendered_row = None
@@ -142,7 +151,7 @@ def __init__(self, *, timeout: Optional[float] = 180.0) -> None:
142151
setattr(self, func.__name__, item)
143152
self.children.append(item)
144153

145-
self.__weights = _ViewWeights(self.children)
154+
self.__weights: _ViewWeights = _ViewWeights(self.children)
146155
loop = asyncio.get_running_loop()
147156
self.id: str = os.urandom(16).hex()
148157
self.__cancel_callback: Optional[Callable[[View], None]] = None
@@ -173,7 +182,7 @@ async def __timeout_task_impl(self) -> None:
173182
await asyncio.sleep(self.__timeout_expiry - now)
174183

175184
def to_components(self) -> list[ActionRowPayload]:
176-
def key(item: Item[View]) -> int:
185+
def key(item: Item[Self]) -> int:
177186
return item._rendered_row or 0
178187

179188
children = sorted(self.children, key=key)
@@ -239,7 +248,7 @@ def _expires_at(self) -> Optional[float]:
239248
return time.monotonic() + self.timeout
240249
return None
241250

242-
def add_item(self, item: Item) -> Self:
251+
def add_item(self, item: Item[Self]) -> Self:
243252
"""Adds an item to the view.
244253
245254
This function returns the class instance to allow for fluent-style
@@ -272,7 +281,7 @@ def add_item(self, item: Item) -> Self:
272281
self.children.append(item)
273282
return self
274283

275-
def remove_item(self, item: Item) -> Self:
284+
def remove_item(self, item: Item[Self]) -> Self:
276285
"""Removes an item from the view.
277286
278287
This function returns the class instance to allow for fluent-style
@@ -336,7 +345,9 @@ async def on_timeout(self) -> None:
336345
"""
337346
pass
338347

339-
async def on_error(self, error: Exception, item: Item, interaction: MessageInteraction) -> None:
348+
async def on_error(
349+
self, error: Exception, item: Item[Self], interaction: MessageInteraction
350+
) -> None:
340351
"""|coro|
341352
342353
A callback that is called when an item's callback or :meth:`interaction_check`
@@ -356,7 +367,7 @@ async def on_error(self, error: Exception, item: Item, interaction: MessageInter
356367
print(f"Ignoring exception in view {self} for item {item}:", file=sys.stderr)
357368
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
358369

359-
async def _scheduled_task(self, item: Item, interaction: MessageInteraction) -> None:
370+
async def _scheduled_task(self, item: Item[Self], interaction: MessageInteraction) -> None:
360371
try:
361372
if self.timeout:
362373
self.__timeout_expiry = time.monotonic() + self.timeout
@@ -386,7 +397,7 @@ def _dispatch_timeout(self) -> None:
386397
self.__stopped.set_result(True)
387398
asyncio.create_task(self.on_timeout(), name=f"disnake-ui-view-timeout-{self.id}")
388399

389-
def _dispatch_item(self, item: Item, interaction: MessageInteraction) -> None:
400+
def _dispatch_item(self, item: Item[Self], interaction: MessageInteraction) -> None:
390401
if self.__stopped.done():
391402
return
392403

@@ -396,15 +407,15 @@ def _dispatch_item(self, item: Item, interaction: MessageInteraction) -> None:
396407

397408
def refresh(self, components: list[ActionRowComponent[ActionRowMessageComponent]]) -> None:
398409
# TODO: this is pretty hacky at the moment, see https://github.com/DisnakeDev/disnake/commit/9384a72acb8c515b13a600592121357e165368da
399-
old_state: dict[tuple[int, str], Item] = {
410+
old_state: dict[tuple[int, str], Item[Self]] = {
400411
(item.type.value, item.custom_id): item # pyright: ignore[reportAttributeAccessIssue]
401412
for item in self.children
402413
if item.is_dispatchable()
403414
}
404415

405-
children: list[Item] = []
416+
children: list[Item[Self]] = []
406417
for component in (c for row in components for c in row.children):
407-
older: Optional[Item] = None
418+
older: Optional[Item[Self]] = None
408419
try:
409420
older = old_state[component.type.value, component.custom_id] # pyright: ignore[reportArgumentType]
410421
except (KeyError, AttributeError):
@@ -490,7 +501,7 @@ async def wait(self) -> bool:
490501
class ViewStore:
491502
def __init__(self, state: ConnectionState) -> None:
492503
# (component_type, message_id, custom_id): (View, Item)
493-
self._views: dict[tuple[int, Optional[int], str], tuple[View, Item]] = {}
504+
self._views: dict[tuple[int, int | None, str], tuple[View, Item[View]]] = {}
494505
# message_id: View
495506
self._synced_message_views: dict[int, View] = {}
496507
self._state: ConnectionState = state

0 commit comments

Comments
 (0)