Skip to content

Commit 94d5ec4

Browse files
committed
Draft external sampler API
1 parent 011fb35 commit 94d5ec4

File tree

8 files changed

+373
-272
lines changed

8 files changed

+373
-272
lines changed

pymc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __set_compiler_flags():
7171
from pymc.printing import *
7272
from pymc.pytensorf import *
7373
from pymc.sampling import *
74+
from pymc.sampling import external
7475
from pymc.smc import *
7576
from pymc.stats import *
7677
from pymc.step_methods import *

pymc/sampling/external/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pymc.sampling.external.base import ExternalSampler
15+
from pymc.sampling.external.jax import Blackjax, Numpyro
16+
from pymc.sampling.external.nutpie import Nutpie

pymc/sampling/external/base.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from abc import ABC, abstractmethod
15+
16+
from pymc.model.core import modelcontext
17+
from pymc.util import get_value_vars_from_user_vars
18+
19+
20+
class ExternalSampler(ABC):
21+
def __init__(self, vars=None, model=None):
22+
model = modelcontext(model)
23+
if vars is None:
24+
vars = model.free_RVs
25+
else:
26+
vars = get_value_vars_from_user_vars(vars, model=model)
27+
if set(vars) != set(model.free_RVs):
28+
raise ValueError(
29+
"External samplers must sample all the model free_RVs, not just a subset"
30+
)
31+
self.vars = vars
32+
self.model = model
33+
34+
@abstractmethod
35+
def sample(
36+
self,
37+
tune,
38+
draws,
39+
chains,
40+
initvals,
41+
random_seed,
42+
progressbar,
43+
var_names,
44+
idata_kwargs,
45+
compute_convergence_checks,
46+
**kwargs,
47+
):
48+
pass

pymc/sampling/external/jax.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from collections.abc import Sequence
15+
from typing import Literal
16+
17+
from arviz import InferenceData
18+
19+
from pymc.sampling.external.base import ExternalSampler
20+
from pymc.util import RandomState
21+
22+
23+
class JAXSampler(ExternalSampler):
24+
nuts_sampler = None # Should be defined by subclass
25+
26+
def __init__(
27+
self,
28+
vars=None,
29+
model=None,
30+
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
31+
chain_method: Literal["parallel", "vectorized"] = "parallel",
32+
jitter: bool = True,
33+
keep_untransformed: bool = False,
34+
nuts_kwargs: dict | None = None,
35+
):
36+
super().__init__(vars, model)
37+
self.postprocessing_backend = postprocessing_backend
38+
self.chain_method = chain_method
39+
self.jitter = jitter
40+
self.keep_untransformed = keep_untransformed
41+
self.nuts_kwargs = nuts_kwargs or {}
42+
43+
def sample(
44+
self,
45+
*,
46+
tune: int = 1000,
47+
draws: int = 1000,
48+
chains: int = 4,
49+
initvals=None,
50+
random_seed: RandomState | None = None,
51+
progressbar: bool = True,
52+
var_names: Sequence[str] | None = None,
53+
idata_kwargs: dict | None = None,
54+
compute_convergence_checks: bool = True,
55+
target_accept: float = 0.8,
56+
nuts_sampler,
57+
**kwargs,
58+
) -> InferenceData:
59+
from pymc.sampling.jax import sample_jax_nuts
60+
61+
return sample_jax_nuts(
62+
tune=tune,
63+
draws=draws,
64+
chains=chains,
65+
target_accept=target_accept,
66+
random_seed=random_seed,
67+
var_names=var_names,
68+
progressbar=progressbar,
69+
idata_kwargs=idata_kwargs,
70+
compute_convergence_checks=compute_convergence_checks,
71+
initvals=initvals,
72+
jitter=self.jitter,
73+
model=self.model,
74+
chain_method=self.chain_method,
75+
postprocessing_backend=self.postprocessing_backend,
76+
keep_untransformed=self.keep_untransformed,
77+
nuts_kwargs=self.nuts_kwargs,
78+
nuts_sampler=self.nuts_sampler,
79+
**kwargs,
80+
)
81+
82+
83+
class Numpyro(JAXSampler):
84+
nuts_sampler = "numpyro"
85+
86+
87+
class Blackjax(JAXSampler):
88+
nuts_sampler = "blackjax"

pymc/sampling/external/nutpie.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2025 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import warnings
15+
16+
from arviz import InferenceData, dict_to_dataset
17+
from pytensor.scalar import discrete_dtypes
18+
19+
from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations
20+
from pymc.sampling.external.base import ExternalSampler
21+
from pymc.stats.convergence import log_warnings, run_convergence_checks
22+
from pymc.util import _get_seeds_per_chain
23+
24+
25+
class Nutpie(ExternalSampler):
26+
def __init__(
27+
self,
28+
vars=None,
29+
model=None,
30+
backend="numba",
31+
gradient_backend="pytensor",
32+
compile_kwargs=None,
33+
sample_kwargs=None,
34+
):
35+
super().__init__(vars, model)
36+
if any(var.dtype in discrete_dtypes for var in self.vars):
37+
raise ValueError("Nutpie can only sample continuous variables")
38+
self.backend = backend
39+
self.gradient_backend = gradient_backend
40+
self.compile_kwargs = compile_kwargs or {}
41+
self.sample_kwargs = sample_kwargs or {}
42+
43+
def sample(
44+
self,
45+
*,
46+
tune,
47+
draws,
48+
chains,
49+
initvals,
50+
random_seed,
51+
progressbar,
52+
var_names,
53+
idata_kwargs,
54+
compute_convergence_checks,
55+
**kwargs,
56+
):
57+
try:
58+
import nutpie
59+
except ImportError as err:
60+
raise ImportError(
61+
"nutpie not found. Install it with conda install -c conda-forge nutpie"
62+
) from err
63+
64+
from nutpie.sample import _BackgroundSampler
65+
66+
if initvals:
67+
warnings.warn(
68+
"initvals are currently ignored by the nutpie sampler.",
69+
UserWarning,
70+
)
71+
if idata_kwargs:
72+
warnings.warn(
73+
"idata_kwargs are currently ignored by the nutpie sampler.",
74+
UserWarning,
75+
)
76+
77+
compiled_model = nutpie.compile_pymc_model(
78+
self.model,
79+
var_names=var_names,
80+
backend=self.backend,
81+
gradient_backend=self.gradient_backend,
82+
**self.compile_kwargs,
83+
)
84+
85+
result = nutpie.sample(
86+
compiled_model,
87+
tune=tune,
88+
draws=draws,
89+
chains=chains,
90+
seed=_get_seeds_per_chain(random_seed, 1)[0],
91+
progress_bar=progressbar,
92+
**self.sample_kwargs,
93+
**kwargs,
94+
)
95+
if isinstance(result, _BackgroundSampler):
96+
# Wrap _BackgroundSampler so that when sampling is finished we run post_process_sampler
97+
class NutpieBackgroundSamplerWrapper(_BackgroundSampler):
98+
def __init__(self, *args, pymc_model, compute_convergence_checks, **kwargs):
99+
self.pymc_model = pymc_model
100+
self.compute_convergence_checks = compute_convergence_checks
101+
super().__init__(*args, **kwargs, return_raw_trace=False)
102+
103+
def _extract(self, *args, **kwargs):
104+
idata = super()._extract(*args, **kwargs)
105+
return Nutpie._post_process_sample(
106+
model=self.pymc_model,
107+
idata=idata,
108+
compute_convergence_checks=self.compute_convergence_checks,
109+
)
110+
111+
# non-blocked sampling
112+
return NutpieBackgroundSamplerWrapper(
113+
result,
114+
pymc_model=self.model,
115+
compute_convergence_checks=compute_convergence_checks,
116+
)
117+
else:
118+
return self._post_process_sample(self.model, result, compute_convergence_checks)
119+
120+
@staticmethod
121+
def _post_process_sample(
122+
model, idata: InferenceData, compute_convergence_checks
123+
) -> InferenceData:
124+
# Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
125+
# gather observed and constant data as nutpie.sample() has no access to the PyMC model
126+
if compute_convergence_checks:
127+
log_warnings(run_convergence_checks(idata, model))
128+
129+
coords, dims = coords_and_dims_for_inferencedata(model)
130+
constant_data = dict_to_dataset(
131+
find_constants(model),
132+
library=idata.attrs.get("library", None),
133+
coords=coords,
134+
dims=dims,
135+
default_dims=[],
136+
)
137+
observed_data = dict_to_dataset(
138+
find_observations(model),
139+
library=idata.attrs.get("library", None),
140+
coords=coords,
141+
dims=dims,
142+
default_dims=[],
143+
)
144+
idata.add_groups(
145+
{"constant_data": constant_data, "observed_data": observed_data},
146+
coords=coords,
147+
dims=dims,
148+
)
149+
return idata

0 commit comments

Comments
 (0)