Skip to content

Commit e86992b

Browse files
committed
refactor: make predictions argument explicit
1 parent d7a82b1 commit e86992b

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

pymc_extras/model_builder.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ def predict(
530530
self,
531531
X_pred: np.ndarray | pd.DataFrame | pd.Series,
532532
extend_idata: bool = True,
533+
predictions: bool = False,
533534
**kwargs,
534535
) -> np.ndarray:
535536
"""
@@ -559,7 +560,7 @@ def predict(
559560
"""
560561

561562
posterior_predictive_samples = self.sample_posterior_predictive(
562-
X_pred, extend_idata, combined=False, **kwargs
563+
X_pred, extend_idata, predictions, combined=False, **kwargs
563564
)
564565

565566
if self.output_var not in posterior_predictive_samples:
@@ -624,7 +625,7 @@ def sample_prior_predictive(
624625

625626
return prior_predictive_samples
626627

627-
def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs):
628+
def sample_posterior_predictive(self, X_pred, extend_idata, predictions, combined, **kwargs):
628629
"""
629630
Sample from the model's posterior predictive distribution.
630631
@@ -646,12 +647,12 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs):
646647
self._data_setter(X_pred)
647648

648649
with self.model: # sample with new input data
649-
post_pred = pm.sample_posterior_predictive(self.idata, **kwargs)
650+
post_pred = pm.sample_posterior_predictive(self.idata, predictions=predictions, **kwargs)
650651
if extend_idata:
651652
self.idata.extend(post_pred, join="right")
652653

653-
# Determine the correct group dynamically
654-
group_name = "predictions" if kwargs.get("predictions", False) else "posterior_predictive"
654+
# Determine the correct group
655+
group_name = "predictions" if predictions else "posterior_predictive"
655656

656657
posterior_predictive_samples = az.extract(
657658
post_pred, group_name, combined=combined
@@ -703,6 +704,7 @@ def predict_posterior(
703704
X_pred: np.ndarray | pd.DataFrame | pd.Series,
704705
extend_idata: bool = True,
705706
combined: bool = True,
707+
predictions: bool = False,
706708
**kwargs,
707709
) -> xr.DataArray:
708710
"""
@@ -726,7 +728,7 @@ def predict_posterior(
726728

727729
X_pred = self._validate_data(X_pred)
728730
posterior_predictive_samples = self.sample_posterior_predictive(
729-
X_pred, extend_idata, combined, **kwargs
731+
X_pred, extend_idata, predictions, combined, **kwargs
730732
)
731733

732734
if self.output_var not in posterior_predictive_samples:

0 commit comments

Comments
 (0)