diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d76398448..6ce974b01b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -95,6 +95,8 @@ These changes are available on the `master` branch, but have not yet been releas ([#2624](https://github.com/Pycord-Development/pycord/pull/2624)) - Fixed editing `ForumChannel` flags not working. ([#2641](https://github.com/Pycord-Development/pycord/pull/2641)) +- Fixed Async I/O errors that could be raised when using `Client.run`. + ([#2645](https://github.com/Pycord-Development/pycord/pull/2645)) - Fixed `AttributeError` when accessing `Member.guild_permissions` for user installed apps. ([#2650](https://github.com/Pycord-Development/pycord/pull/2650)) - Fixed type annotations of cached properties. diff --git a/discord/client.py b/discord/client.py index 2d0f1b8770..0d03797f4f 100644 --- a/discord/client.py +++ b/discord/client.py @@ -27,7 +27,6 @@ import asyncio import logging -import signal import sys import traceback from types import TracebackType @@ -124,6 +123,12 @@ class Client: A number of options can be passed to the :class:`Client`. + .. container:: operations + + .. describe:: async with x + + Asynchronously initializes the client. + Parameters ----------- max_messages: Optional[:class:`int`] @@ -228,9 +233,7 @@ def __init__( ): # self.ws is set in the connect method self.ws: DiscordWebSocket = None # type: ignore - self.loop: asyncio.AbstractEventLoop = ( - asyncio.get_event_loop() if loop is None else loop - ) + self._loop: asyncio.AbstractEventLoop | None = loop self._listeners: dict[str, list[tuple[asyncio.Future, Callable[..., bool]]]] = ( {} ) @@ -246,7 +249,7 @@ def __init__( proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, - loop=self.loop, + loop=self._loop, ) self._handlers: dict[str, Callable] = {"ready": self._handle_ready} @@ -258,7 +261,8 @@ def __init__( self._enable_debug_events: bool = options.pop("enable_debug_events", False) self._connection: ConnectionState = self._get_state(**options) self._connection.shard_count = self.shard_count - self._closed: bool = False + self._closed: asyncio.Event = asyncio.Event() + self._closing_task: asyncio.Lock = asyncio.Lock() self._ready: asyncio.Event = asyncio.Event() self._connection._get_websocket = self._get_websocket self._connection._get_client = lambda: self @@ -272,12 +276,23 @@ def __init__( self._tasks = set() async def __aenter__(self) -> Client: - loop = asyncio.get_running_loop() - self.loop = loop - self.http.loop = loop - self._connection.loop = loop + if self._loop is None: + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + # No event loop was found, this should not happen + # because entering on this context manager means a + # loop is already active, but we need to handle it + # anyways just to prevent future errors. + + # Maybe handle different system event loop policies? + self._loop = asyncio.new_event_loop() + + self.http.loop = self.loop + self._connection.loop = self.loop self._ready = asyncio.Event() + self._closed = asyncio.Event() return self @@ -310,6 +325,21 @@ def _get_state(self, **options: Any) -> ConnectionState: def _handle_ready(self) -> None: self._ready.set() + @property + def loop(self) -> asyncio.AbstractEventLoop: + """The event loop that the client uses for asynchronous operations.""" + if self._loop is None: + raise RuntimeError("loop is not set") + return self._loop + + @loop.setter + def loop(self, value: asyncio.AbstractEventLoop) -> None: + if not isinstance(value, asyncio.AbstractEventLoop): + raise TypeError( + f"expected a AbstractEventLoop object, got {value.__class__.__name__!r} instead" + ) + self._loop = value + @property def latency(self) -> float: """Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. If no websocket @@ -746,23 +776,24 @@ async def close(self) -> None: Closes the connection to Discord. """ - if self._closed: - return + async with self._closing_task: + if self.is_closed(): + return - await self.http.close() - self._closed = True + await self.http.close() - for voice in self.voice_clients: - try: - await voice.disconnect(force=True) - except Exception: - # if an error happens during disconnects, disregard it. - pass + for voice in self.voice_clients: + try: + await voice.disconnect(force=True) + except Exception: + # if an error happens during disconnects, disregard it. + pass - if self.ws is not None and self.ws.open: - await self.ws.close(code=1000) + if self.ws is not None and self.ws.open: + await self.ws.close(code=1000) - self._ready.clear() + self._ready.clear() + self._closed.set() def clear(self) -> None: """Clears the internal state of the bot. @@ -771,7 +802,7 @@ def clear(self) -> None: and :meth:`is_ready` both return ``False`` along with the bot's internal cache cleared. """ - self._closed = False + self._closed.clear() self._ready.clear() self._connection.clear() self.http.recreate() @@ -786,10 +817,11 @@ async def start(self, token: str, *, reconnect: bool = True) -> None: TypeError An unexpected keyword argument was received. """ + # Update the loop to get the running one in case the one set is MISSING await self.login(token) await self.connect(reconnect=reconnect) - def run(self, *args: Any, **kwargs: Any) -> None: + def run(self, token: str, *, reconnect: bool = True) -> None: """A blocking call that abstracts away the event loop initialisation from you. @@ -800,12 +832,20 @@ def run(self, *args: Any, **kwargs: Any) -> None: Roughly Equivalent to: :: try: - loop.run_until_complete(start(*args, **kwargs)) + asyncio.run(start(token)) except KeyboardInterrupt: - loop.run_until_complete(close()) - # cancel all tasks lingering - finally: - loop.close() + return + + Parameters + ---------- + token: :class:`str` + The authentication token. Do not prefix this token with + anything as the library will do it for you. + reconnect: :class:`bool` + If we should attempt reconnecting to the gateway, either due to internet + failure or a specific failure on Discord's part. Certain + disconnects that lead to bad state will not be handled (such as + invalid sharding payloads or bad tokens). .. warning:: @@ -813,47 +853,36 @@ def run(self, *args: Any, **kwargs: Any) -> None: is blocking. That means that registration of events or anything being called after this function call will not execute until it returns. """ - loop = self.loop - - try: - loop.add_signal_handler(signal.SIGINT, loop.stop) - loop.add_signal_handler(signal.SIGTERM, loop.stop) - except (NotImplementedError, RuntimeError): - pass async def runner(): - try: - await self.start(*args, **kwargs) - finally: - if not self.is_closed(): - await self.close() + async with self: + await self.start(token=token, reconnect=reconnect) - def stop_loop_on_completion(f): - loop.stop() + try: + run = self.loop.run_until_complete + requires_cleanup = True + except RuntimeError: + run = asyncio.run + requires_cleanup = False - future = asyncio.ensure_future(runner(), loop=loop) - future.add_done_callback(stop_loop_on_completion) try: - loop.run_forever() - except KeyboardInterrupt: - _log.info("Received signal to terminate bot and event loop.") + run(runner()) finally: - future.remove_done_callback(stop_loop_on_completion) - _log.info("Cleaning up tasks.") - _cleanup_loop(loop) + # Ensure the bot is closed + if not self.is_closed(): + self.loop.run_until_complete(self.close()) - if not future.cancelled(): - try: - return future.result() - except KeyboardInterrupt: - # I am unsure why this gets raised here but suppress it anyway - return None + # asyncio.run automatically does the cleanup tasks, so if we use + # it we don't need to clean up the tasks. + if requires_cleanup: + _log.info("Cleaning up tasks.") + _cleanup_loop(self.loop) # properties def is_closed(self) -> bool: """Indicates if the WebSocket connection is closed.""" - return self._closed + return self._closed.is_set() @property def activity(self) -> ActivityTypes | None: diff --git a/discord/http.py b/discord/http.py index 145a3c411b..e0aac98804 100644 --- a/discord/http.py +++ b/discord/http.py @@ -175,9 +175,7 @@ def __init__( loop: asyncio.AbstractEventLoop | None = None, unsync_clock: bool = True, ) -> None: - self.loop: asyncio.AbstractEventLoop = ( - asyncio.get_event_loop() if loop is None else loop - ) + self.loop: asyncio.AbstractEventLoop = loop or MISSING self.connector = connector self.__session: aiohttp.ClientSession = MISSING # filled in static_login self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary()