diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 3d810cc5f..55d8f0d88 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -1,7 +1,6 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see -import inspect from abc import abstractmethod from typing import Any, Callable, Dict, Optional, Union from warnings import warn @@ -54,14 +53,6 @@ def __init__( # Wrap as `CallablePotentialWrapper` if `potential_fn` is a Callable. if not isinstance(potential_fn, BasePotential): - kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys()) - for key in ["theta", "x_o"]: - assert key in kwargs_of_callable, ( - "If you pass a `Callable` as `potential_fn` then it must have " - "`theta` and `x_o` as inputs, even if some of these keyword " - "arguments are unused." - ) - # If the `potential_fn` is a Callable then we wrap it as a # `CallablePotentialWrapper` which inherits from `BasePotential`. potential_device = "cpu" if device is None else device diff --git a/sbi/inference/potentials/base_potential.py b/sbi/inference/potentials/base_potential.py index f7f9dfe41..fa65a7d88 100644 --- a/sbi/inference/potentials/base_potential.py +++ b/sbi/inference/potentials/base_potential.py @@ -1,8 +1,9 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see +import inspect from abc import ABCMeta, abstractmethod -from typing import Optional +from typing import Callable, Optional import torch from torch import Tensor @@ -85,18 +86,39 @@ def return_x_o(self) -> Optional[Tensor]: class CallablePotentialWrapper(BasePotential): """If `potential_fn` is a callable it gets wrapped as this.""" - allow_iid_x = True # type: ignore - def __init__( self, - callable_potential, + potential_fn: Callable, prior: Optional[Distribution], x_o: Optional[Tensor] = None, device: str = "cpu", ): + """Wraps a callable potential function. + + Args: + potential_fn: Callable potential function, must have `theta` and `x_o` as + arguments. + prior: Prior distribution. + x_o: Observed data. + device: Device on which to evaluate the potential function. + + """ super().__init__(prior, x_o, device) - self.callable_potential = callable_potential + + kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys()) + required_keys = ["theta", "x_o"] + for key in required_keys: + assert key in kwargs_of_callable, ( + "If you pass a `Callable` as `potential_fn` then it must have " + "`theta` and `x_o` as inputs, even if some of these keyword " + "arguments are unused." + ) + self.potential_fn = potential_fn def __call__(self, theta, track_gradients: bool = True): + """Call the callable potential function on given theta. + + Note, x_o is re-used from the initialization of the potential function. + """ with torch.set_grad_enabled(track_gradients): - return self.callable_potential(theta=theta, x_o=self.x_o) + return self.potential_fn(theta=theta, x_o=self.x_o) diff --git a/sbi/utils/conditional_density_utils.py b/sbi/utils/conditional_density_utils.py index ee8606389..d6c73b7c9 100644 --- a/sbi/utils/conditional_density_utils.py +++ b/sbi/utils/conditional_density_utils.py @@ -271,7 +271,7 @@ def condition_mog( return logits, means, precfs_xx, sumlogdiag -class ConditionedPotential: +class ConditionedPotential(BasePotential): def __init__( self, potential_fn: BasePotential, @@ -282,20 +282,26 @@ def __init__( Return conditional posterior log-probability or $-\infty$ if outside prior. Args: - theta: Free parameters $\theta_i$, batch dimension 1. + potential_fn: Potential function to condition on. + condition: Fixed parameters $\theta_j$, batch size 1. + dims_to_sample: Which dimensions to sample from. The dimensions not + specified in `dims_to_sample` will be fixed to values given in + `condition`. Returns: Conditional posterior log-probability $\log(p(\theta_i|\theta_j, x))$, masked outside of prior. """ + condition = torch.atleast_2d(condition) + if condition.shape[0] != 1: + raise ValueError("Condition with batch size > 1 not supported.") + self.potential_fn = potential_fn self.condition = condition self.dims_to_sample = dims_to_sample self.device = self.potential_fn.device - def __call__( - self, theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True - ) -> Tensor: + def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: r""" Returns the conditional potential $\log(p(\theta_i|\theta_j, x))$. diff --git a/tests/mnle_test.py b/tests/mnle_test.py index 80f34bbd8..a95a2a6ac 100644 --- a/tests/mnle_test.py +++ b/tests/mnle_test.py @@ -165,7 +165,7 @@ def test_mnle_accuracy_with_different_samplers_and_trials( # True posterior samples transform = mcmc_transform(prior) true_posterior_samples = MCMCPosterior( - PotentialFunctionProvider(prior, atleast_2d(x_o)), + BinomialGammaPotential(prior, atleast_2d(x_o)), theta_transform=transform, proposal=prior, **mcmc_kwargs, @@ -189,14 +189,9 @@ def test_mnle_accuracy_with_different_samplers_and_trials( ) -class PotentialFunctionProvider(BasePotential): - """Returns potential function for reference posterior of a mixed likelihood.""" - - allow_iid_x = True # type: ignore - +class BinomialGammaPotential(BasePotential): def __init__(self, prior, x_o, concentration_scaling=1.0, device="cpu"): super().__init__(prior, x_o, device) - self.concentration_scaling = concentration_scaling def __call__(self, theta, track_gradients: bool = True): @@ -207,33 +202,25 @@ def __call__(self, theta, track_gradients: bool = True): return iid_ll + self.prior.log_prob(theta) - def iid_likelihood(self, theta: torch.Tensor) -> torch.Tensor: - """Returns the likelihood summed over a batch of i.i.d. data.""" - - lp_choices = torch.stack( - [ - Binomial(probs=th.reshape(1, -1)).log_prob(self.x_o[:, 1:]) - for th in theta[:, 1:] - ], - dim=1, + def iid_likelihood(self, theta): + batch_size = theta.shape[0] + num_trials = self.x_o.shape[0] + theta = theta.reshape(batch_size, 1, -1) + beta, rho = theta[:, :, :1], theta[:, :, 1:] + # vectorized + logprob_choices = Binomial(probs=rho).log_prob( + self.x_o[:, 1:].reshape(1, num_trials, -1) ) - lp_rts = torch.stack( - [ - InverseGamma( - concentration=self.concentration_scaling * torch.ones_like(beta_i), - rate=beta_i, - ).log_prob(self.x_o[:, :1]) - for beta_i in theta[:, :1] - ], - dim=1, - ) + logprob_rts = InverseGamma( + concentration=self.concentration_scaling * torch.ones_like(beta), + rate=beta, + ).log_prob(self.x_o[:, :1].reshape(1, num_trials, -1)) - joint_likelihood = (lp_choices + lp_rts).reshape( - self.x_o.shape[0], theta.shape[0] - ) + joint_likelihood = (logprob_choices + logprob_rts).squeeze() - return joint_likelihood.sum(0) + assert joint_likelihood.shape == torch.Size([theta.shape[0], self.x_o.shape[0]]) + return joint_likelihood.sum(1) @pytest.mark.slow @@ -295,7 +282,7 @@ def sim_wrapper(theta): ) prior_transform = mcmc_transform(prior) true_posterior_samples = MCMCPosterior( - PotentialFunctionProvider( + BinomialGammaPotential( prior, atleast_2d(x_o), concentration_scaling=float(theta_o[0, 2]) diff --git a/tests/potential_test.py b/tests/potential_test.py index 77fb83cd1..9584e6916 100644 --- a/tests/potential_test.py +++ b/tests/potential_test.py @@ -5,7 +5,7 @@ import pytest import torch -from torch import eye, ones, zeros +from torch import Tensor, eye, ones, zeros from torch.distributions import MultivariateNormal from sbi.inference import ( @@ -14,6 +14,9 @@ RejectionPosterior, VIPosterior, ) +from sbi.inference.potentials.base_potential import CallablePotentialWrapper +from sbi.utils import BoxUniform +from sbi.utils.conditional_density_utils import ConditionedPotential @pytest.mark.parametrize( @@ -64,3 +67,25 @@ def potential(theta, x_o): sample_std = torch.std(approx_samples, dim=0) assert torch.allclose(sample_mean, torch.as_tensor(mean) - x_o, atol=0.2) assert torch.allclose(sample_std, torch.sqrt(torch.as_tensor(cov)), atol=0.1) + + +@pytest.mark.parametrize( + "condition", + [ + torch.rand(1, 2), + pytest.param( + torch.rand(2, 2), + marks=pytest.mark.xfail( + raises=ValueError, + match="Condition with batch size > 1 not supported", + ), + ), + ], +) +def test_conditioned_potential(condition: Tensor): + potential_fn = CallablePotentialWrapper( + potential_fn=lambda theta, x_o: theta, + prior=BoxUniform(low=zeros(2), high=ones(2)), + ) + + ConditionedPotential(potential_fn, condition=condition, dims_to_sample=[0]) diff --git a/tutorials/Example_01_DecisionMakingModel.ipynb b/tutorials/Example_01_DecisionMakingModel.ipynb index 49bf6e728..fcfa10ced 100644 --- a/tutorials/Example_01_DecisionMakingModel.ipynb +++ b/tutorials/Example_01_DecisionMakingModel.ipynb @@ -73,7 +73,15 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n" + ] + } + ], "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", @@ -109,9 +117,9 @@ " distribution, mimics an experimental condition.\n", "\n", " \"\"\"\n", - " beta, ps = theta[:, :1], theta[:, 1:]\n", + " beta, rho = theta[:, :1], theta[:, 1:]\n", "\n", - " choices = Binomial(probs=ps).sample()\n", + " choices = Binomial(probs=rho).sample()\n", " rts = InverseGamma(\n", " concentration=concentration_scaling * torch.ones_like(beta), rate=beta\n", " ).sample()\n", @@ -121,8 +129,7 @@ "\n", "# The potential function defines the ground truth likelihood and allows us to\n", "# obtain reference posterior samples via MCMC.\n", - "class PotentialFunctionProvider(BasePotential):\n", - " allow_iid_x = True # type: ignore\n", + "class BinomialGammaPotential(BasePotential):\n", "\n", " def __init__(self, prior, x_o, concentration_scaling=1.0, device=\"cpu\"):\n", " super().__init__(prior, x_o, device)\n", @@ -137,29 +144,24 @@ " return iid_ll + self.prior.log_prob(theta)\n", "\n", " def iid_likelihood(self, theta):\n", - " lp_choices = torch.stack(\n", - " [\n", - " Binomial(probs=th.reshape(1, -1)).log_prob(self.x_o[:, 1:])\n", - " for th in theta[:, 1:]\n", - " ],\n", - " dim=1,\n", + " batch_size = theta.shape[0]\n", + " num_trials = self.x_o.shape[0]\n", + " theta = theta.reshape(batch_size, 1, -1)\n", + " beta, rho = theta[:, :, :1], theta[:, :, 1:]\n", + " # vectorized\n", + " logprob_choices = Binomial(probs=rho).log_prob(\n", + " self.x_o[:, 1:].reshape(1, num_trials, -1)\n", " )\n", "\n", - " lp_rts = torch.stack(\n", - " [\n", - " InverseGamma(\n", - " concentration=self.concentration_scaling * torch.ones_like(beta_i),\n", - " rate=beta_i,\n", - " ).log_prob(self.x_o[:, :1])\n", - " for beta_i in theta[:, :1]\n", - " ],\n", - " dim=1,\n", - " )\n", + " logprob_rts = InverseGamma(\n", + " concentration=self.concentration_scaling * torch.ones_like(beta),\n", + " rate=beta,\n", + " ).log_prob(self.x_o[:, :1].reshape(1, num_trials, -1))\n", "\n", - " joint_likelihood = (lp_choices + lp_rts).squeeze()\n", + " joint_likelihood = (logprob_choices + logprob_rts).squeeze()\n", "\n", - " assert joint_likelihood.shape == torch.Size([self.x_o.shape[0], theta.shape[0]])\n", - " return joint_likelihood.sum(0)" + " assert joint_likelihood.shape == torch.Size([theta.shape[0], self.x_o.shape[0]])\n", + " return joint_likelihood.sum(1)" ] }, { @@ -202,7 +204,30 @@ "cell_type": "code", "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/janteusen/qode/sbi/sbi/inference/posteriors/mcmc_posterior.py:115: UserWarning: The default value for thinning in MCMC sampling has been changed from 10 to 1. This might cause the results differ from the last benchmark.\n", + " thin = _process_thin_default(thin)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8070275b9eac45d1991d5be41935c145", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Running vectorized MCMC with 20 chains: 0%| | 0/3000 [00:00" + "
" ] }, "metadata": {}, @@ -307,6 +355,7 @@ " points_colors=[\"k\"],\n", " ),\n", " labels=[r\"$\\beta$\", r\"$\\rho$\"],\n", + " figsize=(6, 6),\n", ")\n", "\n", "plt.sca(ax[1, 1])\n", @@ -335,9 +384,38 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fb02120c58a54d029953b4c589f24eca", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Running vectorized MCMC with 20 chains: 0%| | 0/3000 [00:00" + "
" ] }, "metadata": {}, @@ -379,6 +457,7 @@ " points_colors=[\"k\"],\n", " ),\n", " labels=[r\"$\\beta$\", r\"$\\rho$\"],\n", + " figsize=(6, 6),\n", ")\n", "\n", "plt.sca(ax[1, 1])\n", @@ -391,14 +470,14 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "c2st between true and MNLE posterior: 0.567\n" + "c2st between true and MNLE posterior: 0.593\n" ] } ], @@ -423,15 +502,16 @@ "source": [ "## MNLE with experimental conditions\n", "\n", - "In the perceptual decision-making research it is common to design experiments with varying experimental decisions, e.g., to vary the difficulty of the task.\n", + "In the perceptual decision-making research, it is common to design experiments with varying experimental decisions, e.g., to vary the difficulty of the task.\n", "During parameter inference, it can be beneficial to incorporate the experimental conditions.\n", + "\n", "In MNLE, we are learning an emulator that should be able to generate synthetic experimental data including reaction times and choices given different experimental conditions.\n", "Thus, to make MNLE work with experimental conditions, we need to include them in the training process, i.e., treat them like auxiliary parameters of the simulator:\n" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -440,15 +520,17 @@ "def sim_wrapper(theta):\n", " # simulate with experiment conditions\n", " return mixed_simulator(\n", + " # we assume the first two parameters are beta and rho\n", " theta=theta[:, :2],\n", - " concentration_scaling=theta[:, 2:]\n", - " + 1, # add 1 to deal with 0 values from Categorical distribution\n", + " # we treat the third concentration parameter as an experimental condition\n", + " # add 1 to deal with 0 values from Categorical distribution\n", + " concentration_scaling=theta[:, 2:] + 1,\n", " )" ] }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -458,7 +540,7 @@ " [\n", " Gamma(torch.tensor([1.0]), torch.tensor([0.5])),\n", " Beta(torch.tensor([2.0]), torch.tensor([2.0])),\n", - " Categorical(probs=torch.ones(1, 3)),\n", + " Categorical(probs=torch.ones(1, 3)), # 3 discrete conditions\n", " ],\n", " validate_args=False,\n", ")\n", @@ -474,6 +556,7 @@ "num_trials = 10\n", "theta_o = proposal.sample((1,))\n", "theta_o[0, 2] = 2.0 # set condition to 2 as in original simulator.\n", + "# NOTE: we use the same experimental condition for all trials.\n", "x_o = sim_wrapper(theta_o.repeat(num_trials, 1))" ] }, @@ -492,9 +575,32 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/janteusen/qode/sbi/sbi/inference/posteriors/mcmc_posterior.py:115: UserWarning: The default value for thinning in MCMC sampling has been changed from 10 to 1. This might cause the results differ from the last benchmark.\n", + " thin = _process_thin_default(thin)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ad169fdca3da40649e6e1c329460e355", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Running vectorized MCMC with 20 chains: 0%| | 0/3000 [00:00" ]