Skip to content

Commit 53cb22d

Browse files
committed
fix: iid score device handling, ref tests
1 parent 85bb404 commit 53cb22d

File tree

3 files changed

+78
-56
lines changed

3 files changed

+78
-56
lines changed

sbi/inference/potentials/score_fn_iid.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108

109109
self.vector_field_estimator = vector_field_estimator.to(device).eval()
110110
self.prior = prior
111+
self.device = device
111112

112113
def to(self, device: Union[str, torch.device]) -> None:
113114
"""
@@ -332,9 +333,9 @@ def marginal_denoising_posterior_precision_est_fn(
332333
std = self.vector_field_estimator.std_fn(time)
333334

334335
if precisions_posteriors.ndim == 4:
335-
Ident = torch.eye(precisions_posteriors.shape[-1])
336+
Ident = torch.eye(precisions_posteriors.shape[-1], device=self.device)
336337
else:
337-
Ident = torch.ones_like(precisions_posteriors)
338+
Ident = torch.ones_like(precisions_posteriors, device=self.device)
338339

339340
marginal_precisions = m**2 / std**2 * Ident + precisions_posteriors
340341
return marginal_precisions
@@ -649,7 +650,9 @@ def estimate_posterior_precision(
649650
# NOTE: To avoid circular imports :(
650651
from sbi.inference.posteriors.vector_field_posterior import VectorFieldPosterior
651652

652-
posterior = VectorFieldPosterior(vector_field_estimator, prior)
653+
posterior = VectorFieldPosterior(
654+
vector_field_estimator, prior, device=conditions.device
655+
)
653656

654657
if precision_est_budget is None:
655658
if precision_est_only_diag:
@@ -725,7 +728,7 @@ def marginal_denoising_posterior_precision_est_fn(
725728

726729
m = self.vector_field_estimator.mean_t_fn(time)
727730
std = self.vector_field_estimator.std_fn(time)
728-
cov0 = std**2 * jac + torch.eye(d)[None, None, :, :]
731+
cov0 = std**2 * jac + torch.eye(d, device=self.device)[None, None, :, :]
729732

730733
denoising_posterior_precision = m**2 / std**2 * torch.inverse(cov0)
731734

sbi/inference/potentials/vector_field_potential.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ def gradient(
237237
"is not defined for this vector field estimator."
238238
)
239239

240+
device = theta.device
241+
240242
if time is None:
241243
time = torch.tensor([self.vector_field_estimator.t_min])
242244

@@ -249,14 +251,17 @@ def gradient(
249251
with torch.set_grad_enabled(track_gradients):
250252
if not self.x_is_iid or self._x_o.shape[0] == 1:
251253
score = self.vector_field_estimator.score(
252-
input=theta, condition=self.x_o, t=time
254+
input=theta, condition=self.x_o, t=time.to(device)
253255
)
254256
else:
255257
assert self.prior is not None, "Prior is required for iid methods."
256258

257259
iid_method = get_iid_method(self.iid_method)
258260
score_fn_iid = iid_method(
259-
self.vector_field_estimator, self.prior, **(self.iid_params or {})
261+
self.vector_field_estimator,
262+
self.prior,
263+
device=device,
264+
**(self.iid_params or {}),
260265
)
261266

262267
score = score_fn_iid(theta, self.x_o, time)

tests/inference_on_device_test.py

Lines changed: 64 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,7 @@
1515
from torch.distributions import MultivariateNormal
1616

1717
from sbi import utils as utils
18-
from sbi.inference import (
19-
ABC,
20-
FMPE,
21-
NLE,
22-
NPE,
23-
NPE_A,
24-
NPE_C,
25-
NPSE,
26-
NRE_A,
27-
NRE_B,
28-
NRE_C,
29-
VIPosterior,
30-
likelihood_estimator_based_potential,
31-
ratio_estimator_based_potential,
32-
)
18+
from sbi.inference.abc import MCABC as ABC
3319
from sbi.inference.posteriors.ensemble_posterior import (
3420
EnsemblePotential,
3521
)
@@ -38,29 +24,33 @@
3824
from sbi.inference.posteriors.posterior_parameters import (
3925
MCMCPosteriorParameters,
4026
)
27+
from sbi.inference.posteriors.vi_posterior import VIPosterior
4128
from sbi.inference.potentials.base_potential import BasePotential
42-
from sbi.inference.potentials.likelihood_based_potential import LikelihoodBasedPotential
29+
from sbi.inference.potentials.likelihood_based_potential import (
30+
LikelihoodBasedPotential,
31+
likelihood_estimator_based_potential,
32+
)
4333
from sbi.inference.potentials.posterior_based_potential import PosteriorBasedPotential
44-
from sbi.inference.potentials.ratio_based_potential import RatioBasedPotential
45-
from sbi.inference.trainers.vfpe.base_vf_inference import VectorFieldTrainer
34+
from sbi.inference.potentials.ratio_based_potential import (
35+
RatioBasedPotential,
36+
ratio_estimator_based_potential,
37+
)
38+
from sbi.inference.trainers.nle import NLE
39+
from sbi.inference.trainers.npe import NPE, NPE_A, NPE_C
40+
from sbi.inference.trainers.nre import NRE_A, NRE_B, NRE_C
41+
from sbi.inference.trainers.vfpe import FMPE, NPSE
4642
from sbi.neural_nets.embedding_nets import FCEmbedding
4743
from sbi.neural_nets.factory import (
4844
classifier_nn,
4945
embedding_net_warn_msg,
5046
likelihood_nn,
5147
posterior_nn,
5248
)
53-
from sbi.simulators import diagonal_linear_gaussian, linear_gaussian
54-
from sbi.utils.torchutils import (
55-
BoxUniform,
56-
gpu_available,
57-
process_device,
58-
)
59-
from sbi.utils.user_input_checks import (
60-
validate_theta_and_x,
61-
)
49+
from sbi.simulators.linear_gaussian import diagonal_linear_gaussian, linear_gaussian
50+
from sbi.utils import BoxUniform
51+
from sbi.utils.torchutils import gpu_available, process_device
52+
from sbi.utils.user_input_checks import validate_theta_and_x
6253

63-
# tests in this file are skipped if there is GPU device available
6454
pytestmark = pytest.mark.skipif(
6555
not gpu_available(), reason="No CUDA or MPS device available."
6656
)
@@ -720,29 +710,43 @@ def test_to_method_on_posteriors(device: str, sampling_method: str):
720710
@pytest.mark.gpu
721711
@pytest.mark.parametrize("device", ["cpu", "gpu"])
722712
@pytest.mark.parametrize("device_inference", ["cpu", "gpu"])
723-
@pytest.mark.parametrize(
724-
"iid_method", ["fnpe", "gauss", "auto_gauss", "jac_gauss", None]
725-
)
726-
@pytest.mark.parametrize("method", (FMPE, NPSE))
727-
def test_VectorFieldPosterior_device_handling(
728-
method: VectorFieldTrainer, device: str, device_inference: str, iid_method: str
713+
@pytest.mark.parametrize("num_trials", [1, 2])
714+
@pytest.mark.parametrize("vf_trainer", [FMPE, NPSE])
715+
def test_vector_field_methods_device_handling(
716+
vf_trainer, device: str, device_inference: str, num_trials: int
729717
):
730718
"""Test VectorFieldPosterior on different devices training and inference devices.
731719
720+
Tests both ode and sde sampling for both FMPE and NPSE.
721+
722+
Tests iid methods for num_trials = 2.
723+
732724
Args:
725+
vf_trainer: vector field trainer class to use.
733726
device: device to train the model on.
734727
device_inference: device to run the inference on.
735728
iid_method: method to sample from the posterior.
736729
"""
737-
num_trials = 2
730+
731+
num_dims = 2
732+
num_simulations = 1000
733+
if vf_trainer == NPSE:
734+
iid_methods = ["fnpe", "gauss", "auto_gauss", "jac_gauss"]
735+
else:
736+
iid_methods = ["fnpe"]
737+
738738
device = process_device(device)
739739
device_inference = process_device(device_inference)
740-
prior = BoxUniform(torch.zeros(3), torch.ones(3), device=device)
741-
inference = method(prior=prior, vf_estimator="mlp", device=device)
742-
density_estimator = inference.append_simulations(
743-
torch.randn((100, 3)), torch.randn((100, 2))
744-
).train(max_num_epochs=1)
745-
posterior = inference.build_posterior(density_estimator, prior)
740+
741+
prior = BoxUniform(torch.zeros(num_dims), torch.ones(num_dims), device=device)
742+
theta = prior.sample((num_simulations,))
743+
x = theta + 0.1 * torch.randn_like(theta)
744+
745+
inference = vf_trainer(prior=prior, device=device)
746+
_ = inference.append_simulations(theta, x).train(max_num_epochs=10)
747+
posterior = inference.build_posterior(
748+
sample_with="sde" if num_trials > 1 else "ode"
749+
)
746750

747751
# faster but inaccurate log_prob computation
748752
posterior.potential_fn.neural_ode.update_params(exact=False, atol=1e-4, rtol=1e-4)
@@ -752,13 +756,23 @@ def test_VectorFieldPosterior_device_handling(
752756
f"VectorFieldPosterior is not in device {device_inference}."
753757
)
754758

755-
x_o = torch.ones(num_trials).to(device_inference)
756-
samples = posterior.sample((2,), x=x_o, iid_method=iid_method)
757-
assert samples.device.type == device_inference.split(":")[0], (
758-
f"Samples are not on device {device_inference}."
759-
)
759+
x_o = torch.ones(num_trials, num_dims).to(device_inference)
760+
if num_trials > 1:
761+
for iid_method in iid_methods:
762+
samples = posterior.sample((2,), x=x_o, iid_method=iid_method)
763+
assert samples.device.type == device_inference.split(":")[0], (
764+
f"Samples are not on device {device_inference}. "
765+
f"{vf_trainer.__name__} with {iid_method}"
766+
)
767+
else:
768+
samples = posterior.sample((2,), x=x_o)
769+
assert samples.device.type == device_inference.split(":")[0], (
770+
f"Samples are not on device {device_inference}. "
771+
f"{vf_trainer.__name__} with {iid_method}"
772+
)
760773

761-
log_probs = posterior.log_prob(samples, x=x_o)
762-
assert log_probs.device.type == device_inference.split(":")[0], (
763-
f"log_prob was not correctly moved to {device_inference}."
764-
)
774+
log_probs = posterior.log_prob(samples, x=x_o)
775+
assert log_probs.device.type == device_inference.split(":")[0], (
776+
f"log_prob was not correctly moved to {device_inference}. "
777+
f"{vf_trainer.__name__} with {iid_method}"
778+
)

0 commit comments

Comments
 (0)