@@ -31,7 +31,7 @@ def __init__(self, max_size_bytes: Optional[int], max_alloc_timeout: Optional[fl
31
31
self .max_alloc_timeout = max_alloc_timeout
32
32
self ._lock_metadata = mp .Lock ()
33
33
self ._current_size = mp .Value (ctypes .c_int64 , 0 , lock = False )
34
- self ._enqueued_size = mp .Value (ctypes .c_int64 , 0 , lock = False )
34
+ self ._enqueued_size = mp .Value (ctypes .c_int64 , 0 , lock = True )
35
35
self ._handle_counter = mp .Value (ctypes .c_int64 , 0 , lock = False )
36
36
self ._allocated_tensors : Dict [Handle , torch .Tensor ] = {}
37
37
self .runtime_pid = os .getpid ()
@@ -138,7 +138,8 @@ async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float])
138
138
start_time = time .perf_counter ()
139
139
loop = asyncio .get_event_loop ()
140
140
141
- self .enqueued_size_bytes += alloc_size
141
+ with self ._enqueued_size .get_lock ():
142
+ self ._enqueued_size .value += alloc_size
142
143
allocated = False
143
144
try :
144
145
context_manager = async_timeout .timeout (timeout ) if timeout != 0 else contextlib .AsyncExitStack ()
@@ -155,13 +156,15 @@ async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float])
155
156
await loop .run_in_executor (None , self ._wait_until_available , alloc_size , remaining_timeout )
156
157
157
158
allocated = True
158
- self .enqueued_size_bytes -= alloc_size
159
+ with self ._enqueued_size .get_lock ():
160
+ self ._enqueued_size .value -= alloc_size
159
161
yield
160
162
except asyncio .TimeoutError :
161
163
raise AllocationFailed (f"Could not allocate { alloc_size } within { timeout } seconds" )
162
164
finally :
163
165
if not allocated :
164
- self .enqueued_size_bytes -= alloc_size
166
+ with self ._enqueued_size .get_lock ():
167
+ self ._enqueued_size .value -= alloc_size
165
168
166
169
def _free (self , alloc_size : int , alloc_task : asyncio .Task ):
167
170
if alloc_task .exception () is not None :
0 commit comments