diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 96e1e9b5..76ec71e4 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -1678,6 +1678,91 @@ def sample_statespace_matrices( return matrix_idata + def sample_filter_outputs( + self, idata, filter_output_names: str | list[str] | None, group: str = "posterior", **kwargs + ): + compile_kwargs = kwargs.pop("compile_kwargs", {}) + compile_kwargs.setdefault("mode", self.mode) + + with pm.Model(coords=self.coords) as m: + pm_mod = modelcontext(None) + self._build_dummy_graph() + self._insert_random_variables() + + if self.data_names: + for name in self.data_names: + pm.Data(**self._exog_data_info[name]) + + self._insert_data_variables() + + x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace() + data = self._fit_data + + obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None) + + data, nan_mask = register_data_with_pymc( + data, + n_obs=self.ssm.k_endog, + obs_coords=obs_coords, + register_data=True, + ) + + filter_outputs = self.kalman_filter.build_graph( + data, + x0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + ) + + smoother_outputs = self.kalman_smoother.build_graph( + T, R, Q, filter_outputs[0], filter_outputs[3] + ) + + all_filter_outputs = filter_outputs[:-1] + list(smoother_outputs) + + if filter_output_names is None: + filter_output_names = all_filter_outputs + else: + unknown_filter_output_names = np.setdiff1d( + filter_output_names, [x.name for x in all_filter_outputs] + ) + if unknown_filter_output_names.size > 0: + raise ValueError( + f"{unknown_filter_output_names} not a valid filter output name!" + ) + filter_output_names = [ + x for x in all_filter_outputs if x.name in filter_output_names + ] + + for output in filter_output_names: + match output.name: + case "filtered_states" | "predicted_states" | "smoothed_states": + dims = [TIME_DIM, "state"] + case "filtered_covariances" | "predicted_covariances" | "smoothed_covariances": + dims = [TIME_DIM, "state", "state_aux"] + case "observed_states": + dims = [TIME_DIM, "observed_state"] + case "observed_covariances": + dims = [TIME_DIM, "observed_state", "observed_state_aux"] + + pm.Deterministic(output.name, output, dims=dims) + + frozen_model = freeze_dims_and_data(m) + with frozen_model: + idata_filter = pm.sample_posterior_predictive( + idata if group == "posterior" else idata.prior, + var_names=[x.name for x in frozen_model.deterministics], + compile_kwargs=compile_kwargs, + **kwargs, + ) + return idata_filter + @staticmethod def _validate_forecast_args( time_index: pd.RangeIndex | pd.DatetimeIndex, diff --git a/tests/statespace/core/test_statespace.py b/tests/statespace/core/test_statespace.py index bfcd114a..1371a351 100644 --- a/tests/statespace/core/test_statespace.py +++ b/tests/statespace/core/test_statespace.py @@ -1,3 +1,5 @@ +import re + from collections.abc import Sequence from functools import partial @@ -1017,3 +1019,30 @@ def test_foreacast_valid_index(exog_pymc_mod, exog_ss_mod, exog_data): assert forecasts.forecast_latent.shape[2] == n_periods assert forecasts.forecast_observed.shape[2] == n_periods + + +@pytest.mark.filterwarnings("ignore:Provided data contains missing values") +@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables") +@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.") +@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op") +@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.") +def test_sample_filter_outputs(rng, exog_ss_mod, idata_exog): + # Simple tests + idata_filter_prior = exog_ss_mod.sample_filter_outputs( + idata_exog, filter_output_names=None, group="prior" + ) + + specific_outputs = ["filtered_states", "filtered_covariances"] + idata_filter_specific = exog_ss_mod.sample_filter_outputs( + idata_exog, filter_output_names=specific_outputs + ) + missing_outputs = np.setdiff1d( + specific_outputs, [x for x in idata_filter_specific.posterior_predictive.data_vars] + ) + + assert missing_outputs.size == 0 + + msg = "['filter_covariances' 'filter_states'] not a valid filter output name!" + incorrect_outputs = ["filter_states", "filter_covariances"] + with pytest.raises(ValueError, match=re.escape(msg)): + exog_ss_mod.sample_filter_outputs(idata_exog, filter_output_names=incorrect_outputs)