Skip to content

Commit 68ee1a7

Browse files
authored
fix: temporary wrappers to fix MADE (#1398)
* add temporary MADE wrappers * test MADEMoG
1 parent 16436e6 commit 68ee1a7

File tree

4 files changed

+137
-3
lines changed

4 files changed

+137
-3
lines changed

sbi/neural_nets/net_builders/flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torch import Tensor, nn, relu, tanh, tensor, uint8
1616

1717
from sbi.neural_nets.estimators import NFlowsFlow, ZukoFlow
18-
from sbi.utils.nn_utils import get_numel
18+
from sbi.utils.nn_utils import MADEMoGWrapper, get_numel
1919
from sbi.utils.sbiutils import (
2020
standardizing_net,
2121
standardizing_transform,
@@ -77,7 +77,7 @@ def build_made(
7777
standardizing_net(batch_y, structured_y), embedding_net
7878
)
7979

80-
distribution = distributions_.MADEMoG(
80+
distribution = MADEMoGWrapper(
8181
features=x_numel,
8282
hidden_features=hidden_features,
8383
context_features=y_numel,

sbi/utils/nn_utils.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
from typing import Optional
33
from warnings import warn
44

5+
import nflows.nn.nde.made as made
6+
import numpy as np
7+
import torch
8+
import torch.nn.functional as F
9+
from pyknos.nflows import distributions as distributions_
510
from torch import Tensor, nn
611

712

@@ -62,3 +67,130 @@ def check_net_device(
6267
return net.to(device)
6368
else:
6469
return net
70+
71+
72+
"""
73+
Temporary Patches to fix nflows MADE bug. Remove once upstream bug is fixed.
74+
"""
75+
76+
77+
class MADEWrapper(made.MADE):
78+
"""Implementation of MADE.
79+
80+
It can use either feedforward blocks or residual blocks (default is residual).
81+
Optionally, it can use batch norm or dropout within blocks (default is no).
82+
"""
83+
84+
def __init__(
85+
self,
86+
features,
87+
hidden_features,
88+
context_features=None,
89+
num_blocks=2,
90+
output_multiplier=1,
91+
use_residual_blocks=True,
92+
random_mask=False,
93+
activation=F.relu,
94+
dropout_probability=0.0,
95+
use_batch_norm=False,
96+
):
97+
if use_residual_blocks and random_mask:
98+
raise ValueError("Residual blocks can't be used with random masks.")
99+
super().__init__(
100+
features + 1,
101+
hidden_features,
102+
context_features,
103+
num_blocks,
104+
output_multiplier,
105+
use_residual_blocks,
106+
random_mask,
107+
activation,
108+
dropout_probability,
109+
use_batch_norm,
110+
)
111+
112+
def forward(self, inputs, context=None):
113+
# add dummy input to ensure all dims conditioned on context.
114+
dummy_input = torch.zeros((inputs.shape[:-1] + (1,)))
115+
concat_input = torch.cat((dummy_input, inputs), dim=-1)
116+
outputs = super().forward(concat_input, context)
117+
# the final layer of MADE produces self.output_multiplier outputs for each
118+
# input dimension, in order. We only want the outputs corresponding to the
119+
# real inputs, so we discard the first self.output_multiplier outputs.
120+
return outputs[..., self.output_multiplier :]
121+
122+
123+
"""
124+
Temporary Patches to fix nflows MADE bug. Remove once upstream bug is fixed.
125+
"""
126+
127+
128+
class MADEMoGWrapper(distributions_.MADEMoG):
129+
def __init__(
130+
self,
131+
features,
132+
hidden_features,
133+
context_features,
134+
num_blocks=2,
135+
num_mixture_components=1,
136+
use_residual_blocks=True,
137+
random_mask=False,
138+
activation=F.relu,
139+
dropout_probability=0.0,
140+
use_batch_norm=False,
141+
custom_initialization=False,
142+
):
143+
super().__init__(
144+
features + 1,
145+
hidden_features,
146+
context_features,
147+
num_blocks,
148+
num_mixture_components,
149+
use_residual_blocks,
150+
random_mask,
151+
activation,
152+
dropout_probability,
153+
use_batch_norm,
154+
custom_initialization,
155+
)
156+
157+
def _log_prob(self, inputs, context=None):
158+
dummy_input = torch.zeros((inputs.shape[:-1] + (1,)))
159+
concat_inputs = torch.cat((dummy_input, inputs), dim=-1)
160+
161+
outputs = self._made.forward(concat_inputs, context=context)
162+
outputs = outputs.reshape(
163+
*concat_inputs.shape, self._made.num_mixture_components, 3
164+
)
165+
166+
logits, means, unconstrained_stds = (
167+
outputs[..., 0],
168+
outputs[..., 1],
169+
outputs[..., 2],
170+
)
171+
# remove first dimension of means, unconstrained_stds
172+
logits = logits[..., 1:, :]
173+
means = means[..., 1:, :]
174+
unconstrained_stds = unconstrained_stds[..., 1:, :]
175+
176+
log_mixture_coefficients = torch.log_softmax(logits, dim=-1)
177+
stds = F.softplus(unconstrained_stds) + self._made.epsilon
178+
179+
log_prob = torch.sum(
180+
torch.logsumexp(
181+
log_mixture_coefficients
182+
- 0.5
183+
* (
184+
np.log(2 * np.pi)
185+
+ 2 * torch.log(stds)
186+
+ ((inputs[..., None] - means) / stds) ** 2
187+
),
188+
dim=-1,
189+
),
190+
dim=-1,
191+
)
192+
return log_prob
193+
194+
def _sample(self, num_samples, context=None):
195+
samples = self._made.sample(num_samples, context=context)
196+
return samples[..., 1:]

tests/density_estimator_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from sbi.neural_nets.net_builders import (
1818
build_categoricalmassestimator,
19+
build_made,
1920
build_maf,
2021
build_maf_rqs,
2122
build_mdn,
@@ -36,6 +37,7 @@
3637

3738
# List of all density estimator builders for testing.
3839
model_builders = [
40+
build_made,
3941
build_mdn,
4042
build_maf,
4143
build_maf_rqs,

tests/linearGaussian_snpe_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def simulator(theta):
147147
@pytest.mark.slow
148148
@pytest.mark.parametrize(
149149
"density_estimator",
150-
["mdn", "maf", "maf_rqs", "nsf", "zuko_maf", "zuko_nsf"],
150+
["made", "mdn", "maf", "maf_rqs", "nsf", "zuko_maf", "zuko_nsf"],
151151
)
152152
def test_density_estimators_on_linearGaussian(density_estimator):
153153
"""Test NPE with different density estimators on linear Gaussian example."""

0 commit comments

Comments
 (0)