Skip to content

Commit c851e16

Browse files
louisfauryLouis Faury
andauthored
[Feature] Adds ordinal distributions (#2520)
Co-authored-by: Louis Faury <louis.faury@helsing.ai>
1 parent d524d0d commit c851e16

File tree

5 files changed

+215
-10
lines changed

5 files changed

+215
-10
lines changed

docs/source/reference/modules.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,8 @@ Some distributions are typically used in RL scripts.
553553
OneHotCategorical
554554
MaskedCategorical
555555
MaskedOneHotCategorical
556+
Ordinal
557+
OneHotOrdinal
556558

557559
Utils
558560
-----

test/test_distributions.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from torchrl.modules import (
1818
NormalParamWrapper,
1919
OneHotCategorical,
20+
OneHotOrdinal,
21+
Ordinal,
2022
ReparamGradientStrategy,
2123
TanhNormal,
2224
TruncatedNormal,
@@ -28,6 +30,7 @@
2830
TanhDelta,
2931
)
3032
from torchrl.modules.distributions.continuous import SafeTanhTransform
33+
from torchrl.modules.distributions.discrete import _generate_ordinal_logits
3134

3235
if os.getenv("PYTORCH_TEST_FBCODE"):
3336
from pytorch.rl.test._utils_internal import get_default_devices
@@ -677,6 +680,125 @@ def test_reparam(self, grad_method, sparse):
677680
assert logits.grad is not None and logits.grad.norm() > 0
678681

679682

683+
class TestOrdinal:
684+
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
685+
@pytest.mark.parametrize("device", get_default_devices())
686+
@pytest.mark.parametrize("logit_shape", [(10,), (1, 1), (10, 10), (5, 10, 20)])
687+
def test_correct_sampling_shape(
688+
self, logit_shape: tuple[int, ...], dtype: torch.dtype, device: str
689+
) -> None:
690+
logits = torch.testing.make_tensor(logit_shape, dtype=dtype, device=device)
691+
692+
sampler = Ordinal(scores=logits)
693+
actions = sampler.sample() # type: ignore[no-untyped-call]
694+
log_probs = sampler.log_prob(actions) # type: ignore[no-untyped-call]
695+
696+
expected_log_prob_shape = logit_shape[:-1]
697+
expected_action_shape = logit_shape[:-1]
698+
699+
assert actions.size() == torch.Size(expected_action_shape)
700+
assert log_probs.size() == torch.Size(expected_log_prob_shape)
701+
702+
@pytest.mark.parametrize("num_categories", [1, 10, 20])
703+
def test_correct_range(self, num_categories: int) -> None:
704+
seq_size = 10
705+
batch_size = 100
706+
logits = torch.ones((batch_size, seq_size, num_categories))
707+
708+
sampler = Ordinal(scores=logits)
709+
710+
actions = sampler.sample() # type: ignore[no-untyped-call]
711+
712+
assert actions.min() >= 0
713+
assert actions.max() < num_categories
714+
715+
def test_bounded_gradients(self) -> None:
716+
logits = torch.tensor(
717+
[[1.0, 0.0, torch.finfo().max], [1.0, 0.0, torch.finfo().min]],
718+
requires_grad=True,
719+
dtype=torch.float32,
720+
)
721+
722+
sampler = Ordinal(scores=logits)
723+
724+
actions = sampler.sample()
725+
log_probs = sampler.log_prob(actions)
726+
727+
dummy_objective = log_probs.sum()
728+
dummy_objective.backward()
729+
730+
assert logits.grad is not None
731+
assert not torch.isnan(logits.grad).any()
732+
733+
def test_generate_ordinal_logits_numerical(self) -> None:
734+
logits = torch.ones((3, 4))
735+
736+
ordinal_logits = _generate_ordinal_logits(scores=logits)
737+
738+
expected_ordinal_logits = torch.tensor(
739+
[
740+
[-4.2530, -3.2530, -2.2530, -1.2530],
741+
[-4.2530, -3.2530, -2.2530, -1.2530],
742+
[-4.2530, -3.2530, -2.2530, -1.2530],
743+
]
744+
)
745+
746+
torch.testing.assert_close(
747+
ordinal_logits, expected_ordinal_logits, atol=1e-4, rtol=1e-6
748+
)
749+
750+
751+
class TestOneHotOrdinal:
752+
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
753+
@pytest.mark.parametrize("device", get_default_devices())
754+
@pytest.mark.parametrize("logit_shape", [(10,), (10, 10), (5, 10, 20)])
755+
def test_correct_sampling_shape(
756+
self, logit_shape: tuple[int, ...], dtype: torch.dtype, device: str
757+
) -> None:
758+
logits = torch.testing.make_tensor(logit_shape, dtype=dtype, device=device)
759+
760+
sampler = OneHotOrdinal(scores=logits)
761+
actions = sampler.sample() # type: ignore[no-untyped-call]
762+
log_probs = sampler.log_prob(actions) # type: ignore[no-untyped-call]
763+
expected_log_prob_shape = logit_shape[:-1]
764+
765+
expected_action_shape = logit_shape
766+
767+
assert actions.size() == torch.Size(expected_action_shape)
768+
assert log_probs.size() == torch.Size(expected_log_prob_shape)
769+
770+
@pytest.mark.parametrize("num_categories", [2, 10, 20])
771+
def test_correct_range(self, num_categories: int) -> None:
772+
seq_size = 10
773+
batch_size = 100
774+
logits = torch.ones((batch_size, seq_size, num_categories))
775+
776+
sampler = OneHotOrdinal(scores=logits)
777+
778+
actions = sampler.sample() # type: ignore[no-untyped-call]
779+
780+
assert torch.all(actions.sum(-1))
781+
assert actions.shape[-1] == num_categories
782+
783+
def test_bounded_gradients(self) -> None:
784+
logits = torch.tensor(
785+
[[1.0, 0.0, torch.finfo().max], [1.0, 0.0, torch.finfo().min]],
786+
requires_grad=True,
787+
dtype=torch.float32,
788+
)
789+
790+
sampler = OneHotOrdinal(scores=logits)
791+
792+
actions = sampler.sample()
793+
log_probs = sampler.log_prob(actions)
794+
795+
dummy_objective = log_probs.sum()
796+
dummy_objective.backward()
797+
798+
assert logits.grad is not None
799+
assert not torch.isnan(logits.grad).any()
800+
801+
680802
if __name__ == "__main__":
681803
args, unknown = argparse.ArgumentParser().parse_known_args()
682804
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
NormalParamExtractor,
1515
NormalParamWrapper,
1616
OneHotCategorical,
17+
OneHotOrdinal,
18+
Ordinal,
1719
ReparamGradientStrategy,
1820
TanhDelta,
1921
TanhNormal,

torchrl/modules/distributions/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
MaskedCategorical,
1818
MaskedOneHotCategorical,
1919
OneHotCategorical,
20+
OneHotOrdinal,
21+
Ordinal,
2022
ReparamGradientStrategy,
2123
)
2224

@@ -31,5 +33,7 @@
3133
MaskedCategorical,
3234
MaskedOneHotCategorical,
3335
OneHotCategorical,
36+
Ordinal,
37+
OneHotOrdinal,
3438
)
3539
}

torchrl/modules/distributions/discrete.py

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99

1010
import torch
1111
import torch.distributions as D
12+
import torch.nn.functional as F
1213

13-
__all__ = [
14-
"OneHotCategorical",
15-
"MaskedCategorical",
16-
]
14+
__all__ = ["OneHotCategorical", "MaskedCategorical", "Ordinal", "OneHotOrdinal"]
1715

1816

1917
def _treat_categorical_params(
@@ -56,7 +54,7 @@ class ReparamGradientStrategy(Enum):
5654
class OneHotCategorical(D.Categorical):
5755
"""One-hot categorical distribution.
5856
59-
This class behaves excacly as torch.distributions.Categorical except that it reads and produces one-hot encodings
57+
This class behaves exactly as torch.distributions.Categorical except that it reads and produces one-hot encodings
6058
of the discrete tensors.
6159
6260
Args:
@@ -66,7 +64,7 @@ class OneHotCategorical(D.Categorical):
6664
reparameterized samples.
6765
``ReparamGradientStrategy.PassThrough`` will compute the sample gradients
6866
by using the softmax valued log-probability as a proxy to the
69-
samples gradients.
67+
sample gradients.
7068
``ReparamGradientStrategy.RelaxedOneHot`` will use
7169
:class:`torch.distributions.RelaxedOneHot` to sample from the distribution.
7270
@@ -81,8 +79,6 @@ class OneHotCategorical(D.Categorical):
8179
8280
"""
8381

84-
num_params: int = 1
85-
8682
def __init__(
8783
self,
8884
logits: Optional[torch.Tensor] = None,
@@ -155,7 +151,7 @@ class MaskedCategorical(D.Categorical):
155151
Args:
156152
logits (torch.Tensor): event log probabilities (unnormalized)
157153
probs (torch.Tensor): event probabilities. If provided, the probabilities
158-
corresponding to to masked items will be zeroed and the probability
154+
corresponding to masked items will be zeroed and the probability
159155
re-normalized along its last dimension.
160156
161157
Keyword Args:
@@ -306,7 +302,7 @@ class MaskedOneHotCategorical(MaskedCategorical):
306302
Args:
307303
logits (torch.Tensor): event log probabilities (unnormalized)
308304
probs (torch.Tensor): event probabilities. If provided, the probabilities
309-
corresponding to to masked items will be zeroed and the probability
305+
corresponding to masked items will be zeroed and the probability
310306
re-normalized along its last dimension.
311307
312308
Keyword Args:
@@ -469,3 +465,82 @@ def rsample(self, sample_shape: Union[torch.Size, Sequence] = None) -> torch.Ten
469465
raise ValueError(
470466
f"Unknown reparametrization strategy {self.reparam_strategy}."
471467
)
468+
469+
470+
class Ordinal(D.Categorical):
471+
"""A discrete distribution for learning to sample from finite ordered sets.
472+
473+
It is defined in contrast with the `Categorical` distribution, which does
474+
not impose any notion of proximity or ordering over its support's atoms.
475+
The `Ordinal` distribution explicitly encodes those concepts, which is
476+
useful for learning discrete sampling from continuous sets. See §5 of
477+
`Tang & Agrawal, 2020<https://arxiv.org/pdf/1901.10500.pdf>`_ for details.
478+
479+
.. note::
480+
This class is mostly useful when you want to learn a distribution over
481+
a finite set which is obtained by discretising a continuous set.
482+
483+
Args:
484+
scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions.
485+
Typically, the output of a neural network parametrising the distribution.
486+
487+
Examples:
488+
>>> num_atoms, num_samples = 5, 20
489+
>>> mean = (num_atoms - 1) / 2 # Target mean for samples, centered around the middle atom
490+
>>> torch.manual_seed(42)
491+
>>> logits = torch.ones((num_atoms), requires_grad=True)
492+
>>> optimizer = torch.optim.Adam([logits], lr=0.1)
493+
>>>
494+
>>> # Perform optimisation loop to minimise deviation from `mean`
495+
>>> for _ in range(20):
496+
>>> sampler = Ordinal(scores=logits)
497+
>>> samples = sampler.sample((num_samples,))
498+
>>> # Define loss to encourage samples around the mean by penalising deviation from mean
499+
>>> loss = torch.mean((samples - mean) ** 2 * sampler.log_prob(samples))
500+
>>> loss.backward()
501+
>>> optimizer.step()
502+
>>> optimizer.zero_grad()
503+
>>>
504+
>>> sampler.probs
505+
tensor([0.0308, 0.1586, 0.4727, 0.2260, 0.1120], ...)
506+
>>> # Print histogram to observe sample distribution frequency across 5 bins (0, 1, 2, 3, and 4)
507+
>>> torch.histogram(sampler.sample((1000,)).reshape(-1).float(), bins=num_atoms)
508+
torch.return_types.histogram(
509+
hist=tensor([ 24., 158., 478., 228., 112.]),
510+
bin_edges=tensor([0.0000, 0.8000, 1.6000, 2.4000, 3.2000, 4.0000]))
511+
"""
512+
513+
def __init__(self, scores: torch.Tensor):
514+
logits = _generate_ordinal_logits(scores)
515+
super().__init__(logits=logits)
516+
517+
518+
class OneHotOrdinal(OneHotCategorical):
519+
"""The one-hot version of the :class:`~tensordict.nn.distributions.Ordinal` distribution.
520+
521+
Args:
522+
scores (torch.Tensor): a tensor of shape [..., N] where N is the size of the set which supports the distributions.
523+
Typically, the output of a neural network parametrising the distribution.
524+
"""
525+
526+
def __init__(self, scores: torch.Tensor):
527+
logits = _generate_ordinal_logits(scores)
528+
super().__init__(logits=logits)
529+
530+
531+
def _generate_ordinal_logits(scores: torch.Tensor) -> torch.Tensor:
532+
"""Implements Eq. 4 of `Tang & Agrawal, 2020<https://arxiv.org/pdf/1901.10500.pdf>`__."""
533+
# Assigns Bernoulli-like probabilities for each class in the set
534+
log_probs = F.logsigmoid(scores)
535+
complementary_log_probs = F.logsigmoid(-scores)
536+
537+
# Total log-probability for being "larger than k"
538+
larger_than_log_probs = log_probs.cumsum(dim=-1)
539+
540+
# Total log-probability for being "smaller than k"
541+
smaller_than_log_probs = (
542+
complementary_log_probs.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
543+
- complementary_log_probs
544+
)
545+
546+
return larger_than_log_probs + smaller_than_log_probs

0 commit comments

Comments
 (0)