Skip to content
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
e503c88
logit transform changes
Mar 17, 2025
f488f86
Merge branch 'sbi-dev:main' into logit_transform
anastasiakrouglova Mar 18, 2025
3737ee6
Merge branch 'sbi-dev:main' into logit_transform
anastasiakrouglova Mar 18, 2025
8be9079
add new ZScoreTypes
Mar 18, 2025
23bd180
add new ZScoreTypes
Mar 18, 2025
162b976
resolving bug z_scoring last year
Mar 19, 2025
c5772d2
resolving bug z_scoring last year
Mar 19, 2025
b279eb0
Merge branch 'sbi-dev:main' into logit_transform
anastasiakrouglova Mar 19, 2025
993efb7
revert z_score parser
Mar 19, 2025
443b8ef
Merge branch 'sbi-dev:main' into logit_transform
anastasiakrouglova Mar 19, 2025
862f3d2
adjusted logit structure in build_zuko_flow
Mar 19, 2025
bccd82b
resolve pyright error
Mar 19, 2025
3e3a8d5
Merge branch 'main' into logit_transform
anastasiakrouglova Mar 19, 2025
145ef4e
revert flow as a test
Mar 20, 2025
5a26157
add x_dist variable
Mar 20, 2025
0dd3baa
add logit to sbiutils_test.py
Mar 20, 2025
000c123
add logit if statement
Mar 20, 2025
b3bc54b
add logit if statement
Mar 20, 2025
fa80559
add logit if statement
Mar 20, 2025
1de1a98
add logit if statement
Mar 20, 2025
e2007af
remove logit if statement
Mar 20, 2025
1603756
resolve pyright issues
Mar 20, 2025
1ffd0e9
cover logit in tests
Mar 20, 2025
73af5ac
Merge branch 'sbi-dev:main' into logit_transform
anastasiakrouglova Mar 20, 2025
9cc887b
cover tests for logit in flow.py
Mar 20, 2025
9573905
cover tests for CNF
Mar 20, 2025
adff499
adding faq for logit transformation
Mar 20, 2025
8176707
adding faq for logit transformation
Mar 20, 2025
12c4c85
Merge branch 'sbi-dev:main' into logit_transform
anastasiakrouglova Mar 21, 2025
69502d7
stash changes
Mar 21, 2025
d617030
feedback guy adjustments
Mar 21, 2025
9037534
add documentation if statements
Mar 21, 2025
db361a0
update sbiutils
Jun 10, 2025
d660631
resolve comment 1 and 2 of Jan
Jun 10, 2025
9332583
ruff linted push
Jun 10, 2025
7dcf919
cleanup density_estimator_test.py
Jun 10, 2025
bbca1ed
cleanup density_estimator_test.py and ruff check
Jun 10, 2025
7142ccb
Merge branch 'sbi-dev:main' into logit_transform
anastasiakrouglova Jul 1, 2025
b09ebb7
adjusted docstrings
Jul 1, 2025
8c029f4
add tests convergence unconstrained space
Jul 2, 2025
d63af48
adjust faq
Jul 2, 2025
9fb371a
add test snle
Jul 2, 2025
be06d58
adjust linear gaussian and estimate c2st
Jul 2, 2025
c732ded
adjust documentation
Jul 28, 2025
021d67f
Update sbi/neural_nets/net_builders/flow.py
anastasiakrouglova Jul 28, 2025
d375df9
Update sbi/utils/sbiutils.py
anastasiakrouglova Jul 28, 2025
bd0c055
Update sbi/utils/sbiutils.py
anastasiakrouglova Jul 28, 2025
3908896
add literal import to sbi utils
Jul 29, 2025
33129d8
adjust literals and add get_transform_to_unconstrained
Jul 29, 2025
580fd50
adjust literals and format
Jul 29, 2025
dfd9948
add new line for ruff
Jul 29, 2025
d6dca31
stying ruff
Jul 29, 2025
b8cd4da
stying ruff
Jul 29, 2025
e528737
fix flow builder z-score defaults.
janfb Jul 31, 2025
0665400
refactor zuko flow build functions
janfb Jul 31, 2025
4d4bfb8
re-use y-embedding helper function.
janfb Jul 31, 2025
8946a09
fix typing
janfb Jul 31, 2025
89f990c
Merge branch 'main' into logit_transform
janfb Jul 31, 2025
340ba70
small fixes.
janfb Jul 31, 2025
c4aa2d1
fix unconstrained nle test
janfb Jul 31, 2025
1ee797d
refactor z-score-parser test
janfb Jul 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions docs/faq/question_08_unconstrained.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# What if almost all posterior samples lie outside the prior bounds for some conditionals?

If you've encountered the following warning:

```
WARNING:root:Only 0.002% proposal samples were accepted. It
may take a long time to collect the remaining 99980
samples. Consider interrupting (Ctrl-C) and switching to a
different sampling method with
`build_posterior(..., sample_with='mcmc')`. or
`build_posterior(..., sample_with='vi')`.
```

this indicates that a significant portion of the samples proposed by the density estimator fall outside the prior bounds. Several factors might be causing this issue:

1) Simulator Issues: Ensure that your simulator is functioning as expected and producing realistic outputs.
2) Insufficient Training Data: If the density estimator has been trained on too few simulations, it may lead to invalid estimations.
3) Problematic True Data: Check if there are inconsistencies or unexpected values in the observed data.


### Possible solutions

If you've ruled out these issues, you can try training your density estimator in an unbounded space using a logit transformation. This transformation maps your data to logit space before training and then applies the inverse logit (sigmoid function) to ensure that the trained density estimator remains within the prior bounds.

Instead of standardizing parameters using z-scoring, you can use the logit transformation. However, this requires providing a density estimation. The specific approach depends on the method you're using:

- For NPE (Neural Posterior Estimation): You can simply use the prior as the density estimation.
- For NLE/NRE (Neural Likelihood Estimation / Neural Ratio Estimation): A rough density approximation over data boundaries is needed, making the process more complex.


### What do I do if my data is highly nonlinear?

Therefore, you can enable the `transform_to_unconstrained` transformation when defining your density estimator, use:

```
density_estimator_build_fun = posterior_nn(
model="zuko_nsf", hidden_features=60, num_transforms=3, z_score_theta="transform_to_unconstrained", x_dist=prior
)
inference = NPE_C(prior, density_estimator=density_estimator_build_fun)
```
This ensures that your density estimator operates in a transformed space where it respects prior bounds, improving the efficiency of rejection sampling.

Note: The logit transformation is currently only supported for `zuko` density estimators (e.g., `zuko_nsf` and `zuko_maf`).
66 changes: 64 additions & 2 deletions sbi/neural_nets/net_builders/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
rational_quadratic, # pyright: ignore[reportAttributeAccessIssue]
)
from torch import Tensor, nn, relu, tanh, tensor, uint8
from torch.distributions import Distribution

from sbi.neural_nets.estimators import NFlowsFlow, ZukoFlow, ZukoUnconditionalFlow
from sbi.utils.nn_utils import MADEMoGWrapper, get_numel
from sbi.utils.sbiutils import (
biject_transform_zuko,
mcmc_transform,
standardizing_net,
standardizing_transform,
standardizing_transform_zuko,
Expand Down Expand Up @@ -1009,11 +1012,20 @@
hidden_features: Union[Sequence[int], int] = 50,
num_transforms: int = 5,
embedding_net: nn.Module = nn.Identity(),
x_dist: Optional[Distribution] = None,
**kwargs,
) -> ZukoFlow:
"""
Fundamental building blocks to build a Zuko normalizing flow model.
The following cases are considered in the if statements down below:
z_score_x is `independent, `structured` or None, in which case we just use
the normal standardizing transform.
z_score_x is `transform_to_unconstrained`, in this case, we check if `x_dist` is
provided and has a support property. If `x_dist` is not valid (i.e. None
or has no support property), we raise an error.
Args:
which_nf (str): The type of normalizing flow to build.
batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring.
Expand All @@ -1024,11 +1036,21 @@
- `structured`: treat dimensions as related, therefore compute mean and std
over the entire batch, instead of per-dimension. Should be used when each
sample is, for example, a time series or an image.
- `transform_to_unconstrained`: Transforms to
an unbound space, if bounds from `x_dist` are given.
z_score_y: Whether to z-score ys passing into the network, same options as
z_score_x.
hidden_features: The number of hidden features in the flow. Defaults to 50.
num_transforms: The number of transformations in the flow. Defaults to 5.
embedding_net: The embedding network to use. Defaults to nn.Identity().
x_dist: The distribution over x, used to determine the bounds for the
unconstrained transformation.
- In Neural Posterior Estimation (NPE), `x_dist` typically corresponds
to the prior over x (e.g., a `BoxUniform`).
- For Neural Likelihood Estimation (NLE) or Neural Ratio Estimation (NRE),
`x_dist` may instead be a user-specified distribution. However, make sure
all the data lies within the support of the distribution if you want to
use the `transform_to_unconstrained` option for NLE and NRE.
**kwargs: Additional keyword arguments to pass to the flow constructor.
Returns:
Expand Down Expand Up @@ -1066,7 +1088,28 @@
transform = flow_built.transform

z_score_x_bool, structured_x = z_score_parser(z_score_x)
if z_score_x_bool:

# Only x (i.e., prior for NPE) can be transformed to unbound space (not y)
# when x_dist is provided.
if z_score_x == "transform_to_unconstrained":
if x_dist is None:
raise ValueError(
"Transformation to unconstrained space requires a distribution "
"provided through `x_dist`."
)
elif not hasattr(x_dist, "support"):
raise ValueError(

Check warning on line 1101 in sbi/neural_nets/net_builders/flow.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/flow.py#L1101

Added line #L1101 was not covered by tests
"`x_dist` requires a `.support` attribute for"
"an unconstrained transformation."
)
else:
transform_to_unconstrained = mcmc_transform(x_dist)
transform = (
biject_transform_zuko(transform_to_unconstrained),
transform,
)

elif z_score_x_bool:

Check warning on line 1112 in sbi/neural_nets/net_builders/flow.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/flow.py#L1112

Added line #L1112 was not covered by tests
transform = (
standardizing_transform_zuko(batch_x, structured_x),
transform,
Expand All @@ -1085,7 +1128,26 @@
transforms = flow_built.transform.transforms

z_score_x_bool, structured_x = z_score_parser(z_score_x)
if z_score_x_bool:

if z_score_x == "transform_to_unconstrained":
if x_dist is None:
raise ValueError(
"Transformation to unconstrained space requires a distribution "
"provided through `x_dist`."
)
elif not hasattr(x_dist, "support"):
raise ValueError(

Check warning on line 1139 in sbi/neural_nets/net_builders/flow.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/flow.py#L1139

Added line #L1139 was not covered by tests
"`x_dist` requires a `.support` attribute for"
"an unconstrained transformation."
)
else:
transform_to_unconstrained = mcmc_transform(x_dist)
transforms = (
biject_transform_zuko(transform_to_unconstrained),
*transforms,
)

elif z_score_x_bool:
transforms = (
standardizing_transform_zuko(batch_x, structured_x),
*transforms,
Expand Down
51 changes: 47 additions & 4 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import random
import warnings
from math import pi
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)

import numpy as np
import pyknos.nflows.transforms as nflows_tf
Expand Down Expand Up @@ -101,7 +111,7 @@
return clamped_val


def z_score_parser(z_score_flag: Optional["str"]) -> Tuple[bool, bool]:
def z_score_parser(z_score_flag: Optional[str]) -> Tuple[bool, bool]:
"""Parses string z-score flag into booleans.
Converts string flag into booleans denoting whether to z-score or not, and whether
Expand Down Expand Up @@ -133,11 +143,15 @@
# Got one of two valid z-scoring methods.
z_score_bool = True
structured_data = z_score_flag == "structured"

elif z_score_flag == "transform_to_unconstrained":
# Dependent on the distribution, the biject_to function
# will provide e.g., a logit, exponential of z-scored distribution.
z_score_bool, structured_data = False, False
else:
# Return warning due to invalid option, defaults to not z-scoring.
raise ValueError(
"Invalid z-scoring option. Use 'none', 'independent', or 'structured'."
"Invalid z-scoring option. Use 'none', 'independent'"
"'structured' or 'transform_to_unconstrained'."
)

return z_score_bool, structured_data
Expand Down Expand Up @@ -197,6 +211,35 @@
)


class CallableTransform:
"""Wraps a PyTorch Transform to be used in Zuko UnconditionalTransform."""

def __init__(self, transform):
self.transform = transform

def __call__(self):
return self.transform

Check warning on line 221 in sbi/utils/sbiutils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/sbiutils.py#L221

Added line #L221 was not covered by tests


def biject_transform_zuko(
transform,
) -> zuko.flows.UnconditionalTransform:
"""
Wraps a pytorch transform in a Zuko unconditional transfrom on a bounded interval.
Args:
transform: a bijective transformation for Zuko, depending on the input
(e.g., logit, exponential or z-scored)
Returns:
Zuko bijective transformation
"""
return zuko.flows.UnconditionalTransform(
CallableTransform(transform),
buffer=True,
)


def z_standardization(
batch_t: Tensor,
structured_dims: bool = False,
Expand Down
52 changes: 51 additions & 1 deletion tests/density_estimator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import pytest
import torch
from torch import eye, zeros
from torch.distributions import MultivariateNormal
from torch.distributions import HalfNormal, MultivariateNormal

from sbi.neural_nets.embedding_nets import CNNEmbedding
from sbi.neural_nets.estimators.shape_handling import reshape_to_sample_batch_event
from sbi.neural_nets.estimators.zuko_flow import ZukoFlow
from sbi.neural_nets.net_builders import (
build_categoricalmassestimator,
build_made,
Expand All @@ -34,6 +35,8 @@
build_zuko_sospf,
build_zuko_unaf,
)
from sbi.neural_nets.net_builders.flow import build_zuko_flow
from sbi.utils.torchutils import BoxUniform

# List of all density estimator builders for testing.
model_builders = [
Expand Down Expand Up @@ -462,3 +465,50 @@ def test_mixed_density_estimator(
# Test samples
samples = density_estimator.sample(sample_shape, condition=conditions)
assert samples.shape == (*sample_shape, batch_dim, *input_event_shape)


@pytest.mark.parametrize("which_nf", ["MAF", "CNF"])
@pytest.mark.parametrize(
"x_dist",
[
BoxUniform(low=-2 * torch.ones(5), high=2 * torch.ones(5)),
HalfNormal(scale=torch.ones(1) * 2),
MultivariateNormal(loc=zeros(5), covariance_matrix=eye(5)),
],
)
def test_build_zuko_flow_with_valid_unconstrained_transform(which_nf, x_dist):
"""Test that ZukoFlow builds successfully with valid `x_dist`."""
# input dimension is 5
batch_x = torch.randn(10, 5)
batch_y = torch.randn(10, 3)

# Test case where x_dist is provided (should not raise an error)
flow = build_zuko_flow(
which_nf=which_nf,
batch_x=batch_x,
batch_y=batch_y,
z_score_x="transform_to_unconstrained",
z_score_y="transform_to_unconstrained",
x_dist=x_dist,
)
assert isinstance(flow, ZukoFlow)


@pytest.mark.parametrize("which_nf", ["MAF", "CNF"])
def test_build_zuko_flow_missing_x_dist_raises_error(which_nf):
"""Test that ValueError is raised if `x_dist` is None when required."""
batch_x = torch.randn(10, 5)
batch_y = torch.randn(10, 3)

with pytest.raises(
ValueError,
match=r".*distribution.*x_dist.*",
):
build_zuko_flow(
which_nf=which_nf,
batch_x=batch_x,
batch_y=batch_y,
z_score_x="transform_to_unconstrained",
z_score_y="transform_to_unconstrained",
x_dist=None, # No distribution provided
)
Loading