Skip to content

Commit 36545af

Browse files
author
Vincent Moens
committed
[BugFix] compatibility to new Composite dist log_prob/entropy APIs
ghstack-source-id: a09b6c3 Pull Request resolved: #2435
1 parent d40fa4f commit 36545af

File tree

3 files changed

+20
-5
lines changed

3 files changed

+20
-5
lines changed

test/test_cost.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7565,6 +7565,7 @@ def _create_mock_actor(
75657565
"action1": (action_key, "action1"),
75667566
},
75677567
log_prob_key=sample_log_prob_key,
7568+
aggregate_probabilities=True,
75687569
)
75697570
module_out_keys = [
75707571
("params", "action1", "loc"),
@@ -7634,6 +7635,7 @@ def _create_mock_actor_value(
76347635
"action1": ("action", "action1"),
76357636
},
76367637
log_prob_key=sample_log_prob_key,
7638+
aggregate_probabilities=True,
76377639
)
76387640
module_out_keys = [
76397641
("params", "action1", "loc"),
@@ -7690,6 +7692,7 @@ def _create_mock_actor_value_shared(
76907692
"action1": ("action", "action1"),
76917693
},
76927694
log_prob_key=sample_log_prob_key,
7695+
aggregate_probabilities=True,
76937696
)
76947697
module_out_keys = [
76957698
("params", "action1", "loc"),
@@ -8627,6 +8630,7 @@ def _create_mock_actor(
86278630
"action1": (action_key, "action1"),
86288631
},
86298632
log_prob_key=sample_log_prob_key,
8633+
aggregate_probabilities=True,
86308634
)
86318635
module_out_keys = [
86328636
("params", "action1", "loc"),
@@ -8727,6 +8731,7 @@ def _create_mock_common_layer_setup(
87278731
"action1": ("action", "action1"),
87288732
},
87298733
log_prob_key=sample_log_prob_key,
8734+
aggregate_probabilities=True,
87308735
)
87318736
module_out_keys = [
87328737
("params", "action1", "loc"),

torchrl/objectives/a2c.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,8 +420,13 @@ def _log_probs(
420420
if isinstance(action, torch.Tensor):
421421
log_prob = dist.log_prob(action)
422422
else:
423-
tensordict = dist.log_prob(tensordict)
424-
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
423+
maybe_log_prob = dist.log_prob(tensordict)
424+
if not isinstance(maybe_log_prob, torch.Tensor):
425+
# In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not
426+
# be a tensor
427+
log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob)
428+
else:
429+
log_prob = maybe_log_prob
425430
log_prob = log_prob.unsqueeze(-1)
426431
return log_prob, dist
427432

torchrl/objectives/ppo.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -490,8 +490,13 @@ def _log_weight(
490490
if isinstance(action, torch.Tensor):
491491
log_prob = dist.log_prob(action)
492492
else:
493-
tensordict = dist.log_prob(tensordict)
494-
log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
493+
maybe_log_prob = dist.log_prob(tensordict)
494+
if not isinstance(maybe_log_prob, torch.Tensor):
495+
# In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not
496+
# be a tensor
497+
log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob)
498+
else:
499+
log_prob = maybe_log_prob
495500

496501
log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
497502
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
@@ -1130,7 +1135,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
11301135
x = previous_dist.sample((self.samples_mc_kl,))
11311136
previous_log_prob = previous_dist.log_prob(x)
11321137
current_log_prob = current_dist.log_prob(x)
1133-
if is_tensor_collection(x):
1138+
if is_tensor_collection(current_log_prob):
11341139
previous_log_prob = previous_log_prob.get(
11351140
self.tensor_keys.sample_log_prob
11361141
)

0 commit comments

Comments
 (0)