@@ -13,16 +13,52 @@ async def __aenter__(self) -> None: ...
13
13
async def __aexit__ (self , exc_type , exc_val , exc_tb ) -> None : ...
14
14
15
15
16
- # Try to use anyio.Lock, fall back to asyncio.Lock
17
- # Note: anyio is required for proper trio support
16
+ # Import both libraries if available
17
+ import asyncio # noqa: WPS433
18
+ from enum import Enum , auto
19
+
20
+ class AsyncContext (Enum ):
21
+ """Enum representing different async context types."""
22
+
23
+ ASYNCIO = auto ()
24
+ TRIO = auto ()
25
+ UNKNOWN = auto ()
26
+
27
+ # Check for anyio and trio availability
18
28
try :
19
29
import anyio # noqa: WPS433
30
+ has_anyio = True
31
+ try :
32
+ import trio # noqa: WPS433
33
+ has_trio = True
34
+ except ImportError : # pragma: no cover
35
+ has_trio = False
20
36
except ImportError : # pragma: no cover
21
- import asyncio # noqa: WPS433
37
+ has_anyio = False
38
+ has_trio = False
22
39
23
- Lock : type [AsyncLock ] = asyncio .Lock
24
- else :
25
- Lock = cast (type [AsyncLock ], anyio .Lock )
40
+
41
+ def detect_async_context () -> AsyncContext :
42
+ """Detect which async context we're currently running in.
43
+
44
+ Returns:
45
+ AsyncContext: The current async context type
46
+ """
47
+ if not has_anyio : # pragma: no cover
48
+ return AsyncContext .ASYNCIO
49
+
50
+ if has_trio :
51
+ try :
52
+ # Check if we're in a trio context
53
+ # Will raise RuntimeError if not in trio context
54
+ trio .lowlevel .current_task ()
55
+ return AsyncContext .TRIO
56
+ except (RuntimeError , AttributeError ):
57
+ # Not in a trio context or trio API changed
58
+ pass
59
+
60
+ # Default to asyncio
61
+ return AsyncContext .ASYNCIO
26
62
27
63
_ValueType = TypeVar ('_ValueType' )
28
64
_AwaitableT = TypeVar ('_AwaitableT' , bound = Awaitable )
@@ -78,9 +114,9 @@ class ReAwaitable:
78
114
79
115
def __init__ (self , coro : Awaitable [_ValueType ]) -> None :
80
116
"""We need just an awaitable to work with."""
81
- self ._lock = Lock ()
82
117
self ._coro = coro
83
118
self ._cache : _ValueType | _Sentinel = _sentinel
119
+ self ._lock = None # Will be created lazily based on the backend
84
120
85
121
def __await__ (self ) -> Generator [None , None , _ValueType ]:
86
122
"""
@@ -126,8 +162,22 @@ def __repr__(self) -> str:
126
162
"""
127
163
return repr (self ._coro )
128
164
165
+ def _create_lock (self ) -> AsyncLock :
166
+ """Create the appropriate lock based on the current async context."""
167
+ context = detect_async_context ()
168
+
169
+ if context == AsyncContext .TRIO and has_anyio :
170
+ return anyio .Lock ()
171
+
172
+ # For ASYNCIO or UNKNOWN contexts
173
+ return asyncio .Lock ()
174
+
129
175
async def _awaitable (self ) -> _ValueType :
130
176
"""Caches the once awaited value forever."""
177
+ # Create the lock if it doesn't exist
178
+ if self ._lock is None :
179
+ self ._lock = self ._create_lock ()
180
+
131
181
async with self ._lock :
132
182
if self ._cache is _sentinel :
133
183
self ._cache = await self ._coro
0 commit comments