@@ -67,6 +67,63 @@ def __init__(
67
67
show_progress_bars = show_progress_bars ,
68
68
)
69
69
70
+ def append_simulations (
71
+ self ,
72
+ theta : torch .Tensor ,
73
+ x : torch .Tensor ,
74
+ proposal : Optional [DirectPosterior ] = None ,
75
+ exclude_invalid_x : Optional [bool ] = None ,
76
+ data_device : Optional [str ] = None ,
77
+ ) -> NeuralInference :
78
+ if (
79
+ proposal is None
80
+ or proposal is self ._prior
81
+ or (
82
+ isinstance (proposal , RestrictedPrior ) and proposal ._prior is self ._prior
83
+ )
84
+ ):
85
+ current_round = 0
86
+ else :
87
+ raise NotImplementedError (
88
+ "FMPE with proposal different from prior is not implemented."
89
+ )
90
+
91
+ if exclude_invalid_x is None :
92
+ exclude_invalid_x = current_round == 0
93
+
94
+ if data_device is None :
95
+ data_device = self ._device
96
+
97
+ theta , x = validate_theta_and_x (
98
+ theta , x , data_device = data_device , training_device = self ._device
99
+ )
100
+
101
+ is_valid_x , num_nans , num_infs = handle_invalid_x (
102
+ x , exclude_invalid_x = exclude_invalid_x
103
+ )
104
+
105
+ x = x [is_valid_x ]
106
+ theta = theta [is_valid_x ]
107
+
108
+ # Check for problematic z-scoring
109
+ warn_if_zscoring_changes_data (x )
110
+ # Check whether there are NaNs or Infs in the data and remove accordingly.
111
+ npe_msg_on_invalid_x (
112
+ num_nans = num_nans ,
113
+ num_infs = num_infs ,
114
+ exclude_invalid_x = exclude_invalid_x ,
115
+ algorithm = "Single-round FMPE" ,
116
+ )
117
+
118
+ self ._data_round_index .append (current_round )
119
+ prior_masks = mask_sims_from_prior (int (current_round > 0 ), theta .size (0 ))
120
+
121
+ self ._theta_roundwise .append (theta )
122
+ self ._x_roundwise .append (x )
123
+ self ._prior_masks .append (prior_masks )
124
+
125
+ return self
126
+
70
127
def train (
71
128
self ,
72
129
training_batch_size : int = 50 ,
@@ -76,6 +133,7 @@ def train(
76
133
max_num_epochs : int = 2 ** 31 - 1 ,
77
134
clip_max_norm : Optional [float ] = 5.0 ,
78
135
resume_training : bool = False ,
136
+ allow_multi_round_usage : bool = False ,
79
137
show_train_summary : bool = False ,
80
138
dataloader_kwargs : Optional [dict ] = None ,
81
139
) -> ConditionalDensityEstimator :
@@ -89,16 +147,32 @@ def train(
89
147
max_num_epochs: Maximum number of epochs to train for.
90
148
clip_max_norm: Maximum norm for gradient clipping. Defaults to 5.0.
91
149
resume_training: Whether to resume training. Defaults to False.
150
+ allow_multi_round_usage: Whether to allow training with simulations that
151
+ have not been sampled from the prior, e.g., in a sequential inference
152
+ setting. Note that can lead to biased inference results.
92
153
show_train_summary: Whether to show the training summary. Defaults to False.
93
154
dataloader_kwargs: Additional keyword arguments for the dataloader.
94
155
95
156
Returns:
96
157
DensityEstimator: Trained flow matching estimator.
97
158
"""
98
159
160
+ # Load data from most recent round.
161
+ self ._round = max (self ._data_round_index )
162
+
163
+ if self ._round == 0 and self ._neural_net is not None :
164
+ assert allow_multi_round_usage , (
165
+ "You have already trained this neural network and now appended new "
166
+ "simulations with `append_simulations(theta, x)` without providing a "
167
+ "proposal. If the new simulations are sampled from the prior, you "
168
+ "can avoid this error by passing `allow_multi_round_usage=True` to "
169
+ "the `train(...)` method. However, if the new simulations were not "
170
+ "sampled from the prior, the result of FMPE will not be the true "
171
+ "posterior. Instead, it will be the proposal posterior, which "
172
+ "(usually) is more narrow than the true posterior. " ,
173
+ )
174
+
99
175
start_idx = 0 # as there is no multi-round FMPE yet
100
- current_round = 1 # as there is no multi-round FMPE yet
101
- self ._data_round_index .append (current_round )
102
176
103
177
train_loader , val_loader = self .get_dataloaders (
104
178
start_idx ,
@@ -130,7 +204,7 @@ def train(
130
204
list (self ._neural_net .net .parameters ()), lr = learning_rate
131
205
)
132
206
self .epoch = 0
133
- # NOTE: we deal with losses , not log probs here .
207
+ # NOTE: in the FMPE context we use MSE loss , not log probs.
134
208
self ._val_loss = float ("Inf" )
135
209
136
210
while self .epoch <= max_num_epochs and not self ._converged (
@@ -223,7 +297,7 @@ def build_posterior(
223
297
Args:
224
298
density_estimator: Density estimator for the posterior.
225
299
prior: Prior distribution.
226
- sample_with: Sampling method.
300
+ sample_with: Sampling method, currently only "direct" is supported .
227
301
direct_sampling_parameters: kwargs for DirectPosterior.
228
302
229
303
Returns:
@@ -261,57 +335,3 @@ def build_posterior(
261
335
)
262
336
263
337
return deepcopy (self ._posterior )
264
-
265
- def append_simulations (
266
- self ,
267
- theta : torch .Tensor ,
268
- x : torch .Tensor ,
269
- proposal : Optional [DirectPosterior ] = None ,
270
- exclude_invalid_x : Optional [bool ] = None ,
271
- data_device : Optional [str ] = None ,
272
- ) -> NeuralInference :
273
- if (
274
- proposal is None
275
- or proposal is self ._prior
276
- or (
277
- isinstance (proposal , RestrictedPrior ) and proposal ._prior is self ._prior
278
- )
279
- ):
280
- current_round = 0
281
- else :
282
- raise NotImplementedError ("Mutli-round FMPE is currently not supported." )
283
-
284
- if exclude_invalid_x is None :
285
- exclude_invalid_x = current_round == 0
286
-
287
- if data_device is None :
288
- data_device = self ._device
289
-
290
- theta , x = validate_theta_and_x (
291
- theta , x , data_device = data_device , training_device = self ._device
292
- )
293
-
294
- is_valid_x , num_nans , num_infs = handle_invalid_x (
295
- x , exclude_invalid_x = exclude_invalid_x
296
- )
297
-
298
- x = x [is_valid_x ]
299
- theta = theta [is_valid_x ]
300
-
301
- # Check for problematic z-scoring
302
- warn_if_zscoring_changes_data (x )
303
- # Check whether there are NaNs or Infs in the data and remove accordingly.
304
- npe_msg_on_invalid_x (
305
- num_nans = num_nans ,
306
- num_infs = num_infs ,
307
- exclude_invalid_x = exclude_invalid_x ,
308
- algorithm = "Single-round FMPE" ,
309
- )
310
-
311
- prior_masks = mask_sims_from_prior (int (current_round > 0 ), theta .size (0 ))
312
-
313
- self ._theta_roundwise .append (theta )
314
- self ._x_roundwise .append (x )
315
- self ._prior_masks .append (prior_masks )
316
-
317
- return self
0 commit comments