From d7cedd7fe9c0dc2b89e991cf20d04346a55d31c5 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sat, 2 Aug 2025 11:07:24 -0400 Subject: [PATCH 1/4] add kwarg to allow for custom sample_stats --- pymc/testing.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/pymc/testing.py b/pymc/testing.py index 886177ef0..ec1bc4678 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -20,10 +20,12 @@ import numpy as np import pytensor import pytensor.tensor as pt +import xarray as xr from arviz import InferenceData from numpy import random as nr from numpy import testing as npt +from numpy.typing import NDArray from pytensor.compile import SharedVariable from pytensor.compile.mode import Mode from pytensor.graph.basic import Constant, Variable, equal_computations, graph_inputs @@ -976,7 +978,14 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None: raise AssertionError(f"RV found in graph: {rvs}") -def mock_sample(draws: int = 10, **kwargs): +SampleStatsCreator = Callable[[tuple[str, ...]], NDArray] + + +def mock_sample( + draws: int = 10, + sample_stats: dict[str, SampleStatsCreator] | None = None, + **kwargs, +) -> InferenceData: """Mock :func:`pymc.sample` with :func:`pymc.sample_prior_predictive`. Useful for testing models that use pm.sample without running MCMC sampling. @@ -1006,6 +1015,36 @@ def mock_pymc_sample(): pm.sample = original_sample + By default, the sample_stats group is not created. Pass a dictionary of functions + that create sample statistics, where the keys are the names of the statistics + and the values are functions that take a size tuple and return an array of that size. + + .. code-block:: python + + from functools import partial + + import numpy as np + import numpy.typing as npt + + from pymc.testing import mock_sample + + + def mock_diverging(size: tuple[str, ...]) -> npt.NDArray: + return np.zeros(size) + + + def mock_tree_depth(size: tuple[str, ...]) -> npt.NDArray: + return np.random.choice(range(2, 10), size=size) + + + mock_sample_with_stats = partial( + mock_sample, + sample_stats={ + "diverging": mock_diverging, + "tree_depth": mock_tree_depth, + }, + ) + """ random_seed = kwargs.get("random_seed", None) model = kwargs.get("model", None) @@ -1028,6 +1067,16 @@ def mock_pymc_sample(): del idata["prior"] if "prior_predictive" in idata: del idata["prior_predictive"] + + if sample_stats is not None: + sizes = idata["posterior"].sizes + size = (sizes["chain"], sizes["draw"]) + sample_stats_ds = xr.Dataset( + {name: (("chain", "draw"), creator(size)) for name, creator in sample_stats.items()}, + coords=idata["posterior"].coords, + ) + idata.add_groups(sample_stats=sample_stats_ds) + return idata From a7a1b3f3120765650d8d42d8486f4de90d20a2b4 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sat, 2 Aug 2025 11:07:59 -0400 Subject: [PATCH 2/4] test the testing --- tests/test_testing.py | 38 +++++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/tests/test_testing.py b/tests/test_testing.py index 105e2f620..b49ec768b 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -13,6 +13,7 @@ # limitations under the License. from contextlib import ExitStack as does_not_raise +import numpy as np import pytest import pymc as pm @@ -38,28 +39,47 @@ def test_domain(values, edges, expectation): @pytest.mark.parametrize( - "args, kwargs, expected_size", + "args, kwargs, expected_size, sample_stats", [ - pytest.param((), {}, (1, 10), id="default"), - pytest.param((100,), {}, (1, 100), id="positional-draws"), - pytest.param((), {"draws": 100}, (1, 100), id="keyword-draws"), - pytest.param((100,), {"chains": 6}, (6, 100), id="chains"), + pytest.param((), {}, (1, 10), None, id="default"), + pytest.param((100,), {}, (1, 100), None, id="positional-draws"), + pytest.param((), {"draws": 100}, (1, 100), None, id="keyword-draws"), + pytest.param((100,), {"chains": 6}, (6, 100), None, id="chains"), + pytest.param( + (100,), + {"chains": 6}, + (6, 100), + { + "diverging": np.zeros, + "tree_depth": lambda size: np.random.choice(range(2, 10), size=size), + }, + id="with_sample_stats", + ), ], ) -def test_mock_sample(args, kwargs, expected_size) -> None: +def test_mock_sample(args, kwargs, expected_size, sample_stats) -> None: expected_chains, expected_draws = expected_size _, model, _ = simple_normal(bounded_prior=True) with model: - idata = mock_sample(*args, **kwargs) + idata = mock_sample(*args, **kwargs, sample_stats=sample_stats) assert "posterior" in idata assert "observed_data" in idata assert "prior" not in idata assert "posterior_predictive" not in idata - assert "sample_stats" not in idata - assert idata.posterior.sizes == {"chain": expected_chains, "draw": expected_draws} + expected_sizes = {"chain": expected_chains, "draw": expected_draws} + + if sample_stats: + sample_stats_ds = idata["sample_stats"] + for name in sample_stats.keys(): + assert sample_stats_ds[name].sizes == expected_sizes + + else: + assert "sample_stats" not in idata + + assert idata.posterior.sizes == expected_sizes mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown) From 04337581e6188e87ad9b9b39761dfa0e6f52cb58 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 3 Aug 2025 01:09:25 -0400 Subject: [PATCH 3/4] fix the typing for the size --- pymc/testing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/testing.py b/pymc/testing.py index ec1bc4678..6f6c742dc 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -978,7 +978,7 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None: raise AssertionError(f"RV found in graph: {rvs}") -SampleStatsCreator = Callable[[tuple[str, ...]], NDArray] +SampleStatsCreator = Callable[[tuple[int, int]], NDArray] def mock_sample( From 1c193849d803f88bd44c9eae232a464ac18d7925 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 3 Aug 2025 01:11:40 -0400 Subject: [PATCH 4/4] change the typehints in the docstrings --- pymc/testing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/testing.py b/pymc/testing.py index 6f6c742dc..aae6be414 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -1029,11 +1029,11 @@ def mock_pymc_sample(): from pymc.testing import mock_sample - def mock_diverging(size: tuple[str, ...]) -> npt.NDArray: + def mock_diverging(size: tuple[int, int]) -> npt.NDArray: return np.zeros(size) - def mock_tree_depth(size: tuple[str, ...]) -> npt.NDArray: + def mock_tree_depth(size: tuple[int, int]) -> npt.NDArray: return np.random.choice(range(2, 10), size=size)