Skip to content

Commit ce363b7

Browse files
authored
refactor: unify build_posterior in base class for all trainers (#1610)
* fix: spelling update * refactor(NeuralInference): define methods for building the posterior * refactor(nle): update build_posterior method to use parent class implementation * refactor(npe): update build_posterior method to use parent class implementation * refactor(nre): update build_posterior method to use parent class implementation * refactor(npse): implement required abstract method from parent class * refactor(npse, fmpe): remove and replace build_posterior method with superclass method * refactor: update build_posterior estimator annotation * refactor: move _get_potential_function method from build_posterior to _create_posterior * refactor: update _get_potential_function return type * refactor: Update conditional arrangement for _create_posterior * docs: update vector_field_potential docstring * test(build_posterior): raise error for invalid density_estimator * refactor(build_posterior): move estimator and prior checking into helper functions * refactor: rearrange method order * refactor: remove comment and update return type for abstract method train * refactor: update method ordering for trainer classes to follow convention * refactor(FMPE, NPSE): add vectorfield_sampling_parameters parameter * test: update NRE_A to NRE base class * test(FMPE, NPSE): add vectorfield_sampling_parameters to build_posterior argument
1 parent 1b78d06 commit ce363b7

17 files changed

+564
-506
lines changed

docs/tutorials/00_getting_started.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@
117117
"cell_type": "markdown",
118118
"metadata": {},
119119
"source": [
120-
"The `sbi` toolbox uses neural networks to learn the relationship between parameters and data. In this exampmle, we will use neural perform posterior estimation (NPE). To run NPE, we first instatiate a trainer, which we call `inference`:"
120+
"The `sbi` toolbox uses neural networks to learn the relationship between parameters and data. In this example, we will use neural perform posterior estimation (NPE). To run NPE, we first instatiate a trainer, which we call `inference`:"
121121
]
122122
},
123123
{

sbi/inference/potentials/likelihood_based_potential.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def likelihood_estimator_based_potential(
2626
prior: Distribution, # type: ignore
2727
x_o: Optional[Tensor],
2828
enable_transform: bool = True,
29-
) -> Tuple[Callable, TorchTransform]:
29+
) -> Tuple["LikelihoodBasedPotential", TorchTransform]:
3030
r"""Returns potential :math:`\log(p(x_o|\theta)p(\theta))` for likelihood estimator.
3131
3232
It also returns a transformation that can be used to transform the potential into

sbi/inference/potentials/ratio_based_potential.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Callable, Optional, Tuple, Union
4+
from typing import Optional, Tuple, Union
55

66
import torch
77
from torch import Tensor, nn
@@ -18,7 +18,7 @@ def ratio_estimator_based_potential(
1818
prior: Distribution,
1919
x_o: Optional[Tensor],
2020
enable_transform: bool = True,
21-
) -> Tuple[Callable, TorchTransform]:
21+
) -> Tuple["RatioBasedPotential", TorchTransform]:
2222
r"""Returns the potential for ratio-based methods.
2323
2424
It also returns a transformation that can be used to transform the potential into

sbi/inference/potentials/vector_field_potential.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def vector_field_estimator_based_potential(
3737
enable_transform: Whether to enable transforms. Not supported yet.
3838
**kwargs: Additional keyword arguments passed to
3939
`VectorFieldBasedPotential`.
40+
Returns:
41+
The potential function and a transformation that maps
42+
to unconstrained space.
4043
"""
4144
device = str(next(vector_field_estimator.parameters()).device)
4245

@@ -60,9 +63,9 @@ def __init__(
6063
vector_field_estimator: ConditionalVectorFieldEstimator,
6164
prior: Optional[Distribution], # type: ignore
6265
x_o: Optional[Tensor] = None,
66+
device: Union[str, torch.device] = "cpu",
6367
iid_method: Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"] = "auto_gauss",
6468
iid_params: Optional[Dict[str, Any]] = None,
65-
device: Union[str, torch.device] = "cpu",
6669
neural_ode_backend: str = "zuko",
6770
neural_ode_kwargs: Optional[Dict[str, Any]] = None,
6871
):
@@ -78,11 +81,11 @@ def __init__(
7881
vector_field_estimator: The neural network modelling the vector field.
7982
prior: The prior distribution.
8083
x_o: The observed data at which to evaluate the posterior.
84+
device: The device on which to evaluate the potential.
8185
iid_method: Which method to use for computing the score in the iid setting.
8286
We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss".
8387
iid_params: Parameters for the iid method, for arguments see
8488
`IIDScoreFunction`.
85-
device: The device on which to evaluate the potential.
8689
neural_ode_backend: The backend to use for the neural ODE. Currently,
8790
only "zuko" is supported.
8891
neural_ode_kwargs: Additional keyword arguments for the neural ODE.

0 commit comments

Comments
 (0)