Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion sbi/inference/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@
from dataclasses import asdict
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
)
from warnings import warn

import torch
Expand Down
15 changes: 8 additions & 7 deletions sbi/inference/trainers/nle/mnle.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Any, Callable, Dict, Literal, Optional, Union
from typing import Any, Dict, Literal, Optional, Union

from torch.distributions import Distribution

Expand All @@ -14,6 +14,7 @@
)
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimatorTrainer
from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
from sbi.sbi_types import TensorBoardSummaryWriter
from sbi.utils.sbiutils import del_entries

Expand All @@ -33,7 +34,7 @@ class MNLE(LikelihoodEstimatorTrainer):
def __init__(
self,
prior: Optional[Distribution] = None,
density_estimator: Union[str, Callable] = "mnle",
density_estimator: Union[Literal["mnle"], DensityEstimatorBuilder] = "mnle",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[TensorBoardSummaryWriter] = None,
Expand All @@ -47,12 +48,12 @@ def __init__(
prior must be passed to `.build_posterior()`.
density_estimator: If it is a string, it must be "mnle" to use the
preconfiugred neural nets for MNLE. Alternatively, a function
that builds a custom neural network can be provided. The function will
that builds a custom neural network, which adheres to
`DensityEstimatorBuilder` protocol can be provided. The function will
be called with the first batch of simulations (theta, x), which can
thus be used for shape inference and potentially for z-scoring. It
needs to return a PyTorch `nn.Module` implementing the density
estimator. The density estimator needs to provide the methods
`.log_prob`, `.log_prob_iid()` and `.sample()`.
thus be used for shape inference and potentially for z-scoring. The
density estimator needs to provide the methods `.log_prob` and
`.sample()`.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
Expand Down
17 changes: 10 additions & 7 deletions sbi/inference/trainers/nle/nle_a.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Callable, Optional, Union
from typing import Literal, Optional, Union

from torch.distributions import Distribution

from sbi.inference.trainers.nle.nle_base import LikelihoodEstimatorTrainer
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
from sbi.sbi_types import TensorBoardSummaryWriter
from sbi.utils.sbiutils import del_entries

Expand All @@ -21,7 +22,9 @@ class NLE_A(LikelihoodEstimatorTrainer):
def __init__(
self,
prior: Optional[Distribution] = None,
density_estimator: Union[str, Callable] = "maf",
density_estimator: Union[
Literal["nsf", "maf", "mdn", "made"], DensityEstimatorBuilder
] = "maf",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[TensorBoardSummaryWriter] = None,
Expand All @@ -35,12 +38,12 @@ def __init__(
prior must be passed to `.build_posterior()`.
density_estimator: If it is a string, use a pre-configured network of the
provided type (one of nsf, maf, mdn, made). Alternatively, a function
that builds a custom neural network can be provided. The function will
that builds a custom neural network, which adheres to
`DensityEstimatorBuilder` protocol can be provided. The function will
be called with the first batch of simulations (theta, x), which can
thus be used for shape inference and potentially for z-scoring. It
needs to return a PyTorch `nn.Module` implementing the density
estimator. The density estimator needs to provide the methods
`.log_prob` and `.sample()`.
thus be used for shape inference and potentially for z-scoring. The
density estimator needs to provide the methods `.log_prob` and
`.sample()`.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
Expand Down
17 changes: 10 additions & 7 deletions sbi/inference/trainers/nle/nle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from abc import ABC
from copy import deepcopy
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union
from typing import Any, Dict, Literal, Optional, Tuple, Union

import torch
from torch import Tensor
Expand All @@ -26,6 +26,7 @@
from sbi.inference.trainers.base import NeuralInference
from sbi.neural_nets import likelihood_nn
from sbi.neural_nets.estimators import ConditionalDensityEstimator
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
)
Expand All @@ -38,7 +39,9 @@ class LikelihoodEstimatorTrainer(NeuralInference, ABC):
def __init__(
self,
prior: Optional[Distribution] = None,
density_estimator: Union[str, Callable] = "maf",
density_estimator: Union[
Literal["nsf", "maf", "mdn", "made"], DensityEstimatorBuilder
] = "maf",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[SummaryWriter] = None,
Expand All @@ -53,12 +56,12 @@ def __init__(
distribution) can be used.
density_estimator: If it is a string, use a pre-configured network of the
provided type (one of nsf, maf, mdn, made). Alternatively, a function
that builds a custom neural network can be provided. The function will
that builds a custom neural network, which adheres to
`DensityEstimatorBuilder` protocol can be provided. The function will
be called with the first batch of simulations (theta, x), which can
thus be used for shape inference and potentially for z-scoring. It
needs to return a PyTorch `nn.Module` implementing the density
estimator. The density estimator needs to provide the methods
`.log_prob` and `.sample()`.
thus be used for shape inference and potentially for z-scoring. The
density estimator needs to provide the methods `.log_prob` and
`.sample()`.

See docstring of `NeuralInference` class for all other arguments.
"""
Expand Down
15 changes: 8 additions & 7 deletions sbi/inference/trainers/npe/mnpe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Any, Callable, Dict, Literal, Optional, Union
from typing import Any, Dict, Literal, Optional, Union

from torch.distributions import Distribution

Expand All @@ -15,6 +15,7 @@
)
from sbi.inference.trainers.npe.npe_c import NPE_C
from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
from sbi.sbi_types import TensorBoardSummaryWriter
from sbi.utils.sbiutils import del_entries

Expand All @@ -31,7 +32,7 @@ class MNPE(NPE_C):
def __init__(
self,
prior: Optional[Distribution] = None,
density_estimator: Union[str, Callable] = "mnpe",
density_estimator: Union[Literal["mnpe"], DensityEstimatorBuilder] = "mnpe",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[TensorBoardSummaryWriter] = None,
Expand All @@ -45,12 +46,12 @@ def __init__(
prior must be passed to `.build_posterior()`.
density_estimator: If it is a string, it must be "mnpe" to use the
preconfigured neural nets for MNPE. Alternatively, a function
that builds a custom neural network can be provided. The function will
that builds a custom neural network, which adheres to
`DensityEstimatorBuilder` protocol can be provided. The function will
be called with the first batch of simulations (theta, x), which can
thus be used for shape inference and potentially for z-scoring. It
needs to return a PyTorch `nn.Module` implementing the density
estimator. The density estimator needs to provide the methods
`.log_prob`, `.log_prob_iid()` and `.sample()`.
thus be used for shape inference and potentially for z-scoring. The
density estimator needs to provide the methods `.log_prob` and
`.sample()`.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
Expand Down
31 changes: 18 additions & 13 deletions sbi/inference/trainers/npe/npe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Literal, Optional, Union

import torch
from pyknos.mdn.mdn import MultivariateGaussianMDN
Expand All @@ -16,7 +16,10 @@
from sbi.inference.trainers.npe.npe_base import (
PosteriorEstimatorTrainer,
)
from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
from sbi.neural_nets.estimators.base import (
ConditionalDensityEstimator,
DensityEstimatorBuilder,
)
from sbi.sbi_types import TensorBoardSummaryWriter
from sbi.utils import torchutils
from sbi.utils.sbiutils import (
Expand Down Expand Up @@ -49,7 +52,9 @@ class NPE_A(PosteriorEstimatorTrainer):
def __init__(
self,
prior: Optional[Distribution] = None,
density_estimator: Union[str, Callable] = "mdn_snpe_a",
density_estimator: Union[
Literal["mdn_snpe_a"], DensityEstimatorBuilder
] = "mdn_snpe_a",
num_components: int = 10,
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
Expand All @@ -65,17 +70,17 @@ def __init__(
distribution) can be used.
density_estimator: If it is a string (only "mdn_snpe_a" is valid), use a
pre-configured mixture of densities network. Alternatively, a function
that builds a custom neural network can be provided. The function will
that builds a custom neural network, which adheres to
`DensityEstimatorBuilder` protocol can be provided. The function will
be called with the first batch of simulations (theta, x), which can
thus be used for shape inference and potentially for z-scoring. It
needs to return a PyTorch `nn.Module` implementing the density
estimator. The density estimator needs to provide the methods
`.log_prob` and `.sample()`. Note that until the last round only a
single (multivariate) Gaussian component is used for training (see
Algorithm 1 in [1]). In the last round, this component is replicated
`num_components` times, its parameters are perturbed with a very small
noise, and then the last training round is done with the expanded
Gaussian mixture as estimator for the proposal posterior.
thus be used for shape inference and potentially for z-scoring. The
density estimator needs to provide the methods `.log_prob` and
`.sample()`. Note that until the last round only a single (multivariate)
Gaussian component is used for training (seeAlgorithm 1 in [1]). In the
last round, this component is replicated `num_components` times, its
parameters are perturbed with a very small noise, and then the last
training round is done with the expanded Gaussian mixture as estimator
for the proposal posterior.
num_components: Number of components of the mixture of Gaussians in the
last round. This overrides the `num_components` value passed to
`posterior_nn()`.
Expand Down
21 changes: 13 additions & 8 deletions sbi/inference/trainers/npe/npe_b.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Any, Callable, Optional, Union
from typing import Any, Literal, Optional, Union

import torch
from torch import Tensor
from torch.distributions import Distribution

import sbi.utils as utils
from sbi.inference.trainers.npe.npe_base import PosteriorEstimatorTrainer
from sbi.inference.trainers.npe.npe_base import (
PosteriorEstimatorTrainer,
)
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
from sbi.neural_nets.estimators.shape_handling import reshape_to_sample_batch_event
from sbi.sbi_types import TensorBoardSummaryWriter
from sbi.utils.sbiutils import del_entries
Expand All @@ -33,7 +36,9 @@ class NPE_B(PosteriorEstimatorTrainer):
def __init__(
self,
prior: Optional[Distribution] = None,
density_estimator: Union[str, Callable] = "maf",
density_estimator: Union[
Literal["nsf", "maf", "mdn", "made"], DensityEstimatorBuilder
] = "maf",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[TensorBoardSummaryWriter] = None,
Expand All @@ -46,12 +51,12 @@ def __init__(
parameters, e.g. which ranges are meaningful for them.
density_estimator: If it is a string, use a pre-configured network of the
provided type (one of nsf, maf, mdn, made). Alternatively, a function
that builds a custom neural network can be provided. The function will
that builds a custom neural network, which adheres to
`DensityEstimatorBuilder` protocol can be provided. The function will
be called with the first batch of simulations (theta, x), which can
thus be used for shape inference and potentially for z-scoring. It
needs to return a PyTorch `nn.Module` implementing the density
estimator. The density estimator needs to provide the methods
`.log_prob` and `.sample()`.
thus be used for shape inference and potentially for z-scoring. The
density estimator needs to provide the methods `.log_prob` and
`.sample()`.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
Expand Down
20 changes: 13 additions & 7 deletions sbi/inference/trainers/npe/npe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@
)
from sbi.inference.potentials import posterior_estimator_based_potential
from sbi.inference.potentials.posterior_based_potential import PosteriorBasedPotential
from sbi.inference.trainers.base import NeuralInference, check_if_proposal_has_default_x
from sbi.inference.trainers.base import (
NeuralInference,
check_if_proposal_has_default_x,
)
from sbi.neural_nets import posterior_nn
from sbi.neural_nets.estimators import ConditionalDensityEstimator
from sbi.neural_nets.estimators.base import DensityEstimatorBuilder
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
Expand All @@ -54,7 +58,9 @@ class PosteriorEstimatorTrainer(NeuralInference, ABC):
def __init__(
self,
prior: Optional[Distribution] = None,
density_estimator: Union[str, Callable] = "maf",
density_estimator: Union[
Literal["nsf", "maf", "mdn", "made"], DensityEstimatorBuilder
] = "maf",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[SummaryWriter] = None,
Expand All @@ -69,12 +75,12 @@ def __init__(
Args:
density_estimator: If it is a string, use a pre-configured network of the
provided type (one of nsf, maf, mdn, made). Alternatively, a function
that builds a custom neural network can be provided. The function will
that builds a custom neural network, which adheres to
`DensityEstimatorBuilder` protocol can be provided. The function will
be called with the first batch of simulations (theta, x), which can
thus be used for shape inference and potentially for z-scoring. It
needs to return a PyTorch `nn.Module` implementing the density
estimator. The density estimator needs to provide the methods
`.log_prob` and `.sample()`.
thus be used for shape inference and potentially for z-scoring. The
density estimator needs to provide the methods `.log_prob` and
`.sample()`.

See docstring of `NeuralInference` class for all other arguments.
"""
Expand Down
Loading