@@ -11192,6 +11192,9 @@ def __init__(self, out_keys: Sequence[NestedKey] = None, time_key: str = "time")
11192
11192
self .last_inv_time = None
11193
11193
self .last_call_time = None
11194
11194
self .last_reset_time = None
11195
+ self .time_step_key = self .out_keys [0 ]
11196
+ self .time_policy_key = self .out_keys [1 ]
11197
+ self .time_reset_key = self .out_keys [2 ]
11195
11198
11196
11199
def _reset_env_preprocess (self , tensordict : TensorDictBase ) -> TensorDictBase :
11197
11200
self .last_reset_time = self .last_inv_time = time .time ()
@@ -11219,13 +11222,17 @@ def _reset(
11219
11222
time_elapsed = torch .tensor (
11220
11223
current_time - self .last_reset_time , device = tensordict .device
11221
11224
)
11222
- self ._maybe_expand_and_set (self .out_keys [2 ], time_elapsed , tensordict_reset )
11223
11225
self ._maybe_expand_and_set (
11224
- self .out_keys [0 ], time_elapsed * 0 , tensordict_reset
11226
+ self .time_reset_key , time_elapsed , tensordict_reset
11227
+ )
11228
+ self ._maybe_expand_and_set (
11229
+ self .time_step_key , time_elapsed * 0 , tensordict_reset
11225
11230
)
11226
11231
self .last_call_time = current_time
11227
11232
# Placeholder
11228
- self ._maybe_expand_and_set (self .out_keys [1 ], time_elapsed * 0 , tensordict_reset )
11233
+ self ._maybe_expand_and_set (
11234
+ self .time_policy_key , time_elapsed * 0 , tensordict_reset
11235
+ )
11229
11236
return tensordict_reset
11230
11237
11231
11238
def _inv_call (self , tensordict : TensorDictBase ) -> TensorDictBase :
@@ -11234,7 +11241,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
11234
11241
time_elapsed = torch .tensor (
11235
11242
current_time - self .last_call_time , device = tensordict .device
11236
11243
)
11237
- self ._maybe_expand_and_set (self .out_keys [ 1 ] , time_elapsed , tensordict )
11244
+ self ._maybe_expand_and_set (self .time_policy_key , time_elapsed , tensordict )
11238
11245
self .last_inv_time = current_time
11239
11246
return tensordict
11240
11247
@@ -11246,23 +11253,25 @@ def _step(
11246
11253
time_elapsed = torch .tensor (
11247
11254
current_time - self .last_inv_time , device = tensordict .device
11248
11255
)
11249
- self ._maybe_expand_and_set (self .out_keys [0 ], time_elapsed , next_tensordict )
11250
11256
self ._maybe_expand_and_set (
11251
- self .out_keys [2 ], time_elapsed * 0 , next_tensordict
11257
+ self .time_step_key , time_elapsed , next_tensordict
11258
+ )
11259
+ self ._maybe_expand_and_set (
11260
+ self .time_reset_key , time_elapsed * 0 , next_tensordict
11252
11261
)
11253
11262
self .last_call_time = current_time
11254
11263
# presumbly no need to worry about batch size incongruencies here
11255
- next_tensordict .set (self .out_keys [ 1 ] , tensordict .get (self .out_keys [ 1 ] ))
11264
+ next_tensordict .set (self .time_policy_key , tensordict .get (self .time_policy_key ))
11256
11265
return next_tensordict
11257
11266
11258
11267
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
11259
- observation_spec [self .out_keys [ 0 ] ] = Unbounded (
11268
+ observation_spec [self .time_step_key ] = Unbounded (
11260
11269
shape = observation_spec .shape , device = observation_spec .device
11261
11270
)
11262
- observation_spec [self .out_keys [ 1 ] ] = Unbounded (
11271
+ observation_spec [self .time_policy_key ] = Unbounded (
11263
11272
shape = observation_spec .shape , device = observation_spec .device
11264
11273
)
11265
- observation_spec [self .out_keys [ 2 ] ] = Unbounded (
11274
+ observation_spec [self .time_reset_key ] = Unbounded (
11266
11275
shape = observation_spec .shape , device = observation_spec .device
11267
11276
)
11268
11277
return observation_spec
0 commit comments