|
3 | 3 |
|
4 | 4 | from __future__ import annotations
|
5 | 5 |
|
| 6 | +import sys |
| 7 | + |
| 8 | +import pymc |
6 | 9 | import pytest
|
7 | 10 | from torch import eye, ones, zeros
|
8 | 11 | from torch.distributions import MultivariateNormal
|
@@ -317,7 +320,18 @@ def simulator(theta):
|
317 | 320 | pytest.param("slice_np", "uniform", marks=pytest.mark.mcmc),
|
318 | 321 | pytest.param("slice_np_vectorized", "gaussian", marks=pytest.mark.mcmc),
|
319 | 322 | pytest.param("slice_np_vectorized", "uniform", marks=pytest.mark.mcmc),
|
320 |
| - pytest.param("nuts_pymc", "gaussian", marks=pytest.mark.mcmc), |
| 323 | + pytest.param( |
| 324 | + "nuts_pymc", |
| 325 | + "gaussian", |
| 326 | + marks=( |
| 327 | + pytest.mark.mcmc, |
| 328 | + pytest.mark.skipif( |
| 329 | + condition=sys.version_info >= (3, 10) |
| 330 | + and pymc.__version__ >= "5.20.1", |
| 331 | + reason="Inconsistent behaviour with pymc>=5.20.1 and python>=3.10", |
| 332 | + ), |
| 333 | + ), |
| 334 | + ), |
321 | 335 | pytest.param("nuts_pyro", "uniform", marks=pytest.mark.mcmc),
|
322 | 336 | pytest.param("hmc_pyro", "gaussian", marks=pytest.mark.mcmc),
|
323 | 337 | ("rejection", "uniform"),
|
|
0 commit comments