Skip to content

Commit 4ea5fbc

Browse files
committed
test: added test for predictions grouping
1 parent ce1b2d5 commit 4ea5fbc

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tests/test_model_builder.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,3 +304,35 @@ def test_id():
304304
).hexdigest()[:16]
305305

306306
assert model_builder.id == expected_id
307+
308+
@pytest.mark.parametrize("predictions", [True, False])
309+
def test_predict_respects_predictions_flag(fitted_model_instance, predictions):
310+
x_pred = np.random.uniform(0, 1, 100)
311+
prediction_data = pd.DataFrame({"input": x_pred})
312+
output_var = fitted_model_instance.output_var
313+
314+
# Snapshot the original posterior_predictive values
315+
pp_before = fitted_model_instance.idata.posterior_predictive[output_var].values.copy()
316+
317+
# Ensure 'predictions' group is not present initially
318+
assert "predictions" not in fitted_model_instance.idata.groups()
319+
320+
# Run prediction with predictions=True or False
321+
fitted_model_instance.predict(
322+
prediction_data["input"],
323+
extend_idata=True,
324+
combined=False,
325+
predictions=predictions,
326+
)
327+
328+
pp_after = fitted_model_instance.idata.posterior_predictive[output_var].values
329+
330+
# Check predictions group presence
331+
if predictions:
332+
assert "predictions" in fitted_model_instance.idata.groups()
333+
# Posterior predictive should remain unchanged
334+
np.testing.assert_array_equal(pp_before, pp_after)
335+
else:
336+
assert "predictions" not in fitted_model_instance.idata.groups()
337+
# Posterior predictive should be updated
338+
np.testing.assert_array_not_equal(pp_before, pp_after)

0 commit comments

Comments
 (0)