We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent afb81de commit 1d45117Copy full SHA for 1d45117
torchrl/collectors/collectors.py
@@ -1148,6 +1148,21 @@ def rollout(self) -> TensorDictBase:
1148
and self._frames < self.init_random_frames
1149
):
1150
self.env.rand_action(self._shuttle)
1151
+ if (
1152
+ self.policy_device is not None
1153
+ and self.policy_device != self.env_device
1154
+ ):
1155
+ # TODO: This may break with exclusive / ragged lazy stacks
1156
+ self._shuttle.apply(
1157
+ lambda name, val: val.to(
1158
+ device=self.policy_device, non_blocking=True
1159
+ )
1160
+ if name in self._policy_output_keys
1161
+ else val,
1162
+ out=self._shuttle,
1163
+ named=True,
1164
+ nested_keys=True,
1165
1166
else:
1167
if self._cast_to_policy_device:
1168
if self.policy_device is not None:
0 commit comments