Skip to content

Commit 0bbf0d6

Browse files
Fix hidden state extraction where there are multiple observed states (#548)
1 parent fd2933f commit 0bbf0d6

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

pymc_extras/statespace/models/structural/core.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -310,16 +310,20 @@ def _hidden_states_from_data(self, data):
310310

311311
for i, (name, s) in enumerate(zip(names, state_slices)):
312312
obs_idx = info[name]["obs_state_idx"]
313+
313314
if obs_idx is None:
314315
continue
315316

316317
X = data[..., s]
318+
317319
if info[name]["combine_hidden_states"]:
318-
sum_idx = np.flatnonzero(obs_idx)
319-
result.append(X[..., sum_idx].sum(axis=-1)[..., None])
320+
sum_idx_joined = np.flatnonzero(obs_idx)
321+
sum_idx_split = np.split(sum_idx_joined, info[name]["k_endog"])
322+
for sum_idx in sum_idx_split:
323+
result.append(X[..., sum_idx].sum(axis=-1)[..., None])
320324
else:
321-
comp_names = self.state_names[s]
322-
for j, state_name in enumerate(comp_names):
325+
n_components = len(self.state_names[s])
326+
for j in range(n_components):
323327
result.append(X[..., j, None])
324328

325329
return np.concatenate(result, axis=-1)
@@ -332,7 +336,15 @@ def _get_subcomponent_names(self):
332336

333337
for i, (name, s) in enumerate(zip(names, state_slices)):
334338
if info[name]["combine_hidden_states"]:
335-
result.append(name)
339+
if self.k_endog == 1:
340+
result.append(name)
341+
else:
342+
# If there are multiple observed states, we will combine per hidden state, preserving the
343+
# observed state names. Note this happens even if this *component* has only 1 state for consistency,
344+
# as long as the statespace model has multiple observed states.
345+
result.extend(
346+
[f"{name}[{obs_name}]" for obs_name in info[name]["observed_state_names"]]
347+
)
336348
else:
337349
comp_names = self.state_names[s]
338350
result.extend([f"{name}[{comp_name}]" for comp_name in comp_names])
@@ -540,7 +552,7 @@ def __init__(
540552
self._component_info = {
541553
self.name: {
542554
"k_states": self.k_states,
543-
"k_enodg": self.k_endog,
555+
"k_endog": self.k_endog,
544556
"k_posdef": self.k_posdef,
545557
"observed_state_names": self.observed_state_names,
546558
"combine_hidden_states": combine_hidden_states,

tests/statespace/models/structural/test_core.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,20 @@ def test_extract_multiple_observed(rng):
128128
reg = st.RegressionComponent(
129129
state_names=["a", "b"], name="exog", observed_state_names=["data_2", "data_3"]
130130
)
131+
ar = st.AutoregressiveComponent(observed_state_names=["data_1", "data_2"], order=3)
131132
me = st.MeasurementError("obs", observed_state_names=["data_1", "data_3"])
132-
mod = (ll + season + reg + me).build(verbose=True)
133+
mod = (ll + season + reg + ar + me).build(verbose=True)
133134

134135
with pm.Model(coords=mod.coords) as m:
135136
data_exog = pm.Data("data_exog", data.values)
136137

137138
x0 = pm.Normal("x0", dims=["state"])
138139
P0 = pm.Deterministic("P0", pt.eye(mod.k_states), dims=["state", "state_aux"])
139140
beta_exog = pm.Normal("beta_exog", dims=["endog_exog", "state_exog"])
141+
params_auto_regressive = pm.Normal(
142+
"params_auto_regressive", dims=["endog_auto_regressive", "lag_auto_regressive"]
143+
)
144+
sigma_auto_regressive = pm.Normal("sigma_auto_regressive", dims=["endog_auto_regressive"])
140145
initial_trend = pm.Normal("initial_trend", dims=["endog_trend", "state_trend"])
141146
sigma_trend = pm.Exponential("sigma_trend", 1, dims=["endog_trend", "shock_trend"])
142147
seasonal_coefs = pm.Normal("seasonal", dims=["state_seasonal"])
@@ -155,11 +160,13 @@ def test_extract_multiple_observed(rng):
155160
"trend[trend[data_1]]",
156161
"trend[level[data_2]]",
157162
"trend[trend[data_2]]",
158-
"seasonal",
163+
"seasonal[data_1]",
159164
"exog[a[data_2]]",
160165
"exog[b[data_2]]",
161166
"exog[a[data_3]]",
162167
"exog[b[data_3]]",
168+
"auto_regressive[data_1]",
169+
"auto_regressive[data_2]",
163170
]
164171

165172
missing = set(comp_states) - set(expected_states)

0 commit comments

Comments
 (0)