@@ -367,43 +367,58 @@ def test_poisson(self):
367
367
368
368
@pytest .mark .parametrize ("n" , [2 , 3 , 4 ])
369
369
def test_categorical (self , n ):
370
+ domain = Domain (range (n ), dtype = "int64" , edges = (0 , n ))
371
+ paramdomains = {"p" : Simplex (n )}
372
+
370
373
check_logp (
371
374
pm .Categorical ,
372
- Domain ( range ( n ), dtype = "int64" , edges = ( 0 , n )) ,
373
- { "p" : Simplex ( n )} ,
375
+ domain ,
376
+ paramdomains ,
374
377
lambda value , p : categorical_logpdf (value , p ),
375
378
)
376
379
377
- def test_categorical_logp_batch_dims (self ):
380
+ check_selfconsistency_discrete_logcdf (
381
+ pm .Categorical ,
382
+ domain ,
383
+ paramdomains ,
384
+ )
385
+
386
+ @pytest .mark .parametrize ("method" , (logp , logcdf ), ids = lambda x : x .__name__ )
387
+ def test_categorical_logp_batch_dims (self , method ):
378
388
# Core case
379
389
p = np .array ([0.2 , 0.3 , 0.5 ])
380
390
value = np .array (2.0 )
381
- logp_expr = logp (pm .Categorical .dist (p = p , shape = value .shape ), value )
382
- assert logp_expr .type .ndim == 0
383
- np .testing .assert_allclose (logp_expr .eval (), np .log (0.5 ))
391
+ expr = method (pm .Categorical .dist (p = p , shape = value .shape ), value )
392
+ assert expr .type .ndim == 0
393
+ expected_p = 0.5 if method is logp else 1.0
394
+ np .testing .assert_allclose (expr .exp ().eval (), expected_p )
384
395
385
396
# Explicit batched value broadcasts p
386
397
bcast_p = p [None ] # shape (1, 3)
387
398
batch_value = np .array ([0 , 1 ]) # shape(3,)
388
- logp_expr = logp (pm .Categorical .dist (p = bcast_p , shape = batch_value .shape ), batch_value )
389
- assert logp_expr .type .ndim == 1
390
- np .testing .assert_allclose (logp_expr .eval (), np .log ([0.2 , 0.3 ]))
399
+ expr = method (pm .Categorical .dist (p = bcast_p , shape = batch_value .shape ), batch_value )
400
+ assert expr .type .ndim == 1
401
+ expected_p = [0.2 , 0.3 ] if method is logp else [0.2 , 0.5 ]
402
+ np .testing .assert_allclose (expr .exp ().eval (), expected_p )
403
+
404
+ # Implicit batch value broadcasts p
405
+ expr = method (pm .Categorical .dist (p = p , shape = ()), batch_value )
406
+ assert expr .type .ndim == 1
407
+ expected_p = [0.2 , 0.3 ] if method is logp else [0.2 , 0.5 ]
408
+ np .testing .assert_allclose (expr .exp ().eval (), expected_p )
391
409
392
410
# Explicit batched value and batched p
393
411
batch_p = np .array ([p [::- 1 ], p ])
394
- logp_expr = logp (pm .Categorical .dist (p = batch_p , shape = batch_value .shape ), batch_value )
395
- assert logp_expr .type .ndim == 1
396
- np .testing .assert_allclose (logp_expr .eval (), np .log ([0.5 , 0.3 ]))
397
-
398
- # Implicit batch value broadcasts p
399
- logp_expr = logp (pm .Categorical .dist (p = p , shape = ()), batch_value )
400
- assert logp_expr .type .ndim == 1
401
- np .testing .assert_allclose (logp_expr .eval (), np .log ([0.2 , 0.3 ]))
412
+ expr = method (pm .Categorical .dist (p = batch_p , shape = batch_value .shape ), batch_value )
413
+ assert expr .type .ndim == 1
414
+ expected_p = [0.5 , 0.3 ] if method is logp else [0.5 , 0.5 ]
415
+ np .testing .assert_allclose (expr .exp ().eval (), expected_p )
402
416
403
417
# Implicit batch p broadcasts value
404
- logp_expr = logp (pm .Categorical .dist (p = batch_p , shape = None ), value )
405
- assert logp_expr .type .ndim == 1
406
- np .testing .assert_allclose (logp_expr .eval (), np .log ([0.2 , 0.5 ]))
418
+ expr = method (pm .Categorical .dist (p = batch_p , shape = None ), value )
419
+ assert expr .type .ndim == 1
420
+ expected_p = [0.2 , 0.5 ] if method is logp else [1.0 , 1.0 ]
421
+ np .testing .assert_allclose (expr .exp ().eval (), expected_p )
407
422
408
423
@pytensor .config .change_flags (compute_test_value = "raise" )
409
424
def test_categorical_bounds (self ):
0 commit comments