@@ -530,6 +530,7 @@ def predict(
530
530
self ,
531
531
X_pred : np .ndarray | pd .DataFrame | pd .Series ,
532
532
extend_idata : bool = True ,
533
+ predictions : bool = False ,
533
534
** kwargs ,
534
535
) -> np .ndarray :
535
536
"""
@@ -559,7 +560,7 @@ def predict(
559
560
"""
560
561
561
562
posterior_predictive_samples = self .sample_posterior_predictive (
562
- X_pred , extend_idata , combined = False , ** kwargs
563
+ X_pred , extend_idata , predictions , combined = False , ** kwargs
563
564
)
564
565
565
566
if self .output_var not in posterior_predictive_samples :
@@ -624,7 +625,7 @@ def sample_prior_predictive(
624
625
625
626
return prior_predictive_samples
626
627
627
- def sample_posterior_predictive (self , X_pred , extend_idata , combined , ** kwargs ):
628
+ def sample_posterior_predictive (self , X_pred , extend_idata , predictions , combined , ** kwargs ):
628
629
"""
629
630
Sample from the model's posterior predictive distribution.
630
631
@@ -646,12 +647,12 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs):
646
647
self ._data_setter (X_pred )
647
648
648
649
with self .model : # sample with new input data
649
- post_pred = pm .sample_posterior_predictive (self .idata , ** kwargs )
650
+ post_pred = pm .sample_posterior_predictive (self .idata , predictions = predictions , ** kwargs )
650
651
if extend_idata :
651
652
self .idata .extend (post_pred , join = "right" )
652
653
653
- # Determine the correct group dynamically
654
- group_name = "predictions" if kwargs . get ( " predictions" , False ) else "posterior_predictive"
654
+ # Determine the correct group
655
+ group_name = "predictions" if predictions else "posterior_predictive"
655
656
656
657
posterior_predictive_samples = az .extract (
657
658
post_pred , group_name , combined = combined
@@ -703,6 +704,7 @@ def predict_posterior(
703
704
X_pred : np .ndarray | pd .DataFrame | pd .Series ,
704
705
extend_idata : bool = True ,
705
706
combined : bool = True ,
707
+ predictions : bool = False ,
706
708
** kwargs ,
707
709
) -> xr .DataArray :
708
710
"""
@@ -726,7 +728,7 @@ def predict_posterior(
726
728
727
729
X_pred = self ._validate_data (X_pred )
728
730
posterior_predictive_samples = self .sample_posterior_predictive (
729
- X_pred , extend_idata , combined , ** kwargs
731
+ X_pred , extend_idata , predictions , combined , ** kwargs
730
732
)
731
733
732
734
if self .output_var not in posterior_predictive_samples :
0 commit comments