Skip to content

Commit 9b21a38

Browse files
committed
refactor: vectorize gt potential, refactor nb
1 parent 510cd15 commit 9b21a38

File tree

2 files changed

+53
-53
lines changed

2 files changed

+53
-53
lines changed

tests/mnle_test.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_mnle_accuracy_with_different_samplers_and_trials(
165165
# True posterior samples
166166
transform = mcmc_transform(prior)
167167
true_posterior_samples = MCMCPosterior(
168-
PotentialFunctionProvider(prior, atleast_2d(x_o)),
168+
BinomialGammaPotential(prior, atleast_2d(x_o)),
169169
theta_transform=transform,
170170
proposal=prior,
171171
**mcmc_kwargs,
@@ -189,14 +189,9 @@ def test_mnle_accuracy_with_different_samplers_and_trials(
189189
)
190190

191191

192-
class PotentialFunctionProvider(BasePotential):
193-
"""Returns potential function for reference posterior of a mixed likelihood."""
194-
195-
allow_iid_x = True # type: ignore
196-
192+
class BinomialGammaPotential(BasePotential):
197193
def __init__(self, prior, x_o, concentration_scaling=1.0, device="cpu"):
198194
super().__init__(prior, x_o, device)
199-
200195
self.concentration_scaling = concentration_scaling
201196

202197
def __call__(self, theta, track_gradients: bool = True):
@@ -207,33 +202,25 @@ def __call__(self, theta, track_gradients: bool = True):
207202

208203
return iid_ll + self.prior.log_prob(theta)
209204

210-
def iid_likelihood(self, theta: torch.Tensor) -> torch.Tensor:
211-
"""Returns the likelihood summed over a batch of i.i.d. data."""
212-
213-
lp_choices = torch.stack(
214-
[
215-
Binomial(probs=th.reshape(1, -1)).log_prob(self.x_o[:, 1:])
216-
for th in theta[:, 1:]
217-
],
218-
dim=1,
205+
def iid_likelihood(self, theta):
206+
batch_size = theta.shape[0]
207+
num_trials = self.x_o.shape[0]
208+
theta = theta.reshape(batch_size, 1, -1)
209+
beta, rho = theta[:, :, :1], theta[:, :, 1:]
210+
# vectorized
211+
logprob_choices = Binomial(probs=rho).log_prob(
212+
self.x_o[:, 1:].reshape(1, num_trials, -1)
219213
)
220214

221-
lp_rts = torch.stack(
222-
[
223-
InverseGamma(
224-
concentration=self.concentration_scaling * torch.ones_like(beta_i),
225-
rate=beta_i,
226-
).log_prob(self.x_o[:, :1])
227-
for beta_i in theta[:, :1]
228-
],
229-
dim=1,
230-
)
215+
logprob_rts = InverseGamma(
216+
concentration=self.concentration_scaling * torch.ones_like(beta),
217+
rate=beta,
218+
).log_prob(self.x_o[:, :1].reshape(1, num_trials, -1))
231219

232-
joint_likelihood = (lp_choices + lp_rts).reshape(
233-
self.x_o.shape[0], theta.shape[0]
234-
)
220+
joint_likelihood = (logprob_choices + logprob_rts).squeeze()
235221

236-
return joint_likelihood.sum(0)
222+
assert joint_likelihood.shape == torch.Size([theta.shape[0], self.x_o.shape[0]])
223+
return joint_likelihood.sum(1)
237224

238225

239226
@pytest.mark.slow
@@ -295,7 +282,7 @@ def sim_wrapper(theta):
295282
)
296283
prior_transform = mcmc_transform(prior)
297284
true_posterior_samples = MCMCPosterior(
298-
PotentialFunctionProvider(
285+
BinomialGammaPotential(
299286
prior,
300287
atleast_2d(x_o),
301288
concentration_scaling=float(theta_o[0, 2])

0 commit comments

Comments
 (0)