15
15
from torch .distributions import MultivariateNormal
16
16
17
17
from sbi import utils as utils
18
- from sbi .inference import (
19
- ABC ,
20
- FMPE ,
21
- NLE ,
22
- NPE ,
23
- NPE_A ,
24
- NPE_C ,
25
- NPSE ,
26
- NRE_A ,
27
- NRE_B ,
28
- NRE_C ,
29
- VIPosterior ,
30
- likelihood_estimator_based_potential ,
31
- ratio_estimator_based_potential ,
32
- )
18
+ from sbi .inference .abc import MCABC as ABC
33
19
from sbi .inference .posteriors .ensemble_posterior import (
34
20
EnsemblePotential ,
35
21
)
38
24
from sbi .inference .posteriors .posterior_parameters import (
39
25
MCMCPosteriorParameters ,
40
26
)
27
+ from sbi .inference .posteriors .vi_posterior import VIPosterior
41
28
from sbi .inference .potentials .base_potential import BasePotential
42
- from sbi .inference .potentials .likelihood_based_potential import LikelihoodBasedPotential
29
+ from sbi .inference .potentials .likelihood_based_potential import (
30
+ LikelihoodBasedPotential ,
31
+ likelihood_estimator_based_potential ,
32
+ )
43
33
from sbi .inference .potentials .posterior_based_potential import PosteriorBasedPotential
44
- from sbi .inference .potentials .ratio_based_potential import RatioBasedPotential
45
- from sbi .inference .trainers .vfpe .base_vf_inference import VectorFieldTrainer
34
+ from sbi .inference .potentials .ratio_based_potential import (
35
+ RatioBasedPotential ,
36
+ ratio_estimator_based_potential ,
37
+ )
38
+ from sbi .inference .trainers .nle import NLE
39
+ from sbi .inference .trainers .npe import NPE , NPE_A , NPE_C
40
+ from sbi .inference .trainers .nre import NRE_A , NRE_B , NRE_C
41
+ from sbi .inference .trainers .vfpe import FMPE , NPSE
46
42
from sbi .neural_nets .embedding_nets import FCEmbedding
47
43
from sbi .neural_nets .factory import (
48
44
classifier_nn ,
49
45
embedding_net_warn_msg ,
50
46
likelihood_nn ,
51
47
posterior_nn ,
52
48
)
53
- from sbi .simulators import diagonal_linear_gaussian , linear_gaussian
54
- from sbi .utils .torchutils import (
55
- BoxUniform ,
56
- gpu_available ,
57
- process_device ,
58
- )
59
- from sbi .utils .user_input_checks import (
60
- validate_theta_and_x ,
61
- )
49
+ from sbi .simulators .linear_gaussian import diagonal_linear_gaussian , linear_gaussian
50
+ from sbi .utils import BoxUniform
51
+ from sbi .utils .torchutils import gpu_available , process_device
52
+ from sbi .utils .user_input_checks import validate_theta_and_x
62
53
63
- # tests in this file are skipped if there is GPU device available
64
54
pytestmark = pytest .mark .skipif (
65
55
not gpu_available (), reason = "No CUDA or MPS device available."
66
56
)
@@ -720,29 +710,43 @@ def test_to_method_on_posteriors(device: str, sampling_method: str):
720
710
@pytest .mark .gpu
721
711
@pytest .mark .parametrize ("device" , ["cpu" , "gpu" ])
722
712
@pytest .mark .parametrize ("device_inference" , ["cpu" , "gpu" ])
723
- @pytest .mark .parametrize (
724
- "iid_method" , ["fnpe" , "gauss" , "auto_gauss" , "jac_gauss" , None ]
725
- )
726
- @pytest .mark .parametrize ("method" , (FMPE , NPSE ))
727
- def test_VectorFieldPosterior_device_handling (
728
- method : VectorFieldTrainer , device : str , device_inference : str , iid_method : str
713
+ @pytest .mark .parametrize ("num_trials" , [1 , 2 ])
714
+ @pytest .mark .parametrize ("vf_trainer" , [FMPE , NPSE ])
715
+ def test_vector_field_methods_device_handling (
716
+ vf_trainer , device : str , device_inference : str , num_trials : int
729
717
):
730
718
"""Test VectorFieldPosterior on different devices training and inference devices.
731
719
720
+ Tests both ode and sde sampling for both FMPE and NPSE.
721
+
722
+ Tests iid methods for num_trials = 2.
723
+
732
724
Args:
725
+ vf_trainer: vector field trainer class to use.
733
726
device: device to train the model on.
734
727
device_inference: device to run the inference on.
735
728
iid_method: method to sample from the posterior.
736
729
"""
737
- num_trials = 2
730
+
731
+ num_dims = 2
732
+ num_simulations = 1000
733
+ if vf_trainer == NPSE :
734
+ iid_methods = ["fnpe" , "gauss" , "auto_gauss" , "jac_gauss" ]
735
+ else :
736
+ iid_methods = ["fnpe" ]
737
+
738
738
device = process_device (device )
739
739
device_inference = process_device (device_inference )
740
- prior = BoxUniform (torch .zeros (3 ), torch .ones (3 ), device = device )
741
- inference = method (prior = prior , vf_estimator = "mlp" , device = device )
742
- density_estimator = inference .append_simulations (
743
- torch .randn ((100 , 3 )), torch .randn ((100 , 2 ))
744
- ).train (max_num_epochs = 1 )
745
- posterior = inference .build_posterior (density_estimator , prior )
740
+
741
+ prior = BoxUniform (torch .zeros (num_dims ), torch .ones (num_dims ), device = device )
742
+ theta = prior .sample ((num_simulations ,))
743
+ x = theta + 0.1 * torch .randn_like (theta )
744
+
745
+ inference = vf_trainer (prior = prior , device = device )
746
+ _ = inference .append_simulations (theta , x ).train (max_num_epochs = 10 )
747
+ posterior = inference .build_posterior (
748
+ sample_with = "sde" if num_trials > 1 else "ode"
749
+ )
746
750
747
751
# faster but inaccurate log_prob computation
748
752
posterior .potential_fn .neural_ode .update_params (exact = False , atol = 1e-4 , rtol = 1e-4 )
@@ -752,13 +756,23 @@ def test_VectorFieldPosterior_device_handling(
752
756
f"VectorFieldPosterior is not in device { device_inference } ."
753
757
)
754
758
755
- x_o = torch .ones (num_trials ).to (device_inference )
756
- samples = posterior .sample ((2 ,), x = x_o , iid_method = iid_method )
757
- assert samples .device .type == device_inference .split (":" )[0 ], (
758
- f"Samples are not on device { device_inference } ."
759
- )
759
+ x_o = torch .ones (num_trials , num_dims ).to (device_inference )
760
+ if num_trials > 1 :
761
+ for iid_method in iid_methods :
762
+ samples = posterior .sample ((2 ,), x = x_o , iid_method = iid_method )
763
+ assert samples .device .type == device_inference .split (":" )[0 ], (
764
+ f"Samples are not on device { device_inference } . "
765
+ f"{ vf_trainer .__name__ } with { iid_method } "
766
+ )
767
+ else :
768
+ samples = posterior .sample ((2 ,), x = x_o )
769
+ assert samples .device .type == device_inference .split (":" )[0 ], (
770
+ f"Samples are not on device { device_inference } . "
771
+ f"{ vf_trainer .__name__ } with { iid_method } "
772
+ )
760
773
761
- log_probs = posterior .log_prob (samples , x = x_o )
762
- assert log_probs .device .type == device_inference .split (":" )[0 ], (
763
- f"log_prob was not correctly moved to { device_inference } ."
764
- )
774
+ log_probs = posterior .log_prob (samples , x = x_o )
775
+ assert log_probs .device .type == device_inference .split (":" )[0 ], (
776
+ f"log_prob was not correctly moved to { device_inference } . "
777
+ f"{ vf_trainer .__name__ } with { iid_method } "
778
+ )
0 commit comments