Skip to content

Commit 4c173ab

Browse files
authored
docs: fix docstrings in LikelihoodEstimator and RatioEstimator (#1571)
docstrings.
1 parent 8ab6841 commit 4c173ab

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

sbi/inference/trainers/nle/nle_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def append_simulations(
102102
With default settings, this is not used at all for `NLE`. Only when
103103
the user later on requests `.train(discard_prior_samples=True)`, we
104104
use these indices to find which training data stemmed from the prior.
105+
algorithm: Which algorithm is used. This is used to give a more informative
106+
warning or error message when invalid simulations are found.
105107
data_device: Where to store the data, default is on the same device where
106108
the training is happening. If training a large dataset on a GPU with not
107109
much VRAM can set to 'cpu' to store data on system memory instead.
@@ -141,6 +143,16 @@ def train(
141143
r"""Train the density estimator to learn the distribution $p(x|\theta)$.
142144
143145
Args:
146+
training_batch_size: Training batch size.
147+
learning_rate: Learning rate for Adam optimizer.
148+
validation_fraction: The fraction of data to use for validation.
149+
stop_after_epochs: The number of epochs to wait for improvement on the
150+
validation set before terminating training.
151+
max_num_epochs: Maximum number of epochs to run. If reached, we stop
152+
training even when the validation loss is still decreasing. Otherwise,
153+
we train until validation loss increases (see also `stop_after_epochs`).
154+
clip_max_norm: Value at which to clip the total gradient norm in order to
155+
prevent exploding gradients. Use None for no clipping.
144156
resume_training: Can be used in case training time is limited, e.g. on a
145157
cluster. If `True`, the split between train and validation set, the
146158
optimizer, the number of epochs, and the best validation log-prob will

sbi/inference/trainers/nre/nre_base.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def append_simulations(
112112
With default settings, this is not used at all for `NRE`. Only when
113113
the user later on requests `.train(discard_prior_samples=True)`, we
114114
use these indices to find which training data stemmed from the prior.
115+
algorithm: Which algorithm is used. This is used to give a more informative
116+
warning or error message when invalid simulations are found.
115117
data_device: Where to store the data, default is on the same device where
116118
the training is happening. If training a large dataset on a GPU with not
117119
much VRAM can set to 'cpu' to store data on system memory instead.
@@ -153,8 +155,16 @@ def train(
153155
154156
Args:
155157
num_atoms: Number of atoms to use for classification.
156-
exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
157-
during training. Expect errors, silent or explicit, when `False`.
158+
training_batch_size: Training batch size.
159+
learning_rate: Learning rate for Adam optimizer.
160+
validation_fraction: The fraction of data to use for validation.
161+
stop_after_epochs: The number of epochs to wait for improvement on the
162+
validation set before terminating training.
163+
max_num_epochs: Maximum number of epochs to run. If reached, we stop
164+
training even when the validation loss is still decreasing. Otherwise,
165+
we train until validation loss increases (see also `stop_after_epochs`).
166+
clip_max_norm: Value at which to clip the total gradient norm in order to
167+
prevent exploding gradients. Use None for no clipping.
158168
resume_training: Can be used in case training time is limited, e.g. on a
159169
cluster. If `True`, the split between train and validation set, the
160170
optimizer, the number of epochs, and the best validation log-prob will
@@ -164,6 +174,8 @@ def train(
164174
samples.
165175
retrain_from_scratch: Whether to retrain the conditional density
166176
estimator for the posterior from scratch each round.
177+
show_train_summary: Whether to print the number of epochs and validation
178+
loss after the training.
167179
dataloader_kwargs: Additional or updated kwargs to be passed to the training
168180
and validation dataloaders (like, e.g., a collate_fn).
169181
loss_kwargs: Additional or updated kwargs to be passed to the self._loss fn.

0 commit comments

Comments
 (0)