1111from collections .abc import Sequence
1212from functools import partial
1313from itertools import groupby
14- from typing import TYPE_CHECKING , Callable , ClassVar , Optional
14+ from typing import (
15+ TYPE_CHECKING ,
16+ Callable ,
17+ ClassVar ,
18+ Generic ,
19+ Optional ,
20+ TypeVar ,
21+ )
1522
1623from ..components import (
1724 VALID_ACTION_ROW_MESSAGE_COMPONENT_TYPES ,
3845 from .item import ItemCallbackType
3946
4047
48+ V_co = TypeVar ("V_co" , bound = "View" , covariant = True )
49+
4150_log = logging .getLogger (__name__ )
4251
4352
44- def _component_to_item (component : ActionRowMessageComponent ) -> Item :
53+ def _component_to_item (component : ActionRowMessageComponent ) -> Item [ V_co ] :
4554 if item := _message_component_to_item (component ):
4655 return item
4756 else :
48- return Item .from_component (component )
57+ return Item [ V_co ] .from_component (component )
4958
5059
51- class _ViewWeights :
60+ class _ViewWeights ( Generic [ V_co ]) :
5261 __slots__ = ("weights" ,)
5362
54- def __init__ (self , children : list [Item ]) -> None :
63+ def __init__ (self , children : list [Item [ V_co ] ]) -> None :
5564 self .weights : list [int ] = [0 , 0 , 0 , 0 , 0 ]
5665
57- key : Callable [[Item [View ]], int ] = lambda i : sys .maxsize if i .row is None else i .row
66+ key : Callable [[Item [V_co ]], int ] = lambda i : sys .maxsize if i .row is None else i .row
5867 children = sorted (children , key = key )
5968 for _ , group in groupby (children , key = key ):
6069 for item in group :
6170 self .add_item (item )
6271
63- def find_open_space (self , item : Item ) -> int :
72+ def find_open_space (self , item : Item [ V_co ] ) -> int :
6473 for index , weight in enumerate (self .weights ):
6574 if weight + item .width <= 5 :
6675 return index
6776
6877 msg = "could not find open space for item"
6978 raise ValueError (msg )
7079
71- def add_item (self , item : Item ) -> None :
80+ def add_item (self , item : Item [ V_co ] ) -> None :
7281 if item .row is not None :
7382 total = self .weights [item .row ] + item .width
7483 if total > 5 :
@@ -81,7 +90,7 @@ def add_item(self, item: Item) -> None:
8190 self .weights [index ] += item .width
8291 item ._rendered_row = index
8392
84- def remove_item (self , item : Item ) -> None :
93+ def remove_item (self , item : Item [ V_co ] ) -> None :
8594 if item ._rendered_row is not None :
8695 self .weights [item ._rendered_row ] -= item .width
8796 item ._rendered_row = None
@@ -142,7 +151,7 @@ def __init__(self, *, timeout: Optional[float] = 180.0) -> None:
142151 setattr (self , func .__name__ , item )
143152 self .children .append (item )
144153
145- self .__weights = _ViewWeights (self .children )
154+ self .__weights : _ViewWeights = _ViewWeights (self .children )
146155 loop = asyncio .get_running_loop ()
147156 self .id : str = os .urandom (16 ).hex ()
148157 self .__cancel_callback : Optional [Callable [[View ], None ]] = None
@@ -173,7 +182,7 @@ async def __timeout_task_impl(self) -> None:
173182 await asyncio .sleep (self .__timeout_expiry - now )
174183
175184 def to_components (self ) -> list [ActionRowPayload ]:
176- def key (item : Item [View ]) -> int :
185+ def key (item : Item [Self ]) -> int :
177186 return item ._rendered_row or 0
178187
179188 children = sorted (self .children , key = key )
@@ -239,7 +248,7 @@ def _expires_at(self) -> Optional[float]:
239248 return time .monotonic () + self .timeout
240249 return None
241250
242- def add_item (self , item : Item ) -> Self :
251+ def add_item (self , item : Item [ Self ] ) -> Self :
243252 """Adds an item to the view.
244253
245254 This function returns the class instance to allow for fluent-style
@@ -272,7 +281,7 @@ def add_item(self, item: Item) -> Self:
272281 self .children .append (item )
273282 return self
274283
275- def remove_item (self , item : Item ) -> Self :
284+ def remove_item (self , item : Item [ Self ] ) -> Self :
276285 """Removes an item from the view.
277286
278287 This function returns the class instance to allow for fluent-style
@@ -336,7 +345,9 @@ async def on_timeout(self) -> None:
336345 """
337346 pass
338347
339- async def on_error (self , error : Exception , item : Item , interaction : MessageInteraction ) -> None :
348+ async def on_error (
349+ self , error : Exception , item : Item [Self ], interaction : MessageInteraction
350+ ) -> None :
340351 """|coro|
341352
342353 A callback that is called when an item's callback or :meth:`interaction_check`
@@ -356,7 +367,7 @@ async def on_error(self, error: Exception, item: Item, interaction: MessageInter
356367 print (f"Ignoring exception in view { self } for item { item } :" , file = sys .stderr )
357368 traceback .print_exception (error .__class__ , error , error .__traceback__ , file = sys .stderr )
358369
359- async def _scheduled_task (self , item : Item , interaction : MessageInteraction ) -> None :
370+ async def _scheduled_task (self , item : Item [ Self ] , interaction : MessageInteraction ) -> None :
360371 try :
361372 if self .timeout :
362373 self .__timeout_expiry = time .monotonic () + self .timeout
@@ -386,7 +397,7 @@ def _dispatch_timeout(self) -> None:
386397 self .__stopped .set_result (True )
387398 asyncio .create_task (self .on_timeout (), name = f"disnake-ui-view-timeout-{ self .id } " )
388399
389- def _dispatch_item (self , item : Item , interaction : MessageInteraction ) -> None :
400+ def _dispatch_item (self , item : Item [ Self ] , interaction : MessageInteraction ) -> None :
390401 if self .__stopped .done ():
391402 return
392403
@@ -396,15 +407,15 @@ def _dispatch_item(self, item: Item, interaction: MessageInteraction) -> None:
396407
397408 def refresh (self , components : list [ActionRowComponent [ActionRowMessageComponent ]]) -> None :
398409 # TODO: this is pretty hacky at the moment, see https://github.com/DisnakeDev/disnake/commit/9384a72acb8c515b13a600592121357e165368da
399- old_state : dict [tuple [int , str ], Item ] = {
410+ old_state : dict [tuple [int , str ], Item [ Self ] ] = {
400411 (item .type .value , item .custom_id ): item # pyright: ignore[reportAttributeAccessIssue]
401412 for item in self .children
402413 if item .is_dispatchable ()
403414 }
404415
405- children : list [Item ] = []
416+ children : list [Item [ Self ] ] = []
406417 for component in (c for row in components for c in row .children ):
407- older : Optional [Item ] = None
418+ older : Optional [Item [ Self ] ] = None
408419 try :
409420 older = old_state [component .type .value , component .custom_id ] # pyright: ignore[reportArgumentType]
410421 except (KeyError , AttributeError ):
@@ -490,7 +501,7 @@ async def wait(self) -> bool:
490501class ViewStore :
491502 def __init__ (self , state : ConnectionState ) -> None :
492503 # (component_type, message_id, custom_id): (View, Item)
493- self ._views : dict [tuple [int , Optional [ int ] , str ], tuple [View , Item ]] = {}
504+ self ._views : dict [tuple [int , int | None , str ], tuple [View , Item [ View ] ]] = {}
494505 # message_id: View
495506 self ._synced_message_views : dict [int , View ] = {}
496507 self ._state : ConnectionState = state
0 commit comments