Skip to content

Commit 0188cea

Browse files
committed
fix: address comments
1 parent bcc75db commit 0188cea

File tree

5 files changed

+75
-31
lines changed

5 files changed

+75
-31
lines changed
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
22
from sbi.neural_nets.estimators.categorical_net import (
3+
CategoricalMADE,
34
CategoricalMassEstimator,
45
CategoricalNet,
5-
CategoricalMADE,
66
)
77
from sbi.neural_nets.estimators.flowmatching_estimator import FlowMatchingEstimator
8-
from sbi.neural_nets.estimators.mixed_density_estimator import (
9-
MixedDensityEstimator,
10-
)
8+
from sbi.neural_nets.estimators.mixed_density_estimator import MixedDensityEstimator
119
from sbi.neural_nets.estimators.nflows_flow import NFlowsFlow
1210
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
1311
from sbi.neural_nets.estimators.zuko_flow import ZukoFlow

sbi/neural_nets/estimators/categorical_net.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Optional
4+
from typing import Callable, Optional
55

66
import torch
77
from nflows.nn.nde.made import MADE
@@ -15,29 +15,46 @@
1515

1616

1717
class CategoricalMADE(MADE):
18+
"""Conditional density (mass) estimation for a n-dim categorical random variable.
19+
20+
Takes as input parameters theta and learns the parameters p of a Categorical.
21+
22+
Defines log prob and sample functions.
23+
"""
24+
1825
def __init__(
1926
self,
20-
categories, # Tensor[int]
21-
hidden_features,
22-
context_features=None,
23-
num_blocks=2,
24-
use_residual_blocks=True,
25-
random_mask=False,
26-
activation=F.relu,
27-
dropout_probability=0.0,
28-
use_batch_norm=False,
29-
epsilon=1e-2,
30-
custom_initialization=True,
27+
num_categories: Tensor, # Tensor[int]
28+
hidden_features: int,
29+
context_features: Optional[int] = None,
30+
num_blocks: int = 2,
31+
use_residual_blocks: bool = True,
32+
random_mask: bool = False,
33+
activation: Callable = F.relu,
34+
dropout_probability: float = 0.0,
35+
use_batch_norm: bool = False,
36+
epsilon: float = 1e-2,
37+
custom_initialization: bool = True,
3138
embedding_net: Optional[nn.Module] = nn.Identity(),
3239
):
40+
"""Initialize the neural net.
41+
42+
Args:
43+
num_categories: number of categories for each variable. len(categories)
44+
defines the number of input units, i.e., dimensionality of the features.
45+
max(categories) defines the number of output units, i.e., the largest
46+
number of categories.
47+
num_hidden: number of hidden units per layer.
48+
num_layers: number of hidden layers.
49+
embedding_net: emebedding net for input.
50+
"""
3351
if use_residual_blocks and random_mask:
3452
raise ValueError("Residual blocks can't be used with random masks.")
3553

36-
self.num_variables = len(categories)
37-
self.num_categories = int(max(categories))
38-
self.categories = categories
54+
self.num_variables = len(num_categories)
55+
self.num_categories = int(torch.max(num_categories))
3956
self.mask = torch.zeros(self.num_variables, self.num_categories)
40-
for i, c in enumerate(categories):
57+
for i, c in enumerate(num_categories):
4158
self.mask[i, :c] = 1
4259

4360
super().__init__(
@@ -60,7 +77,18 @@ def __init__(
6077
if custom_initialization:
6178
self._initialize()
6279

63-
def forward(self, inputs, context=None):
80+
def forward(self, inputs: Tensor, context: Optional[Tensor] = None) -> Tensor:
81+
r"""Forward pass of the categorical density estimator network to compute the
82+
conditional density at a given time.
83+
84+
Args:
85+
input: Original data, x0. (batch_size, *input_shape)
86+
condition: Conditioning variable. (batch_size, *condition_shape)
87+
88+
Returns:
89+
Predicted categorical probabilities. (batch_size, *input_shape,
90+
num_categories)
91+
"""
6492
embedded_context = self.embedding_net.forward(context)
6593
return super().forward(inputs, context=embedded_context)
6694

@@ -69,8 +97,16 @@ def compute_probs(self, outputs):
6997
ps = ps / ps.sum(dim=-1, keepdim=True)
7098
return ps
7199

72-
# outputs (batch_size, num_variables, num_categories)
73-
def log_prob(self, inputs, context=None):
100+
def log_prob(self, inputs: Tensor, context: Optional[Tensor] = None) -> Tensor:
101+
r"""Return log-probability of samples.
102+
103+
Args:
104+
input: Input datapoints of shape `(batch_size, *input_shape)`.
105+
context: Context of shape `(batch_size, *condition_shape)`.
106+
107+
Returns:
108+
Log-probabilities of shape `(batch_size, num_variables, num_categories)`.
109+
"""
74110
outputs = self.forward(inputs, context=context)
75111
outputs = outputs.reshape(*inputs.shape, self.num_categories)
76112
ps = self.compute_probs(outputs)

sbi/neural_nets/estimators/mixed_density_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
147147
f"{input_batch_dim} do not match."
148148
)
149149

150-
num_disc = self.discrete_net.net.num_variables
151-
cont_input, disc_input = _separate_input(input, num_discrete_columns=num_disc)
150+
num_discrete_variables = self.discrete_net.net.num_variables
151+
cont_input, disc_input = _separate_input(input, num_discrete_variables)
152152
# Embed continuous condition
153153
embedded_condition = self.condition_embedding(condition)
154154
# expand and repeat to match batch of inputs.

sbi/neural_nets/net_builders/categorial.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def build_autoregressive_categoricalmassestimator(
7373
z_score_y: Optional[str] = "independent",
7474
num_hidden: int = 20,
7575
num_layers: int = 2,
76-
categories: Optional[Tensor] = None,
76+
num_categories: Optional[Tensor] = None,
7777
embedding_net: nn.Module = nn.Identity(),
7878
):
7979
"""Returns a density estimator for a categorical random variable.
@@ -86,13 +86,14 @@ def build_autoregressive_categoricalmassestimator(
8686
num_hidden: Number of hidden units per layer.
8787
num_layers: Number of hidden layers.
8888
embedding_net: Embedding net for y.
89+
num_categories: number of categories for each variable.
8990
"""
9091

9192
if z_score_x != "none":
9293
raise ValueError("Categorical input should not be z-scored.")
93-
if categories is None:
94+
if num_categories is None:
9495
warnings.warn(
95-
"Inferring categories from batch_x. Ensure all categories are present.",
96+
"Inferring num_categories from batch_x. Ensure all categories are present.",
9697
stacklevel=2,
9798
)
9899

@@ -108,10 +109,12 @@ def build_autoregressive_categoricalmassestimator(
108109

109110
batch_x_discrete = batch_x[:, _is_discrete(batch_x)]
110111
inferred_categories = tensor([unique(col).numel() for col in batch_x_discrete.T])
111-
categories = categories if categories is not None else inferred_categories
112+
num_categories = (
113+
num_categories if num_categories is not None else inferred_categories
114+
)
112115

113116
categorical_net = CategoricalMADE(
114-
categories=categories,
117+
num_categories=num_categories,
115118
hidden_features=num_hidden,
116119
context_features=y_numel,
117120
num_blocks=num_layers,

sbi/neural_nets/net_builders/mnle.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def build_mnle(
6060
z_score_y: Optional[str] = "independent",
6161
flow_model: str = "nsf",
6262
categorical_model: str = "mlp",
63+
num_categorical_columns: Optional[Tensor] = None,
6364
embedding_net: nn.Module = nn.Identity(),
6465
combined_embedding_net: Optional[nn.Module] = None,
6566
num_transforms: int = 2,
@@ -108,6 +109,8 @@ def build_mnle(
108109
data.
109110
categorical_model: type of categorical net to use for the discrete part of
110111
the data. Can be "made" or "mlp".
112+
num_categorical_columns: Number of categorical columns of each variable in the
113+
input data. If None, the function will infer this from the data.
111114
embedding_net: Optional embedding network for y, required if y is > 1D.
112115
combined_embedding_net: Optional embedding for combining the discrete
113116
part of the input and the embedded condition into a joined
@@ -137,7 +140,10 @@ def build_mnle(
137140
stacklevel=2,
138141
)
139142
# Separate continuous and discrete data.
140-
num_disc = int(torch.sum(_is_discrete(batch_x)))
143+
if num_categorical_columns is None:
144+
num_disc = int(torch.sum(_is_discrete(batch_x)))
145+
else:
146+
num_disc = len(num_categorical_columns)
141147
cont_x, disc_x = _separate_input(batch_x, num_discrete_columns=num_disc)
142148

143149
# Set up y-embedding net with z-scoring.
@@ -160,6 +166,7 @@ def build_mnle(
160166
num_hidden=hidden_features,
161167
num_layers=hidden_layers,
162168
embedding_net=embedding_net,
169+
num_categories=num_categorical_columns,
163170
)
164171
elif categorical_model == "mlp":
165172
assert num_disc == 1, "MLP only supports 1D input."

0 commit comments

Comments
 (0)