Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sbi/inference/trainers/npe/mnpe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Dict, Optional, Union

from torch.distributions import Distribution

Expand All @@ -12,6 +12,7 @@
RejectionPosterior,
VIPosterior,
)
from sbi.inference.trainers.npe.npe_base import DensityEstimatorBuilder
from sbi.inference.trainers.npe.npe_c import NPE_C
from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter
Expand All @@ -22,7 +23,7 @@ class MNPE(NPE_C):
def __init__(
self,
prior: Optional[Distribution] = None,
density_estimator: Union[str, Callable] = "mnpe",
density_estimator: Union[str, DensityEstimatorBuilder] = "mnpe",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[TensorboardSummaryWriter] = None,
Expand Down
9 changes: 6 additions & 3 deletions sbi/inference/trainers/npe/npe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from torch.distributions import Distribution, MultivariateNormal

from sbi.inference.posteriors.direct_posterior import DirectPosterior
from sbi.inference.trainers.npe.npe_base import PosteriorEstimator
from sbi.inference.trainers.npe.npe_base import (
DensityEstimatorBuilder,
PosteriorEstimator,
)
from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
from sbi.utils import torchutils
Expand All @@ -32,7 +35,7 @@ class NPE_A(PosteriorEstimator):
def __init__(
self,
prior: Optional[Distribution] = None,
density_estimator: Union[str, Callable] = "mdn_snpe_a",
density_estimator: Union[str, DensityEstimatorBuilder] = "mdn_snpe_a",
num_components: int = 10,
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
Expand Down Expand Up @@ -231,7 +234,7 @@ def train(

def correct_for_proposal(
self,
density_estimator: Optional[TorchModule] = None,
density_estimator: Optional[torch.nn.Module] = None,
) -> "NPE_A_MDN":
r"""Build mixture of Gaussians that approximates the posterior.

Expand Down
9 changes: 6 additions & 3 deletions sbi/inference/trainers/npe/npe_b.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Any, Callable, Optional, Union
from typing import Any, Optional, Union

import torch
from torch import Tensor
from torch.distributions import Distribution

import sbi.utils as utils
from sbi.inference.trainers.npe.npe_base import PosteriorEstimator
from sbi.inference.trainers.npe.npe_base import (
DensityEstimatorBuilder,
PosteriorEstimator,
)
from sbi.neural_nets.estimators.shape_handling import reshape_to_sample_batch_event
from sbi.sbi_types import TensorboardSummaryWriter
from sbi.utils.sbiutils import del_entries
Expand All @@ -20,7 +23,7 @@ class NPE_B(PosteriorEstimator):
def __init__(
self,
prior: Optional[Distribution] = None,
density_estimator: Union[str, Callable] = "maf",
density_estimator: Union[str, DensityEstimatorBuilder] = "maf",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[TensorboardSummaryWriter] = None,
Expand Down
25 changes: 22 additions & 3 deletions sbi/inference/trainers/npe/npe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Protocol, Union
from warnings import warn

import torch
Expand All @@ -25,7 +25,7 @@
from sbi.inference.potentials import posterior_estimator_based_potential
from sbi.inference.trainers.base import NeuralInference, check_if_proposal_has_default_x
from sbi.neural_nets import posterior_nn
from sbi.neural_nets.estimators import ConditionalDensityEstimator
from sbi.neural_nets.estimators import ConditionalDensityEstimator, NFlowsFlow
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
Expand All @@ -45,11 +45,30 @@
from sbi.utils.torchutils import assert_all_finite


class DensityEstimatorBuilder(Protocol):
"""Protocol for building a neural network from the data for the density
estimator."""

def __call__(self, theta: Tensor, x: Tensor, **kwargs) -> NFlowsFlow:
"""Build a density estimator from theta and x, which is mainly used for infering
shape and z-scoring. The density estimator should have the methods `.sample()`
and `.log_prob()`. The function should return an inheritance of `nn.Module`.

Args:
theta: Parameter sets.
x: Simulation outputs.

Returns:
Density Estimator.
"""
...


class PosteriorEstimator(NeuralInference, ABC):
def __init__(
self,
prior: Optional[Distribution] = None,
density_estimator: Union[str, Callable] = "maf",
density_estimator: Union[str, DensityEstimatorBuilder] = "maf",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[SummaryWriter] = None,
Expand Down
7 changes: 5 additions & 2 deletions sbi/inference/trainers/npe/npe_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from torch.distributions import Distribution, MultivariateNormal, Uniform

from sbi.inference.posteriors.direct_posterior import DirectPosterior
from sbi.inference.trainers.npe.npe_base import PosteriorEstimator
from sbi.inference.trainers.npe.npe_base import (
DensityEstimatorBuilder,
PosteriorEstimator,
)
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
Expand All @@ -34,7 +37,7 @@ class NPE_C(PosteriorEstimator):
def __init__(
self,
prior: Optional[Distribution] = None,
density_estimator: Union[str, Callable] = "maf",
density_estimator: Union[str, DensityEstimatorBuilder] = "maf",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[TensorboardSummaryWriter] = None,
Expand Down
4 changes: 2 additions & 2 deletions sbi/sbi_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@

# Define alias types because otherwise, the documentation by mkdocs became very long and
# made the website look ugly.
TensorboardSummaryWriter = NewType("Writer", SummaryWriter)
TensorboardSummaryWriter = NewType("TensorboardSummaryWriter", SummaryWriter)
# TorchTransform = NewType("torch Transform", Transform)
TorchModule = NewType("Module", Module)
TorchModule = NewType("TorchModule", Module)
TorchDistribution = NewType("torch Distribution", Distribution)
# See PEP 613 for the reason why we need to use TypeAlias here.
TorchTransform: TypeAlias = Transform
Expand Down