Skip to content

Commit 6daf716

Browse files
manuelgloecklermichaeldeistler
authored andcommitted
Add options to docstring
1 parent 33aa5ad commit 6daf716

File tree

4 files changed

+3
-5
lines changed

4 files changed

+3
-5
lines changed

sbi/inference/npse/npse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def __init__(
5656
Args:
5757
prior: Prior distribution.
5858
score_estimator: Neural network architecture for the score estimator. Can be
59-
a string (e.g. 'mlp') or a callable that returns a neural network.
59+
a string (e.g. 'mlp' or 'ada_mlp') or a callable that returns a neural
60+
network.
6061
sde_type: Type of SDE to use. Must be one of ['vp', 've', 'subvp'].
6162
device: Device to run the training on.
6263
logging_level: Logging level for the training. Can be an integer or a

sbi/inference/potentials/score_based_potential.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def gradient(
162162
raise NotImplementedError(
163163
"Score accumulation for IID data is not yet implemented."
164164
)
165-
165+
166166
return score
167167

168168
def get_continuous_normalizing_flow(
@@ -229,4 +229,3 @@ def f(t, x):
229229
exact=exact,
230230
)
231231
return transform
232-

sbi/samplers/score/predictors.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,3 @@ def predict(self, theta: Tensor, t1: Tensor, t0: Tensor):
120120
f_backward = f - (1 + self.eta**2) / 2 * g**2 * score
121121
g_backward = self.eta * g
122122
return theta - f_backward * dt + g_backward * torch.randn_like(theta) * dt_sqrt
123-

tests/score_samplers_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def _build_gaussian_score_estimator(
6161
# Note the precondition predicts a correct Gaussian score by default if the neural
6262
# net predicts 0!
6363
class DummyNet(torch.nn.Module):
64-
6564
def __init__(self):
6665
super().__init__()
6766
self.dummy_param_for_device_detection = torch.nn.Linear(1, 1)

0 commit comments

Comments
 (0)