Skip to content

Commit 737764d

Browse files
committed
feedback; fix texts
1 parent b058b1c commit 737764d

File tree

3 files changed

+47
-44
lines changed

3 files changed

+47
-44
lines changed

sbi/inference/potentials/likelihood_based_potential.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -117,23 +117,25 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
117117
return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore
118118

119119
def condition_on_theta(
120-
self, theta_condition: Tensor, dims_to_sample: List[int]
120+
self, local_theta: Tensor, dims_global_theta: List[int]
121121
) -> Callable:
122-
"""Returns a potential function conditioned on a subset of theta dimensions.
122+
r"""Returns a potential function conditioned on a subset of theta dimensions.
123123
124-
The condition is a part of theta, but is assumed to correspond to a batch of iid
125-
x_o. For example, it can be a batch of experimental conditions that corresponds
126-
to a batch of i.i.d. trials in x_o.
124+
The goal of this function is to divide the original `theta` into a
125+
`global_theta` we do inference over, and a `local_theta` we condition on (in
126+
addition to conditioning on `x_o`). Thus, the returned potential function will
127+
calculate $\prod_{i=1}^{N}p(x_i | local_theta_i, \global_theta)$, where `x_i`
128+
and `local_theta_i` are fixed and `global_theta` varies at inference time.
127129
128130
Args:
129-
theta_condition: The condition values to be conditioned.
130-
dims_to_sample: The indices of the columns in theta that will be sampled,
131-
i.e., that *not* conditioned. For example, if original theta has shape
132-
`(batch_dim, 3)`, and `dims_to_sample=[0, 1]`, then the potential will
133-
set `theta[:, 3] = theta_condition` at inference time.
131+
local_theta: The condition values to be conditioned.
132+
dims_global_theta: The indices of the columns in `theta` that will be
133+
sampled, i.e., that *not* conditioned. For example, if original theta
134+
has shape `(batch_dim, 3)`, and `dims_global_theta=[0, 1]`, then the
135+
potential will set `theta[:, 3] = local_theta` at inference time.
134136
135137
Returns:
136-
A potential function conditioned on the theta_condition.
138+
A potential function conditioned on the `local_theta`.
137139
"""
138140

139141
assert self.x_is_iid, "Conditioning is only supported for iid data."
@@ -142,20 +144,20 @@ def conditioned_potential(
142144
theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True
143145
) -> Tensor:
144146
assert (
145-
len(dims_to_sample) == theta.shape[1]
146-
), "dims_to_sample must match the number of parameters to sample."
147-
theta_without_condition = theta[:, dims_to_sample]
147+
len(dims_global_theta) == theta.shape[1]
148+
), "dims_global_theta must match the number of parameters to sample."
149+
global_theta = theta[:, dims_global_theta]
148150
x_o = x_o if x_o is not None else self.x_o
149151
# x needs shape (sample_dim (iid), batch_dim (xs), *event_shape)
150152
if x_o.dim() < 3:
151153
x_o = reshape_to_sample_batch_event(
152154
x_o, event_shape=x_o.shape[1:], leading_is_sample=self.x_is_iid
153155
)
154156

155-
return _log_likelihood_over_iid_conditions(
157+
return _log_likelihood_over_iid_trials_and_local_theta(
156158
x=x_o,
157-
theta_without_condition=theta_without_condition,
158-
condition=theta_condition,
159+
global_theta=global_theta,
160+
local_theta=local_theta,
159161
estimator=self.likelihood_estimator,
160162
track_gradients=track_gradients,
161163
)
@@ -219,51 +221,50 @@ def _log_likelihoods_over_trials(
219221
return log_likelihood_trial_sum
220222

221223

222-
def _log_likelihood_over_iid_conditions(
224+
def _log_likelihood_over_iid_trials_and_local_theta(
223225
x: Tensor,
224-
theta_without_condition: Tensor,
225-
condition: Tensor,
226+
global_theta: Tensor,
227+
local_theta: Tensor,
226228
estimator: ConditionalDensityEstimator,
227229
track_gradients: bool = False,
228230
) -> Tensor:
229-
"""Returns $\\log(p(x_o|\theta, condition)$, where x_o is a batch of iid data, and
230-
condition is a matching batch of conditions.
231+
"""Returns $\\prod_{i=1}^N \\log(p(x_i|\theta, local_theta_i)$.
232+
233+
`x` is a batch of iid data, and `local_theta` is a matching batch of condition
234+
values that were part of `theta` but are treated as local iid variables at inference
235+
time.
231236
232237
This function is different from `_log_likelihoods_over_trials` in that it moves the
233238
iid batch dimension of `x` onto the batch dimension of `theta`. This is needed when
234239
the likelihood estimator is conditioned on a batch of conditions that are iid with
235240
the batch of `x`. It avoids the evaluation of the likelihood for every combination
236-
of `x` and `condition`. Instead, it manually constructs a batch covering all
237-
combination of iid trials and theta batch and reshapes to sum over the iid
238-
likelihoods.
241+
of `x` and `local_theta`.
239242
240243
Args:
241244
x: data with shape `(sample_dim, x_batch_dim, *x_event_shape)`, where sample_dim
242245
holds the i.i.d. trials and batch_dim holds a batch of xs, e.g., non-iid
243246
observations.
244-
theta_without_condition: Batch of parameters `(theta_batch_dim,
247+
global_theta: Batch of parameters `(theta_batch_dim,
245248
num_parameters)`.
246-
condition: Batch of conditions of shape `(sample_dim, num_conditions)`, must
249+
local_theta: Batch of conditions of shape `(sample_dim, num_local_thetas)`, must
247250
match x's `sample_dim`.
248251
estimator: DensityEstimator.
249252
track_gradients: Whether to track gradients.
250253
251254
Returns:
252255
log_likelihood: log likelihood for each x in x_batch_dim, for each theta in
253-
theta_batch_dim, summed over all i.i.d. trials. Shape
254-
`(x_batch_dim, theta_batch_dim)`.
256+
theta_batch_dim, summed over all iid trials. Shape `(x_batch_dim,
257+
theta_batch_dim)`.
255258
"""
256259
assert x.dim() > 2, "x must have shape (sample_dim, batch_dim, *event_shape)."
257260
assert (
258-
condition.dim() == 2
261+
local_theta.dim() == 2
259262
), "condition must have shape (sample_dim, num_conditions)."
260-
assert (
261-
theta_without_condition.dim() == 2
262-
), "theta must have shape (batch_dim, num_parameters)."
263+
assert global_theta.dim() == 2, "theta must have shape (batch_dim, num_parameters)."
263264
num_trials, num_xs = x.shape[:2]
264-
num_thetas = theta_without_condition.shape[0]
265+
num_thetas = global_theta.shape[0]
265266
assert (
266-
condition.shape[0] == num_trials
267+
local_theta.shape[0] == num_trials
267268
), "Condition batch size must match the number of iid trials in x."
268269

269270
# move the iid batch dimension onto the batch dimension of theta and repeat it there
@@ -272,8 +273,8 @@ def _log_likelihood_over_iid_conditions(
272273
# construct theta and condition to cover all trial-theta combinations
273274
theta_with_condition = torch.cat(
274275
[
275-
theta_without_condition.repeat(num_trials, 1), # repeat ABAB
276-
condition.repeat_interleave(num_thetas, dim=0), # repeat AABB
276+
global_theta.repeat(num_trials, 1), # repeat ABAB
277+
local_theta.repeat_interleave(num_thetas, dim=0), # repeat AABB
277278
],
278279
dim=-1,
279280
)

tests/mnle_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sbi.inference.posteriors.vi_posterior import VIPosterior
1515
from sbi.inference.potentials.base_potential import BasePotential
1616
from sbi.inference.potentials.likelihood_based_potential import (
17-
_log_likelihood_over_iid_conditions,
17+
_log_likelihood_over_iid_trials_and_local_theta,
1818
likelihood_estimator_based_potential,
1919
)
2020
from sbi.neural_nets import likelihood_nn
@@ -299,8 +299,8 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
299299
estimator = trainer.append_simulations(theta, x).train()
300300

301301
potential_fn, _ = likelihood_estimator_based_potential(estimator, proposal, x_o)
302-
conditioned_potential_fn = potential_fn.condition_on(
303-
condition_o, dims_to_sample=[0, 1]
302+
conditioned_potential_fn = potential_fn.condition_on_theta(
303+
condition_o, dims_global_theta=[0, 1]
304304
)
305305

306306
# True posterior samples
@@ -350,7 +350,7 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
350350
),
351351
],
352352
)
353-
def test_log_likelihood_over_iid_conditions(
353+
def test_log_likelihood_over_local_iid_theta(
354354
num_thetas, num_trials, num_xs, num_conditions
355355
):
356356
"""Test log likelihood over iid conditions using MNLE.
@@ -397,7 +397,9 @@ def test_log_likelihood_over_iid_conditions(
397397
theta = proposal.sample((num_thetas,))[:, :2]
398398
# x_o has shape (iid, batch, *event)
399399
# condition_o has shape (iid, num_conditions)
400-
ll_batched = _log_likelihood_over_iid_conditions(x_o, theta, condition_o, estimator)
400+
ll_batched = _log_likelihood_over_iid_trials_and_local_theta(
401+
x_o, theta, condition_o, estimator
402+
)
401403

402404
# looped conditioning
403405
ll_single = []

tutorials/Example_01_DecisionMakingModel.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@
658658
"# Then, we condition on the experimental conditions.\n",
659659
"conditioned_potential_fn = potential_fn.condition_on_theta(\n",
660660
" conditions, # pass only the conditions, must match the batch of iid data in x_o\n",
661-
" dims_to_sample=[0, 1] # pass the dimensions in the original theta that correspond to beta and rho\n",
661+
" dims_global_theta=[0, 1] # pass the dimensions in the original theta that correspond to beta and rho\n",
662662
")\n",
663663
"\n",
664664
"# Using this potential function, we can now obtain conditional samples.\n",
@@ -815,7 +815,7 @@
815815
" # condition the potential\n",
816816
" conditioned_potential_fn = potential_fn.condition_on_theta(\n",
817817
" conditions[idx],\n",
818-
" dims_to_sample=[0, 1]\n",
818+
" dims_global_theta=[0, 1]\n",
819819
" )\n",
820820
"\n",
821821
" # pass potential to sampler\n",

0 commit comments

Comments
 (0)