Skip to content

Commit 76aa9bc

Browse files
kurtamohlerVincent Moens
authored andcommitted
[BugFix] Fix MultiAction reset
ghstack-source-id: a2f7bfd Pull Request resolved: #2789
1 parent 03d6586 commit 76aa9bc

File tree

2 files changed

+35
-24
lines changed

2 files changed

+35
-24
lines changed

torchrl/data/map/query.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -86,26 +86,26 @@ class QueryModule(TensorDictModuleBase):
8686
providing the ``clone`` argument to the forward method.
8787
Defaults to ``False``.
8888
89-
Examples:
90-
>>> query_module = QueryModule(
91-
... in_keys=["key1", "key2"],
92-
... index_key="index",
93-
... hash_module=SipHash(),
94-
... )
95-
>>> query = TensorDict(
96-
... {
97-
... "key1": torch.Tensor([[1], [1], [1], [2]]),
98-
... "key2": torch.Tensor([[3], [3], [2], [3]]),
99-
... "other": torch.randn(4),
100-
... },
101-
... batch_size=(4,),
102-
... )
103-
>>> res = query_module(query)
104-
>>> # The first two pairs of key1 and key2 match
105-
>>> assert res["index"][0] == res["index"][1]
106-
>>> # The last three pairs of key1 and key2 have at least one mismatching value
107-
>>> assert res["index"][1] != res["index"][2]
108-
>>> assert res["index"][2] != res["index"][3]
89+
Examples:
90+
>>> query_module = QueryModule(
91+
... in_keys=["key1", "key2"],
92+
... index_key="index",
93+
... hash_module=SipHash(),
94+
... )
95+
>>> query = TensorDict(
96+
... {
97+
... "key1": torch.Tensor([[1], [1], [1], [2]]),
98+
... "key2": torch.Tensor([[3], [3], [2], [3]]),
99+
... "other": torch.randn(4),
100+
... },
101+
... batch_size=(4,),
102+
... )
103+
>>> res = query_module(query)
104+
>>> # The first two pairs of key1 and key2 match
105+
>>> assert res["index"][0] == res["index"][1]
106+
>>> # The last three pairs of key1 and key2 have at least one mismatching value
107+
>>> assert res["index"][1] != res["index"][2]
108+
>>> assert res["index"][2] != res["index"][3]
109109
110110
"""
111111

torchrl/envs/transforms/transforms.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ class Transform(nn.Module):
209209
"""
210210

211211
invertible = False
212+
enable_inv_on_reset = False
212213

213214
def __init__(
214215
self,
@@ -293,6 +294,13 @@ def _reset(
293294
"""Resets a transform if it is stateful."""
294295
return tensordict_reset
295296

297+
def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
298+
"""Inverts the input to :meth:`TransformedEnv._reset`, if needed."""
299+
if self.enable_inv_on_reset:
300+
with _set_missing_tolerance(self, True):
301+
tensordict = self.inv(tensordict)
302+
return tensordict
303+
296304
def init(self, tensordict) -> None:
297305
pass
298306

@@ -1018,10 +1026,7 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
10181026
tensordict = tensordict.select(
10191027
*self.reset_keys, *self.state_spec.keys(True, True), strict=False
10201028
)
1021-
# Inputs might be transformed, so need to apply inverse transform
1022-
# before passing to the env reset function.
1023-
with _set_missing_tolerance(self.transform, True):
1024-
tensordict = self.transform.inv(tensordict)
1029+
tensordict = self.transform._reset_env_preprocess(tensordict)
10251030
tensordict_reset = self.base_env._reset(tensordict, **kwargs)
10261031
if tensordict is None:
10271032
# make sure all transforms see a source tensordict
@@ -1369,6 +1374,11 @@ def _reset(
13691374
tensordict_reset = t._reset(tensordict, tensordict_reset)
13701375
return tensordict_reset
13711376

1377+
def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
1378+
for t in reversed(self.transforms):
1379+
tensordict = t._reset_env_preprocess(tensordict)
1380+
return tensordict
1381+
13721382
def init(self, tensordict: TensorDictBase) -> None:
13731383
for t in self.transforms:
13741384
t.init(tensordict)
@@ -4725,6 +4735,7 @@ class UnaryTransform(Transform):
47254735
[torchrl][INFO] check_env_specs succeeded!
47264736
47274737
"""
4738+
enable_inv_on_reset = True
47284739

47294740
def __init__(
47304741
self,

0 commit comments

Comments
 (0)