Skip to content

Commit 60c763f

Browse files
committed
Allow censoring Categorical distributions
1 parent fa43eba commit 60c763f

File tree

3 files changed

+88
-28
lines changed

3 files changed

+88
-28
lines changed

pymc/distributions/discrete.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,29 +1155,58 @@ def support_point(rv, size, p):
11551155
mode = pt.full(size, mode)
11561156
return mode
11571157

1158-
def logp(value, p):
1159-
k = pt.shape(p)[-1]
1160-
value_clip = pt.clip(value, 0, k - 1)
1158+
@staticmethod
1159+
def _safe_index_value_p(value, p):
1160+
# Find the probabily of the given value by indexing in p,
1161+
# after handling broadcasting and invalid values.
11611162

11621163
# In the standard case p has one more dimension than value
11631164
dim_diff = p.type.ndim - value.type.ndim
11641165
if dim_diff > 1:
11651166
# p brodacasts implicitly beyond value
1166-
value_clip = pt.shape_padleft(value_clip, dim_diff - 1)
1167+
value = pt.shape_padleft(value, dim_diff - 1)
11671168
elif dim_diff < 1:
11681169
# value broadcasts implicitly beyond p
11691170
p = pt.shape_padleft(p, 1 - dim_diff)
11701171

1171-
a = pt.log(pt.take_along_axis(p, value_clip[..., None], axis=-1).squeeze(-1))
1172+
k = pt.shape(p)[-1]
1173+
value_clip = pt.clip(value, 0, k - 1).astype(int)
1174+
return value, pt.log(pt.take_along_axis(p, value_clip[..., None], axis=-1).squeeze(-1))
11721175

1173-
res = pt.switch(
1176+
def logp(value, p):
1177+
k = pt.shape(p)[-1]
1178+
value, safe_value_p = Categorical._safe_index_value_p(value, p)
1179+
1180+
value_p = pt.switch(
11741181
pt.or_(pt.lt(value, 0), pt.gt(value, k - 1)),
11751182
-np.inf,
1176-
a,
1183+
safe_value_p,
11771184
)
11781185

11791186
return check_parameters(
1180-
res,
1187+
value_p,
1188+
0 <= p,
1189+
p <= 1,
1190+
pt.isclose(pt.sum(p, axis=-1), 1),
1191+
msg="0 <= p <=1, sum(p) = 1",
1192+
)
1193+
1194+
def logcdf(value, p):
1195+
k = pt.shape(p)[-1]
1196+
value, safe_value_p = Categorical._safe_index_value_p(value, p.cumsum(-1))
1197+
1198+
value_p = pt.switch(
1199+
pt.lt(value, 0),
1200+
-np.inf,
1201+
pt.switch(
1202+
pt.gt(value, k - 1),
1203+
0,
1204+
safe_value_p,
1205+
),
1206+
)
1207+
1208+
return check_parameters(
1209+
value_p,
11811210
0 <= p,
11821211
p <= 1,
11831212
pt.isclose(pt.sum(p, axis=-1), 1),

tests/distributions/test_censored.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import pymc as pm
1919

20+
from pymc import logp
2021
from pymc.distributions.shape_utils import change_dist_size
2122

2223

@@ -110,3 +111,18 @@ def test_dist_broadcasted_by_lower_upper(self):
110111
pm.Normal.dist(size=(3, 4, 2)), lower=np.zeros((2,)), upper=np.zeros((4, 2))
111112
)
112113
assert tuple(x.owner.inputs[0].shape.eval()) == (3, 4, 2)
114+
115+
def test_censored_categorical(self):
116+
cat = pm.Categorical.dist([0.1, 0.2, 0.2, 0.3, 0.2], shape=(5,))
117+
118+
np.testing.assert_allclose(
119+
logp(cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(),
120+
[0, 0.1, 0.2, 0.2, 0.3, 0.2, 0],
121+
)
122+
123+
censored_cat = pm.Censored.dist(cat, lower=1, upper=3, shape=(5,))
124+
125+
np.testing.assert_allclose(
126+
logp(censored_cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(),
127+
[0, 0, 0.3, 0.2, 0.5, 0, 0],
128+
)

tests/distributions/test_discrete.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -367,43 +367,58 @@ def test_poisson(self):
367367

368368
@pytest.mark.parametrize("n", [2, 3, 4])
369369
def test_categorical(self, n):
370+
domain = Domain(range(n), dtype="int64", edges=(0, n))
371+
paramdomains = {"p": Simplex(n)}
372+
370373
check_logp(
371374
pm.Categorical,
372-
Domain(range(n), dtype="int64", edges=(0, n)),
373-
{"p": Simplex(n)},
375+
domain,
376+
paramdomains,
374377
lambda value, p: categorical_logpdf(value, p),
375378
)
376379

377-
def test_categorical_logp_batch_dims(self):
380+
check_selfconsistency_discrete_logcdf(
381+
pm.Categorical,
382+
domain,
383+
paramdomains,
384+
)
385+
386+
@pytest.mark.parametrize("method", (logp, logcdf), ids=lambda x: x.__name__)
387+
def test_categorical_logp_batch_dims(self, method):
378388
# Core case
379389
p = np.array([0.2, 0.3, 0.5])
380390
value = np.array(2.0)
381-
logp_expr = logp(pm.Categorical.dist(p=p, shape=value.shape), value)
382-
assert logp_expr.type.ndim == 0
383-
np.testing.assert_allclose(logp_expr.eval(), np.log(0.5))
391+
expr = method(pm.Categorical.dist(p=p, shape=value.shape), value)
392+
assert expr.type.ndim == 0
393+
expected_p = 0.5 if method is logp else 1.0
394+
np.testing.assert_allclose(expr.exp().eval(), expected_p)
384395

385396
# Explicit batched value broadcasts p
386397
bcast_p = p[None] # shape (1, 3)
387398
batch_value = np.array([0, 1]) # shape(3,)
388-
logp_expr = logp(pm.Categorical.dist(p=bcast_p, shape=batch_value.shape), batch_value)
389-
assert logp_expr.type.ndim == 1
390-
np.testing.assert_allclose(logp_expr.eval(), np.log([0.2, 0.3]))
399+
expr = method(pm.Categorical.dist(p=bcast_p, shape=batch_value.shape), batch_value)
400+
assert expr.type.ndim == 1
401+
expected_p = [0.2, 0.3] if method is logp else [0.2, 0.5]
402+
np.testing.assert_allclose(expr.exp().eval(), expected_p)
403+
404+
# Implicit batch value broadcasts p
405+
expr = method(pm.Categorical.dist(p=p, shape=()), batch_value)
406+
assert expr.type.ndim == 1
407+
expected_p = [0.2, 0.3] if method is logp else [0.2, 0.5]
408+
np.testing.assert_allclose(expr.exp().eval(), expected_p)
391409

392410
# Explicit batched value and batched p
393411
batch_p = np.array([p[::-1], p])
394-
logp_expr = logp(pm.Categorical.dist(p=batch_p, shape=batch_value.shape), batch_value)
395-
assert logp_expr.type.ndim == 1
396-
np.testing.assert_allclose(logp_expr.eval(), np.log([0.5, 0.3]))
397-
398-
# Implicit batch value broadcasts p
399-
logp_expr = logp(pm.Categorical.dist(p=p, shape=()), batch_value)
400-
assert logp_expr.type.ndim == 1
401-
np.testing.assert_allclose(logp_expr.eval(), np.log([0.2, 0.3]))
412+
expr = method(pm.Categorical.dist(p=batch_p, shape=batch_value.shape), batch_value)
413+
assert expr.type.ndim == 1
414+
expected_p = [0.5, 0.3] if method is logp else [0.5, 0.5]
415+
np.testing.assert_allclose(expr.exp().eval(), expected_p)
402416

403417
# Implicit batch p broadcasts value
404-
logp_expr = logp(pm.Categorical.dist(p=batch_p, shape=None), value)
405-
assert logp_expr.type.ndim == 1
406-
np.testing.assert_allclose(logp_expr.eval(), np.log([0.2, 0.5]))
418+
expr = method(pm.Categorical.dist(p=batch_p, shape=None), value)
419+
assert expr.type.ndim == 1
420+
expected_p = [0.2, 0.5] if method is logp else [1.0, 1.0]
421+
np.testing.assert_allclose(expr.exp().eval(), expected_p)
407422

408423
@pytensor.config.change_flags(compute_test_value="raise")
409424
def test_categorical_bounds(self):

0 commit comments

Comments
 (0)