-
Notifications
You must be signed in to change notification settings - Fork 196
MNPE class similar to MNLE #1362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 8 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
a567fc8
wip: build_mnpe integrated, build_mnle refactored, mnle_test fails; T…
dgedon b9a5b92
wip: added MNPE class and test case for it
dgedon fe274c9
wip: added MNPE class and test case for it, not working yet
dgedon 6016f07
fix: tests, embned+maf not working
dgedon 736e3ac
wip: fixed discrete data issue in mcmc_transform with MultipleIndepen…
dgedon 0934e86
wip: remove unnecessary helper function (introduced while working on …
dgedon 6412964
bug fix with normalization when using embedding nets
dgedon 48f0e63
revert unnecessary gpu handling things. Now MultipleIndependent does …
dgedon 4f9524f
review changes: comments, missing import, static type check
dgedon 7b140b4
simplify mixed nets (default logtransform set for mnle/mnpe), merge c…
dgedon b475446
remove legacy mnle.py that was not interacted with by users
dgedon 3d7f870
refactor prior transform function (code duplication)
dgedon 07783d2
cleanup cosmetics
dgedon 8b94843
cleanup cosmetics (again)
dgedon 2c35ea3
add accuracy test with MoG and analytic reference posterior
dgedon 03cbde8
revert tutorial file
dgedon 88827a3
old tutorial to remove conflict
dgedon 43d7e52
jupyter merge conflict
dgedon b9356ad
incorporate review comments
dgedon f40a10b
mark gpu test as xfail
dgedon bfaabe8
commenting
dgedon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed | ||
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/> | ||
|
||
from typing import Any, Callable, Dict, Optional, Union | ||
|
||
from torch.distributions import Distribution | ||
|
||
from sbi.inference.posteriors import ( | ||
DirectPosterior, | ||
ImportanceSamplingPosterior, | ||
MCMCPosterior, | ||
RejectionPosterior, | ||
VIPosterior, | ||
) | ||
from sbi.inference.trainers.npe.npe_c import NPE_C | ||
from sbi.neural_nets.estimators import MixedDensityEstimator | ||
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule | ||
from sbi.utils.sbiutils import del_entries | ||
|
||
|
||
class MNPE(NPE_C): | ||
def __init__( | ||
self, | ||
prior: Optional[Distribution] = None, | ||
density_estimator: Union[str, Callable] = "mnpe", | ||
device: str = "cpu", | ||
logging_level: Union[int, str] = "WARNING", | ||
summary_writer: Optional[TensorboardSummaryWriter] = None, | ||
show_progress_bars: bool = True, | ||
): | ||
r"""Mixed Neural Posterior Estimation (MNPE). | ||
Like NPE-C, but designed to be applied to data with mixed types, e.g., | ||
continuous parameters and discrete parameters like they occur in models with | ||
switching components. The emebedding net will only operate on the continuous | ||
parameters, note this to design the dimension of the embedding net. | ||
Args: | ||
prior: A probability distribution that expresses prior knowledge about the | ||
parameters, e.g. which ranges are meaningful for them. If `None`, the | ||
prior must be passed to `.build_posterior()`. | ||
density_estimator: If it is a string, it must be "mnpe" to use the | ||
preconfigured neural nets for MNPE. Alternatively, a function | ||
that builds a custom neural network can be provided. The function will | ||
be called with the first batch of simulations (theta, x), which can | ||
thus be used for shape inference and potentially for z-scoring. It | ||
needs to return a PyTorch `nn.Module` implementing the density | ||
estimator. The density estimator needs to provide the methods | ||
`.log_prob`, `.log_prob_iid()` and `.sample()`. | ||
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". | ||
logging_level: Minimum severity of messages to log. One of the strings | ||
INFO, WARNING, DEBUG, ERROR and CRITICAL. | ||
summary_writer: A tensorboard `SummaryWriter` to control, among others, log | ||
file location (default is `<current working directory>/logs`.) | ||
show_progress_bars: Whether to show a progressbar during simulation and | ||
sampling. | ||
""" | ||
|
||
if isinstance(density_estimator, str): | ||
assert ( | ||
density_estimator == "mnpe" | ||
), f"""MNPE can be used with preconfigured 'mnpe' density estimator only, | ||
not with {density_estimator}.""" | ||
kwargs = del_entries(locals(), entries=("self", "__class__")) | ||
super().__init__(**kwargs) | ||
|
||
def train( | ||
self, | ||
training_batch_size: int = 200, | ||
learning_rate: float = 5e-4, | ||
validation_fraction: float = 0.1, | ||
stop_after_epochs: int = 20, | ||
max_num_epochs: int = 2**31 - 1, | ||
clip_max_norm: Optional[float] = 5.0, | ||
resume_training: bool = False, | ||
discard_prior_samples: bool = False, | ||
retrain_from_scratch: bool = False, | ||
show_train_summary: bool = False, | ||
dataloader_kwargs: Optional[Dict] = None, | ||
) -> MixedDensityEstimator: | ||
density_estimator = super().train( | ||
**del_entries(locals(), entries=("self", "__class__")) | ||
) | ||
assert isinstance( | ||
density_estimator, MixedDensityEstimator | ||
), f"""Internal net must be of type | ||
MixedDensityEstimator but is {type(density_estimator)}.""" | ||
return density_estimator | ||
|
||
def build_posterior( | ||
self, | ||
density_estimator: Optional[TorchModule] = None, | ||
prior: Optional[Distribution] = None, | ||
sample_with: str = "direct", | ||
mcmc_method: str = "slice_np_vectorized", | ||
vi_method: str = "rKL", | ||
direct_sampling_parameters: Optional[Dict[str, Any]] = None, | ||
mcmc_parameters: Optional[Dict[str, Any]] = None, | ||
vi_parameters: Optional[Dict[str, Any]] = None, | ||
rejection_sampling_parameters: Optional[Dict[str, Any]] = None, | ||
importance_sampling_parameters: Optional[Dict[str, Any]] = None, | ||
) -> Union[ | ||
MCMCPosterior, | ||
RejectionPosterior, | ||
VIPosterior, | ||
DirectPosterior, | ||
ImportanceSamplingPosterior, | ||
]: | ||
"""Build posterior from the neural density estimator. | ||
Args: | ||
density_estimator: The density estimator that the posterior is based on. | ||
If `None`, use the latest neural density estimator that was trained. | ||
prior: Prior distribution. | ||
sample_with: Method to use for sampling from the posterior. Must be one of | ||
[`direct` | `mcmc` | `rejection` | `vi` | `importance`]. | ||
mcmc_method: Method used for MCMC sampling, one of `slice_np`, `slice`, | ||
`hmc`, `nuts`. Currently defaults to `slice_np` for a custom numpy | ||
implementation of slice sampling; select `hmc`, `nuts` or `slice` for | ||
Pyro-based sampling. | ||
vi_method: Method used for VI, one of [`rKL`, `fKL`, `IW`, `alpha`]. | ||
direct_sampling_parameters: Additional kwargs passed to `DirectPosterior`. | ||
mcmc_parameters: Additional kwargs passed to `MCMCPosterior`. | ||
vi_parameters: Additional kwargs passed to `VIPosterior`. | ||
rejection_sampling_parameters: Additional kwargs passed to ` | ||
RejectionPosterior`. | ||
importance_sampling_parameters: Additional kwargs passed to | ||
`ImportanceSamplingPosterior`. | ||
Returns: | ||
Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods. | ||
""" | ||
if density_estimator is not None: | ||
assert isinstance( | ||
density_estimator, MixedDensityEstimator | ||
), f"""net must be of type MixedDensityEstimator but is { | ||
type(density_estimator) | ||
}.""" | ||
|
||
return super().build_posterior( | ||
density_estimator=density_estimator, | ||
prior=prior, | ||
sample_with=sample_with, | ||
mcmc_method=mcmc_method, | ||
vi_method=vi_method, | ||
direct_sampling_parameters=direct_sampling_parameters, | ||
mcmc_parameters=mcmc_parameters, | ||
vi_parameters=vi_parameters, | ||
rejection_sampling_parameters=rejection_sampling_parameters, | ||
importance_sampling_parameters=importance_sampling_parameters, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -11,10 +11,12 @@ | |||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
class MixedDensityEstimator(ConditionalDensityEstimator): | ||||||||||||||||||||
"""Class performing Mixed Neural Likelihood Estimation. | ||||||||||||||||||||
"""Class performing Mixed Neural Density Estimation. | ||||||||||||||||||||
janfb marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
|
||||||||||||||||||||
MNLE combines a Categorical net and a neural density estimator to model data | ||||||||||||||||||||
with mixed types, e.g., as they occur in decision-making models. | ||||||||||||||||||||
This estimator combines a Categorical net and a neural density estimator to model | ||||||||||||||||||||
data with mixed types (discrete and continuous), e.g., as they occur in | ||||||||||||||||||||
decision-making models. It can be used for both likelihood and posterior estimation | ||||||||||||||||||||
of mixed data. | ||||||||||||||||||||
|
This estimator combines a Categorical net and a neural density estimator to model | |
data with mixed types (discrete and continuous), e.g., as they occur in | |
decision-making models. It can be used for both likelihood and posterior estimation | |
of mixed data. | |
This estimator combines a categorical mass estimator and a density estimator to model | |
variables with mixed types (discrete and continuous). It can be used for both likelihood | |
estimation (e.g., for discrete decisions and continuous reaction times in decision-making | |
models) or posterior estimation (e.g., for models that have both discrete and continuous | |
parameters). |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggested change
estimation,use '.sample(...)' to generate samples though a forward | |
estimation, use '.sample(...)' to generate samples though a forward |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.