-
Notifications
You must be signed in to change notification settings - Fork 197
Score-based iid sampling #1381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Score-based iid sampling #1381
Changes from 26 commits
5350e58
58fd7b8
74cfad4
857399d
bcea468
843ce7d
384f36f
016c5a7
9152d28
df1f30d
85bf355
5195417
5bcf427
09f0113
ad240b7
44e08f2
d29ebf8
b0b8b41
0d0991b
f22467d
6cbe5ae
97a8dda
5349747
0539b62
bfe6df2
fd0f964
904c8dc
9cbfcce
b5bae5f
d680a08
6a343e7
2032a53
1de5fcf
2f02c2e
f3e3d6d
8b8179f
9fa0c41
f4c578d
8fc7ec6
9cf38bb
8ce68cb
7aa36c3
05c8e8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
from zuko.transforms import FreeFormJacobianTransform | ||
|
||
from sbi.inference.potentials.base_potential import BasePotential | ||
from sbi.inference.potentials.score_fn_iid import get_iid_method | ||
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator | ||
from sbi.neural_nets.estimators.shape_handling import ( | ||
reshape_to_batch_event, | ||
|
@@ -57,7 +58,8 @@ def __init__( | |
score_estimator: ConditionalScoreEstimator, | ||
prior: Optional[Distribution], | ||
x_o: Optional[Tensor] = None, | ||
iid_method: str = "iid_bridge", | ||
iid_method: str = "auto_gauss", | ||
iid_params: Optional[dict] = None, | ||
manuelgloeckler marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
device: str = "cpu", | ||
): | ||
r"""Returns the score function for score-based methods. | ||
|
@@ -66,30 +68,37 @@ def __init__( | |
score_estimator: The neural network modelling the score. | ||
prior: The prior distribution. | ||
x_o: The observed data at which to evaluate the posterior. | ||
iid_method: Which method to use for computing the score. Currently, only | ||
`iid_bridge` as proposed in Geffner et al. is implemented. | ||
iid_method: Which method to use for computing the score in the iid setting. | ||
We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss". | ||
iid_params: Parameters for the iid method, for arguments see ScoreFnIID. | ||
device: The device on which to evaluate the potential. | ||
""" | ||
self.score_estimator = score_estimator | ||
self.score_estimator.eval() | ||
self.iid_method = iid_method | ||
self.iid_params = iid_params | ||
super().__init__(prior, x_o, device=device) | ||
|
||
def set_x( | ||
self, | ||
x_o: Optional[Tensor], | ||
x_is_iid: Optional[bool] = False, | ||
iid_method: str = "auto_gauss", | ||
iid_params: Optional[dict] = None, | ||
manuelgloeckler marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
rebuild_flow: Optional[bool] = True, | ||
): | ||
""" | ||
Set the observed data and whether it is IID. | ||
Args: | ||
x_o: The observed data. | ||
x_is_iid: Whether the observed data is IID (if batch_dim>1). | ||
rebuild_flow: Whether to save (overwrrite) a low-tolerance flow model, useful if | ||
the flow needs to be evaluated many times (e.g. for MAP calculation). | ||
x_o: The observed data. | ||
x_is_iid: Whether the observed data is IID (if batch_dim>1). | ||
rebuild_flow: Whether to save (overwrrite) a low-tolerance flow model, | ||
useful if the flow needs to be evaluated many times | ||
(e.g. for MAP calculation). | ||
manuelgloeckler marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
""" | ||
super().set_x(x_o, x_is_iid) | ||
self.iid_method = iid_method | ||
self.iid_params = iid_params | ||
if rebuild_flow and self._x_o is not None: | ||
# By default, we want a high-tolerance flow. | ||
# This flow will be used mainly for MAP calculations, hence we want to save | ||
|
@@ -172,10 +181,16 @@ def gradient( | |
input=theta, condition=self.x_o, time=time | ||
) | ||
else: | ||
raise NotImplementedError( | ||
"Score accumulation for IID data is not yet implemented." | ||
assert self.prior is not None, "Prior is required for iid methods." | ||
|
||
method_iid = get_iid_method(self.iid_method) | ||
manuelgloeckler marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
# Always creating a new object every call is not efficient... | ||
manuelgloeckler marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
score_fn_iid = method_iid( | ||
self.score_estimator, self.prior, **(self.iid_params or {}) | ||
) | ||
|
||
score = score_fn_iid(theta, self.x_o, time) # type: ignore | ||
manuelgloeckler marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
return score | ||
|
||
def get_continuous_normalizing_flow( | ||
|
@@ -217,9 +232,6 @@ def rebuild_flow( | |
x_density_estimator = reshape_to_batch_event( | ||
self.x_o, event_shape=self.score_estimator.condition_shape | ||
) | ||
assert x_density_estimator.shape[0] == 1, ( | ||
|
||
"PosteriorScoreBasedPotential supports only x batchsize of 1`." | ||
) | ||
|
||
flow = self.get_continuous_normalizing_flow( | ||
condition=x_density_estimator, atol=atol, rtol=rtol, exact=exact | ||
|
Uh oh!
There was an error while loading. Please reload this page.