Skip to content

Commit 4acb2a6

Browse files
Try to start the socket if it has a task group (#21)
1 parent 80a7b84 commit 4acb2a6

File tree

1 file changed

+43
-23
lines changed

1 file changed

+43
-23
lines changed

src/zmq_anyio/_socket.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import pickle
44
import selectors
55
from collections import deque
6-
from contextlib import AsyncExitStack
76
from functools import partial
7+
from threading import get_ident
88
from typing import (
99
Any,
1010
Callable,
@@ -13,14 +13,13 @@
1313

1414
from anyio import (
1515
Event,
16-
Lock,
1716
TASK_STATUS_IGNORED,
1817
create_task_group,
1918
get_cancelled_exc_class,
2019
sleep,
2120
wait_readable,
2221
)
23-
from anyio.abc import TaskStatus
22+
from anyio.abc import TaskGroup, TaskStatus
2423
from anyioutils import FIRST_COMPLETED, Future, create_task, wait
2524

2625
import zmq
@@ -157,15 +156,18 @@ class Socket(zmq.Socket):
157156
_fd = None
158157
_exit_stack = None
159158
_task_group = None
159+
__task_group = None
160+
_thread = None
160161
started = None
161162
stopped = None
162-
_start_lock = None
163+
_starting = None
163164
_exited = None
164165

165166
def __init__(
166167
self,
167168
context_or_socket: zmq.Context | zmq.Socket,
168169
socket_type: int = -1,
170+
task_group: TaskGroup | None = None,
169171
**kwargs,
170172
) -> None:
171173
"""
@@ -188,7 +190,7 @@ def __init__(
188190
self.started = Event()
189191
self._exited = Event()
190192
self.stopped = Event()
191-
self._start_lock = Lock()
193+
self._task_group = task_group
192194

193195
def get(self, key):
194196
result = super().get(key)
@@ -825,44 +827,56 @@ def _update_handler(self, state) -> None:
825827
self._schedule_remaining_events()
826828

827829
async def __aenter__(self) -> Socket:
828-
assert self._start_lock is not None
829-
async with self._start_lock:
830-
if self._task_group is None:
831-
async with AsyncExitStack() as exit_stack:
832-
self._task_group = await exit_stack.enter_async_context(
833-
create_task_group()
834-
)
835-
self._exit_stack = exit_stack.pop_all()
836-
await self._task_group.start(self._start)
830+
if self._starting:
831+
return
832+
833+
self._starting = True
834+
if self._task_group is None:
835+
self.__task_group = create_task_group()
836+
self._task_group = await self.__task_group.__aenter__()
837+
await self._task_group.start(self._start)
837838

838839
return self
839840

840841
async def __aexit__(self, exc_type, exc_value, exc_tb):
841842
await self.stop()
842-
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)
843+
if self.__task_group is not None:
844+
return await self.__task_group.__aexit__(exc_type, exc_value, exc_tb)
843845

844846
async def start(
845847
self,
846848
*,
847849
task_status: TaskStatus[None] = TASK_STATUS_IGNORED,
848850
) -> None:
849-
assert self._start_lock is not None
850-
async with self._start_lock:
851-
if self._task_group is None:
852-
async with create_task_group() as self._task_group:
853-
await self._task_group.start(self._start)
854-
task_status.started()
855-
else:
851+
if self._starting:
852+
return
853+
854+
self._starting = True
855+
assert self.started is not None
856+
if self.started.is_set():
857+
task_status.started()
858+
return
859+
860+
if self._task_group is None:
861+
async with create_task_group() as self._task_group:
856862
await self._task_group.start(self._start)
857863
task_status.started()
864+
else:
865+
await self._task_group.start(self._start)
866+
task_status.started()
867+
868+
async def _start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
869+
assert self.started is not None
870+
if self.started.is_set():
871+
return
858872

859-
async def _start(self, *, task_status: TaskStatus[None]):
860873
assert self.started is not None
861874
assert self.stopped is not None
862875
assert self._exited is not None
863876
assert self._task_group is not None
864877
task_status.started()
865878
self.started.set()
879+
self._thread = get_ident()
866880
try:
867881
while True:
868882
wait_stopped_task = create_task(
@@ -922,6 +936,12 @@ def _check_started(self):
922936
"Socket must be used with async context manager (or `await sock.start()`)"
923937
)
924938

939+
self._task_group.start_soon(self._start)
940+
941+
assert self._thread is not None
942+
if self._thread != get_ident():
943+
raise RuntimeError("Socket must be used in the same thread")
944+
925945

926946
def ignore_exceptions(exc: BaseException) -> bool:
927947
return True

0 commit comments

Comments
 (0)