3
3
import pickle
4
4
import selectors
5
5
from collections import deque
6
- from contextlib import AsyncExitStack
7
6
from functools import partial
7
+ from threading import get_ident
8
8
from typing import (
9
9
Any ,
10
10
Callable ,
13
13
14
14
from anyio import (
15
15
Event ,
16
- Lock ,
17
16
TASK_STATUS_IGNORED ,
18
17
create_task_group ,
19
18
get_cancelled_exc_class ,
20
19
sleep ,
21
20
wait_readable ,
22
21
)
23
- from anyio .abc import TaskStatus
22
+ from anyio .abc import TaskGroup , TaskStatus
24
23
from anyioutils import FIRST_COMPLETED , Future , create_task , wait
25
24
26
25
import zmq
@@ -157,15 +156,18 @@ class Socket(zmq.Socket):
157
156
_fd = None
158
157
_exit_stack = None
159
158
_task_group = None
159
+ __task_group = None
160
+ _thread = None
160
161
started = None
161
162
stopped = None
162
- _start_lock = None
163
+ _starting = None
163
164
_exited = None
164
165
165
166
def __init__ (
166
167
self ,
167
168
context_or_socket : zmq .Context | zmq .Socket ,
168
169
socket_type : int = - 1 ,
170
+ task_group : TaskGroup | None = None ,
169
171
** kwargs ,
170
172
) -> None :
171
173
"""
@@ -188,7 +190,7 @@ def __init__(
188
190
self .started = Event ()
189
191
self ._exited = Event ()
190
192
self .stopped = Event ()
191
- self ._start_lock = Lock ()
193
+ self ._task_group = task_group
192
194
193
195
def get (self , key ):
194
196
result = super ().get (key )
@@ -825,44 +827,56 @@ def _update_handler(self, state) -> None:
825
827
self ._schedule_remaining_events ()
826
828
827
829
async def __aenter__ (self ) -> Socket :
828
- assert self ._start_lock is not None
829
- async with self ._start_lock :
830
- if self ._task_group is None :
831
- async with AsyncExitStack () as exit_stack :
832
- self ._task_group = await exit_stack .enter_async_context (
833
- create_task_group ()
834
- )
835
- self ._exit_stack = exit_stack .pop_all ()
836
- await self ._task_group .start (self ._start )
830
+ if self ._starting :
831
+ return
832
+
833
+ self ._starting = True
834
+ if self ._task_group is None :
835
+ self .__task_group = create_task_group ()
836
+ self ._task_group = await self .__task_group .__aenter__ ()
837
+ await self ._task_group .start (self ._start )
837
838
838
839
return self
839
840
840
841
async def __aexit__ (self , exc_type , exc_value , exc_tb ):
841
842
await self .stop ()
842
- return await self ._exit_stack .__aexit__ (exc_type , exc_value , exc_tb )
843
+ if self .__task_group is not None :
844
+ return await self .__task_group .__aexit__ (exc_type , exc_value , exc_tb )
843
845
844
846
async def start (
845
847
self ,
846
848
* ,
847
849
task_status : TaskStatus [None ] = TASK_STATUS_IGNORED ,
848
850
) -> None :
849
- assert self ._start_lock is not None
850
- async with self ._start_lock :
851
- if self ._task_group is None :
852
- async with create_task_group () as self ._task_group :
853
- await self ._task_group .start (self ._start )
854
- task_status .started ()
855
- else :
851
+ if self ._starting :
852
+ return
853
+
854
+ self ._starting = True
855
+ assert self .started is not None
856
+ if self .started .is_set ():
857
+ task_status .started ()
858
+ return
859
+
860
+ if self ._task_group is None :
861
+ async with create_task_group () as self ._task_group :
856
862
await self ._task_group .start (self ._start )
857
863
task_status .started ()
864
+ else :
865
+ await self ._task_group .start (self ._start )
866
+ task_status .started ()
867
+
868
+ async def _start (self , * , task_status : TaskStatus [None ] = TASK_STATUS_IGNORED ):
869
+ assert self .started is not None
870
+ if self .started .is_set ():
871
+ return
858
872
859
- async def _start (self , * , task_status : TaskStatus [None ]):
860
873
assert self .started is not None
861
874
assert self .stopped is not None
862
875
assert self ._exited is not None
863
876
assert self ._task_group is not None
864
877
task_status .started ()
865
878
self .started .set ()
879
+ self ._thread = get_ident ()
866
880
try :
867
881
while True :
868
882
wait_stopped_task = create_task (
@@ -922,6 +936,12 @@ def _check_started(self):
922
936
"Socket must be used with async context manager (or `await sock.start()`)"
923
937
)
924
938
939
+ self ._task_group .start_soon (self ._start )
940
+
941
+ assert self ._thread is not None
942
+ if self ._thread != get_ident ():
943
+ raise RuntimeError ("Socket must be used in the same thread" )
944
+
925
945
926
946
def ignore_exceptions (exc : BaseException ) -> bool :
927
947
return True
0 commit comments