@@ -36,7 +36,156 @@ def test_all_basic_rvs_are_wrapped():
36
36
37
37
38
38
def test_normal ():
39
- pass
39
+ c_size = tensor ("c_size" , shape = (), dtype = int )
40
+ c_size_xr = as_xtensor (c_size , name = "c_size_xr" )
41
+
42
+ # Vector inputs
43
+ mu_vec_in = tensor ("mu" , shape = (2 ,))
44
+ sigma_vec_in = tensor ("sigma" , shape = (2 ,))
45
+
46
+ mu_vec_xr = as_xtensor (mu_vec_in , dims = ("a" ,), name = "mu_xr" )
47
+ sigma_vec_xr = as_xtensor (sigma_vec_in , dims = ("a" ,), name = "sigma_xr" )
48
+
49
+ # Vector inputs: Basic case (no extra_dims)
50
+ out_vec = pxr .normal (mu_vec_xr , sigma_vec_xr , rng = rng )
51
+ assert out_vec .type .dims == ("a" ,)
52
+ assert out_vec .type .shape == (2 ,)
53
+ assert equal_computations (
54
+ [lower_rewrite (out_vec .values )],
55
+ [rewrite_graph (ptr .normal (mu_vec_in , sigma_vec_in , rng = rng ))],
56
+ )
57
+
58
+ mu_val = np .array ([0.0 , 10.0 ])
59
+ sigma_val = np .array ([1.0 , 2.0 ])
60
+ eval_rng_seed_vec_basic = 12345
61
+
62
+ actual_val_vec_basic = out_vec .eval (
63
+ {
64
+ mu_vec_in : mu_val ,
65
+ sigma_vec_in : sigma_val ,
66
+ rng : np .random .default_rng (eval_rng_seed_vec_basic ),
67
+ }
68
+ )
69
+ expected_val_vec_basic = np .random .default_rng (eval_rng_seed_vec_basic ).normal (
70
+ mu_val , sigma_val
71
+ )
72
+ np .testing .assert_allclose (actual_val_vec_basic , expected_val_vec_basic )
73
+
74
+ # Vector inputs: With extra_dims
75
+ out_vec_extra = pxr .normal (
76
+ mu_vec_xr , sigma_vec_xr , extra_dims = dict (c = c_size_xr ), rng = rng
77
+ )
78
+ assert out_vec_extra .type .dims == ("c" , "a" )
79
+ assert equal_computations (
80
+ [lower_rewrite (out_vec_extra .values )],
81
+ [
82
+ rewrite_graph (
83
+ ptr .normal (
84
+ mu_vec_in , sigma_vec_in , size = (c_size , mu_vec_in .shape [0 ]), rng = rng
85
+ )
86
+ )
87
+ ],
88
+ )
89
+
90
+ c_size_val = 5
91
+ eval_rng_seed_vec_extra = 67890
92
+ actual_val_vec_extra = out_vec_extra .eval (
93
+ {
94
+ mu_vec_in : mu_val ,
95
+ sigma_vec_in : sigma_val ,
96
+ c_size : c_size_val ,
97
+ rng : np .random .default_rng (eval_rng_seed_vec_extra ),
98
+ }
99
+ )
100
+ expected_val_vec_extra = np .random .default_rng (eval_rng_seed_vec_extra ).normal (
101
+ loc = mu_val , scale = sigma_val , size = (c_size_val , mu_val .shape [0 ])
102
+ )
103
+ np .testing .assert_allclose (actual_val_vec_extra , expected_val_vec_extra )
104
+
105
+ # Scalar inputs
106
+ mu_scalar_in = tensor ("mu_s" , shape = ())
107
+ sigma_scalar_in = tensor ("sigma_s" , shape = ())
108
+
109
+ mu_scalar_xr = as_xtensor (mu_scalar_in , name = "mu_s_xr" )
110
+ sigma_scalar_xr = as_xtensor (sigma_scalar_in , name = "sigma_s_xr" )
111
+
112
+ # Scalar inputs: Basic case
113
+ out_scalar = pxr .normal (mu_scalar_xr , sigma_scalar_xr , rng = rng )
114
+ assert out_scalar .type .dims == ()
115
+ assert out_scalar .type .shape == ()
116
+ assert equal_computations (
117
+ [lower_rewrite (out_scalar .values )],
118
+ [rewrite_graph (ptr .normal (mu_scalar_in , sigma_scalar_in , rng = rng ))],
119
+ )
120
+
121
+ mu_s_val = 0.0
122
+ sigma_s_val = 1.0
123
+ eval_rng_seed_scalar_basic = 23456
124
+ actual_val_scalar_basic = out_scalar .eval (
125
+ {
126
+ mu_scalar_in : mu_s_val ,
127
+ sigma_scalar_in : sigma_s_val ,
128
+ rng : np .random .default_rng (eval_rng_seed_scalar_basic ),
129
+ }
130
+ )
131
+ expected_val_scalar_basic = np .random .default_rng (
132
+ eval_rng_seed_scalar_basic
133
+ ).normal (mu_s_val , sigma_s_val )
134
+ np .testing .assert_allclose (actual_val_scalar_basic , expected_val_scalar_basic )
135
+
136
+ # Scalar inputs: With extra_dims
137
+ out_scalar_extra = pxr .normal (
138
+ mu_scalar_xr , sigma_scalar_xr , extra_dims = dict (c = c_size_xr ), rng = rng
139
+ )
140
+ assert out_scalar_extra .type .dims == ("c" ,)
141
+ assert equal_computations (
142
+ [lower_rewrite (out_scalar_extra .values )],
143
+ [
144
+ rewrite_graph (
145
+ ptr .normal (mu_scalar_in , sigma_scalar_in , size = (c_size ,), rng = rng )
146
+ )
147
+ ],
148
+ )
149
+
150
+ eval_rng_seed_scalar_extra = 78901
151
+ actual_val_scalar_extra = out_scalar_extra .eval (
152
+ {
153
+ mu_scalar_in : mu_s_val ,
154
+ sigma_scalar_in : sigma_s_val ,
155
+ c_size : c_size_val ,
156
+ rng : np .random .default_rng (eval_rng_seed_scalar_extra ),
157
+ }
158
+ )
159
+ expected_val_scalar_extra = np .random .default_rng (
160
+ eval_rng_seed_scalar_extra
161
+ ).normal (loc = mu_s_val , scale = sigma_s_val , size = (c_size_val ,))
162
+ np .testing .assert_allclose (actual_val_scalar_extra , expected_val_scalar_extra )
163
+
164
+ # Error conditions
165
+ # Invalid core_dims: normal is element-wise, expects core_dims=() for params.
166
+ with pytest .raises (
167
+ ValueError ,
168
+ match = re .escape (
169
+ "Parameter mu_xr has invalid core dimensions ['a']. "
170
+ "Expected [] based on RV definition and core_dims argument."
171
+ ),
172
+ ):
173
+ pxr .normal (mu_vec_xr , sigma_vec_xr , core_dims = ("a" ,), rng = rng )
174
+
175
+ # Invalid extra_dims (conflicting with existing batch dims)
176
+ a_size_xr = mu_vec_xr .sizes ["a" ]
177
+ with pytest .raises (
178
+ ValueError ,
179
+ match = re .escape (
180
+ "Size dimensions ['a'] conflict with parameter dimensions. They should be unique."
181
+ ),
182
+ ):
183
+ pxr .normal (
184
+ mu_vec_xr ,
185
+ sigma_vec_xr ,
186
+ extra_dims = dict (c = c_size_xr , a = a_size_xr ), # 'a' conflicts
187
+ rng = rng ,
188
+ )
40
189
41
190
42
191
def test_categorical ():
0 commit comments