@@ -209,6 +209,7 @@ class Transform(nn.Module):
209
209
"""
210
210
211
211
invertible = False
212
+ enable_inv_on_reset = False
212
213
213
214
def __init__ (
214
215
self ,
@@ -293,6 +294,13 @@ def _reset(
293
294
"""Resets a transform if it is stateful."""
294
295
return tensordict_reset
295
296
297
+ def _reset_env_preprocess (self , tensordict : TensorDictBase ) -> TensorDictBase :
298
+ """Inverts the input to :meth:`TransformedEnv._reset`, if needed."""
299
+ if self .enable_inv_on_reset :
300
+ with _set_missing_tolerance (self , True ):
301
+ tensordict = self .inv (tensordict )
302
+ return tensordict
303
+
296
304
def init (self , tensordict ) -> None :
297
305
pass
298
306
@@ -1018,10 +1026,7 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
1018
1026
tensordict = tensordict .select (
1019
1027
* self .reset_keys , * self .state_spec .keys (True , True ), strict = False
1020
1028
)
1021
- # Inputs might be transformed, so need to apply inverse transform
1022
- # before passing to the env reset function.
1023
- with _set_missing_tolerance (self .transform , True ):
1024
- tensordict = self .transform .inv (tensordict )
1029
+ tensordict = self .transform ._reset_env_preprocess (tensordict )
1025
1030
tensordict_reset = self .base_env ._reset (tensordict , ** kwargs )
1026
1031
if tensordict is None :
1027
1032
# make sure all transforms see a source tensordict
@@ -1369,6 +1374,11 @@ def _reset(
1369
1374
tensordict_reset = t ._reset (tensordict , tensordict_reset )
1370
1375
return tensordict_reset
1371
1376
1377
+ def _reset_env_preprocess (self , tensordict : TensorDictBase ) -> TensorDictBase :
1378
+ for t in reversed (self .transforms ):
1379
+ tensordict = t ._reset_env_preprocess (tensordict )
1380
+ return tensordict
1381
+
1372
1382
def init (self , tensordict : TensorDictBase ) -> None :
1373
1383
for t in self .transforms :
1374
1384
t .init (tensordict )
@@ -4725,6 +4735,7 @@ class UnaryTransform(Transform):
4725
4735
[torchrl][INFO] check_env_specs succeeded!
4726
4736
4727
4737
"""
4738
+ enable_inv_on_reset = True
4728
4739
4729
4740
def __init__ (
4730
4741
self ,
0 commit comments