Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
a567fc8
wip: build_mnpe integrated, build_mnle refactored, mnle_test fails; T…
dgedon Jan 9, 2025
b9a5b92
wip: added MNPE class and test case for it
dgedon Jan 9, 2025
fe274c9
wip: added MNPE class and test case for it, not working yet
dgedon Jan 10, 2025
6016f07
fix: tests, embned+maf not working
dgedon Jan 10, 2025
736e3ac
wip: fixed discrete data issue in mcmc_transform with MultipleIndepen…
dgedon Jan 17, 2025
0934e86
wip: remove unnecessary helper function (introduced while working on …
dgedon Jan 17, 2025
6412964
bug fix with normalization when using embedding nets
dgedon Mar 18, 2025
48f0e63
revert unnecessary gpu handling things. Now MultipleIndependent does …
dgedon Mar 18, 2025
4f9524f
review changes: comments, missing import, static type check
dgedon Mar 19, 2025
7b140b4
simplify mixed nets (default logtransform set for mnle/mnpe), merge c…
dgedon Mar 19, 2025
b475446
remove legacy mnle.py that was not interacted with by users
dgedon Mar 19, 2025
3d7f870
refactor prior transform function (code duplication)
dgedon Mar 19, 2025
07783d2
cleanup cosmetics
dgedon Mar 19, 2025
8b94843
cleanup cosmetics (again)
dgedon Mar 19, 2025
2c35ea3
add accuracy test with MoG and analytic reference posterior
dgedon Mar 20, 2025
03cbde8
revert tutorial file
dgedon Mar 20, 2025
88827a3
old tutorial to remove conflict
dgedon Mar 20, 2025
43d7e52
jupyter merge conflict
dgedon Mar 20, 2025
b9356ad
incorporate review comments
dgedon Mar 20, 2025
f40a10b
mark gpu test as xfail
dgedon Mar 20, 2025
bfaabe8
commenting
dgedon Mar 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sbi/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)
from sbi.inference.trainers.fmpe import FMPE
from sbi.inference.trainers.nle import MNLE, NLE_A
from sbi.inference.trainers.npe import NPE_A, NPE_B, NPE_C # noqa: F401
from sbi.inference.trainers.npe import MNPE, NPE_A, NPE_B, NPE_C # noqa: F401
from sbi.inference.trainers.npse import NPSE
from sbi.inference.trainers.nre import BNRE, NRE_A, NRE_B, NRE_C # noqa: F401

Expand Down
1 change: 1 addition & 0 deletions sbi/inference/trainers/npe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sbi.inference.trainers.npe.mnpe import MNPE # noqa: F401
from sbi.inference.trainers.npe.npe_a import NPE_A # noqa: F401
from sbi.inference.trainers.npe.npe_b import NPE_B # noqa: F401
from sbi.inference.trainers.npe.npe_base import PosteriorEstimator # noqa: F401
Expand Down
151 changes: 151 additions & 0 deletions sbi/inference/trainers/npe/mnpe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Any, Callable, Dict, Optional, Union

from torch.distributions import Distribution

from sbi.inference.posteriors import (
DirectPosterior,
ImportanceSamplingPosterior,
MCMCPosterior,
RejectionPosterior,
VIPosterior,
)
from sbi.inference.trainers.npe.npe_c import NPE_C
from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
from sbi.utils.sbiutils import del_entries


class MNPE(NPE_C):
def __init__(
self,
prior: Optional[Distribution] = None,
density_estimator: Union[str, Callable] = "mnpe",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[TensorboardSummaryWriter] = None,
show_progress_bars: bool = True,
):
r"""Mixed Neural Posterior Estimation (MNPE).
Like NPE-C, but designed to be applied to data with mixed types, e.g.,
continuous parameters and discrete parameters like they occur in models with
switching components. The emebedding net will only operate on the continuous
parameters, note this to design the dimension of the embedding net.
Args:
prior: A probability distribution that expresses prior knowledge about the
parameters, e.g. which ranges are meaningful for them. If `None`, the
prior must be passed to `.build_posterior()`.
density_estimator: If it is a string, it must be "mnpe" to use the
preconfigured neural nets for MNPE. Alternatively, a function
that builds a custom neural network can be provided. The function will
be called with the first batch of simulations (theta, x), which can
thus be used for shape inference and potentially for z-scoring. It
needs to return a PyTorch `nn.Module` implementing the density
estimator. The density estimator needs to provide the methods
`.log_prob`, `.log_prob_iid()` and `.sample()`.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
summary_writer: A tensorboard `SummaryWriter` to control, among others, log
file location (default is `<current working directory>/logs`.)
show_progress_bars: Whether to show a progressbar during simulation and
sampling.
"""

if isinstance(density_estimator, str):
assert (
density_estimator == "mnpe"
), f"""MNPE can be used with preconfigured 'mnpe' density estimator only,
not with {density_estimator}."""
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)

def train(
self,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
resume_training: bool = False,
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
) -> MixedDensityEstimator:
density_estimator = super().train(
**del_entries(locals(), entries=("self", "__class__"))
)
assert isinstance(
density_estimator, MixedDensityEstimator
), f"""Internal net must be of type
MixedDensityEstimator but is {type(density_estimator)}."""
return density_estimator

def build_posterior(
self,
density_estimator: Optional[TorchModule] = None,
prior: Optional[Distribution] = None,
sample_with: str = "direct",
mcmc_method: str = "slice_np_vectorized",
vi_method: str = "rKL",
direct_sampling_parameters: Optional[Dict[str, Any]] = None,
mcmc_parameters: Optional[Dict[str, Any]] = None,
vi_parameters: Optional[Dict[str, Any]] = None,
rejection_sampling_parameters: Optional[Dict[str, Any]] = None,
importance_sampling_parameters: Optional[Dict[str, Any]] = None,
) -> Union[
MCMCPosterior,
RejectionPosterior,
VIPosterior,
DirectPosterior,
ImportanceSamplingPosterior,
]:
"""Build posterior from the neural density estimator.
Args:
density_estimator: The density estimator that the posterior is based on.
If `None`, use the latest neural density estimator that was trained.
prior: Prior distribution.
sample_with: Method to use for sampling from the posterior. Must be one of
[`direct` | `mcmc` | `rejection` | `vi` | `importance`].
mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`,
`hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy
implementation of slice sampling; select `hmc`, `nuts` or `slice` for
Pyro-based sampling.
vi_method: Method used for VI, one of [`rKL`, `fKL`, `IW`, `alpha`].
direct_sampling_parameters: Additional kwargs passed to `DirectPosterior`.
mcmc_parameters: Additional kwargs passed to `MCMCPosterior`.
vi_parameters: Additional kwargs passed to `VIPosterior`.
rejection_sampling_parameters: Additional kwargs passed to `
RejectionPosterior`.
importance_sampling_parameters: Additional kwargs passed to
`ImportanceSamplingPosterior`.
Returns:
Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods.
"""
if density_estimator is not None:
assert isinstance(
density_estimator, MixedDensityEstimator
), f"""net must be of type MixedDensityEstimator but is {
type(density_estimator)
}."""

return super().build_posterior(
density_estimator=density_estimator,
prior=prior,
sample_with=sample_with,
mcmc_method=mcmc_method,
vi_method=vi_method,
direct_sampling_parameters=direct_sampling_parameters,
mcmc_parameters=mcmc_parameters,
vi_parameters=vi_parameters,
rejection_sampling_parameters=rejection_sampling_parameters,
importance_sampling_parameters=importance_sampling_parameters,
)
16 changes: 10 additions & 6 deletions sbi/neural_nets/estimators/mixed_density_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@


class MixedDensityEstimator(ConditionalDensityEstimator):
"""Class performing Mixed Neural Likelihood Estimation.
"""Class performing Mixed Neural Density Estimation.

MNLE combines a Categorical net and a neural density estimator to model data
with mixed types, e.g., as they occur in decision-making models.
This estimator combines a Categorical net and a neural density estimator to model
data with mixed types (discrete and continuous), e.g., as they occur in
decision-making models. It can be used for both likelihood and posterior estimation
of mixed data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This estimator combines a Categorical net and a neural density estimator to model
data with mixed types (discrete and continuous), e.g., as they occur in
decision-making models. It can be used for both likelihood and posterior estimation
of mixed data.
This estimator combines a categorical mass estimator and a density estimator to model
variables with mixed types (discrete and continuous). It can be used for both likelihood
estimation (e.g., for discrete decisions and continuous reaction times in decision-making
models) or posterior estimation (e.g., for models that have both discrete and continuous
parameters).

"""

def __init__(
Expand All @@ -26,7 +28,8 @@ def __init__(
embedding_net: nn.Module = nn.Identity(),
log_transform_input: bool = False,
):
"""Initialize class for combining density estimators for MNLE.
"""Initialize class for combining density estimators for mixed neural
density estimation.

Args:
discrete_net: neural net to model discrete part of the data.
Expand All @@ -51,8 +54,9 @@ def __init__(

def forward(self, input: Tensor):
raise NotImplementedError(
"""The forward method is not implemented for MNLE, use '.sample(...)' to
generate samples though a forward pass."""
"""The forward method is not implemented for mixed neural density
estimation,use '.sample(...)' to generate samples though a forward
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
estimation,use '.sample(...)' to generate samples though a forward
estimation, use '.sample(...)' to generate samples though a forward

pass."""
)

def sample(
Expand Down
3 changes: 2 additions & 1 deletion sbi/neural_nets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
build_resnet_flowmatcher,
)
from sbi.neural_nets.net_builders.mdn import build_mdn
from sbi.neural_nets.net_builders.mnle import build_mnle
from sbi.neural_nets.net_builders.mixed_nets import build_mnle, build_mnpe
from sbi.neural_nets.net_builders.score_nets import build_score_estimator
from sbi.utils.nn_utils import check_net_device

Expand All @@ -42,6 +42,7 @@
"maf_rqs": build_maf_rqs,
"nsf": build_nsf,
"mnle": build_mnle,
"mnpe": build_mnpe,
"zuko_nice": build_zuko_nice,
"zuko_maf": build_zuko_maf,
"zuko_nsf": build_zuko_nsf,
Expand Down
2 changes: 1 addition & 1 deletion sbi/neural_nets/net_builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@
build_resnet_flowmatcher,
)
from sbi.neural_nets.net_builders.mdn import build_mdn
from sbi.neural_nets.net_builders.mnle import build_mnle
from sbi.neural_nets.net_builders.mixed_nets import build_mnle, build_mnpe
from sbi.neural_nets.net_builders.score_nets import build_score_estimator
Loading