@@ -165,7 +165,7 @@ def test_mnle_accuracy_with_different_samplers_and_trials(
165165 # True posterior samples
166166 transform = mcmc_transform (prior )
167167 true_posterior_samples = MCMCPosterior (
168- PotentialFunctionProvider (prior , atleast_2d (x_o )),
168+ BinomialGammaPotential (prior , atleast_2d (x_o )),
169169 theta_transform = transform ,
170170 proposal = prior ,
171171 ** mcmc_kwargs ,
@@ -189,14 +189,9 @@ def test_mnle_accuracy_with_different_samplers_and_trials(
189189 )
190190
191191
192- class PotentialFunctionProvider (BasePotential ):
193- """Returns potential function for reference posterior of a mixed likelihood."""
194-
195- allow_iid_x = True # type: ignore
196-
192+ class BinomialGammaPotential (BasePotential ):
197193 def __init__ (self , prior , x_o , concentration_scaling = 1.0 , device = "cpu" ):
198194 super ().__init__ (prior , x_o , device )
199-
200195 self .concentration_scaling = concentration_scaling
201196
202197 def __call__ (self , theta , track_gradients : bool = True ):
@@ -207,33 +202,25 @@ def __call__(self, theta, track_gradients: bool = True):
207202
208203 return iid_ll + self .prior .log_prob (theta )
209204
210- def iid_likelihood (self , theta : torch .Tensor ) -> torch .Tensor :
211- """Returns the likelihood summed over a batch of i.i.d. data."""
212-
213- lp_choices = torch .stack (
214- [
215- Binomial (probs = th .reshape (1 , - 1 )).log_prob (self .x_o [:, 1 :])
216- for th in theta [:, 1 :]
217- ],
218- dim = 1 ,
205+ def iid_likelihood (self , theta ):
206+ batch_size = theta .shape [0 ]
207+ num_trials = self .x_o .shape [0 ]
208+ theta = theta .reshape (batch_size , 1 , - 1 )
209+ beta , rho = theta [:, :, :1 ], theta [:, :, 1 :]
210+ # vectorized
211+ logprob_choices = Binomial (probs = rho ).log_prob (
212+ self .x_o [:, 1 :].reshape (1 , num_trials , - 1 )
219213 )
220214
221- lp_rts = torch .stack (
222- [
223- InverseGamma (
224- concentration = self .concentration_scaling * torch .ones_like (beta_i ),
225- rate = beta_i ,
226- ).log_prob (self .x_o [:, :1 ])
227- for beta_i in theta [:, :1 ]
228- ],
229- dim = 1 ,
230- )
215+ logprob_rts = InverseGamma (
216+ concentration = self .concentration_scaling * torch .ones_like (beta ),
217+ rate = beta ,
218+ ).log_prob (self .x_o [:, :1 ].reshape (1 , num_trials , - 1 ))
231219
232- joint_likelihood = (lp_choices + lp_rts ).reshape (
233- self .x_o .shape [0 ], theta .shape [0 ]
234- )
220+ joint_likelihood = (logprob_choices + logprob_rts ).squeeze ()
235221
236- return joint_likelihood .sum (0 )
222+ assert joint_likelihood .shape == torch .Size ([theta .shape [0 ], self .x_o .shape [0 ]])
223+ return joint_likelihood .sum (1 )
237224
238225
239226@pytest .mark .slow
@@ -295,7 +282,7 @@ def sim_wrapper(theta):
295282 )
296283 prior_transform = mcmc_transform (prior )
297284 true_posterior_samples = MCMCPosterior (
298- PotentialFunctionProvider (
285+ BinomialGammaPotential (
299286 prior ,
300287 atleast_2d (x_o ),
301288 concentration_scaling = float (theta_o [0 , 2 ])
0 commit comments