Skip to content

Commit 33aa5ad

Browse files
manuelgloecklermichaeldeistler
authored andcommitted
Remove iid_bridge (other PR)
1 parent 63305d9 commit 33aa5ad

File tree

1 file changed

+0
-89
lines changed

1 file changed

+0
-89
lines changed

sbi/inference/potentials/score_based_potential.py

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -230,92 +230,3 @@ def f(t, x):
230230
)
231231
return transform
232232

233-
234-
def _iid_bridge(
235-
theta: Tensor,
236-
xos: Tensor,
237-
time: Tensor,
238-
score_estimator: ConditionalScoreEstimator,
239-
prior: Distribution,
240-
t_max: float = 1.0,
241-
):
242-
r"""
243-
Returns the score-based potential for multiple IID observations.
244-
245-
This can require a special solver to obtain the correct tall posterior.
246-
247-
Args:
248-
input: The parameter values at which to evaluate the potential.
249-
condition: The observed data at which to evaluate the potential.
250-
time: The diffusion time.
251-
score_estimator: The neural network modelling the score.
252-
prior: The prior distribution.
253-
"""
254-
255-
assert (
256-
next(score_estimator.parameters()).device == xos.device
257-
and xos.device == theta.device
258-
), f"""device mismatch: estimator, x, theta: \
259-
{next(score_estimator.parameters()).device}, {xos.device},
260-
{theta.device}."""
261-
262-
# Get number of observations which are left from event_shape if they exist.
263-
condition_shape = score_estimator.condition_shape
264-
num_obs = xos.shape[-len(condition_shape) - 1]
265-
266-
# Calculate likelihood in one batch.
267-
# xos is of shape (num_obs, *condition_shape).
268-
# theta is of shape (num_samples, *parameter_shape).
269-
270-
# TODO: we need to combine the batch shapes of num_obs and num_samples for both
271-
# theta and xos.
272-
theta_per_xo = theta.repeat(num_obs, 1)
273-
xos_per_theta = xos.repeat_interleave(theta.shape[0], dim=0)
274-
275-
score_trial_batch = score_estimator.forward(
276-
input=theta_per_xo,
277-
condition=xos_per_theta,
278-
time=time,
279-
).reshape(num_obs, theta.shape[0], -1)
280-
281-
# Sum over m observations, as in Geffner et al., equation (7).
282-
score_trial_sum = score_trial_batch.sum(0)
283-
prior_contribution = _get_prior_contribution(time, prior, theta, num_obs, t_max)
284-
285-
return score_trial_sum + prior_contribution
286-
287-
288-
def _get_prior_contribution(
289-
diffusion_time: Tensor,
290-
prior: Distribution,
291-
theta: Tensor,
292-
num_obs: int,
293-
t_max: float = 1.0,
294-
):
295-
r"""Returns the prior contribution for multiple IID observations.
296-
297-
Args:
298-
diffusion_time: The diffusion time.
299-
prior: The prior distribution.
300-
theta: The parameter values at which to evaluate the prior contribution.
301-
num_obs: The number of independent observations.
302-
"""
303-
# This method can be used to add several different bridges
304-
# to obtain the posterior for multiple IID observations.
305-
# For now, it only implements the approach by Geffner et al.
306-
307-
# TODO Check if prior has the grad property else use torch autograd.
308-
# For now just use autograd.
309-
# Ensure theta requires gradients
310-
theta.requires_grad_(True)
311-
312-
log_prob_theta = prior.log_prob(theta)
313-
314-
grad_log_prob_theta = torch.autograd.grad(
315-
log_prob_theta,
316-
theta,
317-
grad_outputs=torch.ones_like(log_prob_theta),
318-
create_graph=True,
319-
)[0]
320-
321-
return ((1 - num_obs) * (t_max - diffusion_time)) / t_max * grad_log_prob_theta

0 commit comments

Comments
 (0)