Skip to content

Commit bfb9ffc

Browse files
authored
simplify wait_readable loop + notify_closing (#31)
* simplify wait_readable loop * handle more close possibilities * install graingert's notify_closing * issue close notifications * side step ExceptionGroup stuff * close sockets inside the event loop * catch exception from .fileno() * don't check closed-ness via zmq if possible * Update .github/workflows/test.yml * Update src/zmq_anyio/_socket.py * Update src/zmq_anyio/_socket.py * remove pre-release dep * Update anyio dependency for notify_closing * Update .github/workflows/test.yml * Update .github/workflows/test.yml * Replace trio with anyio[trio] in optional dependencies
1 parent b4758a1 commit bfb9ffc

File tree

3 files changed

+37
-33
lines changed

3 files changed

+37
-33
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ classifiers = [
2727
]
2828
requires-python = ">= 3.9"
2929
dependencies = [
30-
"anyio >=4.8.0,<5.0.0",
30+
"anyio >=4.10.0,<5.0.0",
3131
"anyioutils >=0.7.1,<0.8.0",
3232
"pyzmq >=26.0.0,<28.0.0",
3333
]
@@ -36,7 +36,7 @@ dependencies = [
3636
test = [
3737
"pytest >=8,<9",
3838
"pytest-timeout",
39-
"trio >=0.27.0,<0.28",
39+
"anyio[trio]",
4040
"mypy",
4141
"ruff",
4242
"coverage[toml] >=7,<8",

src/zmq_anyio/_socket.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
get_cancelled_exc_class,
2020
sleep,
2121
wait_readable,
22+
ClosedResourceError,
23+
notify_closing,
2224
)
2325
from anyio.abc import TaskGroup, TaskStatus
24-
from anyioutils import FIRST_COMPLETED, Future, create_task, wait
26+
from anyioutils import Future, create_task
2527

2628
import zmq
2729
from zmq import EVENTS, POLLIN, POLLOUT
@@ -890,36 +892,36 @@ async def _start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
890892
task_status.started()
891893
self.started.set()
892894
self._thread = get_ident()
895+
896+
async def wait_or_cancel() -> None:
897+
assert self.stopped is not None
898+
await self.stopped.wait()
899+
tg.cancel_scope.cancel()
900+
901+
def fileno() -> int:
902+
if self.closed:
903+
return -1
904+
try:
905+
return self._shadow_sock.fileno()
906+
except zmq.ZMQError:
907+
return -1
908+
893909
try:
894-
while True:
895-
wait_stopped_task = create_task(
896-
self.stopped.wait(),
897-
self._task_group,
898-
exception_handler=ignore_exceptions,
899-
)
900-
tasks = [
901-
create_task(
902-
wait_readable(self._shadow_sock), # type: ignore[arg-type]
903-
self._task_group,
904-
exception_handler=ignore_exceptions,
905-
),
906-
wait_stopped_task,
907-
]
908-
done, pending = await wait(
909-
tasks, self._task_group, return_when=FIRST_COMPLETED
910-
)
911-
for task in pending:
912-
task.cancel()
913-
if wait_stopped_task in done:
910+
while (fd := fileno()) > 0:
911+
async with create_task_group() as tg:
912+
tg.start_soon(wait_or_cancel)
913+
try:
914+
await wait_readable(fd)
915+
except ClosedResourceError:
916+
break
917+
finally:
918+
tg.cancel_scope.cancel()
919+
if self.stopped.is_set():
914920
break
915921
await self._handle_events()
916-
except BaseException:
917-
pass
918922
finally:
919923
self._exited.set()
920-
921-
assert self.stopped is not None
922-
self.stopped.set()
924+
self.stopped.set()
923925

924926
async def stop(self):
925927
assert self._exited is not None
@@ -933,11 +935,13 @@ async def stop(self):
933935
self.close()
934936

935937
def close(self, linger: int | None = None) -> None:
936-
try:
937-
if not self.closed and self._fd is not None:
938+
fd = self._fd
939+
if not self.closed and fd is not None:
940+
notify_closing(fd)
941+
try:
938942
super().close(linger=linger)
939-
except BaseException:
940-
pass
943+
except BaseException:
944+
pass
941945

942946
assert self.stopped is not None
943947
self.stopped.set()

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def context(contexts):
6666

6767

6868
@pytest.fixture
69-
def sockets(contexts):
69+
async def sockets(contexts):
7070
sockets = []
7171
yield sockets
7272
# ensure any tracked sockets get their contexts cleaned up

0 commit comments

Comments
 (0)