15
15
from torch .distributions import MultivariateNormal
16
16
17
17
from sbi import utils as utils
18
- from sbi .inference import (
19
- ABC ,
20
- FMPE ,
21
- NLE ,
22
- NPE ,
23
- NPE_A ,
24
- NPE_C ,
25
- NRE_A ,
26
- NRE_B ,
27
- NRE_C ,
28
- VIPosterior ,
29
- likelihood_estimator_based_potential ,
30
- ratio_estimator_based_potential ,
31
- )
18
+ from sbi .inference .abc import MCABC as ABC
32
19
from sbi .inference .posteriors .ensemble_posterior import (
33
20
EnsemblePotential ,
34
21
)
37
24
from sbi .inference .posteriors .posterior_parameters import (
38
25
MCMCPosteriorParameters ,
39
26
)
27
+ from sbi .inference .posteriors .vi_posterior import VIPosterior
40
28
from sbi .inference .potentials .base_potential import BasePotential
41
- from sbi .inference .potentials .likelihood_based_potential import LikelihoodBasedPotential
29
+ from sbi .inference .potentials .likelihood_based_potential import (
30
+ LikelihoodBasedPotential ,
31
+ likelihood_estimator_based_potential ,
32
+ )
42
33
from sbi .inference .potentials .posterior_based_potential import PosteriorBasedPotential
43
- from sbi .inference .potentials .ratio_based_potential import RatioBasedPotential
34
+ from sbi .inference .potentials .ratio_based_potential import (
35
+ RatioBasedPotential ,
36
+ ratio_estimator_based_potential ,
37
+ )
38
+ from sbi .inference .trainers .nle import NLE
39
+ from sbi .inference .trainers .npe import NPE , NPE_A , NPE_C
40
+ from sbi .inference .trainers .nre import NRE_A , NRE_B , NRE_C
41
+ from sbi .inference .trainers .vfpe import FMPE , NPSE
44
42
from sbi .neural_nets .embedding_nets import FCEmbedding
45
43
from sbi .neural_nets .factory import (
46
44
classifier_nn ,
47
45
embedding_net_warn_msg ,
48
46
likelihood_nn ,
49
47
posterior_nn ,
50
48
)
51
- from sbi .simulators import diagonal_linear_gaussian , linear_gaussian
52
- from sbi .utils .torchutils import (
53
- BoxUniform ,
54
- gpu_available ,
55
- process_device ,
56
- )
57
- from sbi .utils .user_input_checks import (
58
- validate_theta_and_x ,
59
- )
49
+ from sbi .simulators .linear_gaussian import diagonal_linear_gaussian , linear_gaussian
50
+ from sbi .utils import BoxUniform
51
+ from sbi .utils .torchutils import gpu_available , process_device
52
+ from sbi .utils .user_input_checks import validate_theta_and_x
60
53
61
- # tests in this file are skipped if there is GPU device available
62
54
pytestmark = pytest .mark .skipif (
63
55
not gpu_available (), reason = "No CUDA or MPS device available."
64
56
)
@@ -718,27 +710,43 @@ def test_to_method_on_posteriors(device: str, sampling_method: str):
718
710
@pytest .mark .gpu
719
711
@pytest .mark .parametrize ("device" , ["cpu" , "gpu" ])
720
712
@pytest .mark .parametrize ("device_inference" , ["cpu" , "gpu" ])
721
- @pytest .mark .parametrize (
722
- "iid_method" , ["fnpe" , "gauss" , "auto_gauss" , "jac_gauss" , None ]
723
- )
724
- def test_VectorFieldPosterior_device_handling (
725
- device : str , device_inference : str , iid_method : str
713
+ @pytest .mark .parametrize ("num_trials" , [1 , 2 ])
714
+ @pytest .mark .parametrize ("vf_trainer" , [FMPE , NPSE ])
715
+ def test_vector_field_methods_device_handling (
716
+ vf_trainer , device : str , device_inference : str , num_trials : int
726
717
):
727
718
"""Test VectorFieldPosterior on different devices training and inference devices.
728
719
720
+ Tests both ode and sde sampling for both FMPE and NPSE.
721
+
722
+ Tests iid methods for num_trials = 2.
723
+
729
724
Args:
725
+ vf_trainer: vector field trainer class to use.
730
726
device: device to train the model on.
731
727
device_inference: device to run the inference on.
732
728
iid_method: method to sample from the posterior.
733
729
"""
730
+
731
+ num_dims = 2
732
+ num_simulations = 1000
733
+ if vf_trainer == NPSE :
734
+ iid_methods = ["fnpe" , "gauss" , "auto_gauss" , "jac_gauss" ]
735
+ else :
736
+ iid_methods = ["fnpe" ]
737
+
734
738
device = process_device (device )
735
739
device_inference = process_device (device_inference )
736
- prior = BoxUniform (torch .zeros (3 ), torch .ones (3 ), device = device )
737
- inference = FMPE (score_estimator = "mlp" , prior = prior , device = device )
738
- density_estimator = inference .append_simulations (
739
- torch .randn ((100 , 3 )), torch .randn ((100 , 2 ))
740
- ).train (max_num_epochs = 1 )
741
- posterior = inference .build_posterior (density_estimator , prior )
740
+
741
+ prior = BoxUniform (torch .zeros (num_dims ), torch .ones (num_dims ), device = device )
742
+ theta = prior .sample ((num_simulations ,))
743
+ x = theta + 0.1 * torch .randn_like (theta )
744
+
745
+ inference = vf_trainer (prior = prior , device = device )
746
+ _ = inference .append_simulations (theta , x ).train (max_num_epochs = 10 )
747
+ posterior = inference .build_posterior (
748
+ sample_with = "sde" if num_trials > 1 else "ode"
749
+ )
742
750
743
751
# faster but inaccurate log_prob computation
744
752
posterior .potential_fn .neural_ode .update_params (exact = False , atol = 1e-4 , rtol = 1e-4 )
@@ -748,13 +756,23 @@ def test_VectorFieldPosterior_device_handling(
748
756
f"VectorFieldPosterior is not in device { device_inference } ."
749
757
)
750
758
751
- x_o = torch .ones (2 ).to (device_inference )
752
- samples = posterior .sample ((2 ,), x = x_o , iid_method = iid_method )
753
- assert samples .device .type == device_inference .split (":" )[0 ], (
754
- f"Samples are not on device { device_inference } ."
755
- )
759
+ x_o = torch .ones (num_trials , num_dims ).to (device_inference )
760
+ if num_trials > 1 :
761
+ for iid_method in iid_methods :
762
+ samples = posterior .sample ((2 ,), x = x_o , iid_method = iid_method )
763
+ assert samples .device .type == device_inference .split (":" )[0 ], (
764
+ f"Samples are not on device { device_inference } . "
765
+ f"{ vf_trainer .__name__ } with { iid_method } "
766
+ )
767
+ else :
768
+ samples = posterior .sample ((2 ,), x = x_o )
769
+ assert samples .device .type == device_inference .split (":" )[0 ], (
770
+ f"Samples are not on device { device_inference } . "
771
+ f"{ vf_trainer .__name__ } with { iid_method } "
772
+ )
756
773
757
- log_probs = posterior .log_prob (samples , x = x_o )
758
- assert log_probs .device .type == device_inference .split (":" )[0 ], (
759
- f"log_prob was not correctly moved to { device_inference } ."
760
- )
774
+ log_probs = posterior .log_prob (samples , x = x_o )
775
+ assert log_probs .device .type == device_inference .split (":" )[0 ], (
776
+ f"log_prob was not correctly moved to { device_inference } . "
777
+ f"{ vf_trainer .__name__ } with { iid_method } "
778
+ )
0 commit comments