Skip to content

Commit d524d0d

Browse files
author
Vincent Moens
committed
[Feature] Send info dict to the storage device in RBs
ghstack-source-id: 4ed60d6 Pull Request resolved: #2527
1 parent da0bf18 commit d524d0d

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An
712712
"for a proper usage of the batch-size arguments."
713713
)
714714
if not self._prefetch:
715-
ret = self._sample(batch_size)
715+
result = self._sample(batch_size)
716716
else:
717717
with self._futures_lock:
718718
while (
@@ -722,11 +722,15 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An
722722
) or not len(self._prefetch_queue):
723723
fut = self._prefetch_executor.submit(self._sample, batch_size)
724724
self._prefetch_queue.append(fut)
725-
ret = self._prefetch_queue.popleft().result()
725+
result = self._prefetch_queue.popleft().result()
726726

727727
if return_info:
728-
return ret
729-
return ret[0]
728+
out, info = result
729+
if getattr(self.storage, "device", None) is not None:
730+
device = self.storage.device
731+
info = tree_map(lambda x: x.to(device) if hasattr(x, "to") else x, info)
732+
return out, info
733+
return result[0]
730734

731735
def mark_update(self, index: Union[int, torch.Tensor]) -> None:
732736
self._sampler.mark_update(index, storage=self._storage)

0 commit comments

Comments
 (0)