@@ -304,3 +304,35 @@ def test_id():
304
304
).hexdigest ()[:16 ]
305
305
306
306
assert model_builder .id == expected_id
307
+
308
+ @pytest .mark .parametrize ("predictions" , [True , False ])
309
+ def test_predict_respects_predictions_flag (fitted_model_instance , predictions ):
310
+ x_pred = np .random .uniform (0 , 1 , 100 )
311
+ prediction_data = pd .DataFrame ({"input" : x_pred })
312
+ output_var = fitted_model_instance .output_var
313
+
314
+ # Snapshot the original posterior_predictive values
315
+ pp_before = fitted_model_instance .idata .posterior_predictive [output_var ].values .copy ()
316
+
317
+ # Ensure 'predictions' group is not present initially
318
+ assert "predictions" not in fitted_model_instance .idata .groups ()
319
+
320
+ # Run prediction with predictions=True or False
321
+ fitted_model_instance .predict (
322
+ prediction_data ["input" ],
323
+ extend_idata = True ,
324
+ combined = False ,
325
+ predictions = predictions ,
326
+ )
327
+
328
+ pp_after = fitted_model_instance .idata .posterior_predictive [output_var ].values
329
+
330
+ # Check predictions group presence
331
+ if predictions :
332
+ assert "predictions" in fitted_model_instance .idata .groups ()
333
+ # Posterior predictive should remain unchanged
334
+ np .testing .assert_array_equal (pp_before , pp_after )
335
+ else :
336
+ assert "predictions" not in fitted_model_instance .idata .groups ()
337
+ # Posterior predictive should be updated
338
+ np .testing .assert_array_not_equal (pp_before , pp_after )
0 commit comments