1
+ import pytest
2
+
3
+ import pytensor .tensor .random as ptr
4
+ from pytensor .graph .basic import equal_computations
5
+ from pytensor .tensor .random .type import random_generator_type
6
+ from pytensor .xtensor import xtensor
7
+ from pytensor .xtensor .random import multinomial , multivariate_normal , normal , categorical
8
+
9
+ lower_rewrite = lambda x : x
10
+
11
+ def test_normal ():
12
+ pass
13
+
14
+ def test_categorical ():
15
+ pass
16
+
17
+ def test_multinomial ():
18
+ rng = random_generator_type ("rng" )
19
+ n = xtensor (shape = (2 ,), dims = ("a" ,))
20
+ p = xtensor (shape = (3 , None ), dims = ("p" , "a" ))
21
+ c_size = xtensor (shape = (), dims = (), dtype = int )
22
+ a_size = n .sizes ["a" ]
23
+
24
+ out = multinomial (n , p , core_dims = ("p" ,), rng = rng )
25
+ assert out .type .dims == ("a" , "p" )
26
+ assert out .type .shape == (2 , 3 )
27
+ assert equal_computations (
28
+ [lower_rewrite (out )],
29
+ [ptr .multinomial (n .values , p .values .T , size = None , rng = rng )],
30
+ )
31
+ # TODO: Make sure we can actually evaluate it
32
+ ...
33
+
34
+ out = multinomial (n , p , core_dims = ("p" ,), size = dict (a = a_size ), rng = rng )
35
+ assert out .type .dims == ("a" , "p" )
36
+ assert equal_computations (
37
+ [lower_rewrite (out )],
38
+ [ptr .multinomial (n .values , p .values .T , size = (a_size .values ,), rng = rng )],
39
+ )
40
+
41
+ out = multinomial (n , p , core_dims = ("p" ,), size = dict (a = a_size , c = c_size ), rng = rng )
42
+ assert out .type .dims == ("a" , "c" , "p" )
43
+ assert equal_computations (
44
+ [lower_rewrite (out )],
45
+ [ptr .multinomial (n .values [:, None ], p .values .T [:, None , :], size = (a_size .values , c_size .values ), rng = rng )],
46
+ )
47
+
48
+ out = multinomial (n , p , core_dims = ("p" ,), size = dict (c = c_size , a = a_size ,), rng = rng )
49
+ assert out .type .dims == ("c" , "a" , "p" )
50
+ assert equal_computations (
51
+ [lower_rewrite (out )],
52
+ [ptr .multinomial (n .values , p .values .T , size = (c_size .values , a_size .values ), rng = rng )],
53
+ )
54
+
55
+ # Test missing core_dims
56
+ with pytest .raises (ValueError ):
57
+ multinomial (n , p , rng = rng )
58
+
59
+ # Test invalid core_dims
60
+ with pytest .raises (ValueError ):
61
+ # n cannot have a core dimension
62
+ multinomial (n , p , core_dims = ("a" ,), rng = rng )
63
+
64
+ # Test incomplete size
65
+ with pytest .raises (ValueError ):
66
+ multinomial (n , p , core_dims = ("p" ,), size = dict (c = c_size ), rng = rng )
67
+
68
+
69
+ def test_multivariate_normal ():
70
+ pass
71
+
72
+ def test_new_out_dim ()
73
+ pass
0 commit comments