-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Draft external sampler API #7880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
321e391
to
6a86352
Compare
Not against this proposal, but I haven't carefully looked at it yet. Wanted to chime in with an alternative right away. We could offer comprehensive developer docs for how to take a PyMC model and get the logp/dlogp out of it. This would encourage an ecosystem of packages like nutpie, that can work directly with model object. If we wanted to then bridge to their thing directly via pm.whatever, it would be a thin wrapper just calling their API and maybe aligning defaults/argnames (again, like nutpie). |
Looking over the actual PR, I like it much better than what we have now from a maintenance perspective. It also makes it easier to have specific parameters for each sampler, removing the horrible I guess my proposal would be an "in addition to" this |
We have a deprecation from now, everything should work except you get warnings. One thing I don't love is the |
6a86352
to
94d5ec4
Compare
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (44.06%) is below the target coverage (50.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #7880 +/- ##
==========================================
- Coverage 92.94% 91.64% -1.31%
==========================================
Files 116 120 +4
Lines 18845 18878 +33
==========================================
- Hits 17516 17300 -216
- Misses 1329 1578 +249
🚀 New features to boost your workflow:
|
I personally prefer this because I think you end up with less dependencies that you have to worry about integrating nicely with each other. For each new sampling algorithm you need to first implement the |
# limitations under the License. | ||
from pymc.sampling.external.base import ExternalSampler | ||
from pymc.sampling.external.jax import Blackjax, Numpyro | ||
from pymc.sampling.external.nutpie import Nutpie |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from pymc.sampling.external.nutpie import Nutpie | |
from pymc.sampling.external.nutpie import Nutpie | |
__all__ = ["Blackjax", "Numpyro", "Nutpie"] |
You probably don't need to export the ABC right?
@@ -704,8 +666,8 @@ def sample_jax_nuts( | |||
dims.update(idata_kwargs.pop("dims")) | |||
|
|||
# Use 'partial' to set default arguments before passing 'idata_kwargs' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unnecessary comment now
compute_convergence_checks, | ||
**kwargs, | ||
): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass | |
raise NotImplementedError |
make it fail fast
def sample( | ||
self, | ||
tune, | ||
draws, | ||
chains, | ||
initvals, | ||
random_seed, | ||
progressbar, | ||
var_names, | ||
idata_kwargs, | ||
compute_convergence_checks, | ||
**kwargs, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Golden opportunity to add type hints and return types early!
|
||
|
||
class JAXSampler(ExternalSampler): | ||
nuts_sampler = None # Should be defined by subclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be enforced by making an abstract method property that inheritors must implement, I've found that pattern quite useful
from pymc.util import RandomState | ||
|
||
|
||
class JAXSampler(ExternalSampler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably best not to assume that any JAXSampler is a NUTS sampler
I know it's just a draft @ricardoV94, but popped some comments in where it made sense to me |
What do you think about just calling e.g. |
As I wrote on the original issue I'm more inclined to that but I know some people love the |
I don't like using the Having gone through the step methods themselves not so long ago, I feel that they could be thoroughly refactored, but that would require a whole different discussion and a lot of work. So I think that the simplest option is to use The uniform API is very nice, and I would recommend that we rely on |
Discussion needed !!
Draft PR motivated by #7699
The need being
The first idea discussed (and implemented here) was to allow
pm.sample(step=ExternalSampler())
that defers sampling to an external library (including the ones we already supported before like nutpie, numpyro, ...)Pros: We are not assuming everything is a
nuts_sampler
, and there are objects / functions users can read to find the arguments that parametrize the samplers.Cons: We're trying to put everything through
pm.sample
so it can reuse 6-8 arguments (tune, draws, chains, idata_kwargs ...) that already existed inpm.sample
? Also tune/draws/chains may not make sense for some external samplers.What I think is useful is to provide a standard API point to find external samplers (with our logic to connect pymc-library), which this PR kind of offers in
pm.external
.The new
ExternalSampler
object is a bit awkward. It doesn't do much other other than allowingpm.sample
to recognize it, so you can passpm.sample(step=pm.external.Nutpie())
, but then it has to arbitrarily split arguments between instantiation and sample, so as to make the sampler specific arguments discoverable, while reusing the few sampler-agnostic arguments.Wouldn't it make more sense to just offer
pm.external.sample_nutpie()
?IIRC this is why we went with
pm.sample_smc
instead ofpm.sample(step=pm.SMC())
which used to exist before. It became awkward to have a function for both approaches.📚 Documentation preview 📚: https://pymc--7880.org.readthedocs.build/en/7880/