1
1
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
2
2
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
3
3
4
- from functools import partial
5
4
from typing import Optional , Tuple
6
5
7
6
import torch
@@ -79,41 +78,36 @@ def set_x(
79
78
self ,
80
79
x_o : Optional [Tensor ],
81
80
x_is_iid : Optional [bool ] = False ,
82
- rebuild_flow : Optional [bool ] = True ,
81
+ atol : float = 1e-5 ,
82
+ rtol : float = 1e-6 ,
83
+ exact : bool = True ,
83
84
):
84
85
"""
85
86
Set the observed data and whether it is IID.
87
+
88
+ Rebuids the continuous normalizing flow if the observed data is set.
89
+
86
90
Args:
87
- x_o: The observed data.
88
- x_is_iid: Whether the observed data is IID (if batch_dim>1).
89
- rebuild_flow: Whether to save (overwrrite) a low-tolerance flow model, useful if
90
- the flow needs to be evaluated many times (e.g. for MAP calculation).
91
+ x_o: The observed data.
92
+ x_is_iid: Whether the observed data is IID (if batch_dim>1).
93
+ atol: Absolute tolerance for the ODE solver.
94
+ rtol: Relative tolerance for the ODE solver.
95
+ exact: Whether to use the exact ODE solver.
91
96
"""
92
97
super ().set_x (x_o , x_is_iid )
93
- if rebuild_flow and self ._x_o is not None :
94
- # By default, we want a high-tolerance flow.
95
- # This flow will be used mainly for MAP calculations, hence we want to save
96
- # it instead of rebuilding it every time.
97
- self .flow = self .rebuild_flow (atol = 1e-2 , rtol = 1e-3 , exact = True )
98
+ if self ._x_o is not None :
99
+ self .flow = self .rebuild_flow (atol = atol , rtol = rtol , exact = exact )
98
100
99
101
def __call__ (
100
102
self ,
101
103
theta : Tensor ,
102
104
track_gradients : bool = True ,
103
- rebuild_flow : bool = True ,
104
- atol : float = 1e-5 ,
105
- rtol : float = 1e-6 ,
106
- exact : bool = True ,
107
105
) -> Tensor :
108
106
"""Return the potential (posterior log prob) via probability flow ODE.
109
107
110
108
Args:
111
109
theta: The parameters at which to evaluate the potential.
112
110
track_gradients: Whether to track gradients.
113
- rebuild_flow: Whether to rebuild the CNF for accurate log_prob evaluation.
114
- atol: Absolute tolerance for the ODE solver.
115
- rtol: Relative tolerance for the ODE solver.
116
- exact: Whether to use the exact ODE solver.
117
111
118
112
Returns:
119
113
The potential function, i.e., the log probability of the posterior.
@@ -123,15 +117,9 @@ def __call__(
123
117
theta , theta .shape [1 :], leading_is_sample = True
124
118
)
125
119
self .score_estimator .eval ()
126
- # use rebuild_flow to evaluate log_prob with better precision, without
127
- # overwriting self.flow
128
- if rebuild_flow or self .flow is None :
129
- flow = self .rebuild_flow (atol = atol , rtol = rtol , exact = exact )
130
- else :
131
- flow = self .flow
132
120
133
121
with torch .set_grad_enabled (track_gradients ):
134
- log_probs = flow .log_prob (theta_density_estimator ).squeeze (- 1 )
122
+ log_probs = self . flow .log_prob (theta_density_estimator ).squeeze (- 1 )
135
123
# Force probability to be zero outside prior support.
136
124
in_prior_support = within_support (self .prior , theta )
137
125
@@ -217,7 +205,7 @@ def rebuild_flow(
217
205
x_density_estimator = reshape_to_batch_event (
218
206
self .x_o , event_shape = self .score_estimator .condition_shape
219
207
)
220
- assert x_density_estimator .shape [0 ] == 1 , (
208
+ assert x_density_estimator .shape [0 ] == 1 or not self . x_is_iid , (
221
209
"PosteriorScoreBasedPotential supports only x batchsize of 1`."
222
210
)
223
211
@@ -312,9 +300,8 @@ def __init__(self, posterior_score_based_potential):
312
300
self .posterior_score_based_potential = posterior_score_based_potential
313
301
314
302
def __call__ (self , input ):
315
- prepared_potential = partial (
316
- self .posterior_score_based_potential .__call__ , rebuild_flow = False
317
- )
318
303
return DifferentiablePotentialFunction .apply (
319
- input , prepared_potential , self .posterior_score_based_potential .gradient
304
+ input ,
305
+ self .posterior_score_based_potential .__call__ ,
306
+ self .posterior_score_based_potential .gradient ,
320
307
)
0 commit comments