@@ -117,23 +117,25 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
117
117
return log_likelihood_batches + self .prior .log_prob (theta ) # type: ignore
118
118
119
119
def condition_on_theta (
120
- self , theta_condition : Tensor , dims_to_sample : List [int ]
120
+ self , local_theta : Tensor , dims_global_theta : List [int ]
121
121
) -> Callable :
122
- """Returns a potential function conditioned on a subset of theta dimensions.
122
+ r """Returns a potential function conditioned on a subset of theta dimensions.
123
123
124
- The condition is a part of theta, but is assumed to correspond to a batch of iid
125
- x_o. For example, it can be a batch of experimental conditions that corresponds
126
- to a batch of i.i.d. trials in x_o.
124
+ The goal of this function is to divide the original `theta` into a
125
+ `global_theta` we do inference over, and a `local_theta` we condition on (in
126
+ addition to conditioning on `x_o`). Thus, the returned potential function will
127
+ calculate $\prod_{i=1}^{N}p(x_i | local_theta_i, \global_theta)$, where `x_i`
128
+ and `local_theta_i` are fixed and `global_theta` varies at inference time.
127
129
128
130
Args:
129
- theta_condition : The condition values to be conditioned.
130
- dims_to_sample : The indices of the columns in theta that will be sampled,
131
- i.e., that *not* conditioned. For example, if original theta has shape
132
- `(batch_dim, 3)`, and `dims_to_sample =[0, 1]`, then the potential will
133
- set `theta[:, 3] = theta_condition ` at inference time.
131
+ local_theta : The condition values to be conditioned.
132
+ dims_global_theta : The indices of the columns in ` theta` that will be
133
+ sampled, i.e., that *not* conditioned. For example, if original theta
134
+ has shape `(batch_dim, 3)`, and `dims_global_theta =[0, 1]`, then the
135
+ potential will set `theta[:, 3] = local_theta ` at inference time.
134
136
135
137
Returns:
136
- A potential function conditioned on the theta_condition .
138
+ A potential function conditioned on the `local_theta` .
137
139
"""
138
140
139
141
assert self .x_is_iid , "Conditioning is only supported for iid data."
@@ -142,20 +144,20 @@ def conditioned_potential(
142
144
theta : Tensor , x_o : Optional [Tensor ] = None , track_gradients : bool = True
143
145
) -> Tensor :
144
146
assert (
145
- len (dims_to_sample ) == theta .shape [1 ]
146
- ), "dims_to_sample must match the number of parameters to sample."
147
- theta_without_condition = theta [:, dims_to_sample ]
147
+ len (dims_global_theta ) == theta .shape [1 ]
148
+ ), "dims_global_theta must match the number of parameters to sample."
149
+ global_theta = theta [:, dims_global_theta ]
148
150
x_o = x_o if x_o is not None else self .x_o
149
151
# x needs shape (sample_dim (iid), batch_dim (xs), *event_shape)
150
152
if x_o .dim () < 3 :
151
153
x_o = reshape_to_sample_batch_event (
152
154
x_o , event_shape = x_o .shape [1 :], leading_is_sample = self .x_is_iid
153
155
)
154
156
155
- return _log_likelihood_over_iid_conditions (
157
+ return _log_likelihood_over_iid_trials_and_local_theta (
156
158
x = x_o ,
157
- theta_without_condition = theta_without_condition ,
158
- condition = theta_condition ,
159
+ global_theta = global_theta ,
160
+ local_theta = local_theta ,
159
161
estimator = self .likelihood_estimator ,
160
162
track_gradients = track_gradients ,
161
163
)
@@ -219,51 +221,50 @@ def _log_likelihoods_over_trials(
219
221
return log_likelihood_trial_sum
220
222
221
223
222
- def _log_likelihood_over_iid_conditions (
224
+ def _log_likelihood_over_iid_trials_and_local_theta (
223
225
x : Tensor ,
224
- theta_without_condition : Tensor ,
225
- condition : Tensor ,
226
+ global_theta : Tensor ,
227
+ local_theta : Tensor ,
226
228
estimator : ConditionalDensityEstimator ,
227
229
track_gradients : bool = False ,
228
230
) -> Tensor :
229
- """Returns $\\ log(p(x_o|\t heta, condition)$, where x_o is a batch of iid data, and
230
- condition is a matching batch of conditions.
231
+ """Returns $\\ prod_{i=1}^N \\ log(p(x_i|\t heta, local_theta_i)$.
232
+
233
+ `x` is a batch of iid data, and `local_theta` is a matching batch of condition
234
+ values that were part of `theta` but are treated as local iid variables at inference
235
+ time.
231
236
232
237
This function is different from `_log_likelihoods_over_trials` in that it moves the
233
238
iid batch dimension of `x` onto the batch dimension of `theta`. This is needed when
234
239
the likelihood estimator is conditioned on a batch of conditions that are iid with
235
240
the batch of `x`. It avoids the evaluation of the likelihood for every combination
236
- of `x` and `condition`. Instead, it manually constructs a batch covering all
237
- combination of iid trials and theta batch and reshapes to sum over the iid
238
- likelihoods.
241
+ of `x` and `local_theta`.
239
242
240
243
Args:
241
244
x: data with shape `(sample_dim, x_batch_dim, *x_event_shape)`, where sample_dim
242
245
holds the i.i.d. trials and batch_dim holds a batch of xs, e.g., non-iid
243
246
observations.
244
- theta_without_condition : Batch of parameters `(theta_batch_dim,
247
+ global_theta : Batch of parameters `(theta_batch_dim,
245
248
num_parameters)`.
246
- condition : Batch of conditions of shape `(sample_dim, num_conditions )`, must
249
+ local_theta : Batch of conditions of shape `(sample_dim, num_local_thetas )`, must
247
250
match x's `sample_dim`.
248
251
estimator: DensityEstimator.
249
252
track_gradients: Whether to track gradients.
250
253
251
254
Returns:
252
255
log_likelihood: log likelihood for each x in x_batch_dim, for each theta in
253
- theta_batch_dim, summed over all i.i.d. trials. Shape
254
- `(x_batch_dim, theta_batch_dim)`.
256
+ theta_batch_dim, summed over all iid trials. Shape `(x_batch_dim,
257
+ theta_batch_dim)`.
255
258
"""
256
259
assert x .dim () > 2 , "x must have shape (sample_dim, batch_dim, *event_shape)."
257
260
assert (
258
- condition .dim () == 2
261
+ local_theta .dim () == 2
259
262
), "condition must have shape (sample_dim, num_conditions)."
260
- assert (
261
- theta_without_condition .dim () == 2
262
- ), "theta must have shape (batch_dim, num_parameters)."
263
+ assert global_theta .dim () == 2 , "theta must have shape (batch_dim, num_parameters)."
263
264
num_trials , num_xs = x .shape [:2 ]
264
- num_thetas = theta_without_condition .shape [0 ]
265
+ num_thetas = global_theta .shape [0 ]
265
266
assert (
266
- condition .shape [0 ] == num_trials
267
+ local_theta .shape [0 ] == num_trials
267
268
), "Condition batch size must match the number of iid trials in x."
268
269
269
270
# move the iid batch dimension onto the batch dimension of theta and repeat it there
@@ -272,8 +273,8 @@ def _log_likelihood_over_iid_conditions(
272
273
# construct theta and condition to cover all trial-theta combinations
273
274
theta_with_condition = torch .cat (
274
275
[
275
- theta_without_condition .repeat (num_trials , 1 ), # repeat ABAB
276
- condition .repeat_interleave (num_thetas , dim = 0 ), # repeat AABB
276
+ global_theta .repeat (num_trials , 1 ), # repeat ABAB
277
+ local_theta .repeat_interleave (num_thetas , dim = 0 ), # repeat AABB
277
278
],
278
279
dim = - 1 ,
279
280
)
0 commit comments