Skip to content

Commit 6799a7f

Browse files
author
Vincent Moens
committed
[BugFix] Fix pendulum device
ghstack-source-id: bcaf20d Pull Request resolved: #2516
1 parent c851e16 commit 6799a7f

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

test/test_env.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3405,16 +3405,16 @@ def test_tictactoe_env_single(self):
34053405
)
34063406
assert r.shape == (5, 100)
34073407

3408-
def test_pendulum_env(self):
3409-
env = PendulumEnv(device=None)
3410-
assert env.device is None
3411-
env = PendulumEnv(device="cpu")
3412-
assert env.device == torch.device("cpu")
3408+
@pytest.mark.parametrize("device", [None, *get_default_devices()])
3409+
def test_pendulum_env(self, device):
3410+
env = PendulumEnv(device=device)
3411+
assert env.device == device
34133412
check_env_specs(env)
3413+
34143414
for _ in range(10):
34153415
r = env.rollout(10)
34163416
assert r.shape == torch.Size((10,))
3417-
r = env.rollout(10, tensordict=TensorDict(batch_size=[5]))
3417+
r = env.rollout(10, tensordict=TensorDict(batch_size=[5], device=device))
34183418
assert r.shape == torch.Size((5, 10))
34193419

34203420

torchrl/envs/custom/pendulum.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ class PendulumEnv(EnvBase):
220220

221221
def __init__(self, td_params=None, seed=None, device=None):
222222
if td_params is None:
223-
td_params = self.gen_params()
223+
td_params = self.gen_params(device=self.device)
224224

225225
super().__init__(device=device)
226226
self._make_spec(td_params)
@@ -273,7 +273,7 @@ def _reset(self, tensordict):
273273
# if no ``tensordict`` is passed, we generate a single set of hyperparameters
274274
# Otherwise, we assume that the input ``tensordict`` contains all the relevant
275275
# parameters to get started.
276-
tensordict = self.gen_params(batch_size=batch_size)
276+
tensordict = self.gen_params(batch_size=batch_size, device=self.device)
277277

278278
high_th = torch.tensor(self.DEFAULT_X, device=self.device)
279279
high_thdot = torch.tensor(self.DEFAULT_Y, device=self.device)
@@ -355,12 +355,12 @@ def make_composite_from_td(td):
355355
return composite
356356

357357
def _set_seed(self, seed: int):
358-
rng = torch.Generator()
358+
rng = torch.Generator(device=self.device)
359359
rng.manual_seed(seed)
360360
self.rng = rng
361361

362362
@staticmethod
363-
def gen_params(g=10.0, batch_size=None) -> TensorDictBase:
363+
def gen_params(g=10.0, batch_size=None, device=None) -> TensorDictBase:
364364
"""Returns a ``tensordict`` containing the physical parameters such as gravitational force and torque or speed limits."""
365365
if batch_size is None:
366366
batch_size = []
@@ -379,6 +379,7 @@ def gen_params(g=10.0, batch_size=None) -> TensorDictBase:
379379
)
380380
},
381381
[],
382+
device=device,
382383
)
383384
if batch_size:
384385
td = td.expand(batch_size).contiguous()

0 commit comments

Comments
 (0)