Skip to content

Commit ce618a4

Browse files
authored
port previous PR (#511)
1 parent 88f0e07 commit ce618a4

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

pymc_extras/prior.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def custom_transform(x):
8484
import copy
8585

8686
from collections.abc import Callable
87+
from functools import partial
8788
from inspect import signature
8889
from typing import Any, Protocol, runtime_checkable
8990

@@ -1354,3 +1355,34 @@ def _is_censored_type(data: dict) -> bool:
13541355

13551356
register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_dict)
13561357
register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict)
1358+
1359+
1360+
def __getattr__(name: str):
1361+
"""Get Prior class through the module.
1362+
1363+
Examples
1364+
--------
1365+
Create a normal distribution.
1366+
1367+
.. code-block:: python
1368+
1369+
from pymc_extras.prior import Normal
1370+
1371+
dist = Normal(mu=1, sigma=2)
1372+
1373+
Create a hierarchical normal distribution.
1374+
1375+
.. code-block:: python
1376+
1377+
import pymc_extras.prior as pr
1378+
1379+
dist = pr.Normal(mu=pr.Normal(), sigma=pr.HalfNormal(), dims="channel")
1380+
samples = dist.sample_prior(coords={"channel": ["C1", "C2", "C3"]})
1381+
1382+
"""
1383+
# Protect against doctest
1384+
if name == "__wrapped__":
1385+
return
1386+
1387+
_get_pymc_distribution(name)
1388+
return partial(Prior, distribution=name)

tests/test_prior.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from pydantic import ValidationError
1313
from pymc.model_graph import fast_eval
1414

15+
import pymc_extras.prior as pr
16+
1517
from pymc_extras.deserialize import (
1618
DESERIALIZERS,
1719
deserialize,
@@ -1147,3 +1149,22 @@ def test_scaled_sample_prior() -> None:
11471149
assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3}
11481150
assert "scaled_var" in prior
11491151
assert "scaled_var_unscaled" in prior
1152+
1153+
1154+
def test_getattr() -> None:
1155+
assert pr.Normal() == Prior("Normal")
1156+
1157+
1158+
def test_import_directly() -> None:
1159+
try:
1160+
from pymc_extras.prior import Normal
1161+
except Exception as e:
1162+
pytest.fail(f"Unexpected exception: {e}")
1163+
1164+
assert Normal() == Prior("Normal")
1165+
1166+
1167+
def test_import_incorrect_directly() -> None:
1168+
match = "PyMC doesn't have a distribution of name 'SomeIncorrectDistribution'"
1169+
with pytest.raises(UnsupportedDistributionError, match=match):
1170+
from pymc_extras.prior import SomeIncorrectDistribution # noqa: F401

0 commit comments

Comments
 (0)