Skip to content

Commit 5847569

Browse files
proboscisclaude
andcommitted
Fix ReAwaitable to use context-specific locks
- Add AsyncContext enum to identify the current async runtime - Create locks dynamically based on detected context - Fix issue when using anyio.run with trio backend 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent c1db704 commit 5847569

File tree

1 file changed

+57
-7
lines changed

1 file changed

+57
-7
lines changed

returns/primitives/reawaitable.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,52 @@ async def __aenter__(self) -> None: ...
1313
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: ...
1414

1515

16-
# Try to use anyio.Lock, fall back to asyncio.Lock
17-
# Note: anyio is required for proper trio support
16+
# Import both libraries if available
17+
import asyncio # noqa: WPS433
18+
from enum import Enum, auto
19+
20+
class AsyncContext(Enum):
21+
"""Enum representing different async context types."""
22+
23+
ASYNCIO = auto()
24+
TRIO = auto()
25+
UNKNOWN = auto()
26+
27+
# Check for anyio and trio availability
1828
try:
1929
import anyio # noqa: WPS433
30+
has_anyio = True
31+
try:
32+
import trio # noqa: WPS433
33+
has_trio = True
34+
except ImportError: # pragma: no cover
35+
has_trio = False
2036
except ImportError: # pragma: no cover
21-
import asyncio # noqa: WPS433
37+
has_anyio = False
38+
has_trio = False
2239

23-
Lock: type[AsyncLock] = asyncio.Lock
24-
else:
25-
Lock = cast(type[AsyncLock], anyio.Lock)
40+
41+
def detect_async_context() -> AsyncContext:
42+
"""Detect which async context we're currently running in.
43+
44+
Returns:
45+
AsyncContext: The current async context type
46+
"""
47+
if not has_anyio: # pragma: no cover
48+
return AsyncContext.ASYNCIO
49+
50+
if has_trio:
51+
try:
52+
# Check if we're in a trio context
53+
# Will raise RuntimeError if not in trio context
54+
trio.lowlevel.current_task()
55+
return AsyncContext.TRIO
56+
except (RuntimeError, AttributeError):
57+
# Not in a trio context or trio API changed
58+
pass
59+
60+
# Default to asyncio
61+
return AsyncContext.ASYNCIO
2662

2763
_ValueType = TypeVar('_ValueType')
2864
_AwaitableT = TypeVar('_AwaitableT', bound=Awaitable)
@@ -78,9 +114,9 @@ class ReAwaitable:
78114

79115
def __init__(self, coro: Awaitable[_ValueType]) -> None:
80116
"""We need just an awaitable to work with."""
81-
self._lock = Lock()
82117
self._coro = coro
83118
self._cache: _ValueType | _Sentinel = _sentinel
119+
self._lock = None # Will be created lazily based on the backend
84120

85121
def __await__(self) -> Generator[None, None, _ValueType]:
86122
"""
@@ -126,8 +162,22 @@ def __repr__(self) -> str:
126162
"""
127163
return repr(self._coro)
128164

165+
def _create_lock(self) -> AsyncLock:
166+
"""Create the appropriate lock based on the current async context."""
167+
context = detect_async_context()
168+
169+
if context == AsyncContext.TRIO and has_anyio:
170+
return anyio.Lock()
171+
172+
# For ASYNCIO or UNKNOWN contexts
173+
return asyncio.Lock()
174+
129175
async def _awaitable(self) -> _ValueType:
130176
"""Caches the once awaited value forever."""
177+
# Create the lock if it doesn't exist
178+
if self._lock is None:
179+
self._lock = self._create_lock()
180+
131181
async with self._lock:
132182
if self._cache is _sentinel:
133183
self._cache = await self._coro

0 commit comments

Comments
 (0)