@@ -17,19 +17,23 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: ...
17
17
import asyncio # noqa: WPS433
18
18
from enum import Enum , auto
19
19
20
+
20
21
class AsyncContext (Enum ):
21
22
"""Enum representing different async context types."""
22
-
23
+
23
24
ASYNCIO = auto ()
24
25
TRIO = auto ()
25
26
UNKNOWN = auto ()
26
27
28
+
27
29
# Check for anyio and trio availability
28
30
try :
29
31
import anyio # noqa: WPS433
32
+
30
33
has_anyio = True
31
34
try :
32
35
import trio # noqa: WPS433
36
+
33
37
has_trio = True
34
38
except ImportError : # pragma: no cover
35
39
has_trio = False
@@ -40,13 +44,13 @@ class AsyncContext(Enum):
40
44
41
45
def detect_async_context () -> AsyncContext :
42
46
"""Detect which async context we're currently running in.
43
-
47
+
44
48
Returns:
45
49
AsyncContext: The current async context type
46
50
"""
47
51
if not has_anyio : # pragma: no cover
48
52
return AsyncContext .ASYNCIO
49
-
53
+
50
54
if has_trio :
51
55
try :
52
56
# Check if we're in a trio context
@@ -56,10 +60,11 @@ def detect_async_context() -> AsyncContext:
56
60
except (RuntimeError , AttributeError ):
57
61
# Not in a trio context or trio API changed
58
62
pass
59
-
63
+
60
64
# Default to asyncio
61
65
return AsyncContext .ASYNCIO
62
66
67
+
63
68
_ValueType = TypeVar ('_ValueType' )
64
69
_AwaitableT = TypeVar ('_AwaitableT' , bound = Awaitable )
65
70
_Ps = ParamSpec ('_Ps' )
@@ -165,10 +170,10 @@ def __repr__(self) -> str:
165
170
def _create_lock (self ) -> AsyncLock :
166
171
"""Create the appropriate lock based on the current async context."""
167
172
context = detect_async_context ()
168
-
173
+
169
174
if context == AsyncContext .TRIO and has_anyio :
170
175
return anyio .Lock ()
171
-
176
+
172
177
# For ASYNCIO or UNKNOWN contexts
173
178
return asyncio .Lock ()
174
179
0 commit comments