Skip to content

Commit 29a1237

Browse files
committed
.wip
1 parent 81f8bc4 commit 29a1237

File tree

1 file changed

+150
-1
lines changed

1 file changed

+150
-1
lines changed

tests/xtensor/test_random.py

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,156 @@ def test_all_basic_rvs_are_wrapped():
3636

3737

3838
def test_normal():
39-
pass
39+
c_size = tensor("c_size", shape=(), dtype=int)
40+
c_size_xr = as_xtensor(c_size, name="c_size_xr")
41+
42+
# Vector inputs
43+
mu_vec_in = tensor("mu", shape=(2,))
44+
sigma_vec_in = tensor("sigma", shape=(2,))
45+
46+
mu_vec_xr = as_xtensor(mu_vec_in, dims=("a",), name="mu_xr")
47+
sigma_vec_xr = as_xtensor(sigma_vec_in, dims=("a",), name="sigma_xr")
48+
49+
# Vector inputs: Basic case (no extra_dims)
50+
out_vec = pxr.normal(mu_vec_xr, sigma_vec_xr, rng=rng)
51+
assert out_vec.type.dims == ("a",)
52+
assert out_vec.type.shape == (2,)
53+
assert equal_computations(
54+
[lower_rewrite(out_vec.values)],
55+
[rewrite_graph(ptr.normal(mu_vec_in, sigma_vec_in, rng=rng))],
56+
)
57+
58+
mu_val = np.array([0.0, 10.0])
59+
sigma_val = np.array([1.0, 2.0])
60+
eval_rng_seed_vec_basic = 12345
61+
62+
actual_val_vec_basic = out_vec.eval(
63+
{
64+
mu_vec_in: mu_val,
65+
sigma_vec_in: sigma_val,
66+
rng: np.random.default_rng(eval_rng_seed_vec_basic),
67+
}
68+
)
69+
expected_val_vec_basic = np.random.default_rng(eval_rng_seed_vec_basic).normal(
70+
mu_val, sigma_val
71+
)
72+
np.testing.assert_allclose(actual_val_vec_basic, expected_val_vec_basic)
73+
74+
# Vector inputs: With extra_dims
75+
out_vec_extra = pxr.normal(
76+
mu_vec_xr, sigma_vec_xr, extra_dims=dict(c=c_size_xr), rng=rng
77+
)
78+
assert out_vec_extra.type.dims == ("c", "a")
79+
assert equal_computations(
80+
[lower_rewrite(out_vec_extra.values)],
81+
[
82+
rewrite_graph(
83+
ptr.normal(
84+
mu_vec_in, sigma_vec_in, size=(c_size, mu_vec_in.shape[0]), rng=rng
85+
)
86+
)
87+
],
88+
)
89+
90+
c_size_val = 5
91+
eval_rng_seed_vec_extra = 67890
92+
actual_val_vec_extra = out_vec_extra.eval(
93+
{
94+
mu_vec_in: mu_val,
95+
sigma_vec_in: sigma_val,
96+
c_size: c_size_val,
97+
rng: np.random.default_rng(eval_rng_seed_vec_extra),
98+
}
99+
)
100+
expected_val_vec_extra = np.random.default_rng(eval_rng_seed_vec_extra).normal(
101+
loc=mu_val, scale=sigma_val, size=(c_size_val, mu_val.shape[0])
102+
)
103+
np.testing.assert_allclose(actual_val_vec_extra, expected_val_vec_extra)
104+
105+
# Scalar inputs
106+
mu_scalar_in = tensor("mu_s", shape=())
107+
sigma_scalar_in = tensor("sigma_s", shape=())
108+
109+
mu_scalar_xr = as_xtensor(mu_scalar_in, name="mu_s_xr")
110+
sigma_scalar_xr = as_xtensor(sigma_scalar_in, name="sigma_s_xr")
111+
112+
# Scalar inputs: Basic case
113+
out_scalar = pxr.normal(mu_scalar_xr, sigma_scalar_xr, rng=rng)
114+
assert out_scalar.type.dims == ()
115+
assert out_scalar.type.shape == ()
116+
assert equal_computations(
117+
[lower_rewrite(out_scalar.values)],
118+
[rewrite_graph(ptr.normal(mu_scalar_in, sigma_scalar_in, rng=rng))],
119+
)
120+
121+
mu_s_val = 0.0
122+
sigma_s_val = 1.0
123+
eval_rng_seed_scalar_basic = 23456
124+
actual_val_scalar_basic = out_scalar.eval(
125+
{
126+
mu_scalar_in: mu_s_val,
127+
sigma_scalar_in: sigma_s_val,
128+
rng: np.random.default_rng(eval_rng_seed_scalar_basic),
129+
}
130+
)
131+
expected_val_scalar_basic = np.random.default_rng(
132+
eval_rng_seed_scalar_basic
133+
).normal(mu_s_val, sigma_s_val)
134+
np.testing.assert_allclose(actual_val_scalar_basic, expected_val_scalar_basic)
135+
136+
# Scalar inputs: With extra_dims
137+
out_scalar_extra = pxr.normal(
138+
mu_scalar_xr, sigma_scalar_xr, extra_dims=dict(c=c_size_xr), rng=rng
139+
)
140+
assert out_scalar_extra.type.dims == ("c",)
141+
assert equal_computations(
142+
[lower_rewrite(out_scalar_extra.values)],
143+
[
144+
rewrite_graph(
145+
ptr.normal(mu_scalar_in, sigma_scalar_in, size=(c_size,), rng=rng)
146+
)
147+
],
148+
)
149+
150+
eval_rng_seed_scalar_extra = 78901
151+
actual_val_scalar_extra = out_scalar_extra.eval(
152+
{
153+
mu_scalar_in: mu_s_val,
154+
sigma_scalar_in: sigma_s_val,
155+
c_size: c_size_val,
156+
rng: np.random.default_rng(eval_rng_seed_scalar_extra),
157+
}
158+
)
159+
expected_val_scalar_extra = np.random.default_rng(
160+
eval_rng_seed_scalar_extra
161+
).normal(loc=mu_s_val, scale=sigma_s_val, size=(c_size_val,))
162+
np.testing.assert_allclose(actual_val_scalar_extra, expected_val_scalar_extra)
163+
164+
# Error conditions
165+
# Invalid core_dims: normal is element-wise, expects core_dims=() for params.
166+
with pytest.raises(
167+
ValueError,
168+
match=re.escape(
169+
"Parameter mu_xr has invalid core dimensions ['a']. "
170+
"Expected [] based on RV definition and core_dims argument."
171+
),
172+
):
173+
pxr.normal(mu_vec_xr, sigma_vec_xr, core_dims=("a",), rng=rng)
174+
175+
# Invalid extra_dims (conflicting with existing batch dims)
176+
a_size_xr = mu_vec_xr.sizes["a"]
177+
with pytest.raises(
178+
ValueError,
179+
match=re.escape(
180+
"Size dimensions ['a'] conflict with parameter dimensions. They should be unique."
181+
),
182+
):
183+
pxr.normal(
184+
mu_vec_xr,
185+
sigma_vec_xr,
186+
extra_dims=dict(c=c_size_xr, a=a_size_xr), # 'a' conflicts
187+
rng=rng,
188+
)
40189

41190

42191
def test_categorical():

0 commit comments

Comments
 (0)