@@ -116,20 +116,28 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
116
116
)
117
117
return log_likelihood_batches + self .prior .log_prob (theta ) # type: ignore
118
118
119
- def condition_on (self , condition : Tensor , dims_to_sample : List [int ]) -> Callable :
120
- """Returns a potential conditioned on a subset of theta dimensions.
119
+ def condition_on_theta (
120
+ self , theta_condition : Tensor , dims_to_sample : List [int ]
121
+ ) -> Callable :
122
+ """Returns a potential function conditioned on a subset of theta dimensions.
121
123
122
124
The condition is a part of theta, but is assumed to correspond to a batch of iid
123
- x_o.
125
+ x_o. For example, it can be a batch of experimental conditions that corresponds
126
+ to a batch of i.i.d. trials in x_o.
124
127
125
128
Args:
126
- condition: The condition to fix.
127
- dims_to_sample: The indices of the parameters to sample.
129
+ theta_condition: The condition values to be conditioned.
130
+ dims_to_sample: The indices of the columns in theta that will be sampled,
131
+ i.e., that *not* conditioned. For example, if original theta has shape
132
+ `(batch_dim, 3)`, and `dims_to_sample=[0, 1]`, then the potential will
133
+ set `theta[:, 3] = theta_condition` at inference time.
128
134
129
135
Returns:
130
- A potential function conditioned on the condition .
136
+ A potential function conditioned on the theta_condition .
131
137
"""
132
138
139
+ assert self .x_is_iid , "Conditioning is only supported for iid data."
140
+
133
141
def conditioned_potential (
134
142
theta : Tensor , x_o : Optional [Tensor ] = None , track_gradients : bool = True
135
143
) -> Tensor :
@@ -138,10 +146,10 @@ def conditioned_potential(
138
146
), "dims_to_sample must match the number of parameters to sample."
139
147
theta_without_condition = theta [:, dims_to_sample ]
140
148
141
- return _log_likelihood_with_iid_condition (
149
+ return _log_likelihood_over_iid_conditions (
142
150
x = x_o if x_o is not None else self .x_o ,
143
151
theta_without_condition = theta_without_condition ,
144
- condition = condition ,
152
+ condition = theta_condition ,
145
153
estimator = self .likelihood_estimator ,
146
154
track_gradients = track_gradients ,
147
155
)
@@ -205,63 +213,75 @@ def _log_likelihoods_over_trials(
205
213
return log_likelihood_trial_sum
206
214
207
215
208
- def _log_likelihood_with_iid_condition (
216
+ def _log_likelihood_over_iid_conditions (
209
217
x : Tensor ,
210
218
theta_without_condition : Tensor ,
211
219
condition : Tensor ,
212
220
estimator : ConditionalDensityEstimator ,
213
221
track_gradients : bool = False ,
214
222
) -> Tensor :
215
- """Return log likelihoods summed over iid trials of `x` with a matching batch of
216
- conditions.
223
+ """Returns $ \\ log(p(x_o| \t heta, condition)$, where x_o is a batch of iid data, and
224
+ condition is a matching batch of conditions.
217
225
218
226
This function is different from `_log_likelihoods_over_trials` in that it moves the
219
- iid batch dimension of `x` onto the batch dimension of `theta`. This is useful when
227
+ iid batch dimension of `x` onto the batch dimension of `theta`. This is needed when
220
228
the likelihood estimator is conditioned on a batch of conditions that are iid with
221
- the batch of `x`. It avoid the evaluation of the likelihood for every combination of
222
- `x` and `condition`. Instead, it manually constructs a batch covering all
223
- combination of iid trial and theta batch and reshapes to sum over the iid
229
+ the batch of `x`. It avoids the evaluation of the likelihood for every combination
230
+ of `x` and `condition`. Instead, it manually constructs a batch covering all
231
+ combination of iid trials and theta batch and reshapes to sum over the iid
224
232
likelihoods.
225
233
226
234
Args:
227
- x: Batch of iid data of shape `(iid_dim, *event_shape)`.
228
- theta_without_condition: Batch of parameters `(batch_dim, *event_shape)`
229
- condition: Batch of conditions of shape `(iid_dim, *condition_shape)`.
235
+ x: data with shape `(sample_dim, x_batch_dim, *x_event_shape)`, where sample_dim
236
+ holds the i.i.d. trials and batch_dim holds a batch of xs, e.g., non-iid
237
+ observations.
238
+ theta_without_condition: Batch of parameters `(theta_batch_dim,
239
+ num_parameters)`.
240
+ condition: Batch of conditions of shape `(sample_dim, num_conditions)`, must
241
+ match x's `sample_dim`.
230
242
estimator: DensityEstimator.
231
243
track_gradients: Whether to track gradients.
232
244
233
245
Returns:
234
- log_likelihood_trial_sum: log likelihood for each parameter, summed over all
235
- batch entries (iid trials) in `x`.
246
+ log_likelihood: log likelihood for each x in x_batch_dim, for each theta in
247
+ theta_batch_dim, summed over all i.i.d. trials. Shape
248
+ `(x_batch_dim, theta_batch_dim)`.
236
249
"""
250
+ assert x .dim () > 2 , "x must have shape (sample_dim, batch_dim, *event_shape)."
237
251
assert (
238
- condition .shape [0 ] == x .shape [0 ]
239
- ), "Condition and iid x must have the same batch size."
240
- num_trials = x .shape [0 ]
241
- num_theta = theta_without_condition .shape [0 ]
242
- x = reshape_to_sample_batch_event (
243
- x , event_shape = x .shape [1 :], leading_is_sample = True
244
- )
252
+ condition .dim () == 2
253
+ ), "condition must have shape (sample_dim, num_conditions)."
254
+ assert (
255
+ theta_without_condition .dim () == 2
256
+ ), "theta must have shape (batch_dim, num_parameters)."
257
+ num_trials , num_xs = x .shape [:2 ]
258
+ num_thetas = theta_without_condition .shape [0 ]
259
+ assert (
260
+ condition .shape [0 ] == num_trials
261
+ ), "Condition batch size must match the number of iid trials in x."
245
262
246
263
# move the iid batch dimension onto the batch dimension of theta and repeat it there
247
- x_expanded = x .reshape (1 , num_trials , - 1 ).repeat_interleave (num_theta , dim = 1 )
248
- # for this to work we construct theta and condition to cover all combinations in the
249
- # trial batch and the theta batch.
250
- theta = torch .cat (
264
+ x .transpose_ (0 , 1 )
265
+ x_repeated = x .repeat_interleave (num_thetas , dim = 1 )
266
+
267
+ # construct theta and condition to cover all trial-theta combinations
268
+ theta_with_condition = torch .cat (
251
269
[
252
270
theta_without_condition .repeat (num_trials , 1 ), # repeat ABAB
253
- condition .repeat_interleave (num_theta , dim = 0 ), # repeat AABB
271
+ condition .repeat_interleave (num_thetas , dim = 0 ), # repeat AABB
254
272
],
255
273
dim = - 1 ,
256
274
)
257
275
258
276
with torch .set_grad_enabled (track_gradients ):
259
- # Calculate likelihood in one batch. Returns (1, num_trials * theta_batch_size)
260
- log_likelihood_trial_batch = estimator .log_prob (x_expanded , condition = theta )
277
+ # Calculate likelihood in one batch. Returns (1, num_trials * num_theta)
278
+ log_likelihood_trial_batch = estimator .log_prob (
279
+ x_repeated , condition = theta_with_condition
280
+ )
261
281
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
262
282
log_likelihood_trial_sum = log_likelihood_trial_batch .reshape (
263
- num_trials , num_theta
264
- ).sum (0 )
283
+ num_xs , num_trials , num_thetas
284
+ ).sum (1 )
265
285
266
286
return log_likelihood_trial_sum
267
287
0 commit comments