Skip to content

Commit edd77fc

Browse files
authored
make sure that self._start and self.stop calls are contained within TaskGroup acmgr (#25)
* make sure that self._start and self.stop calls are contained within TaskGroup acmgr * ruff format * restore stack to None after exiting it * declare __stack * fix mypy
1 parent a153426 commit edd77fc

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

src/zmq_anyio/_socket.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from contextlib import AsyncExitStack
34
import pickle
45
import selectors
56
from collections import deque
@@ -156,7 +157,7 @@ class Socket(zmq.Socket):
156157
_fd = None
157158
_exit_stack = None
158159
_task_group = None
159-
__task_group = None
160+
__stack: AsyncExitStack | None = None
160161
_thread = None
161162
started = None
162163
stopped = None
@@ -191,6 +192,7 @@ def __init__(
191192
self._exited = Event()
192193
self.stopped = Event()
193194
self._task_group = task_group
195+
self.__stack = None
194196

195197
def get(self, key):
196198
result = super().get(key)
@@ -831,17 +833,26 @@ async def __aenter__(self) -> Socket:
831833
return
832834

833835
self._starting = True
834-
if self._task_group is None:
835-
self.__task_group = create_task_group()
836-
self._task_group = await self.__task_group.__aenter__()
837-
await self._task_group.start(self._start)
836+
if self._task_group is not None:
837+
return self
838+
839+
async with AsyncExitStack() as stack:
840+
self._task_group = task_group = await stack.enter_async_context(
841+
create_task_group()
842+
)
843+
await task_group.start(self._start)
844+
stack.push_async_callback(self.stop)
845+
self.__stack = stack.pop_all()
838846

839847
return self
840848

841849
async def __aexit__(self, exc_type, exc_value, exc_tb):
850+
if self.__stack is not None:
851+
try:
852+
return await self.__stack.__aexit__(exc_type, exc_value, exc_tb)
853+
finally:
854+
self.__stack = None
842855
await self.stop()
843-
if self.__task_group is not None:
844-
return await self.__task_group.__aexit__(exc_type, exc_value, exc_tb)
845856

846857
async def start(
847858
self,

0 commit comments

Comments
 (0)