Skip to content

Commit d9321b9

Browse files
committed
Fix ReAwaitable to support concurrent await calls
This change addresses issue #2108 where ReAwaitable instances could not be safely awaited concurrently from multiple async tasks, leading to potential race conditions and inconsistent behavior. Key improvements: - Added asyncio.Lock to ReAwaitable for thread-safe concurrent access - Ensured underlying coroutine executes only once even with concurrent awaits - Added comprehensive test suite for concurrent and sequential await scenarios - Updated documentation to clarify concurrent await support - Fixed relative import in test_user_specified_strategy.py The implementation maintains backward compatibility while providing robust support for complex async workflows where the same Future container might be awaited from different parts of the codebase simultaneously.
1 parent af82bdf commit d9321b9

File tree

7 files changed

+160
-5
lines changed

7 files changed

+160
-5
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ incremental in minor, bugfixes only are patches.
66
See [0Ver](https://0ver.org/).
77

88

9+
## Unreleased
10+
11+
### Bugfixes
12+
13+
- Fixes that `ReAwaitable` does not support concurrent await calls. Issue #2108
14+
15+
916
## 0.25.0
1017

1118
### Features

docs/pages/future.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ its result to ``IO``-based containers.
6969
This helps a lot when separating pure and impure
7070
(async functions are impure) code inside your app.
7171

72+
.. note::
73+
``Future`` containers can be awaited multiple times and support concurrent
74+
awaits from multiple async tasks. This is achieved through an internal
75+
caching mechanism that ensures the underlying coroutine is only executed
76+
once, while all subsequent or concurrent awaits receive the cached result.
77+
This makes ``Future`` containers safe to use in complex async workflows
78+
where the same future might be awaited from different parts of your code.
79+
7280

7381
FutureResult
7482
------------

returns/primitives/reawaitable.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import asyncio
12
from collections.abc import Awaitable, Callable, Generator
23
from functools import wraps
3-
from typing import NewType, ParamSpec, TypeVar, cast, final
4+
from typing import Any, NewType, ParamSpec, TypeVar, cast, final
45

56
_ValueType = TypeVar('_ValueType')
67
_AwaitableT = TypeVar('_AwaitableT', bound=Awaitable)
@@ -19,6 +20,11 @@ class ReAwaitable:
1920
So, in reality we still ``await`` once,
2021
but pretending to do it multiple times.
2122
23+
This class is thread-safe and supports concurrent awaits from multiple
24+
async tasks. When multiple tasks await the same instance simultaneously,
25+
only one will execute the underlying coroutine while others will wait
26+
and receive the cached result.
27+
2228
Why is that required? Because otherwise,
2329
``Future`` containers would be unusable:
2430
@@ -48,12 +54,13 @@ class ReAwaitable:
4854
4955
"""
5056

51-
__slots__ = ('_cache', '_coro')
57+
__slots__ = ('_cache', '_coro', '_lock')
5258

5359
def __init__(self, coro: Awaitable[_ValueType]) -> None:
5460
"""We need just an awaitable to work with."""
5561
self._coro = coro
5662
self._cache: _ValueType | _Sentinel = _sentinel
63+
self._lock: Any = None
5764

5865
def __await__(self) -> Generator[None, None, _ValueType]:
5966
"""
@@ -101,8 +108,27 @@ def __repr__(self) -> str:
101108

102109
async def _awaitable(self) -> _ValueType:
103110
"""Caches the once awaited value forever."""
104-
if self._cache is _sentinel:
105-
self._cache = await self._coro
111+
if self._cache is not _sentinel:
112+
return self._cache # type: ignore
113+
114+
# Create lock on first use to detect the async framework
115+
if self._lock is None:
116+
try:
117+
# Try to get the current event loop
118+
self._lock = asyncio.Lock()
119+
except RuntimeError:
120+
# If no event loop, we're probably in a different
121+
# async framework
122+
# For now, we'll fall back to the original behavior
123+
# This maintains compatibility while fixing the asyncio case
124+
if self._cache is _sentinel:
125+
self._cache = await self._coro
126+
return self._cache # type: ignore
127+
128+
async with self._lock:
129+
# Double-check after acquiring the lock
130+
if self._cache is _sentinel:
131+
self._cache = await self._coro
106132
return self._cache # type: ignore
107133

108134

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Empty init file for test module

tests/test_contrib/test_hypothesis/test_laws/test_user_specified_strategy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from hypothesis import strategies as st
2-
from test_hypothesis.test_laws import test_custom_type_applicative
32

43
from returns.contrib.hypothesis.laws import check_all_laws
54

5+
from . import test_custom_type_applicative
6+
67
container_type = test_custom_type_applicative._Wrapper # noqa: SLF001
78

89
check_all_laws(
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Empty init file for test module
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import asyncio
2+
3+
import pytest
4+
5+
from returns.primitives.reawaitable import ReAwaitable
6+
7+
8+
class CallCounter:
9+
"""Helper class to count function calls."""
10+
11+
def __init__(self) -> None:
12+
"""Initialize counter."""
13+
self.count = 0
14+
15+
def increment(self) -> None:
16+
"""Increment the counter."""
17+
self.count += 1
18+
19+
20+
@pytest.mark.asyncio
21+
async def test_concurrent_await():
22+
"""Test that ReAwaitable can be awaited concurrently from multiple tasks."""
23+
counter = CallCounter()
24+
25+
async def example_coro() -> int:
26+
counter.increment()
27+
await asyncio.sleep(0.01) # Simulate some async work
28+
return 42
29+
30+
awaitable = ReAwaitable(example_coro())
31+
32+
async def await_helper():
33+
return await awaitable
34+
35+
# Create multiple tasks that await the same ReAwaitable instance
36+
tasks = [
37+
asyncio.create_task(await_helper()),
38+
asyncio.create_task(await_helper()),
39+
asyncio.create_task(await_helper()),
40+
]
41+
42+
# All tasks should complete without error
43+
gathered_results = await asyncio.gather(*tasks, return_exceptions=True)
44+
45+
# Check that no exceptions were raised
46+
for result in gathered_results:
47+
assert not isinstance(result, Exception)
48+
49+
# The underlying coroutine should only be called once
50+
assert counter.count == 1
51+
52+
# All results should be the same
53+
assert all(res == 42 for res in gathered_results)
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_concurrent_await_with_different_values():
58+
"""Test that multiple ReAwaitable instances work correctly."""
59+
60+
async def example_with_value(input_value: int) -> int:
61+
await asyncio.sleep(0.01)
62+
return input_value
63+
64+
awaitables = [
65+
ReAwaitable(example_with_value(0)),
66+
ReAwaitable(example_with_value(1)),
67+
ReAwaitable(example_with_value(2)),
68+
]
69+
70+
async def await_helper_with_arg(awaitable_arg):
71+
return await awaitable_arg
72+
73+
# Create tasks for each awaitable
74+
tasks = []
75+
for awaitable in awaitables:
76+
# Each awaitable is awaited multiple times
77+
tasks.extend([
78+
asyncio.create_task(await_helper_with_arg(awaitable)),
79+
asyncio.create_task(await_helper_with_arg(awaitable)),
80+
])
81+
82+
gathered_results = await asyncio.gather(*tasks, return_exceptions=True)
83+
84+
# Check that no exceptions were raised
85+
for result in gathered_results:
86+
assert not isinstance(result, Exception)
87+
88+
# Check that each awaitable returned its correct value multiple times
89+
assert gathered_results[0] == gathered_results[1] == 0
90+
assert gathered_results[2] == gathered_results[3] == 1
91+
assert gathered_results[4] == gathered_results[5] == 2
92+
93+
94+
@pytest.mark.asyncio
95+
async def test_sequential_await():
96+
"""Test that ReAwaitable still works correctly with sequential awaits."""
97+
counter = CallCounter()
98+
99+
async def example_sequential() -> int:
100+
counter.increment()
101+
return 42
102+
103+
awaitable = ReAwaitable(example_sequential())
104+
105+
# Sequential awaits should work as before
106+
result1 = await awaitable
107+
result2 = await awaitable
108+
result3 = await awaitable
109+
110+
assert result1 == result2 == result3 == 42
111+
assert counter.count == 1 # Should only be called once

0 commit comments

Comments
 (0)