Skip to content

Draft (new feature) : Model to estimate when a intervention had effect #480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
10a017e
New feature : Model to estimate when a intervention had effect
JeanVanDyk May 28, 2025
69d79b3
New feature : Model to estimate when a intervention had effect
JeanVanDyk May 28, 2025
bf4eaaa
Minor fix in docstring
JeanVanDyk May 29, 2025
3420c9a
Minor fix in docstring
JeanVanDyk May 29, 2025
3dc23b3
Minor fix in docstring
JeanVanDyk May 29, 2025
d739b4a
Minor fix in docstring
JeanVanDyk May 29, 2025
d48f0c3
Minor fix in docstring
JeanVanDyk May 29, 2025
14afe09
Minor fix in docstring
JeanVanDyk May 29, 2025
60357a5
Minor fix in docstring
JeanVanDyk May 29, 2025
7f57b13
Minor fix in docstring
JeanVanDyk May 29, 2025
2cb92fc
Minor fix in docstring
JeanVanDyk May 29, 2025
d9c06ac
Minor fix in docstring
JeanVanDyk May 29, 2025
52cc0fa
Minor fix in docstring
JeanVanDyk May 29, 2025
faf085b
Minor fix in docstring
JeanVanDyk May 29, 2025
cc9a1f4
Minor fix in docstring
JeanVanDyk May 29, 2025
dea9d6e
Minor fix in docstring
JeanVanDyk May 29, 2025
5e9cde6
fix : hiding progressbar
JeanVanDyk May 30, 2025
ee701f2
Enhancement : Adding the possibility for the user to indicate priors …
JeanVanDyk May 30, 2025
5ee3cb4
Minor fix in docstring
JeanVanDyk Jun 4, 2025
08c520c
updating example notebook
JeanVanDyk Jun 4, 2025
b1681da
updating example notebook
JeanVanDyk Jun 4, 2025
fcfd059
Supporting Date format and adding exceptions for model related issues
JeanVanDyk Jun 4, 2025
64c97b7
changing column index restriction to label restriction
JeanVanDyk Jun 5, 2025
2996331
codespell
JeanVanDyk Jun 17, 2025
1da80fd
resolved merge
JeanVanDyk Jun 17, 2025
020f679
fixing merging issues
JeanVanDyk Jun 18, 2025
5039fda
fixing merging issues
JeanVanDyk Jun 18, 2025
4761b7e
codespell
JeanVanDyk Jun 18, 2025
bec5cd8
codespell
JeanVanDyk Jun 18, 2025
2d4d158
updating notebook
JeanVanDyk Jun 19, 2025
8d607b8
updating notebook with examples and adding time_variable_name parameter
JeanVanDyk Jun 20, 2025
d00f828
Merge branch 'main' into pr/480
drbenvincent Jun 20, 2025
942a1d5
fixing example
JeanVanDyk Jun 20, 2025
4aef14b
revert changes in docs and fixing issues
JeanVanDyk Jun 20, 2025
2b2cbdf
Removing the overriding of fit and calculate_impact, adding a test an…
JeanVanDyk Jun 20, 2025
6769aa7
Using all samples for uncertainty
JeanVanDyk Jun 23, 2025
692d85c
uml and docs
JeanVanDyk Jun 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions causalpy/custom_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,10 @@ class DataException(Exception):

def __init__(self, message: str):
self.message = message


class ModelException(Exception):
"""Exception raised given when there is some error in user-provided model"""

def __init__(self, message: str):
self.message = message
265 changes: 223 additions & 42 deletions causalpy/experiments/interrupted_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,157 @@
from patsy import build_design_matrices, dmatrices
from sklearn.base import RegressorMixin

from causalpy.custom_exceptions import BadIndexException
from causalpy.custom_exceptions import BadIndexException, ModelException
from causalpy.experiments.base import BaseExperiment
from causalpy.plot_utils import get_hdi_to_df, plot_xY
from causalpy.pymc_models import PyMCModel
from causalpy.utils import round_num

from .base import BaseExperiment

LEGEND_FONT_SIZE = 12


class HandlerUTT:
"""
Handle data preprocessing, postprocessing, and plotting steps for models
with unknown treatment intervention times.
"""

def data_preprocessing(self, data, treatment_time, model):
"""
Preprocess the input data and update the model's treatment time constraints.
"""
# Restrict model's treatment time inference to given range
model.set_time_range(treatment_time, data)
return data

def data_postprocessing(self, model, data, idata, treatment_time, pre_y, pre_X):
"""
Postprocess data based on the inferred treatment time for further analysis and plotting.
"""
# --- Getting the time_variable_name ---
time_variable_name = model.get_time_variable_name()

# --- Inferred treatment time ---
treatment_time_mean = idata.posterior["treatment_time"].mean().item()
inferred_treatment_time = int(treatment_time_mean)
idx_treatment_time = data.index[
data[time_variable_name] == inferred_treatment_time
][0]

# --- HDI bounds (credible interval) ---
hdi_bounds = az.hdi(idata, var_names=["treatment_time"])[
"treatment_time"
].values
hdi_start_time = int(hdi_bounds[0])
indice = data.index.get_loc(
data.index[data[time_variable_name] == hdi_start_time][0]
)

# --- Slicing ---
datapre = data[data[time_variable_name] < hdi_start_time]
datapost = data[data[time_variable_name] >= hdi_start_time]

truncated_y = pre_y.isel(obs_ind=slice(0, indice))
truncated_X = pre_X.isel(obs_ind=slice(0, indice))

return datapre, datapost, truncated_y, truncated_X, idx_treatment_time

def plot_intervention_line(self, ax, model, idata, datapost, treatment_time):
"""
Plot a vertical line at the inferred treatment time, along with a shaded area
representing the Highest Density Interval (HDI) of the inferred time.
"""
# --- Getting the time_variable_name ---
time_variable_name = model.get_time_variable_name()

# Extract the HDI (uncertainty interval) of the treatment time
hdi = az.hdi(idata, var_names=["treatment_time"])["treatment_time"].values
x1 = datapost.index[datapost[time_variable_name] == int(hdi[0])][0]
x2 = datapost.index[datapost[time_variable_name] == int(hdi[1])][0]

for i in [0, 1, 2]:
ymin, ymax = ax[i].get_ylim()

# Vertical line for inferred treatment time
ax[i].plot(
[treatment_time, treatment_time],
[ymin, ymax],
ls="-",
lw=3,
color="r",
solid_capstyle="butt",
)

# Shaded region for HDI of treatment time
ax[i].fill_betweenx(
y=[ymin, ymax],
x1=x1,
x2=x2,
alpha=0.1,
color="r",
)

def plot_treated_counterfactual(
self, ax, handles, labels, datapost, post_pred, post_y
):
"""
Plot the inferred post-intervention trajectory (with treatment effect).
"""
# --- Plot predicted trajectory under treatment (with HDI)
h_line, h_patch = plot_xY(
datapost.index,
post_pred["posterior_predictive"].mu_ts,
ax=ax[0],
plot_hdi_kwargs={"color": "yellowgreen"},
)
handles.append((h_line, h_patch))
labels.append("Treated counterfactual")


class HandlerKTT:
"""
Handles data preprocessing, postprocessing, and plotting logic for models
where the treatment time is known in advance.
"""

def data_preprocessing(self, data, treatment_time, model):
"""
Preprocess the data by selecting only the pre-treatment period for model fitting.
"""
# Use only data before treatment for training the model
return data[data.index < treatment_time]

def data_postprocessing(self, model, data, idata, treatment_time, pre_y, pre_X):
"""
Split data into pre- and post-treatment periods using the known treatment time.
"""
return (
data[data.index < treatment_time],
data[data.index >= treatment_time],
pre_y,
pre_X,
treatment_time,
)

def plot_intervention_line(self, model, ax, idata, datapost, treatment_time):
"""
Plot a vertical line at the known treatment time on provided axes.
"""
# --- Plot a vertical line at the known treatment time
for i in [0, 1, 2]:
ax[i].axvline(
x=treatment_time, ls="-", lw=3, color="r", solid_capstyle="butt"
)

def plot_treated_counterfactual(
self, sax, handles, labels, datapost, post_pred, post_y
):
"""
Placeholder method to maintain interface compatibility with HandlerUTT.
"""
pass


class InterruptedTimeSeries(BaseExperiment):
"""
The class for interrupted time series analysis.
Expand Down Expand Up @@ -79,37 +220,41 @@ class InterruptedTimeSeries(BaseExperiment):
def __init__(
self,
data: pd.DataFrame,
treatment_time: Union[int, float, pd.Timestamp],
treatment_time: Union[int, float, pd.Timestamp, tuple, None],
formula: str,
model=None,
**kwargs,
) -> None:
super().__init__(model=model)

# rename the index to "obs_ind"
data.index.name = "obs_ind"
self.input_validation(data, treatment_time)
self.treatment_time = treatment_time
self.input_validation(data, treatment_time, model)
# set experiment type - usually done in subclasses
self.expt_type = "Pre-Post Fit"
# split data in to pre and post intervention
self.datapre = data[data.index < self.treatment_time]
self.datapost = data[data.index >= self.treatment_time]

self.treatment_time = treatment_time
self.formula = formula

# set things up with pre-intervention data
# Getting the right handler
if treatment_time is None or isinstance(treatment_time, tuple):
self.handler = HandlerUTT()
else:
self.handler = HandlerKTT()

# Preprocessing based on handler type
self.datapre = self.handler.data_preprocessing(
data, self.treatment_time, self.model
)

y, X = dmatrices(formula, self.datapre)
# set things up with pre-intervention data
self.outcome_variable_name = y.design_info.column_names[0]
self._y_design_info = y.design_info
self._x_design_info = X.design_info
self.labels = X.design_info.column_names
self.pre_y, self.pre_X = np.asarray(y), np.asarray(X)
# process post-intervention data
(new_y, new_x) = build_design_matrices(
[self._y_design_info, self._x_design_info], self.datapost
)
self.post_X = np.asarray(new_x)
self.post_y = np.asarray(new_y)

# turn into xarray.DataArray's
self.pre_X = xr.DataArray(
self.pre_X,
Expand All @@ -124,35 +269,50 @@ def __init__(
dims=["obs_ind"],
coords={"obs_ind": self.datapre.index},
)
self.post_X = xr.DataArray(
self.post_X,
dims=["obs_ind", "coeffs"],
coords={
"obs_ind": self.datapost.index,
"coeffs": self.labels,
},
)
self.post_y = xr.DataArray(
self.post_y[:, 0],
dims=["obs_ind"],
coords={"obs_ind": self.datapost.index},
)

# fit the model to the observed (pre-intervention) data
if isinstance(self.model, PyMCModel):
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.pre_X.shape[0])}
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(X.shape[0])}
idata = self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
elif isinstance(self.model, RegressorMixin):
self.model.fit(X=self.pre_X, y=self.pre_y)
idata = None
else:
raise ValueError("Model type not recognized")

# score the goodness of fit to the pre-intervention data
self.score = self.model.score(X=self.pre_X, y=self.pre_y)

# Postprocessing with handler
self.datapre, self.datapost, self.pre_y, self.pre_X, self.treatment_time = (
self.handler.data_postprocessing(
self.model, data, idata, treatment_time, self.pre_y, self.pre_X
)
)

# get the model predictions of the observed (pre-intervention) data
self.pre_pred = self.model.predict(X=self.pre_X)

# process post-intervention data
(new_y, new_x) = build_design_matrices(
[self._y_design_info, self._x_design_info], self.datapost
)
self.post_X = np.asarray(new_x)
self.post_y = np.asarray(new_y)
self.post_X = xr.DataArray(
self.post_X,
dims=["obs_ind", "coeffs"],
coords={
"obs_ind": self.datapost.index,
"coeffs": self.labels,
},
)
self.post_y = xr.DataArray(
self.post_y[:, 0],
dims=["obs_ind"],
coords={"obs_ind": self.datapost.index},
)

# calculate the counterfactual
self.post_pred = self.model.predict(X=self.post_X)
self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred)
Expand All @@ -161,16 +321,24 @@ def __init__(
self.post_impact
)

def input_validation(self, data, treatment_time):
def input_validation(self, data, treatment_time, model):
"""Validate the input data and model formula for correctness"""
if treatment_time is None and not hasattr(model, "set_time_range"):
raise ModelException(
"If treatment_time is None, provided model must have a 'set_time_range' method"
)
if isinstance(treatment_time, tuple) and not hasattr(model, "set_time_range"):
raise ModelException(
"If treatment_time is a tuple, provided model must have a 'set_time_range' method"
)
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
treatment_time, pd.Timestamp
treatment_time, (pd.Timestamp, tuple, type(None))
):
raise BadIndexException(
"If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
)
if not isinstance(data.index, pd.DatetimeIndex) and isinstance(
treatment_time, pd.Timestamp
treatment_time, (pd.Timestamp)
):
raise BadIndexException(
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
Expand Down Expand Up @@ -199,6 +367,7 @@ def _bayesian_plot(

fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
# TOP PLOT --------------------------------------------------

# pre-intervention period
h_line, h_patch = plot_xY(
self.datapre.index,
Expand All @@ -213,6 +382,11 @@ def _bayesian_plot(
handles.append(h)
labels.append("Observations")

# Green line for treated counterfactual (if unknown treatment time)
self.handler.plot_treated_counterfactual(
ax, handles, labels, self.datapost, self.post_pred, self.post_y
)

# post intervention period
h_line, h_patch = plot_xY(
self.datapost.index,
Expand Down Expand Up @@ -277,14 +451,10 @@ def _bayesian_plot(
)
ax[2].axhline(y=0, c="k")

# Intervention line
for i in [0, 1, 2]:
ax[i].axvline(
x=self.treatment_time,
ls="-",
lw=3,
color="r",
)
# Plot vertical line marking treatment time (with HDI if it's inferred)
self.handler.plot_intervention_line(
ax, self.model, self.idata, self.datapost, self.treatment_time
)

ax[0].legend(
handles=(h_tuple for h_tuple in handles),
Expand Down Expand Up @@ -429,3 +599,14 @@ def get_plot_data_ols(self) -> pd.DataFrame:
self.plot_data = pd.concat([pre_data, post_data])

return self.plot_data

def plot_treatment_time(self):
"""
display the posterior estimates of the treatment time
"""
if "treatment_time" not in self.idata.posterior.data_vars:
raise ValueError(
"Variable 'treatment_time' not found in inference data (idata)."
)

az.plot_trace(self.idata, var_names="treatment_time")
Loading
Loading