Skip to content

Commit 7eaaccd

Browse files
authored
use deserialize in Censored.from_dict (#525)
1 parent 4cb1851 commit 7eaaccd

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

pymc_extras/prior.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,7 @@ def from_dict(cls, data: dict[str, Any]) -> Censored:
11761176
"""Create a censored distribution from a dictionary."""
11771177
data = data["data"]
11781178
return cls( # type: ignore
1179-
distribution=Prior.from_dict(data["dist"]),
1179+
distribution=deserialize(data["dist"]),
11801180
lower=data["lower"],
11811181
upper=data["upper"],
11821182
)

tests/test_prior.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,3 +1168,38 @@ def test_import_incorrect_directly() -> None:
11681168
match = "PyMC doesn't have a distribution of name 'SomeIncorrectDistribution'"
11691169
with pytest.raises(UnsupportedDistributionError, match=match):
11701170
from pymc_extras.prior import SomeIncorrectDistribution # noqa: F401
1171+
1172+
1173+
@pytest.fixture
1174+
def alternative_prior_deserialize():
1175+
def is_type(data):
1176+
return isinstance(data, dict) and "distribution" in data
1177+
1178+
def deserialize(data):
1179+
return Prior(**data)
1180+
1181+
register_deserialization(is_type=is_type, deserialize=deserialize)
1182+
1183+
yield
1184+
1185+
DESERIALIZERS.pop()
1186+
1187+
1188+
def test_censored_with_alternative(alternative_prior_deserialize) -> None:
1189+
data = {
1190+
"class": "Censored",
1191+
"data": {
1192+
"dist": {
1193+
"distribution": "Normal",
1194+
},
1195+
"lower": 0,
1196+
"upper": 10,
1197+
},
1198+
}
1199+
1200+
instance = deserialize(data)
1201+
1202+
assert isinstance(instance, Censored)
1203+
assert instance.lower == 0
1204+
assert instance.upper == 10
1205+
assert instance.distribution == Prior("Normal")

0 commit comments

Comments
 (0)