Skip to content

Commit c5c511c

Browse files
authored
Change xfail to skipif as outcome is not consistent (#1487)
* Change xfail to skipif as outcome is not consistent * skipif for all tests using nuts_pymc sampler
1 parent 70560a3 commit c5c511c

File tree

5 files changed

+43
-15
lines changed

5 files changed

+43
-15
lines changed

tests/inference_on_device_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77
from typing import Tuple, Union
88

9+
import pymc
910
import pytest
1011
import torch
1112
import torch.distributions.transforms as torch_tf
@@ -76,10 +77,10 @@
7677
"nuts_pymc",
7778
marks=(
7879
pytest.mark.mcmc,
79-
pytest.mark.xfail(
80-
condition=sys.version_info >= (3, 10),
81-
reason="Fails with pymc>=5.20.1 and python>=3.10",
82-
raises=TypeError,
80+
pytest.mark.skipif(
81+
condition=sys.version_info >= (3, 10)
82+
and pymc.__version__ >= "5.20.1",
83+
reason="Inconsistent behaviour with pymc>=5.20.1 and python>=3.10",
8384
),
8485
),
8586
),

tests/linearGaussian_snle_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
from __future__ import annotations
55

6+
import sys
7+
8+
import pymc
69
import pytest
710
import torch
811
from torch import eye, ones, zeros
@@ -393,7 +396,18 @@ def simulator(theta):
393396
pytest.param("slice_np", "uniform", marks=pytest.mark.mcmc),
394397
pytest.param("slice_np_vectorized", "gaussian", marks=pytest.mark.mcmc),
395398
pytest.param("slice_np_vectorized", "uniform", marks=pytest.mark.mcmc),
396-
pytest.param("nuts_pymc", "gaussian", marks=pytest.mark.mcmc),
399+
pytest.param(
400+
"nuts_pymc",
401+
"gaussian",
402+
marks=(
403+
pytest.mark.mcmc,
404+
pytest.mark.skipif(
405+
condition=sys.version_info >= (3, 10)
406+
and pymc.__version__ >= "5.20.1",
407+
reason="Inconsistent behaviour with pymc>=5.20.1 and python>=3.10",
408+
),
409+
),
410+
),
397411
pytest.param("nuts_pyro", "uniform", marks=pytest.mark.mcmc),
398412
pytest.param("hmc_pymc", "gaussian", marks=pytest.mark.mcmc),
399413
("rejection", "uniform"),

tests/linearGaussian_snre_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
from __future__ import annotations
55

6+
import sys
7+
8+
import pymc
69
import pytest
710
from torch import eye, ones, zeros
811
from torch.distributions import MultivariateNormal
@@ -317,7 +320,18 @@ def simulator(theta):
317320
pytest.param("slice_np", "uniform", marks=pytest.mark.mcmc),
318321
pytest.param("slice_np_vectorized", "gaussian", marks=pytest.mark.mcmc),
319322
pytest.param("slice_np_vectorized", "uniform", marks=pytest.mark.mcmc),
320-
pytest.param("nuts_pymc", "gaussian", marks=pytest.mark.mcmc),
323+
pytest.param(
324+
"nuts_pymc",
325+
"gaussian",
326+
marks=(
327+
pytest.mark.mcmc,
328+
pytest.mark.skipif(
329+
condition=sys.version_info >= (3, 10)
330+
and pymc.__version__ >= "5.20.1",
331+
reason="Inconsistent behaviour with pymc>=5.20.1 and python>=3.10",
332+
),
333+
),
334+
),
321335
pytest.param("nuts_pyro", "uniform", marks=pytest.mark.mcmc),
322336
pytest.param("hmc_pyro", "gaussian", marks=pytest.mark.mcmc),
323337
("rejection", "uniform"),

tests/mcmc_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77

88
import numpy as np
9+
import pymc
910
import pytest
1011
import torch
1112
from torch import eye, ones, zeros
@@ -187,11 +188,9 @@ def lp_f(x, track_gradients=True):
187188
"hmc_pyro",
188189
pytest.param(
189190
"nuts_pymc",
190-
marks=pytest.mark.xfail(
191-
condition=sys.version_info >= (3, 10),
192-
reason="Fails with pymc>=5.20.1 and python>=3.10",
193-
strict=True,
194-
raises=TypeError,
191+
marks=pytest.mark.skipif(
192+
condition=sys.version_info >= (3, 10) and pymc.__version__ >= "5.20.1",
193+
reason="Inconsistent behaviour with pymc>=5.20.1 and python>=3.10",
195194
),
196195
),
197196
"hmc_pymc",

tests/posterior_sampler_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import sys
77

8+
import pymc
89
import pytest
910
from pyro.infer.mcmc import MCMC
1011
from torch import Tensor, eye, zeros
@@ -29,10 +30,9 @@
2930
"hmc_pyro",
3031
pytest.param(
3132
"nuts_pymc",
32-
marks=pytest.mark.xfail(
33-
condition=sys.version_info >= (3, 10),
34-
reason="Fails with pymc>=5.20.1 and python>=3.10",
35-
raises=TypeError,
33+
marks=pytest.mark.skipif(
34+
condition=sys.version_info >= (3, 10) and pymc.__version__ >= "5.20.1",
35+
reason="Inconsistent behaviour with pymc>=5.20.1 and python>=3.10",
3636
),
3737
),
3838
"hmc_pymc",

0 commit comments

Comments
 (0)