Skip to content

Draft external sampler API #7880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jul 30, 2025

Discussion needed !!

Draft PR motivated by #7699

The need being

Basically my question is: would PyMC be open to a PR along these lines? For public facing visibility, we'd like to put the algorithm into PyMC rather than use PyMC indirectly (i.e. extract a density from a probabilistic program in PyMC), and I thought this might be the best way.

The first idea discussed (and implemented here) was to allow pm.sample(step=ExternalSampler()) that defers sampling to an external library (including the ones we already supported before like nutpie, numpyro, ...)

import pymc as pm

with pm.Model() as m:
    x = pm.Normal("x")
    idata = pm.sample(nuts_sampler="nutpie", nuts_sampler_kwargs=kwargs)  # <- Before
    idata = pm.sample(step=pm.external.Nutpie(**kwargs))  # <- Now
    idata = pm.sample(step=pm.external.MCLMC(**kwargs))  # <- Future non NUTS methods can also be used

Pros: We are not assuming everything is a nuts_sampler, and there are objects / functions users can read to find the arguments that parametrize the samplers.

Cons: We're trying to put everything through pm.sample so it can reuse 6-8 arguments (tune, draws, chains, idata_kwargs ...) that already existed in pm.sample? Also tune/draws/chains may not make sense for some external samplers.

What I think is useful is to provide a standard API point to find external samplers (with our logic to connect pymc-library), which this PR kind of offers in pm.external.

The new ExternalSampler object is a bit awkward. It doesn't do much other other than allowing pm.sample to recognize it, so you can pass pm.sample(step=pm.external.Nutpie()), but then it has to arbitrarily split arguments between instantiation and sample, so as to make the sampler specific arguments discoverable, while reusing the few sampler-agnostic arguments.

Wouldn't it make more sense to just offer pm.external.sample_nutpie()?

IIRC this is why we went with pm.sample_smc instead of pm.sample(step=pm.SMC()) which used to exist before. It became awkward to have a function for both approaches.


📚 Documentation preview 📚: https://pymc--7880.org.readthedocs.build/en/7880/

@jessegrabowski
Copy link
Member

Not against this proposal, but I haven't carefully looked at it yet. Wanted to chime in with an alternative right away. We could offer comprehensive developer docs for how to take a PyMC model and get the logp/dlogp out of it. This would encourage an ecosystem of packages like nutpie, that can work directly with model object.

If we wanted to then bridge to their thing directly via pm.whatever, it would be a thin wrapper just calling their API and maybe aligning defaults/argnames (again, like nutpie).

@jessegrabowski
Copy link
Member

Looking over the actual PR, I like it much better than what we have now from a maintenance perspective. It also makes it easier to have specific parameters for each sampler, removing the horrible nuts_sampler_kwargs. We would need to have a depreciation period for the current API though?

I guess my proposal would be an "in addition to" this

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jul 30, 2025

We have a deprecation from now, everything should work except you get warnings.

One thing I don't love is the step name. Since these are holistic samplers. What about sampler or method?

Copy link

codecov bot commented Jul 30, 2025

Codecov Report

❌ Patch coverage is 44.06780% with 66 lines in your changes missing coverage. Please review.
✅ Project coverage is 91.64%. Comparing base (58b49f2) to head (94d5ec4).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
pymc/sampling/external/nutpie.py 25.00% 36 Missing ⚠️
pymc/sampling/mcmc.py 61.53% 10 Missing ⚠️
pymc/sampling/external/base.py 43.75% 9 Missing ⚠️
pymc/sampling/external/jax.py 61.90% 8 Missing ⚠️
pymc/sampling/jax.py 0.00% 3 Missing ⚠️

❌ Your patch check has failed because the patch coverage (44.06%) is below the target coverage (50.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7880      +/-   ##
==========================================
- Coverage   92.94%   91.64%   -1.31%     
==========================================
  Files         116      120       +4     
  Lines       18845    18878      +33     
==========================================
- Hits        17516    17300     -216     
- Misses       1329     1578     +249     
Files with missing lines Coverage Δ
pymc/__init__.py 100.00% <100.00%> (ø)
pymc/sampling/external/__init__.py 100.00% <100.00%> (ø)
pymc/sampling/jax.py 0.00% <0.00%> (-94.10%) ⬇️
pymc/sampling/external/jax.py 61.90% <61.90%> (ø)
pymc/sampling/external/base.py 43.75% <43.75%> (ø)
pymc/sampling/mcmc.py 90.26% <61.53%> (-1.12%) ⬇️
pymc/sampling/external/nutpie.py 25.00% <25.00%> (ø)

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Dekermanjian
Copy link
Contributor

Wouldn't it make more sense to just offer pm.external.sample_nutpie()?

I personally prefer this because I think you end up with less dependencies that you have to worry about integrating nicely with each other.

For each new sampling algorithm you need to first implement the pm.external.sample_new_algorithm() and then you potentially also need to make sure that this new algorithm integrates nicely with the current pm.sample() implementation.

# limitations under the License.
from pymc.sampling.external.base import ExternalSampler
from pymc.sampling.external.jax import Blackjax, Numpyro
from pymc.sampling.external.nutpie import Nutpie
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from pymc.sampling.external.nutpie import Nutpie
from pymc.sampling.external.nutpie import Nutpie
__all__ = ["Blackjax", "Numpyro", "Nutpie"]

You probably don't need to export the ABC right?

@@ -704,8 +666,8 @@ def sample_jax_nuts(
dims.update(idata_kwargs.pop("dims"))

# Use 'partial' to set default arguments before passing 'idata_kwargs'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary comment now

compute_convergence_checks,
**kwargs,
):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pass
raise NotImplementedError

make it fail fast

Comment on lines +35 to +47
def sample(
self,
tune,
draws,
chains,
initvals,
random_seed,
progressbar,
var_names,
idata_kwargs,
compute_convergence_checks,
**kwargs,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Golden opportunity to add type hints and return types early!



class JAXSampler(ExternalSampler):
nuts_sampler = None # Should be defined by subclass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be enforced by making an abstract method property that inheritors must implement, I've found that pattern quite useful

from pymc.util import RandomState


class JAXSampler(ExternalSampler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best not to assume that any JAXSampler is a NUTS sampler

@nataziel
Copy link
Contributor

I know it's just a draft @ricardoV94, but popped some comments in where it made sense to me

@jessegrabowski
Copy link
Member

We have a deprecation from now, everything should work except you get warnings.

One thing I don't love is the step name. Since these are holistic samplers. What about sampler or method?

What do you think about just calling e.g. pm.external.sample_nutpie directly, instead of passing something to pm.sample

@ricardoV94
Copy link
Member Author

As I wrote on the original issue I'm more inclined to that but I know some people love the pm.sample funnel

@lucianopaz
Copy link
Member

I don't like using the step argument for this. The whole point of step was to be able to combine different step methods to be used on different random variables, so that people could end up writing their own kind of Gibbs sampler that was specialized for their needs. Some step methods are used for more than a single variable, like NUTS, so overloading step to make it be yet another kind of sampler that cannot be combined with other steps for our variables wouldn't be good.

Having gone through the step methods themselves not so long ago, I feel that they could be thoroughly refactored, but that would require a whole different discussion and a lot of work. So I think that the simplest option is to use pymc.external.sample_blabla or add a different argument to pymc.sample that is something like external_sampler=Blabla.

The uniform API is very nice, and I would recommend that we rely on pydantic.BaseModel to handle the API validation. I'm not sure what methods we would like to expose. From my backend perspective, I would love it if the external samplers exposed some way to collecting samples on the fly, and also something that stores the external sampler's state, so that it could be reset for continued sampling using the ZarrTrace machinery.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants