diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f6d8873..97be353 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,7 +32,9 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install dependencies - run: pip install -e ".[test]" + run: | + pip install -e ".[test]" + pip install git+https://github.com/agronholm/anyio.git#egg=anyio --ignore-installed - name: Check with mypy and ruff if: ${{ (matrix.python-version == '3.13') && (matrix.os == 'ubuntu-latest') }} run: | diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index 164fb99..e56c15c 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -19,9 +19,11 @@ get_cancelled_exc_class, sleep, wait_readable, + ClosedResourceError, + notify_closing, ) from anyio.abc import TaskGroup, TaskStatus -from anyioutils import FIRST_COMPLETED, Future, create_task, wait +from anyioutils import Future, create_task import zmq from zmq import EVENTS, POLLIN, POLLOUT @@ -890,36 +892,36 @@ async def _start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): task_status.started() self.started.set() self._thread = get_ident() + + async def wait_or_cancel() -> None: + assert self.stopped is not None + await self.stopped.wait() + tg.cancel_scope.cancel() + + def fileno() -> int: + if self.closed: + return -1 + try: + return self._shadow_sock.fileno() + except zmq.ZMQError: + return -1 + try: - while True: - wait_stopped_task = create_task( - self.stopped.wait(), - self._task_group, - exception_handler=ignore_exceptions, - ) - tasks = [ - create_task( - wait_readable(self._shadow_sock), # type: ignore[arg-type] - self._task_group, - exception_handler=ignore_exceptions, - ), - wait_stopped_task, - ] - done, pending = await wait( - tasks, self._task_group, return_when=FIRST_COMPLETED - ) - for task in pending: - task.cancel() - if wait_stopped_task in done: + while (fd := fileno()) > 0: + async with create_task_group() as tg: + tg.start_soon(wait_or_cancel) + try: + await wait_readable(fd) + except ClosedResourceError: + break + finally: + tg.cancel_scope.cancel() + if self.stopped.is_set(): break await self._handle_events() - except BaseException: - pass finally: self._exited.set() - - assert self.stopped is not None - self.stopped.set() + self.stopped.set() async def stop(self): assert self._exited is not None @@ -933,11 +935,13 @@ async def stop(self): self.close() def close(self, linger: int | None = None) -> None: - try: - if not self.closed and self._fd is not None: + fd = self._fd + if not self.closed and fd is not None: + notify_closing(fd) + try: super().close(linger=linger) - except BaseException: - pass + except BaseException: + pass assert self.stopped is not None self.stopped.set() diff --git a/tests/conftest.py b/tests/conftest.py index 8f68ac6..99510ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -66,7 +66,7 @@ def context(contexts): @pytest.fixture -def sockets(contexts): +async def sockets(contexts): sockets = [] yield sockets # ensure any tracked sockets get their contexts cleaned up