1
1
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
2
2
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
3
3
4
- from typing import Callable , Optional , Tuple
4
+ import warnings
5
+ from typing import Callable , List , Optional , Tuple
5
6
6
7
import torch
7
8
from torch import Tensor
@@ -115,6 +116,54 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
115
116
)
116
117
return log_likelihood_batches + self .prior .log_prob (theta ) # type: ignore
117
118
119
+ def condition_on_theta (
120
+ self , local_theta : Tensor , dims_global_theta : List [int ]
121
+ ) -> Callable :
122
+ r"""Returns a potential function conditioned on a subset of theta dimensions.
123
+
124
+ The goal of this function is to divide the original `theta` into a
125
+ `global_theta` we do inference over, and a `local_theta` we condition on (in
126
+ addition to conditioning on `x_o`). Thus, the returned potential function will
127
+ calculate $\prod_{i=1}^{N}p(x_i | local_theta_i, \global_theta)$, where `x_i`
128
+ and `local_theta_i` are fixed and `global_theta` varies at inference time.
129
+
130
+ Args:
131
+ local_theta: The condition values to be conditioned.
132
+ dims_global_theta: The indices of the columns in `theta` that will be
133
+ sampled, i.e., that *not* conditioned. For example, if original theta
134
+ has shape `(batch_dim, 3)`, and `dims_global_theta=[0, 1]`, then the
135
+ potential will set `theta[:, 3] = local_theta` at inference time.
136
+
137
+ Returns:
138
+ A potential function conditioned on the `local_theta`.
139
+ """
140
+
141
+ assert self .x_is_iid , "Conditioning is only supported for iid data."
142
+
143
+ def conditioned_potential (
144
+ theta : Tensor , x_o : Optional [Tensor ] = None , track_gradients : bool = True
145
+ ) -> Tensor :
146
+ assert (
147
+ len (dims_global_theta ) == theta .shape [1 ]
148
+ ), "dims_global_theta must match the number of parameters to sample."
149
+ global_theta = theta [:, dims_global_theta ]
150
+ x_o = x_o if x_o is not None else self .x_o
151
+ # x needs shape (sample_dim (iid), batch_dim (xs), *event_shape)
152
+ if x_o .dim () < 3 :
153
+ x_o = reshape_to_sample_batch_event (
154
+ x_o , event_shape = x_o .shape [1 :], leading_is_sample = self .x_is_iid
155
+ )
156
+
157
+ return _log_likelihood_over_iid_trials_and_local_theta (
158
+ x = x_o ,
159
+ global_theta = global_theta ,
160
+ local_theta = local_theta ,
161
+ estimator = self .likelihood_estimator ,
162
+ track_gradients = track_gradients ,
163
+ )
164
+
165
+ return conditioned_potential
166
+
118
167
119
168
def _log_likelihoods_over_trials (
120
169
x : Tensor ,
@@ -172,6 +221,77 @@ def _log_likelihoods_over_trials(
172
221
return log_likelihood_trial_sum
173
222
174
223
224
+ def _log_likelihood_over_iid_trials_and_local_theta (
225
+ x : Tensor ,
226
+ global_theta : Tensor ,
227
+ local_theta : Tensor ,
228
+ estimator : ConditionalDensityEstimator ,
229
+ track_gradients : bool = False ,
230
+ ) -> Tensor :
231
+ """Returns $\\ prod_{i=1}^N \\ log(p(x_i|\t heta, local_theta_i)$.
232
+
233
+ `x` is a batch of iid data, and `local_theta` is a matching batch of condition
234
+ values that were part of `theta` but are treated as local iid variables at inference
235
+ time.
236
+
237
+ This function is different from `_log_likelihoods_over_trials` in that it moves the
238
+ iid batch dimension of `x` onto the batch dimension of `theta`. This is needed when
239
+ the likelihood estimator is conditioned on a batch of conditions that are iid with
240
+ the batch of `x`. It avoids the evaluation of the likelihood for every combination
241
+ of `x` and `local_theta`.
242
+
243
+ Args:
244
+ x: data with shape `(sample_dim, x_batch_dim, *x_event_shape)`, where sample_dim
245
+ holds the i.i.d. trials and batch_dim holds a batch of xs, e.g., non-iid
246
+ observations.
247
+ global_theta: Batch of parameters `(theta_batch_dim,
248
+ num_parameters)`.
249
+ local_theta: Batch of conditions of shape `(sample_dim, num_local_thetas)`, must
250
+ match x's `sample_dim`.
251
+ estimator: DensityEstimator.
252
+ track_gradients: Whether to track gradients.
253
+
254
+ Returns:
255
+ log_likelihood: log likelihood for each x in x_batch_dim, for each theta in
256
+ theta_batch_dim, summed over all iid trials. Shape `(x_batch_dim,
257
+ theta_batch_dim)`.
258
+ """
259
+ assert x .dim () > 2 , "x must have shape (sample_dim, batch_dim, *event_shape)."
260
+ assert (
261
+ local_theta .dim () == 2
262
+ ), "condition must have shape (sample_dim, num_conditions)."
263
+ assert global_theta .dim () == 2 , "theta must have shape (batch_dim, num_parameters)."
264
+ num_trials , num_xs = x .shape [:2 ]
265
+ num_thetas = global_theta .shape [0 ]
266
+ assert (
267
+ local_theta .shape [0 ] == num_trials
268
+ ), "Condition batch size must match the number of iid trials in x."
269
+
270
+ # move the iid batch dimension onto the batch dimension of theta and repeat it there
271
+ x_repeated = torch .transpose (x , 0 , 1 ).repeat_interleave (num_thetas , dim = 1 )
272
+
273
+ # construct theta and condition to cover all trial-theta combinations
274
+ theta_with_condition = torch .cat (
275
+ [
276
+ global_theta .repeat (num_trials , 1 ), # repeat ABAB
277
+ local_theta .repeat_interleave (num_thetas , dim = 0 ), # repeat AABB
278
+ ],
279
+ dim = - 1 ,
280
+ )
281
+
282
+ with torch .set_grad_enabled (track_gradients ):
283
+ # Calculate likelihood in one batch. Returns (1, num_trials * num_theta)
284
+ log_likelihood_trial_batch = estimator .log_prob (
285
+ x_repeated , condition = theta_with_condition
286
+ )
287
+ # Reshape to (x-trials x parameters), sum over trial-log likelihoods.
288
+ log_likelihood_trial_sum = log_likelihood_trial_batch .reshape (
289
+ num_xs , num_trials , num_thetas
290
+ ).sum (1 )
291
+
292
+ return log_likelihood_trial_sum
293
+
294
+
175
295
def mixed_likelihood_estimator_based_potential (
176
296
likelihood_estimator : MixedDensityEstimator ,
177
297
prior : Distribution ,
@@ -192,6 +312,13 @@ def mixed_likelihood_estimator_based_potential(
192
312
to unconstrained space.
193
313
"""
194
314
315
+ warnings .warn (
316
+ "This function is deprecated and will be removed in a future release. Use "
317
+ "`likelihood_estimator_based_potential` instead." ,
318
+ DeprecationWarning ,
319
+ stacklevel = 2 ,
320
+ )
321
+
195
322
device = str (next (likelihood_estimator .discrete_net .parameters ()).device )
196
323
197
324
potential_fn = MixedLikelihoodBasedPotential (
@@ -212,6 +339,13 @@ def __init__(
212
339
):
213
340
super ().__init__ (likelihood_estimator , prior , x_o , device )
214
341
342
+ warnings .warn (
343
+ "This function is deprecated and will be removed in a future release. Use "
344
+ "`LikelihoodBasedPotential` instead." ,
345
+ DeprecationWarning ,
346
+ stacklevel = 2 ,
347
+ )
348
+
215
349
def __call__ (self , theta : Tensor , track_gradients : bool = True ) -> Tensor :
216
350
prior_log_prob = self .prior .log_prob (theta ) # type: ignore
217
351
@@ -231,7 +365,6 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
231
365
with torch .set_grad_enabled (track_gradients ):
232
366
# Call the specific log prob method of the mixed likelihood estimator as
233
367
# this optimizes the evaluation of the discrete data part.
234
- # TODO log_prob_iid
235
368
log_likelihood_trial_batch = self .likelihood_estimator .log_prob (
236
369
input = x ,
237
370
condition = theta .to (self .device ),
0 commit comments