diff --git a/pymc_extras/statespace/models/structural/core.py b/pymc_extras/statespace/models/structural/core.py index 35ef8939..e6f981ce 100644 --- a/pymc_extras/statespace/models/structural/core.py +++ b/pymc_extras/statespace/models/structural/core.py @@ -310,16 +310,20 @@ def _hidden_states_from_data(self, data): for i, (name, s) in enumerate(zip(names, state_slices)): obs_idx = info[name]["obs_state_idx"] + if obs_idx is None: continue X = data[..., s] + if info[name]["combine_hidden_states"]: - sum_idx = np.flatnonzero(obs_idx) - result.append(X[..., sum_idx].sum(axis=-1)[..., None]) + sum_idx_joined = np.flatnonzero(obs_idx) + sum_idx_split = np.split(sum_idx_joined, info[name]["k_endog"]) + for sum_idx in sum_idx_split: + result.append(X[..., sum_idx].sum(axis=-1)[..., None]) else: - comp_names = self.state_names[s] - for j, state_name in enumerate(comp_names): + n_components = len(self.state_names[s]) + for j in range(n_components): result.append(X[..., j, None]) return np.concatenate(result, axis=-1) @@ -332,7 +336,15 @@ def _get_subcomponent_names(self): for i, (name, s) in enumerate(zip(names, state_slices)): if info[name]["combine_hidden_states"]: - result.append(name) + if self.k_endog == 1: + result.append(name) + else: + # If there are multiple observed states, we will combine per hidden state, preserving the + # observed state names. Note this happens even if this *component* has only 1 state for consistency, + # as long as the statespace model has multiple observed states. + result.extend( + [f"{name}[{obs_name}]" for obs_name in info[name]["observed_state_names"]] + ) else: comp_names = self.state_names[s] result.extend([f"{name}[{comp_name}]" for comp_name in comp_names]) @@ -540,7 +552,7 @@ def __init__( self._component_info = { self.name: { "k_states": self.k_states, - "k_enodg": self.k_endog, + "k_endog": self.k_endog, "k_posdef": self.k_posdef, "observed_state_names": self.observed_state_names, "combine_hidden_states": combine_hidden_states, diff --git a/tests/statespace/models/structural/test_core.py b/tests/statespace/models/structural/test_core.py index 43ee7af3..bd9dcb03 100644 --- a/tests/statespace/models/structural/test_core.py +++ b/tests/statespace/models/structural/test_core.py @@ -128,8 +128,9 @@ def test_extract_multiple_observed(rng): reg = st.RegressionComponent( state_names=["a", "b"], name="exog", observed_state_names=["data_2", "data_3"] ) + ar = st.AutoregressiveComponent(observed_state_names=["data_1", "data_2"], order=3) me = st.MeasurementError("obs", observed_state_names=["data_1", "data_3"]) - mod = (ll + season + reg + me).build(verbose=True) + mod = (ll + season + reg + ar + me).build(verbose=True) with pm.Model(coords=mod.coords) as m: data_exog = pm.Data("data_exog", data.values) @@ -137,6 +138,10 @@ def test_extract_multiple_observed(rng): x0 = pm.Normal("x0", dims=["state"]) P0 = pm.Deterministic("P0", pt.eye(mod.k_states), dims=["state", "state_aux"]) beta_exog = pm.Normal("beta_exog", dims=["endog_exog", "state_exog"]) + params_auto_regressive = pm.Normal( + "params_auto_regressive", dims=["endog_auto_regressive", "lag_auto_regressive"] + ) + sigma_auto_regressive = pm.Normal("sigma_auto_regressive", dims=["endog_auto_regressive"]) initial_trend = pm.Normal("initial_trend", dims=["endog_trend", "state_trend"]) sigma_trend = pm.Exponential("sigma_trend", 1, dims=["endog_trend", "shock_trend"]) seasonal_coefs = pm.Normal("seasonal", dims=["state_seasonal"]) @@ -155,11 +160,13 @@ def test_extract_multiple_observed(rng): "trend[trend[data_1]]", "trend[level[data_2]]", "trend[trend[data_2]]", - "seasonal", + "seasonal[data_1]", "exog[a[data_2]]", "exog[b[data_2]]", "exog[a[data_3]]", "exog[b[data_3]]", + "auto_regressive[data_1]", + "auto_regressive[data_2]", ] missing = set(comp_states) - set(expected_states)