Skip to content

Commit 6012fa1

Browse files
committed
fix: resume training vs force first round loss handling, also for SNPE'
1 parent 9e880d6 commit 6012fa1

File tree

3 files changed

+21
-16
lines changed

3 files changed

+21
-16
lines changed

sbi/inference/fmpe/fmpe_base.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def train(
133133
max_num_epochs: int = 2**31 - 1,
134134
clip_max_norm: Optional[float] = 5.0,
135135
resume_training: bool = False,
136-
train_with_proposal_without_correction: bool = False,
136+
force_first_round_loss: bool = False,
137137
show_train_summary: bool = False,
138138
dataloader_kwargs: Optional[dict] = None,
139139
) -> ConditionalDensityEstimator:
@@ -146,8 +146,11 @@ def train(
146146
stop_after_epochs: Number of epochs to train for. Defaults to 20.
147147
max_num_epochs: Maximum number of epochs to train for.
148148
clip_max_norm: Maximum norm for gradient clipping. Defaults to 5.0.
149-
resume_training: Whether to resume training. Defaults to False.
150-
train_with_proposal_without_correction: Whether to allow training with
149+
resume_training: Can be used in case training time is limited, e.g. on a
150+
cluster. If `True`, the split between train and validation set, the
151+
optimizer, the number of epochs, and the best validation log-prob will
152+
be restored from the last time `.train()` was called.
153+
force_first_round_loss: Whether to allow training with
151154
simulations that have not been sampled from the prior, e.g., in a
152155
sequential inference setting. Note that can lead to biased inference
153156
results.
@@ -162,16 +165,18 @@ def train(
162165
self._round = max(self._data_round_index)
163166

164167
if self._round == 0 and self._neural_net is not None:
165-
assert train_with_proposal_without_correction or resume_training, (
166-
"You have already trained this neural network and now appended new "
167-
"simulations with `append_simulations(theta, x)` without providing a "
168-
"proposal. If the new simulations are sampled from the prior, you "
169-
"can avoid this error by passing "
170-
"`train_with_proposal_without_correction=True` to `train(...)` "
171-
"However, if the new simulations were not "
172-
"sampled from the prior, the result of FMPE will not be the true "
173-
"posterior. Instead, it will be the proposal posterior, which "
174-
"(usually) is more narrow than the true posterior. ",
168+
assert force_first_round_loss or resume_training, (
169+
"You have already trained this neural network. After you had trained "
170+
"the network, you again appended simulations with `append_simulations"
171+
"(theta, x)`, but you did not provide a proposal. If the new "
172+
"simulations are sampled from the prior, you can set "
173+
"`.train(..., force_first_round_loss=True`). However, if the new "
174+
"simulations were not sampled from the prior, you should pass the "
175+
"proposal, i.e. `append_simulations(theta, x, proposal)`. If "
176+
"your samples are not sampled from the prior and you do not pass a "
177+
"proposal and you set `force_first_round_loss=True`, the result of "
178+
"FMPE will not be the true posterior. Instead, it will be the proposal "
179+
"posterior, which (usually) is more narrow than the true posterior."
175180
)
176181

177182
start_idx = 0 # as there is no multi-round FMPE yet

sbi/inference/snpe/snpe_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def train(
263263
self._round = max(self._data_round_index)
264264

265265
if self._round == 0 and self._neural_net is not None:
266-
assert force_first_round_loss, (
266+
assert force_first_round_loss or resume_training, (
267267
"You have already trained this neural network. After you had trained "
268268
"the network, you again appended simulations with `append_simulations"
269269
"(theta, x)`, but you did not provide a proposal. If the new "

tests/linearGaussian_fmpe_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def test_c2st_fmpe_for_different_dims_and_resume_training(density_estimator="mlp
214214
)
215215

216216
inference = inference.append_simulations(theta, x)
217-
posterior_estimator = inference.train(max_num_epochs=10)
217+
posterior_estimator = inference.train(max_num_epochs=2)
218218
# Test whether we can stop and resume.
219219
posterior_estimator = inference.train(resume_training=True)
220220

@@ -416,5 +416,5 @@ def test_multi_round_handling_fmpe():
416416

417417
# Append new data with a proposal. This should work without any issues.
418418
inference.append_simulations(theta_new, x_new).train(
419-
max_num_epochs=2, train_with_proposal_without_correction=True
419+
max_num_epochs=2, force_first_round_loss=True
420420
)

0 commit comments

Comments
 (0)