@@ -220,7 +220,7 @@ class PendulumEnv(EnvBase):
220
220
221
221
def __init__ (self , td_params = None , seed = None , device = None ):
222
222
if td_params is None :
223
- td_params = self .gen_params ()
223
+ td_params = self .gen_params (device = self . device )
224
224
225
225
super ().__init__ (device = device )
226
226
self ._make_spec (td_params )
@@ -273,7 +273,7 @@ def _reset(self, tensordict):
273
273
# if no ``tensordict`` is passed, we generate a single set of hyperparameters
274
274
# Otherwise, we assume that the input ``tensordict`` contains all the relevant
275
275
# parameters to get started.
276
- tensordict = self .gen_params (batch_size = batch_size )
276
+ tensordict = self .gen_params (batch_size = batch_size , device = self . device )
277
277
278
278
high_th = torch .tensor (self .DEFAULT_X , device = self .device )
279
279
high_thdot = torch .tensor (self .DEFAULT_Y , device = self .device )
@@ -355,12 +355,12 @@ def make_composite_from_td(td):
355
355
return composite
356
356
357
357
def _set_seed (self , seed : int ):
358
- rng = torch .Generator ()
358
+ rng = torch .Generator (device = self . device )
359
359
rng .manual_seed (seed )
360
360
self .rng = rng
361
361
362
362
@staticmethod
363
- def gen_params (g = 10.0 , batch_size = None ) -> TensorDictBase :
363
+ def gen_params (g = 10.0 , batch_size = None , device = None ) -> TensorDictBase :
364
364
"""Returns a ``tensordict`` containing the physical parameters such as gravitational force and torque or speed limits."""
365
365
if batch_size is None :
366
366
batch_size = []
@@ -379,6 +379,7 @@ def gen_params(g=10.0, batch_size=None) -> TensorDictBase:
379
379
)
380
380
},
381
381
[],
382
+ device = device ,
382
383
)
383
384
if batch_size :
384
385
td = td .expand (batch_size ).contiguous ()
0 commit comments