Skip to content

Commit 38776a1

Browse files
author
Vincent Moens
committed
Update (base update)
[ghstack-poisoned]
2 parents dd52226 + 36545af commit 38776a1

File tree

6 files changed

+129
-9
lines changed

6 files changed

+129
-9
lines changed

test/test_cost.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7565,6 +7565,7 @@ def _create_mock_actor(
75657565
"action1": (action_key, "action1"),
75667566
},
75677567
log_prob_key=sample_log_prob_key,
7568+
aggregate_probabilities=True,
75687569
)
75697570
module_out_keys = [
75707571
("params", "action1", "loc"),
@@ -7634,6 +7635,7 @@ def _create_mock_actor_value(
76347635
"action1": ("action", "action1"),
76357636
},
76367637
log_prob_key=sample_log_prob_key,
7638+
aggregate_probabilities=True,
76377639
)
76387640
module_out_keys = [
76397641
("params", "action1", "loc"),
@@ -7690,6 +7692,7 @@ def _create_mock_actor_value_shared(
76907692
"action1": ("action", "action1"),
76917693
},
76927694
log_prob_key=sample_log_prob_key,
7695+
aggregate_probabilities=True,
76937696
)
76947697
module_out_keys = [
76957698
("params", "action1", "loc"),
@@ -8627,6 +8630,7 @@ def _create_mock_actor(
86278630
"action1": (action_key, "action1"),
86288631
},
86298632
log_prob_key=sample_log_prob_key,
8633+
aggregate_probabilities=True,
86308634
)
86318635
module_out_keys = [
86328636
("params", "action1", "loc"),
@@ -8727,6 +8731,7 @@ def _create_mock_common_layer_setup(
87278731
"action1": ("action", "action1"),
87288732
},
87298733
log_prob_key=sample_log_prob_key,
8734+
aggregate_probabilities=True,
87308735
)
87318736
module_out_keys = [
87328737
("params", "action1", "loc"),
@@ -15277,7 +15282,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1527715282
class MyLoss3(MyLoss2):
1527815283
@dataclass
1527915284
class _AcceptedKeys:
15280-
some_key = "some_value"
15285+
some_key: str = "some_value"
1528115286

1528215287
loss_module = MyLoss3()
1528315288
assert loss_module.tensor_keys.some_key == "some_value"
@@ -15639,6 +15644,67 @@ def __init__(self):
1563915644
assert p.device == dest
1564015645

1564115646

15647+
def test_exploration_compile():
15648+
m = ProbabilisticTensorDictModule(
15649+
in_keys=["loc", "scale"],
15650+
out_keys=["sample"],
15651+
distribution_class=torch.distributions.Normal,
15652+
)
15653+
15654+
# class set_exploration_type_random(set_exploration_type):
15655+
# __init__ = object.__init__
15656+
# type = ExplorationType.RANDOM
15657+
it = exploration_type()
15658+
15659+
@torch.compile(fullgraph=True)
15660+
def func(t):
15661+
with set_exploration_type(ExplorationType.RANDOM):
15662+
t0 = m(t.clone())
15663+
t1 = m(t.clone())
15664+
return t0, t1
15665+
15666+
t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
15667+
t0, t1 = func(t)
15668+
assert (t0["sample"] != t1["sample"]).any()
15669+
assert it == exploration_type()
15670+
15671+
@torch.compile(fullgraph=True)
15672+
def func(t):
15673+
with set_exploration_type(ExplorationType.MEAN):
15674+
t0 = m(t.clone())
15675+
t1 = m(t.clone())
15676+
return t0, t1
15677+
15678+
t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
15679+
t0, t1 = func(t)
15680+
assert (t0["sample"] == t1["sample"]).all()
15681+
assert it == exploration_type()
15682+
15683+
@torch.compile(fullgraph=True)
15684+
@set_exploration_type(ExplorationType.RANDOM)
15685+
def func(t):
15686+
t0 = m(t.clone())
15687+
t1 = m(t.clone())
15688+
return t0, t1
15689+
15690+
t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
15691+
t0, t1 = func(t)
15692+
assert (t0["sample"] != t1["sample"]).any()
15693+
assert it == exploration_type()
15694+
15695+
@torch.compile(fullgraph=True)
15696+
@set_exploration_type(ExplorationType.MEAN)
15697+
def func(t):
15698+
t0 = m(t.clone())
15699+
t1 = m(t.clone())
15700+
return t0, t1
15701+
15702+
t = TensorDict(loc=torch.randn(3), scale=torch.rand(3))
15703+
t0, t1 = func(t)
15704+
assert (t0["sample"] == t1["sample"]).all()
15705+
assert it == exploration_type()
15706+
15707+
1564215708
def test_loss_exploration():
1564315709
class DummyLoss(LossModule):
1564415710
def forward(self, td, mode):

torchrl/__init__.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import os
6+
import weakref
67
from warnings import warn
78

89
import torch
910

1011
from tensordict import set_lazy_legacy
1112

1213
from torch import multiprocessing as mp
14+
from torch.distributions.transforms import _InverseTransform, ComposeTransform
1315

1416
set_lazy_legacy(False).set()
1517

@@ -51,3 +53,42 @@
5153
filter_warnings_subprocess = True
5254

5355
_THREAD_POOL_INIT = torch.get_num_threads()
56+
57+
# monkey-patch dist transforms until https://github.com/pytorch/pytorch/pull/135001/ finds a home
58+
@property
59+
def inv(self):
60+
"""
61+
Returns the inverse :class:`Transform` of this transform.
62+
This should satisfy ``t.inv.inv is t``.
63+
"""
64+
inv = None
65+
if self._inv is not None:
66+
inv = self._inv()
67+
if inv is None:
68+
inv = _InverseTransform(self)
69+
if not torch.compiler.is_dynamo_compiling():
70+
self._inv = weakref.ref(inv)
71+
return inv
72+
73+
74+
torch.distributions.transforms.Transform.inv = inv
75+
76+
77+
@property
78+
def inv(self):
79+
inv = None
80+
if self._inv is not None:
81+
inv = self._inv()
82+
if inv is None:
83+
inv = ComposeTransform([p.inv for p in reversed(self.parts)])
84+
if not torch.compiler.is_dynamo_compiling():
85+
self._inv = weakref.ref(inv)
86+
inv._inv = weakref.ref(self)
87+
else:
88+
# We need inv.inv to be equal to self, but weakref can cause a graph break
89+
inv._inv = lambda out=self: out
90+
91+
return inv
92+
93+
94+
ComposeTransform.inv = inv

torchrl/objectives/a2c.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,13 @@ def _log_probs(
427427
if isinstance(action, torch.Tensor):
428428
log_prob = dist.log_prob(action)
429429
else:
430-
tensordict = dist.log_prob(tensordict)
431-
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
430+
maybe_log_prob = dist.log_prob(tensordict)
431+
if not isinstance(maybe_log_prob, torch.Tensor):
432+
# In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not
433+
# be a tensor
434+
log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob)
435+
else:
436+
log_prob = maybe_log_prob
432437
log_prob = log_prob.unsqueeze(-1)
433438
return log_prob, dist
434439

torchrl/objectives/common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def _forward_wrapper(func):
4646
@functools.wraps(func)
4747
def new_forward(self, *args, **kwargs):
4848
with set_exploration_type(self.deterministic_sampling_mode):
49-
# with nullcontext():
5049
return func(self, *args, **kwargs)
5150

5251
return new_forward
@@ -55,7 +54,7 @@ def new_forward(self, *args, **kwargs):
5554
class _LossMeta(abc.ABCMeta):
5655
def __init__(cls, name, bases, attr_dict):
5756
super().__init__(name, bases, attr_dict)
58-
# cls.forward = _forward_wrapper(cls.forward)
57+
cls.forward = _forward_wrapper(cls.forward)
5958

6059

6160
class LossModule(TensorDictModuleBase, metaclass=_LossMeta):
@@ -229,7 +228,9 @@ def set_keys(self, **kwargs) -> None:
229228
"""
230229
for key, value in kwargs.items():
231230
if key not in self._AcceptedKeys.__dataclass_fields__:
232-
raise ValueError(f"{key} is not an accepted tensordict key")
231+
raise ValueError(
232+
f"{key} is not an accepted tensordict key. Accepted keys are: {self._AcceptedKeys.__dataclass_fields__}."
233+
)
233234
if value is not None:
234235
setattr(self.tensor_keys, key, value)
235236
else:

torchrl/objectives/ppo.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -495,8 +495,13 @@ def _log_weight(
495495
if isinstance(action, torch.Tensor):
496496
log_prob = dist.log_prob(action)
497497
else:
498-
tensordict = dist.log_prob(tensordict)
499-
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
498+
maybe_log_prob = dist.log_prob(tensordict)
499+
if not isinstance(maybe_log_prob, torch.Tensor):
500+
# In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not
501+
# be a tensor
502+
log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob)
503+
else:
504+
log_prob = maybe_log_prob
500505

501506
log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
502507
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
@@ -1144,7 +1149,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
11441149
x = previous_dist.sample((self.samples_mc_kl,))
11451150
previous_log_prob = previous_dist.log_prob(x)
11461151
current_log_prob = current_dist.log_prob(x)
1147-
if is_tensor_collection(x):
1152+
if is_tensor_collection(current_log_prob):
11481153
previous_log_prob = previous_log_prob.get(
11491154
self.tensor_keys.sample_log_prob
11501155
)

torchrl/objectives/redq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
453453
tensordict_select = tensordict.select(
454454
"next", *obs_keys, self.tensor_keys.action, strict=False
455455
)
456+
# We need to copy bc select does not copy sub-tds
457+
tensordict_select = tensordict_select.copy()
456458

457459
selected_models_idx = torch.randperm(self.num_qvalue_nets)[
458460
: self.sub_sample_len

0 commit comments

Comments
 (0)