Skip to content

Commit e3fdb10

Browse files
authored
fix iid-score device handling (#1650)
* fix device handling * fix: iid score device handling, ref tests * remove unnecessary to(device)
1 parent 333427f commit e3fdb10

File tree

4 files changed

+81
-55
lines changed

4 files changed

+81
-55
lines changed

sbi/inference/potentials/score_fn_iid.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108

109109
self.vector_field_estimator = vector_field_estimator.to(device).eval()
110110
self.prior = prior
111+
self.device = device
111112

112113
def to(self, device: Union[str, torch.device]) -> None:
113114
"""
@@ -332,9 +333,9 @@ def marginal_denoising_posterior_precision_est_fn(
332333
std = self.vector_field_estimator.std_fn(time)
333334

334335
if precisions_posteriors.ndim == 4:
335-
Ident = torch.eye(precisions_posteriors.shape[-1])
336+
Ident = torch.eye(precisions_posteriors.shape[-1], device=self.device)
336337
else:
337-
Ident = torch.ones_like(precisions_posteriors)
338+
Ident = torch.ones_like(precisions_posteriors, device=self.device)
338339

339340
marginal_precisions = m**2 / std**2 * Ident + precisions_posteriors
340341
return marginal_precisions
@@ -649,7 +650,9 @@ def estimate_posterior_precision(
649650
# NOTE: To avoid circular imports :(
650651
from sbi.inference.posteriors.vector_field_posterior import VectorFieldPosterior
651652

652-
posterior = VectorFieldPosterior(vector_field_estimator, prior)
653+
posterior = VectorFieldPosterior(
654+
vector_field_estimator, prior, device=conditions.device
655+
)
653656

654657
if precision_est_budget is None:
655658
if precision_est_only_diag:
@@ -725,7 +728,7 @@ def marginal_denoising_posterior_precision_est_fn(
725728

726729
m = self.vector_field_estimator.mean_t_fn(time)
727730
std = self.vector_field_estimator.std_fn(time)
728-
cov0 = std**2 * jac + torch.eye(d)[None, None, :, :]
731+
cov0 = std**2 * jac + torch.eye(d, device=self.device)[None, None, :, :]
729732

730733
denoising_posterior_precision = m**2 / std**2 * torch.inverse(cov0)
731734

sbi/inference/potentials/vector_field_potential.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ def gradient(
237237
"is not defined for this vector field estimator."
238238
)
239239

240+
device = theta.device
241+
240242
if time is None:
241243
time = torch.tensor([self.vector_field_estimator.t_min])
242244

@@ -249,14 +251,17 @@ def gradient(
249251
with torch.set_grad_enabled(track_gradients):
250252
if not self.x_is_iid or self._x_o.shape[0] == 1:
251253
score = self.vector_field_estimator.score(
252-
input=theta, condition=self.x_o, t=time
254+
input=theta, condition=self.x_o, t=time.to(device)
253255
)
254256
else:
255257
assert self.prior is not None, "Prior is required for iid methods."
256258

257259
iid_method = get_iid_method(self.iid_method)
258260
score_fn_iid = iid_method(
259-
self.vector_field_estimator, self.prior, **(self.iid_params or {})
261+
self.vector_field_estimator,
262+
self.prior,
263+
device=device,
264+
**(self.iid_params or {}),
260265
)
261266

262267
score = score_fn_iid(theta, self.x_o, time)

sbi/neural_nets/estimators/score_estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
662662
Returns:
663663
Drift function at a given time.
664664
"""
665-
phi = -0.5 * self._beta_schedule(times)
665+
phi = -0.5 * self._beta_schedule(times).to(input.device)
666666

667667
while len(phi.shape) < len(input.shape):
668668
phi = phi.unsqueeze(-1)
@@ -800,7 +800,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
800800
Returns:
801801
Drift function at a given time.
802802
"""
803-
return torch.tensor([0.0])
803+
return torch.tensor([0.0], device=input.device)
804804

805805
def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
806806
"""Diffusion function for variance exploding SDEs.
@@ -819,4 +819,4 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
819819
while len(g.shape) < len(input.shape):
820820
g = g.unsqueeze(-1)
821821

822-
return g
822+
return g.to(input.device)

tests/inference_on_device_test.py

Lines changed: 64 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,7 @@
1515
from torch.distributions import MultivariateNormal
1616

1717
from sbi import utils as utils
18-
from sbi.inference import (
19-
ABC,
20-
FMPE,
21-
NLE,
22-
NPE,
23-
NPE_A,
24-
NPE_C,
25-
NRE_A,
26-
NRE_B,
27-
NRE_C,
28-
VIPosterior,
29-
likelihood_estimator_based_potential,
30-
ratio_estimator_based_potential,
31-
)
18+
from sbi.inference.abc import MCABC as ABC
3219
from sbi.inference.posteriors.ensemble_posterior import (
3320
EnsemblePotential,
3421
)
@@ -37,28 +24,33 @@
3724
from sbi.inference.posteriors.posterior_parameters import (
3825
MCMCPosteriorParameters,
3926
)
27+
from sbi.inference.posteriors.vi_posterior import VIPosterior
4028
from sbi.inference.potentials.base_potential import BasePotential
41-
from sbi.inference.potentials.likelihood_based_potential import LikelihoodBasedPotential
29+
from sbi.inference.potentials.likelihood_based_potential import (
30+
LikelihoodBasedPotential,
31+
likelihood_estimator_based_potential,
32+
)
4233
from sbi.inference.potentials.posterior_based_potential import PosteriorBasedPotential
43-
from sbi.inference.potentials.ratio_based_potential import RatioBasedPotential
34+
from sbi.inference.potentials.ratio_based_potential import (
35+
RatioBasedPotential,
36+
ratio_estimator_based_potential,
37+
)
38+
from sbi.inference.trainers.nle import NLE
39+
from sbi.inference.trainers.npe import NPE, NPE_A, NPE_C
40+
from sbi.inference.trainers.nre import NRE_A, NRE_B, NRE_C
41+
from sbi.inference.trainers.vfpe import FMPE, NPSE
4442
from sbi.neural_nets.embedding_nets import FCEmbedding
4543
from sbi.neural_nets.factory import (
4644
classifier_nn,
4745
embedding_net_warn_msg,
4846
likelihood_nn,
4947
posterior_nn,
5048
)
51-
from sbi.simulators import diagonal_linear_gaussian, linear_gaussian
52-
from sbi.utils.torchutils import (
53-
BoxUniform,
54-
gpu_available,
55-
process_device,
56-
)
57-
from sbi.utils.user_input_checks import (
58-
validate_theta_and_x,
59-
)
49+
from sbi.simulators.linear_gaussian import diagonal_linear_gaussian, linear_gaussian
50+
from sbi.utils import BoxUniform
51+
from sbi.utils.torchutils import gpu_available, process_device
52+
from sbi.utils.user_input_checks import validate_theta_and_x
6053

61-
# tests in this file are skipped if there is GPU device available
6254
pytestmark = pytest.mark.skipif(
6355
not gpu_available(), reason="No CUDA or MPS device available."
6456
)
@@ -718,27 +710,43 @@ def test_to_method_on_posteriors(device: str, sampling_method: str):
718710
@pytest.mark.gpu
719711
@pytest.mark.parametrize("device", ["cpu", "gpu"])
720712
@pytest.mark.parametrize("device_inference", ["cpu", "gpu"])
721-
@pytest.mark.parametrize(
722-
"iid_method", ["fnpe", "gauss", "auto_gauss", "jac_gauss", None]
723-
)
724-
def test_VectorFieldPosterior_device_handling(
725-
device: str, device_inference: str, iid_method: str
713+
@pytest.mark.parametrize("num_trials", [1, 2])
714+
@pytest.mark.parametrize("vf_trainer", [FMPE, NPSE])
715+
def test_vector_field_methods_device_handling(
716+
vf_trainer, device: str, device_inference: str, num_trials: int
726717
):
727718
"""Test VectorFieldPosterior on different devices training and inference devices.
728719
720+
Tests both ode and sde sampling for both FMPE and NPSE.
721+
722+
Tests iid methods for num_trials = 2.
723+
729724
Args:
725+
vf_trainer: vector field trainer class to use.
730726
device: device to train the model on.
731727
device_inference: device to run the inference on.
732728
iid_method: method to sample from the posterior.
733729
"""
730+
731+
num_dims = 2
732+
num_simulations = 1000
733+
if vf_trainer == NPSE:
734+
iid_methods = ["fnpe", "gauss", "auto_gauss", "jac_gauss"]
735+
else:
736+
iid_methods = ["fnpe"]
737+
734738
device = process_device(device)
735739
device_inference = process_device(device_inference)
736-
prior = BoxUniform(torch.zeros(3), torch.ones(3), device=device)
737-
inference = FMPE(score_estimator="mlp", prior=prior, device=device)
738-
density_estimator = inference.append_simulations(
739-
torch.randn((100, 3)), torch.randn((100, 2))
740-
).train(max_num_epochs=1)
741-
posterior = inference.build_posterior(density_estimator, prior)
740+
741+
prior = BoxUniform(torch.zeros(num_dims), torch.ones(num_dims), device=device)
742+
theta = prior.sample((num_simulations,))
743+
x = theta + 0.1 * torch.randn_like(theta)
744+
745+
inference = vf_trainer(prior=prior, device=device)
746+
_ = inference.append_simulations(theta, x).train(max_num_epochs=10)
747+
posterior = inference.build_posterior(
748+
sample_with="sde" if num_trials > 1 else "ode"
749+
)
742750

743751
# faster but inaccurate log_prob computation
744752
posterior.potential_fn.neural_ode.update_params(exact=False, atol=1e-4, rtol=1e-4)
@@ -748,13 +756,23 @@ def test_VectorFieldPosterior_device_handling(
748756
f"VectorFieldPosterior is not in device {device_inference}."
749757
)
750758

751-
x_o = torch.ones(2).to(device_inference)
752-
samples = posterior.sample((2,), x=x_o, iid_method=iid_method)
753-
assert samples.device.type == device_inference.split(":")[0], (
754-
f"Samples are not on device {device_inference}."
755-
)
759+
x_o = torch.ones(num_trials, num_dims).to(device_inference)
760+
if num_trials > 1:
761+
for iid_method in iid_methods:
762+
samples = posterior.sample((2,), x=x_o, iid_method=iid_method)
763+
assert samples.device.type == device_inference.split(":")[0], (
764+
f"Samples are not on device {device_inference}. "
765+
f"{vf_trainer.__name__} with {iid_method}"
766+
)
767+
else:
768+
samples = posterior.sample((2,), x=x_o)
769+
assert samples.device.type == device_inference.split(":")[0], (
770+
f"Samples are not on device {device_inference}. "
771+
f"{vf_trainer.__name__} with {iid_method}"
772+
)
756773

757-
log_probs = posterior.log_prob(samples, x=x_o)
758-
assert log_probs.device.type == device_inference.split(":")[0], (
759-
f"log_prob was not correctly moved to {device_inference}."
760-
)
774+
log_probs = posterior.log_prob(samples, x=x_o)
775+
assert log_probs.device.type == device_inference.split(":")[0], (
776+
f"log_prob was not correctly moved to {device_inference}. "
777+
f"{vf_trainer.__name__} with {iid_method}"
778+
)

0 commit comments

Comments
 (0)