Skip to content

Commit 244c37d

Browse files
committed
Pass size to specialized truncated dispatch
1 parent e419d53 commit 244c37d

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

pymc/distributions/truncated.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def update(self, node: Node):
5353

5454

5555
@singledispatch
56-
def _truncated(op: Op, lower, upper, *params):
56+
def _truncated(op: Op, lower, upper, size, *params):
5757
"""Return the truncated equivalent of another `RandomVariable`."""
5858
raise NotImplementedError(f"{op} does not have an equivalent truncated version implemented")
5959

@@ -150,7 +150,7 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None):
150150

151151
# Try to use specialized Op
152152
try:
153-
return _truncated(dist.owner.op, lower, upper, *dist.owner.inputs)
153+
return _truncated(dist.owner.op, lower, upper, size, *dist.owner.inputs)
154154
except NotImplementedError:
155155
pass
156156

@@ -339,7 +339,7 @@ def truncated_logprob(op, values, *inputs, **kwargs):
339339

340340

341341
@_truncated.register(NormalRV)
342-
def _truncated_normal(op, lower, upper, rng, size, dtype, mu, sigma):
342+
def _truncated_normal(op, lower, upper, size, rng, old_size, dtype, mu, sigma):
343343
return TruncatedNormal.dist(
344344
mu=mu,
345345
sigma=sigma,

pymc/tests/distributions/test_truncated.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,31 @@ def _icdf_not_implemented(*args, **kwargs):
5353
raise NotImplementedError()
5454

5555

56-
def test_truncation_specialized_op():
56+
@pytest.mark.parametrize("shape_info", ("shape", "dims", "observed"))
57+
def test_truncation_specialized_op(shape_info):
5758
rng = aesara.shared(np.random.default_rng())
5859
x = at.random.normal(0, 10, rng=rng, name="x")
5960

60-
xt = Truncated.dist(x, lower=5, upper=15, shape=(100,))
61+
with Model(coords={"dim": range(100)}) as m:
62+
if shape_info == "shape":
63+
xt = Truncated("xt", dist=x, lower=5, upper=15, shape=(100,))
64+
elif shape_info == "dims":
65+
xt = Truncated("xt", dist=x, lower=5, upper=15, dims=("dim",))
66+
elif shape_info == "observed":
67+
xt = Truncated(
68+
"xt",
69+
dist=x,
70+
lower=5,
71+
upper=15,
72+
observed=np.empty(
73+
100,
74+
),
75+
)
76+
else:
77+
raise ValueError(f"Not a valid shape_info parametrization: {shape_info}")
78+
6179
assert isinstance(xt.owner.op, TruncatedNormalRV)
80+
assert xt.shape.eval() == (100,)
6281

6382
# Test RNG is not reused
6483
assert xt.owner.inputs[0] is not rng

0 commit comments

Comments
 (0)