30
30
T = TypeVar ("T" )
31
31
32
32
33
- class PendingType :
34
- def __repr__ (self ) -> str :
35
- return "AsyncPending"
36
-
37
-
38
- Pending = PendingType ()
39
-
40
-
41
- class PendingValueError (Exception ):
42
- """Exception raised when a value is accessed before it is ready."""
43
-
44
-
45
- class SoonValue (Generic [T ]):
46
- """Holds a value that will be available soon after an async operation."""
47
-
48
- def __init__ (self ) -> None :
49
- self ._stored_value : Union [T , PendingType ] = Pending
50
-
51
- @property
52
- def value (self ) -> "T" :
53
- if isinstance (self ._stored_value , PendingType ):
54
- msg = "The return value of this task is still pending."
55
- raise PendingValueError (msg )
56
- return self ._stored_value
57
-
58
- @property
59
- def ready (self ) -> bool :
60
- return not isinstance (self ._stored_value , PendingType )
61
-
62
-
63
- class TaskGroup :
64
- """Manages a group of asyncio tasks, allowing them to be run concurrently."""
65
-
66
- def __init__ (self ) -> None :
67
- self ._tasks : set [asyncio .Task [Any ]] = set ()
68
- self ._exceptions : list [BaseException ] = []
69
- self ._closed = False
70
-
71
- async def __aenter__ (self ) -> "TaskGroup" :
72
- if self ._closed :
73
- msg = "Cannot enter a task group that has already been closed."
74
- raise RuntimeError (msg )
75
- return self
76
-
77
- async def __aexit__ (
78
- self ,
79
- exc_type : "Optional[type[BaseException]]" , # noqa: PYI036
80
- exc_val : "Optional[BaseException]" , # noqa: PYI036
81
- exc_tb : "Optional[TracebackType]" , # noqa: PYI036
82
- ) -> None :
83
- self ._closed = True
84
- if exc_val :
85
- self ._exceptions .append (exc_val )
86
-
87
- if self ._tasks :
88
- await asyncio .wait (self ._tasks )
89
-
90
- if self ._exceptions :
91
- # Re-raise the first exception encountered.
92
- raise self ._exceptions [0 ]
93
-
94
- def create_task (self , coro : "Coroutine[Any, Any, Any]" ) -> "asyncio.Task[Any]" :
95
- """Create and add a coroutine as a task to the task group.
96
-
97
- Args:
98
- coro (Coroutine): The coroutine to be added as a task.
99
-
100
- Returns:
101
- asyncio.Task: The created asyncio task.
102
-
103
- Raises:
104
- RuntimeError: If the task group has already been closed.
105
- """
106
- if self ._closed :
107
- msg = "Cannot create a task in a task group that has already been closed."
108
- raise RuntimeError (msg )
109
- task = asyncio .create_task (coro )
110
- self ._tasks .add (task )
111
- task .add_done_callback (self ._tasks .discard )
112
- task .add_done_callback (self ._check_result )
113
- return task
114
-
115
- def _check_result (self , task : "asyncio.Task[Any]" ) -> None :
116
- """Check and store exceptions from a completed task.
117
-
118
- Args:
119
- task (asyncio.Task): The task to check for exceptions.
120
- """
121
- try :
122
- task .result () # This will raise the exception if one occurred.
123
- except Exception as e : # noqa: BLE001
124
- self ._exceptions .append (e )
125
-
126
- def start_soon_ (
127
- self ,
128
- async_function : "Callable[ParamSpecT, Awaitable[T]]" ,
129
- name : object = None ,
130
- ) -> "Callable[ParamSpecT, SoonValue[T]]" :
131
- """Create a function to start a new task in this task group.
132
-
133
- Args:
134
- async_function (Callable): An async function to call soon.
135
- name (object, optional): Name of the task for introspection and debugging.
136
-
137
- Returns:
138
- Callable: A function that starts the task and returns a SoonValue object.
139
- """
140
-
141
- @functools .wraps (async_function )
142
- def wrapper (* args : "ParamSpecT.args" , ** kwargs : "ParamSpecT.kwargs" ) -> "SoonValue[T]" :
143
- partial_f = functools .partial (async_function , * args , ** kwargs )
144
- soon_value : SoonValue [T ] = SoonValue ()
145
-
146
- @functools .wraps (partial_f )
147
- async def value_wrapper (* _args : "Any" ) -> None :
148
- value = await partial_f ()
149
- soon_value ._stored_value = value # pyright: ignore[reportPrivateUsage] # noqa: SLF001
150
-
151
- self .create_task (value_wrapper ) # type: ignore[arg-type]
152
- return soon_value
153
-
154
- return wrapper
155
-
156
-
157
- def create_task_group () -> "TaskGroup" :
158
- """Create a TaskGroup for managing multiple concurrent async tasks.
159
-
160
- Returns:
161
- TaskGroup: A new TaskGroup instance.
162
- """
163
- return TaskGroup ()
164
-
165
-
166
33
class CapacityLimiter :
167
34
"""Limits the number of concurrent operations using a semaphore."""
168
35
@@ -195,7 +62,7 @@ async def __aexit__(
195
62
self .release ()
196
63
197
64
198
- _default_limiter = CapacityLimiter (40 )
65
+ _default_limiter = CapacityLimiter (15 )
199
66
200
67
201
68
def run_ (async_function : "Callable[ParamSpecT, Coroutine[Any, Any, ReturnT]]" ) -> "Callable[ParamSpecT, ReturnT]" :
@@ -237,6 +104,7 @@ def await_(
237
104
Args:
238
105
async_function (Callable): The async function to convert.
239
106
raise_sync_error (bool, optional): If False, runs in a new event loop if no loop is present.
107
+ If True (default), raises RuntimeError if no loop is running.
240
108
241
109
Returns:
242
110
Callable: A blocking function that runs the async function.
@@ -248,12 +116,39 @@ def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT
248
116
try :
249
117
loop = asyncio .get_running_loop ()
250
118
except RuntimeError :
251
- loop = None
252
-
253
- if loop is None and raise_sync_error is False :
119
+ # No running event loop
120
+ if raise_sync_error :
121
+ msg = "await_ called without a running event loop and raise_sync_error=True"
122
+ raise RuntimeError (msg ) from None
123
+ return asyncio .run (partial_f ())
124
+ else :
125
+ # Running in an existing event loop.
126
+ if loop .is_running ():
127
+ try :
128
+ # Check if the current context is within a task managed by this loop
129
+ current_task = asyncio .current_task (loop = loop )
130
+ except RuntimeError :
131
+ # Not running inside a task managed by this loop
132
+ current_task = None
133
+
134
+ if current_task is not None :
135
+ # Called from within the event loop's execution context (a task).
136
+ # Blocking here would deadlock the loop.
137
+ msg = "await_ cannot be called from within an async task running on the same event loop. Use 'await' instead."
138
+ raise RuntimeError (msg )
139
+ # Called from a different thread than the loop's thread.
140
+ # It's safe to block this thread and wait for the loop.
141
+ future = asyncio .run_coroutine_threadsafe (partial_f (), loop )
142
+ # This blocks the *calling* thread, not the loop thread.
143
+ return future .result ()
144
+ # This case should ideally not happen if get_running_loop() succeeded
145
+ # but the loop isn't running, but handle defensively.
146
+ # loop is not running
147
+ if raise_sync_error :
148
+ msg = "await_ found a non-running loop via get_running_loop()"
149
+ raise RuntimeError (msg )
150
+ # Fallback to running in a new loop
254
151
return asyncio .run (partial_f ())
255
- # Running in an existing event loop
256
- return asyncio .run (partial_f ())
257
152
258
153
return wrapper
259
154
@@ -286,7 +181,7 @@ async def wrapper(
286
181
return wrapper
287
182
288
183
289
- def maybe_async_ (
184
+ def ensure_async_ (
290
185
function : "Callable[ParamSpecT, Union[Awaitable[ReturnT], ReturnT]]" ,
291
186
) -> "Callable[ParamSpecT, Awaitable[ReturnT]]" :
292
187
"""Convert a function to an async one if it is not already.
@@ -309,24 +204,6 @@ async def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "R
309
204
return wrapper
310
205
311
206
312
- def wrap_sync (fn : "Callable[ParamSpecT, ReturnT]" ) -> "Callable[ParamSpecT, Awaitable[ReturnT]]" :
313
- """Convert a sync function to an async one.
314
-
315
- Args:
316
- fn (Callable): The function to convert.
317
-
318
- Returns:
319
- Callable: An async function that runs the original function.
320
- """
321
- if inspect .iscoroutinefunction (fn ):
322
- return fn
323
-
324
- async def wrapped (* args : "ParamSpecT.args" , ** kwargs : "ParamSpecT.kwargs" ) -> ReturnT :
325
- return await async_ (functools .partial (fn , * args , ** kwargs ))()
326
-
327
- return wrapped
328
-
329
-
330
207
class _ContextManagerWrapper (Generic [T ]):
331
208
def __init__ (self , cm : AbstractContextManager [T ]) -> None :
332
209
self ._cm = cm
@@ -343,7 +220,7 @@ async def __aexit__(
343
220
return self ._cm .__exit__ (exc_type , exc_val , exc_tb )
344
221
345
222
346
- def maybe_async_context (
223
+ def with_ensure_async_ (
347
224
obj : "Union[AbstractContextManager[T], AbstractAsyncContextManager[T]]" ,
348
225
) -> "AbstractAsyncContextManager[T]" :
349
226
"""Convert a context manager to an async one if it is not already.
0 commit comments