Skip to content

Commit 24930b5

Browse files
authored
added sample_filter_outputs utility and accompanying simple tests (#526)
* added sample_filter_outputs utility and accompanying simple tests Rebased from upstream * 1. removed modelcontext call that is not needed 2. Added handle for when filter_output param is passed in as a str 3. removed case statement in favor of dictionary mapping that already exists in conf.py * updated plurality for some of the constants in constants.py * cleaned up commented code, moved internal checks to the top, reduced intermediate variables * updated kalman filter outputs to use names defined in constants.py, updated sample_filter_outputs to allow sampling any filter outputs defined in constants.py
1 parent 5c9eee5 commit 24930b5

File tree

7 files changed

+152
-44
lines changed

7 files changed

+152
-44
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -805,16 +805,16 @@ def _register_kalman_filter_outputs_with_pymc_model(outputs: tuple[pt.TensorVari
805805
states, covs = outputs[:4], outputs[4:]
806806

807807
state_names = [
808-
"filtered_state",
809-
"predicted_state",
810-
"predicted_observed_state",
811-
"smoothed_state",
808+
"filtered_states",
809+
"predicted_states",
810+
"predicted_observed_states",
811+
"smoothed_states",
812812
]
813813
cov_names = [
814-
"filtered_covariance",
815-
"predicted_covariance",
816-
"predicted_observed_covariance",
817-
"smoothed_covariance",
814+
"filtered_covariances",
815+
"predicted_covariances",
816+
"predicted_observed_covariances",
817+
"smoothed_covariances",
818818
]
819819

820820
with mod:
@@ -939,7 +939,7 @@ def build_statespace_graph(
939939
all_kf_outputs = [*states, smooth_states, *covs, smooth_covariances]
940940
self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs)
941941

942-
obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_state"]
942+
obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_states"]
943943
obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None
944944

945945
SequenceMvNormal(
@@ -1678,6 +1678,78 @@ def sample_statespace_matrices(
16781678

16791679
return matrix_idata
16801680

1681+
def sample_filter_outputs(
1682+
self, idata, filter_output_names: str | list[str] | None, group: str = "posterior", **kwargs
1683+
):
1684+
if isinstance(filter_output_names, str):
1685+
filter_output_names = [filter_output_names]
1686+
1687+
if filter_output_names is None:
1688+
filter_output_names = list(FILTER_OUTPUT_DIMS.keys())
1689+
else:
1690+
unknown_filter_output_names = np.setdiff1d(
1691+
filter_output_names, list(FILTER_OUTPUT_DIMS.keys())
1692+
)
1693+
if unknown_filter_output_names.size > 0:
1694+
raise ValueError(f"{unknown_filter_output_names} not a valid filter output name!")
1695+
filter_output_names = [x for x in FILTER_OUTPUT_DIMS.keys() if x in filter_output_names]
1696+
1697+
compile_kwargs = kwargs.pop("compile_kwargs", {})
1698+
compile_kwargs.setdefault("mode", self.mode)
1699+
1700+
with pm.Model(coords=self.coords) as m:
1701+
self._build_dummy_graph()
1702+
self._insert_random_variables()
1703+
1704+
if self.data_names:
1705+
for name in self.data_names:
1706+
pm.Data(**self._exog_data_info[name])
1707+
1708+
self._insert_data_variables()
1709+
1710+
x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace()
1711+
data = self._fit_data
1712+
1713+
obs_coords = m.coords.get(OBS_STATE_DIM, None)
1714+
1715+
data, nan_mask = register_data_with_pymc(
1716+
data,
1717+
n_obs=self.ssm.k_endog,
1718+
obs_coords=obs_coords,
1719+
register_data=True,
1720+
)
1721+
1722+
filter_outputs = self.kalman_filter.build_graph(
1723+
data,
1724+
x0,
1725+
P0,
1726+
c,
1727+
d,
1728+
T,
1729+
Z,
1730+
R,
1731+
H,
1732+
Q,
1733+
)
1734+
1735+
smoother_outputs = self.kalman_smoother.build_graph(
1736+
T, R, Q, filter_outputs[0], filter_outputs[3]
1737+
)
1738+
1739+
filter_outputs = filter_outputs[:-1] + list(smoother_outputs)
1740+
for output in filter_outputs:
1741+
if output.name in filter_output_names:
1742+
dims = FILTER_OUTPUT_DIMS[output.name]
1743+
pm.Deterministic(output.name, output, dims=dims)
1744+
1745+
with freeze_dims_and_data(m):
1746+
return pm.sample_posterior_predictive(
1747+
idata if group == "posterior" else idata.prior,
1748+
var_names=filter_output_names,
1749+
compile_kwargs=compile_kwargs,
1750+
**kwargs,
1751+
)
1752+
16811753
@staticmethod
16821754
def _validate_forecast_args(
16831755
time_index: pd.RangeIndex | pd.DatetimeIndex,

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@
1515
split_vars_into_seq_and_nonseq,
1616
stabilize,
1717
)
18-
from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, MISSING_FILL
18+
from pymc_extras.statespace.utils.constants import (
19+
FILTER_OUTPUT_NAMES,
20+
JITTER_DEFAULT,
21+
MATRIX_NAMES,
22+
MISSING_FILL,
23+
)
1924

2025
MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64"))
21-
PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"]
26+
PARAM_NAMES = MATRIX_NAMES[2:]
2227

2328
assert_time_varying_dim_correct = Assert(
2429
"The first dimension of a time varying matrix (the time dimension) must be "
@@ -119,7 +124,7 @@ def unpack_args(self, args) -> tuple:
119124
# There are always two outputs_info wedged between the seqs and non_seqs
120125
seqs, (a0, P0), non_seqs = args[:n_seq], args[n_seq : n_seq + 2], args[n_seq + 2 :]
121126
return_ordered = []
122-
for name in ["c", "d", "T", "Z", "R", "H", "Q"]:
127+
for name in PARAM_NAMES:
123128
if name in self.seq_names:
124129
idx = self.seq_names.index(name)
125130
return_ordered.append(seqs[idx])
@@ -253,28 +258,28 @@ def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]:
253258
)
254259

255260
filtered_states = pt.specify_shape(filtered_states, (n, self.n_states))
256-
filtered_states.name = "filtered_states"
261+
filtered_states.name = FILTER_OUTPUT_NAMES[0]
257262

258263
predicted_states = pt.specify_shape(predicted_states, (n, self.n_states))
259-
predicted_states.name = "predicted_states"
260-
261-
observed_states = pt.specify_shape(observed_states, (n, self.n_endog))
262-
observed_states.name = "observed_states"
264+
predicted_states.name = FILTER_OUTPUT_NAMES[1]
263265

264266
filtered_covariances = pt.specify_shape(
265267
filtered_covariances, (n, self.n_states, self.n_states)
266268
)
267-
filtered_covariances.name = "filtered_covariances"
269+
filtered_covariances.name = FILTER_OUTPUT_NAMES[2]
268270

269271
predicted_covariances = pt.specify_shape(
270272
predicted_covariances, (n, self.n_states, self.n_states)
271273
)
272-
predicted_covariances.name = "predicted_covariances"
274+
predicted_covariances.name = FILTER_OUTPUT_NAMES[3]
275+
276+
observed_states = pt.specify_shape(observed_states, (n, self.n_endog))
277+
observed_states.name = FILTER_OUTPUT_NAMES[4]
273278

274279
observed_covariances = pt.specify_shape(
275280
observed_covariances, (n, self.n_endog, self.n_endog)
276281
)
277-
observed_covariances.name = "observed_covariances"
282+
observed_covariances.name = FILTER_OUTPUT_NAMES[5]
278283

279284
loglike_obs = pt.specify_shape(loglike_obs.squeeze(), (n,))
280285
loglike_obs.name = "loglike_obs"

pymc_extras/statespace/utils/constants.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,16 @@
3838
LONG_NAME_TO_SHORT = dict(zip(LONG_MATRIX_NAMES, MATRIX_NAMES))
3939

4040
FILTER_OUTPUT_NAMES = [
41-
"filtered_state",
42-
"predicted_state",
43-
"filtered_covariance",
44-
"predicted_covariance",
41+
"filtered_states",
42+
"predicted_states",
43+
"filtered_covariances",
44+
"predicted_covariances",
45+
"predicted_observed_states",
46+
"predicted_observed_covariances",
4547
]
4648

47-
SMOOTHER_OUTPUT_NAMES = ["smoothed_state", "smoothed_covariance"]
48-
OBSERVED_OUTPUT_NAMES = ["predicted_observed_state", "predicted_observed_covariance"]
49+
SMOOTHER_OUTPUT_NAMES = ["smoothed_states", "smoothed_covariances"]
50+
OBSERVED_OUTPUT_NAMES = ["predicted_observed_states", "predicted_observed_covariances"]
4951

5052
MATRIX_DIMS = {
5153
"x0": (ALL_STATE_DIM,),
@@ -60,14 +62,14 @@
6062
}
6163

6264
FILTER_OUTPUT_DIMS = {
63-
"filtered_state": (TIME_DIM, ALL_STATE_DIM),
64-
"smoothed_state": (TIME_DIM, ALL_STATE_DIM),
65-
"predicted_state": (TIME_DIM, ALL_STATE_DIM),
66-
"filtered_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
67-
"smoothed_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
68-
"predicted_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
69-
"predicted_observed_state": (TIME_DIM, OBS_STATE_DIM),
70-
"predicted_observed_covariance": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM),
65+
"filtered_states": (TIME_DIM, ALL_STATE_DIM),
66+
"smoothed_states": (TIME_DIM, ALL_STATE_DIM),
67+
"predicted_states": (TIME_DIM, ALL_STATE_DIM),
68+
"filtered_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
69+
"smoothed_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
70+
"predicted_covariances": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
71+
"predicted_observed_states": (TIME_DIM, OBS_STATE_DIM),
72+
"predicted_observed_covariances": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM),
7173
}
7274

7375
POSITION_DERIVATIVE_NAMES = ["level", "trend", "acceleration", "jerk", "snap", "crackle", "pop"]

tests/statespace/core/test_statespace.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import re
2+
13
from collections.abc import Sequence
24
from functools import partial
35

@@ -485,16 +487,16 @@ def test_build_statespace_graph_raises_if_data_has_missing_fill():
485487

486488
def test_build_statespace_graph(pymc_mod):
487489
for name in [
488-
"filtered_state",
489-
"predicted_state",
490-
"predicted_covariance",
491-
"filtered_covariance",
490+
"filtered_states",
491+
"predicted_states",
492+
"predicted_covariances",
493+
"filtered_covariances",
492494
]:
493495
assert name in [x.name for x in pymc_mod.deterministics]
494496

495497

496498
def test_build_smoother_graph(ss_mod, pymc_mod):
497-
names = ["smoothed_state", "smoothed_covariance"]
499+
names = ["smoothed_states", "smoothed_covariances"]
498500
for name in names:
499501
assert name in [x.name for x in pymc_mod.deterministics]
500502

@@ -1191,11 +1193,11 @@ def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data, idata_
11911193

11921194
# Check that the frozen states and covariances correctly match the sliced index
11931195
np.testing.assert_allclose(
1194-
idata_exog.posterior["predicted_covariance"].sel(time=t0).mean(("chain", "draw")).values,
1196+
idata_exog.posterior["predicted_covariances"].sel(time=t0).mean(("chain", "draw")).values,
11951197
idata_forecast.posterior_predictive["P0_slice"].mean(("chain", "draw")).values,
11961198
)
11971199
np.testing.assert_allclose(
1198-
idata_exog.posterior["predicted_state"].sel(time=t0).mean(("chain", "draw")).values,
1200+
idata_exog.posterior["predicted_states"].sel(time=t0).mean(("chain", "draw")).values,
11991201
idata_forecast.posterior_predictive["x0_slice"].mean(("chain", "draw")).values,
12001202
)
12011203

@@ -1244,3 +1246,30 @@ def test_param_dims_coords(ss_mod_multi_component):
12441246
assert i == len(
12451247
ss_mod_multi_component.coords[s]
12461248
), f"Mismatch between shape {i} and dimension {s}"
1249+
1250+
1251+
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
1252+
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
1253+
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
1254+
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
1255+
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
1256+
def test_sample_filter_outputs(rng, exog_ss_mod, idata_exog):
1257+
# Simple tests
1258+
idata_filter_prior = exog_ss_mod.sample_filter_outputs(
1259+
idata_exog, filter_output_names=None, group="prior"
1260+
)
1261+
1262+
specific_outputs = ["filtered_states", "filtered_covariances"]
1263+
idata_filter_specific = exog_ss_mod.sample_filter_outputs(
1264+
idata_exog, filter_output_names=specific_outputs
1265+
)
1266+
missing_outputs = np.setdiff1d(
1267+
specific_outputs, [x for x in idata_filter_specific.posterior_predictive.data_vars]
1268+
)
1269+
1270+
assert missing_outputs.size == 0
1271+
1272+
msg = "['filter_covariances' 'filter_states'] not a valid filter output name!"
1273+
incorrect_outputs = ["filter_states", "filter_covariances"]
1274+
with pytest.raises(ValueError, match=re.escape(msg)):
1275+
exog_ss_mod.sample_filter_outputs(idata_exog, filter_output_names=incorrect_outputs)

tests/statespace/models/test_SARIMAX.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def test_SARIMAX_update_matches_statsmodels(p, d, q, P, D, Q, S, data, rng):
321321

322322
@pytest.mark.parametrize("filter_output", ["filtered", "predicted", "smoothed"])
323323
def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng):
324-
rv = pymc_mod[f"{filter_output}_covariance"]
324+
rv = pymc_mod[f"{filter_output}_covariances"]
325325
cov_mats = pm.draw(rv, 100, random_seed=rng)
326326
w, v = np.linalg.eig(cov_mats)
327327
assert_array_less(0, w, err_msg=f"Smallest eigenvalue: {min(w.ravel())}")

tests/statespace/models/test_VARMAX.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def test_VARMAX_update_matches_statsmodels(data, order, rng):
156156

157157
@pytest.mark.parametrize("filter_output", ["filtered", "predicted", "smoothed"])
158158
def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng):
159-
rv = pymc_mod[f"{filter_output}_covariance"]
159+
rv = pymc_mod[f"{filter_output}_covariances"]
160160
cov_mats = pm.draw(rv, 100, random_seed=rng)
161161
w, v = np.linalg.eig(cov_mats)
162162
assert_array_less(0, w, err_msg=f"Smallest eigenvalue: {min(w.ravel())}")

tests/statespace/utils/test_coord_assignment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_filter_output_coord_assignment(f, warning, create_model):
9393
with warning:
9494
pymc_model = create_model(f)
9595

96-
for output in FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES + ["predicted_observed_state"]:
96+
for output in FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES + ["predicted_observed_states"]:
9797
assert pymc_model.named_vars_to_dims[output] == FILTER_OUTPUT_DIMS[output]
9898

9999

0 commit comments

Comments
 (0)