Skip to content

Commit 02fc71e

Browse files
authored
Fix race condition in MemoryCache (#487)
1 parent dc0072f commit 02fc71e

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/petals/server/memory_cache.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self, max_size_bytes: Optional[int], max_alloc_timeout: Optional[fl
3131
self.max_alloc_timeout = max_alloc_timeout
3232
self._lock_metadata = mp.Lock()
3333
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
34-
self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=False)
34+
self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=True)
3535
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
3636
self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
3737
self.runtime_pid = os.getpid()
@@ -138,7 +138,8 @@ async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float])
138138
start_time = time.perf_counter()
139139
loop = asyncio.get_event_loop()
140140

141-
self.enqueued_size_bytes += alloc_size
141+
with self._enqueued_size.get_lock():
142+
self._enqueued_size.value += alloc_size
142143
allocated = False
143144
try:
144145
context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack()
@@ -155,13 +156,15 @@ async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float])
155156
await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout)
156157

157158
allocated = True
158-
self.enqueued_size_bytes -= alloc_size
159+
with self._enqueued_size.get_lock():
160+
self._enqueued_size.value -= alloc_size
159161
yield
160162
except asyncio.TimeoutError:
161163
raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds")
162164
finally:
163165
if not allocated:
164-
self.enqueued_size_bytes -= alloc_size
166+
with self._enqueued_size.get_lock():
167+
self._enqueued_size.value -= alloc_size
165168

166169
def _free(self, alloc_size: int, alloc_task: asyncio.Task):
167170
if alloc_task.exception() is not None:

0 commit comments

Comments
 (0)