@@ -112,6 +112,8 @@ def append_simulations(
112
112
With default settings, this is not used at all for `NRE`. Only when
113
113
the user later on requests `.train(discard_prior_samples=True)`, we
114
114
use these indices to find which training data stemmed from the prior.
115
+ algorithm: Which algorithm is used. This is used to give a more informative
116
+ warning or error message when invalid simulations are found.
115
117
data_device: Where to store the data, default is on the same device where
116
118
the training is happening. If training a large dataset on a GPU with not
117
119
much VRAM can set to 'cpu' to store data on system memory instead.
@@ -153,8 +155,16 @@ def train(
153
155
154
156
Args:
155
157
num_atoms: Number of atoms to use for classification.
156
- exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
157
- during training. Expect errors, silent or explicit, when `False`.
158
+ training_batch_size: Training batch size.
159
+ learning_rate: Learning rate for Adam optimizer.
160
+ validation_fraction: The fraction of data to use for validation.
161
+ stop_after_epochs: The number of epochs to wait for improvement on the
162
+ validation set before terminating training.
163
+ max_num_epochs: Maximum number of epochs to run. If reached, we stop
164
+ training even when the validation loss is still decreasing. Otherwise,
165
+ we train until validation loss increases (see also `stop_after_epochs`).
166
+ clip_max_norm: Value at which to clip the total gradient norm in order to
167
+ prevent exploding gradients. Use None for no clipping.
158
168
resume_training: Can be used in case training time is limited, e.g. on a
159
169
cluster. If `True`, the split between train and validation set, the
160
170
optimizer, the number of epochs, and the best validation log-prob will
@@ -164,6 +174,8 @@ def train(
164
174
samples.
165
175
retrain_from_scratch: Whether to retrain the conditional density
166
176
estimator for the posterior from scratch each round.
177
+ show_train_summary: Whether to print the number of epochs and validation
178
+ loss after the training.
167
179
dataloader_kwargs: Additional or updated kwargs to be passed to the training
168
180
and validation dataloaders (like, e.g., a collate_fn).
169
181
loss_kwargs: Additional or updated kwargs to be passed to the self._loss fn.
0 commit comments