Skip to content

Commit d6dc444

Browse files
committed
wip: CategoricalMassEstimator can be build and MixedDensityEstimator too. log_prob has shape issues tho
1 parent 40c657a commit d6dc444

File tree

3 files changed

+240
-8
lines changed

3 files changed

+240
-8
lines changed

sbi/made_mnle.ipynb

Lines changed: 181 additions & 0 deletions
Large diffs are not rendered by default.

sbi/neural_nets/estimators/categorical_net.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,16 @@ def __init__(
2929
use_batch_norm=False,
3030
epsilon=1e-2,
3131
custom_initialization=True,
32+
#TODO: embedding_net: Optional[nn.Module] = None,
3233
):
3334

3435
if use_residual_blocks and random_mask:
3536
raise ValueError("Residual blocks can't be used with random masks.")
3637

3738
self.num_variables = len(categories)
38-
self.max_categories = max(categories)
39+
self.num_categories = int(max(categories))
3940
self.categories = categories
40-
self.mask = torch.zeros(self.num_variables, self.max_categories)
41+
self.mask = torch.zeros(self.num_variables, self.num_categories)
4142
for i, c in enumerate(categories):
4243
self.mask[i, :c] = 1
4344

@@ -46,7 +47,7 @@ def __init__(
4647
hidden_features,
4748
context_features=context_features,
4849
num_blocks=num_blocks,
49-
output_multiplier=self.max_categories,
50+
output_multiplier=self.num_categories,
5051
use_residual_blocks=use_residual_blocks,
5152
random_mask=random_mask,
5253
activation=activation,
@@ -68,10 +69,10 @@ def compute_probs(self, outputs):
6869
ps = ps / ps.sum(dim=-1, keepdim=True)
6970
return ps
7071

71-
# outputs (batch_size, num_variables, max_categories)
72+
# outputs (batch_size, num_variables, num_categories)
7273
def log_prob(self, inputs, context=None):
7374
outputs = self.forward(inputs, context=context)
74-
outputs = outputs.reshape(*inputs.shape, self.max_categories)
75+
outputs = outputs.reshape(*inputs.shape, self.num_categories)
7576
ps = self.compute_probs(outputs)
7677

7778
# categorical log prob
@@ -91,7 +92,7 @@ def sample(self, num_samples, context=None):
9192

9293
for variable in range(self.num_variables):
9394
outputs = self.forward(samples, context)
94-
outputs = outputs.reshape(*samples.shape, self.max_categories)
95+
outputs = outputs.reshape(*samples.shape, self.num_categories)
9596
ps = self.compute_probs(outputs)
9697
samples[:, variable] = Categorical(probs=ps[:,variable]).sample()
9798

sbi/neural_nets/net_builders/categorial.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
from typing import Optional
55

6-
from torch import Tensor, nn, unique
6+
from torch import Tensor, nn, unique, tensor
77

8-
from sbi.neural_nets.estimators import CategoricalMassEstimator, CategoricalNet
8+
from sbi.neural_nets.estimators import CategoricalMassEstimator, CategoricalNet, CategoricalMADE
99
from sbi.utils.nn_utils import get_numel
1010
from sbi.utils.sbiutils import (
1111
standardizing_net,
@@ -61,3 +61,53 @@ def build_categoricalmassestimator(
6161
return CategoricalMassEstimator(
6262
categorical_net, input_shape=batch_x[0].shape, condition_shape=batch_y[0].shape
6363
)
64+
65+
66+
def build_autoregressive_categoricalmassestimator(
67+
batch_x: Tensor,
68+
batch_y: Tensor,
69+
z_score_x: Optional[str] = "none",
70+
z_score_y: Optional[str] = "independent",
71+
num_hidden: int = 20,
72+
num_layers: int = 2,
73+
embedding_net: nn.Module = nn.Identity(),
74+
):
75+
"""Returns a density estimator for a categorical random variable.
76+
77+
Args:
78+
batch_x: A batch of input data.
79+
batch_y: A batch of condition data.
80+
z_score_x: Whether to z-score the input data.
81+
z_score_y: Whether to z-score the condition data.
82+
num_hidden: Number of hidden units per layer.
83+
num_layers: Number of hidden layers.
84+
embedding_net: Embedding net for y.
85+
"""
86+
87+
if z_score_x != "none":
88+
raise ValueError("Categorical input should not be z-scored.")
89+
90+
check_data_device(batch_x, batch_y)
91+
92+
z_score_y_bool, structured_y = z_score_parser(z_score_y)
93+
y_numel = get_numel(batch_y, embedding_net=embedding_net)
94+
95+
if z_score_y_bool:
96+
embedding_net = nn.Sequential(
97+
standardizing_net(batch_y, structured_y), embedding_net
98+
)
99+
100+
101+
categories = tensor([unique(variable).numel() for variable in batch_x.T])
102+
103+
categorical_net = CategoricalMADE(
104+
categories=categories,
105+
context_features=y_numel,
106+
hidden_features=num_hidden,
107+
num_blocks=num_layers,
108+
#TODO: embedding_net=embedding_net,
109+
)
110+
111+
return CategoricalMassEstimator(
112+
categorical_net, input_shape=batch_x[0].shape, condition_shape=batch_y[0].shape
113+
)

0 commit comments

Comments
 (0)