Skip to content

Commit 1d45117

Browse files
author
Vincent Moens
committed
[BugFix] Fix device transfer for collectors with init_random_frames mixed devices
ghstack-source-id: 1684399 Pull Request resolved: #2704
1 parent afb81de commit 1d45117

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

torchrl/collectors/collectors.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,21 @@ def rollout(self) -> TensorDictBase:
11481148
and self._frames < self.init_random_frames
11491149
):
11501150
self.env.rand_action(self._shuttle)
1151+
if (
1152+
self.policy_device is not None
1153+
and self.policy_device != self.env_device
1154+
):
1155+
# TODO: This may break with exclusive / ragged lazy stacks
1156+
self._shuttle.apply(
1157+
lambda name, val: val.to(
1158+
device=self.policy_device, non_blocking=True
1159+
)
1160+
if name in self._policy_output_keys
1161+
else val,
1162+
out=self._shuttle,
1163+
named=True,
1164+
nested_keys=True,
1165+
)
11511166
else:
11521167
if self._cast_to_policy_device:
11531168
if self.policy_device is not None:

0 commit comments

Comments
 (0)