Skip to content

Commit d85c014

Browse files
committed
fix: add multi-round handling for FMPE.
1 parent 77989f1 commit d85c014

File tree

2 files changed

+109
-58
lines changed

2 files changed

+109
-58
lines changed

sbi/inference/fmpe/fmpe_base.py

Lines changed: 78 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,63 @@ def __init__(
6767
show_progress_bars=show_progress_bars,
6868
)
6969

70+
def append_simulations(
71+
self,
72+
theta: torch.Tensor,
73+
x: torch.Tensor,
74+
proposal: Optional[DirectPosterior] = None,
75+
exclude_invalid_x: Optional[bool] = None,
76+
data_device: Optional[str] = None,
77+
) -> NeuralInference:
78+
if (
79+
proposal is None
80+
or proposal is self._prior
81+
or (
82+
isinstance(proposal, RestrictedPrior) and proposal._prior is self._prior
83+
)
84+
):
85+
current_round = 0
86+
else:
87+
raise NotImplementedError(
88+
"FMPE with proposal different from prior is not implemented."
89+
)
90+
91+
if exclude_invalid_x is None:
92+
exclude_invalid_x = current_round == 0
93+
94+
if data_device is None:
95+
data_device = self._device
96+
97+
theta, x = validate_theta_and_x(
98+
theta, x, data_device=data_device, training_device=self._device
99+
)
100+
101+
is_valid_x, num_nans, num_infs = handle_invalid_x(
102+
x, exclude_invalid_x=exclude_invalid_x
103+
)
104+
105+
x = x[is_valid_x]
106+
theta = theta[is_valid_x]
107+
108+
# Check for problematic z-scoring
109+
warn_if_zscoring_changes_data(x)
110+
# Check whether there are NaNs or Infs in the data and remove accordingly.
111+
npe_msg_on_invalid_x(
112+
num_nans=num_nans,
113+
num_infs=num_infs,
114+
exclude_invalid_x=exclude_invalid_x,
115+
algorithm="Single-round FMPE",
116+
)
117+
118+
self._data_round_index.append(current_round)
119+
prior_masks = mask_sims_from_prior(int(current_round > 0), theta.size(0))
120+
121+
self._theta_roundwise.append(theta)
122+
self._x_roundwise.append(x)
123+
self._prior_masks.append(prior_masks)
124+
125+
return self
126+
70127
def train(
71128
self,
72129
training_batch_size: int = 50,
@@ -76,6 +133,7 @@ def train(
76133
max_num_epochs: int = 2**31 - 1,
77134
clip_max_norm: Optional[float] = 5.0,
78135
resume_training: bool = False,
136+
allow_multi_round_usage: bool = False,
79137
show_train_summary: bool = False,
80138
dataloader_kwargs: Optional[dict] = None,
81139
) -> ConditionalDensityEstimator:
@@ -89,16 +147,32 @@ def train(
89147
max_num_epochs: Maximum number of epochs to train for.
90148
clip_max_norm: Maximum norm for gradient clipping. Defaults to 5.0.
91149
resume_training: Whether to resume training. Defaults to False.
150+
allow_multi_round_usage: Whether to allow training with simulations that
151+
have not been sampled from the prior, e.g., in a sequential inference
152+
setting. Note that can lead to biased inference results.
92153
show_train_summary: Whether to show the training summary. Defaults to False.
93154
dataloader_kwargs: Additional keyword arguments for the dataloader.
94155
95156
Returns:
96157
DensityEstimator: Trained flow matching estimator.
97158
"""
98159

160+
# Load data from most recent round.
161+
self._round = max(self._data_round_index)
162+
163+
if self._round == 0 and self._neural_net is not None:
164+
assert allow_multi_round_usage, (
165+
"You have already trained this neural network and now appended new "
166+
"simulations with `append_simulations(theta, x)` without providing a "
167+
"proposal. If the new simulations are sampled from the prior, you "
168+
"can avoid this error by passing `allow_multi_round_usage=True` to "
169+
"the `train(...)` method. However, if the new simulations were not "
170+
"sampled from the prior, the result of FMPE will not be the true "
171+
"posterior. Instead, it will be the proposal posterior, which "
172+
"(usually) is more narrow than the true posterior. ",
173+
)
174+
99175
start_idx = 0 # as there is no multi-round FMPE yet
100-
current_round = 1 # as there is no multi-round FMPE yet
101-
self._data_round_index.append(current_round)
102176

103177
train_loader, val_loader = self.get_dataloaders(
104178
start_idx,
@@ -130,7 +204,7 @@ def train(
130204
list(self._neural_net.net.parameters()), lr=learning_rate
131205
)
132206
self.epoch = 0
133-
# NOTE: we deal with losses, not log probs here.
207+
# NOTE: in the FMPE context we use MSE loss, not log probs.
134208
self._val_loss = float("Inf")
135209

136210
while self.epoch <= max_num_epochs and not self._converged(
@@ -223,7 +297,7 @@ def build_posterior(
223297
Args:
224298
density_estimator: Density estimator for the posterior.
225299
prior: Prior distribution.
226-
sample_with: Sampling method.
300+
sample_with: Sampling method, currently only "direct" is supported.
227301
direct_sampling_parameters: kwargs for DirectPosterior.
228302
229303
Returns:
@@ -261,57 +335,3 @@ def build_posterior(
261335
)
262336

263337
return deepcopy(self._posterior)
264-
265-
def append_simulations(
266-
self,
267-
theta: torch.Tensor,
268-
x: torch.Tensor,
269-
proposal: Optional[DirectPosterior] = None,
270-
exclude_invalid_x: Optional[bool] = None,
271-
data_device: Optional[str] = None,
272-
) -> NeuralInference:
273-
if (
274-
proposal is None
275-
or proposal is self._prior
276-
or (
277-
isinstance(proposal, RestrictedPrior) and proposal._prior is self._prior
278-
)
279-
):
280-
current_round = 0
281-
else:
282-
raise NotImplementedError("Mutli-round FMPE is currently not supported.")
283-
284-
if exclude_invalid_x is None:
285-
exclude_invalid_x = current_round == 0
286-
287-
if data_device is None:
288-
data_device = self._device
289-
290-
theta, x = validate_theta_and_x(
291-
theta, x, data_device=data_device, training_device=self._device
292-
)
293-
294-
is_valid_x, num_nans, num_infs = handle_invalid_x(
295-
x, exclude_invalid_x=exclude_invalid_x
296-
)
297-
298-
x = x[is_valid_x]
299-
theta = theta[is_valid_x]
300-
301-
# Check for problematic z-scoring
302-
warn_if_zscoring_changes_data(x)
303-
# Check whether there are NaNs or Infs in the data and remove accordingly.
304-
npe_msg_on_invalid_x(
305-
num_nans=num_nans,
306-
num_infs=num_infs,
307-
exclude_invalid_x=exclude_invalid_x,
308-
algorithm="Single-round FMPE",
309-
)
310-
311-
prior_masks = mask_sims_from_prior(int(current_round > 0), theta.size(0))
312-
313-
self._theta_roundwise.append(theta)
314-
self._x_roundwise.append(x)
315-
self._prior_masks.append(prior_masks)
316-
317-
return self

tests/linearGaussian_fmpe_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,3 +387,34 @@ def test_fmpe_map():
387387

388388
# Check whether the MAP is close to the ground truth.
389389
assert torch.allclose(map_, gt_posterior.mean, atol=0.1)
390+
391+
392+
def test_multi_round_handling_fmpe():
393+
"""Test whether we can append data and train multiple times with FMPE."""
394+
395+
num_dim = 3
396+
num_simulations = 100
397+
398+
likelihood_shift = -1.0 * ones(num_dim)
399+
likelihood_cov = 0.3 * eye(num_dim)
400+
401+
prior_mean = zeros(num_dim)
402+
prior_cov = eye(num_dim)
403+
prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)
404+
405+
theta = prior.sample((num_simulations,))
406+
x = linear_gaussian(theta, likelihood_shift, likelihood_cov)
407+
408+
inference = FMPE(prior, show_progress_bars=False)
409+
inference.append_simulations(theta, x).train(max_num_epochs=2)
410+
411+
# Append new data without passing a proposal.
412+
theta_new = prior.sample((num_simulations,))
413+
x_new = linear_gaussian(theta_new, likelihood_shift, likelihood_cov)
414+
with pytest.raises(AssertionError, match="You have already trained*"):
415+
inference.append_simulations(theta_new, x_new).train()
416+
417+
# Append new data with a proposal. This should work without any issues.
418+
inference.append_simulations(theta_new, x_new).train(
419+
max_num_epochs=2, allow_multi_round_usage=True
420+
)

0 commit comments

Comments
 (0)