Skip to content

Draft (new feature) : Model to estimate when a intervention had effect #480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
10a017e
New feature : Model to estimate when a intervention had effect
JeanVanDyk May 28, 2025
69d79b3
New feature : Model to estimate when a intervention had effect
JeanVanDyk May 28, 2025
bf4eaaa
Minor fix in docstring
JeanVanDyk May 29, 2025
3420c9a
Minor fix in docstring
JeanVanDyk May 29, 2025
3dc23b3
Minor fix in docstring
JeanVanDyk May 29, 2025
d739b4a
Minor fix in docstring
JeanVanDyk May 29, 2025
d48f0c3
Minor fix in docstring
JeanVanDyk May 29, 2025
14afe09
Minor fix in docstring
JeanVanDyk May 29, 2025
60357a5
Minor fix in docstring
JeanVanDyk May 29, 2025
7f57b13
Minor fix in docstring
JeanVanDyk May 29, 2025
2cb92fc
Minor fix in docstring
JeanVanDyk May 29, 2025
d9c06ac
Minor fix in docstring
JeanVanDyk May 29, 2025
52cc0fa
Minor fix in docstring
JeanVanDyk May 29, 2025
faf085b
Minor fix in docstring
JeanVanDyk May 29, 2025
cc9a1f4
Minor fix in docstring
JeanVanDyk May 29, 2025
dea9d6e
Minor fix in docstring
JeanVanDyk May 29, 2025
5e9cde6
fix : hiding progressbar
JeanVanDyk May 30, 2025
ee701f2
Enhancement : Adding the possibility for the user to indicate priors …
JeanVanDyk May 30, 2025
5ee3cb4
Minor fix in docstring
JeanVanDyk Jun 4, 2025
08c520c
updating example notebook
JeanVanDyk Jun 4, 2025
b1681da
updating example notebook
JeanVanDyk Jun 4, 2025
fcfd059
Supporting Date format and adding exceptions for model related issues
JeanVanDyk Jun 4, 2025
64c97b7
changing column index restriction to label restriction
JeanVanDyk Jun 5, 2025
2996331
codespell
JeanVanDyk Jun 17, 2025
1da80fd
resolved merge
JeanVanDyk Jun 17, 2025
020f679
fixing merging issues
JeanVanDyk Jun 18, 2025
5039fda
fixing merging issues
JeanVanDyk Jun 18, 2025
4761b7e
codespell
JeanVanDyk Jun 18, 2025
bec5cd8
codespell
JeanVanDyk Jun 18, 2025
2d4d158
updating notebook
JeanVanDyk Jun 19, 2025
8d607b8
updating notebook with examples and adding time_variable_name parameter
JeanVanDyk Jun 20, 2025
d00f828
Merge branch 'main' into pr/480
drbenvincent Jun 20, 2025
942a1d5
fixing example
JeanVanDyk Jun 20, 2025
4aef14b
revert changes in docs and fixing issues
JeanVanDyk Jun 20, 2025
2b2cbdf
Removing the overriding of fit and calculate_impact, adding a test an…
JeanVanDyk Jun 20, 2025
6769aa7
Using all samples for uncertainty
JeanVanDyk Jun 23, 2025
692d85c
uml and docs
JeanVanDyk Jun 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions causalpy/pymc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,3 +497,147 @@
)
)
return self.idata


class InterventionTimeEstimator(PyMCModel):
r"""
Custom PyMC model to estimate the time an intervetnion took place.

defines the PyMC model :

.. math::
\alpha &\sim \mathrm{Normal}(0, 1) \\
\beta &\sim \mathrm{Normal}(0, 1) \\
s(t) &= \gamma_{i(t)} \quad \textrm{with} \quad \gamma_{k \in [0, ..., n_{seasons}-1]} \sim \mathrm{Normal}(0, 1)\\
base_{\mu}(t) &= \alpha + \beta \cdot t + s_t\\
\\
\tau &\sim \mathrm{Uniform}(0, 1) \\
w(t) &= sigmoid(t-\tau) \\
\\
level &\sim \mathrm{Normal}(0, 1) \\
trend &\sim \mathrm{Normal}(0, 1) \\
A &\sim \mathrm{Normal}(0, 1) \\
\lambda &\sim \mathrm{HalfNormal}(0, 1) \\
impulse(t) &= A \cdot exp(-\lambda \cdot |t-\tau|) \\
intervention(t) &= level + trend \cdot (t-\tau) + impulse_t\\
\\
\sigma &\sim \mathrm{Normal}(0, 1) \\
\mu(t) &= base_{\mu}(t) + w(t) \cdot intervention(t) \\
\\
y(t) &\sim \mathrm{Normal}(\mu (t), \sigma)

Example
--------
>>> import causalpy as cp
>>> import numpy as np
>>> from causalpy.pymc_models import InterventionTimeEstimator
>>> df = cp.load_data("its")
>>> y = df["y"].values
>>> t = df["t"].values
>>> coords = {"seasons": range(12)} # The data is monthly
>>> estimator = InterventionTimeEstimator()
>>> # We are trying to capture an impulse in the number of death per month due to Covid.
>>> estimator.fit(
... t,
... y,
... coords,
... priors={"impulse":[]}
... )
Inference data...
"""

def build_model(self, t, y, coords, time_range, grain_season, priors):
"""
Defines the PyMC model

:param t: An array of values representing the time over which y is spread
:param y: An array of values representing our outcome y
:param coords: An optional dictionary with the coordinate names for our instruments.
In particular, used to determine the number of seasons.
:param time_range: An optional tuple providing a specific time_range where the
intervention effect should have taken place.
:param priors: An optional dictionary of priors for the parameters of the
different distributions.
:code:`priors = {"alpha":[0, 5], "beta":[0,2], "level":[5, 5], "impulse":[1, 2 ,3]}`
"""

with self:
self.add_coords(coords)

Check warning on line 565 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L564-L565

Added lines #L564 - L565 were not covered by tests

if time_range is None:
time_range = (t.min(), t.max())

Check warning on line 568 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L567-L568

Added lines #L567 - L568 were not covered by tests

# --- Priors ---
switchpoint = pm.Uniform(

Check warning on line 571 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L571

Added line #L571 was not covered by tests
"switchpoint", lower=time_range[0], upper=time_range[1]
)
alpha = pm.Normal(name="alpha", mu=0, sigma=50)
beta = pm.Normal(name="beta", mu=0, sigma=50)
seasons = 0
if "seasons" in coords and len(coords["seasons"]) > 0:
season_idx = np.arange(len(y)) // grain_season % len(coords["seasons"])
seasons_effect = pm.Normal(

Check warning on line 579 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L574-L579

Added lines #L574 - L579 were not covered by tests
"seasons_effect", mu=0, sigma=50, dims="seasons"
)
seasons = seasons_effect[season_idx]

Check warning on line 582 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L582

Added line #L582 was not covered by tests

# --- Intervention effect ---
level = trend = impulse = 0

Check warning on line 585 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L585

Added line #L585 was not covered by tests

if "level" in priors:
mu, sigma = (

Check warning on line 588 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L587-L588

Added lines #L587 - L588 were not covered by tests
(0, 50)
if len(priors["level"]) != 2
else (priors["level"][0], priors["level"][1])
)
level = pm.Normal(

Check warning on line 593 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L593

Added line #L593 was not covered by tests
"level",
mu=mu,
sigma=sigma,
)
if "trend" in priors:
mu, sigma = (

Check warning on line 599 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L598-L599

Added lines #L598 - L599 were not covered by tests
(0, 50)
if len(priors["trend"]) != 2
else (priors["trend"][0], priors["trend"][1])
)
trend = pm.Normal("trend", mu=mu, sigma=sigma)
if "impulse" in priors:
mu, sigma1, sigma2 = (

Check warning on line 606 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L604-L606

Added lines #L604 - L606 were not covered by tests
(0, 50, 50)
if len(priors["impulse"]) != 3
else (
priors["impulse"][0],
priors["impulse"][1],
priors["impulse"][2],
)
)
impulse_amplitude = pm.Normal("impulse_amplitude", mu=mu, sigma=sigma1)
decay_rate = pm.HalfNormal("decay_rate", sigma=sigma2)
impulse = impulse_amplitude * pm.math.exp(

Check warning on line 617 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L615-L617

Added lines #L615 - L617 were not covered by tests
-decay_rate * abs(t - switchpoint)
)

# --- Parameterization ---
weight = pm.math.sigmoid(t - switchpoint)

Check warning on line 622 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L622

Added line #L622 was not covered by tests
# Compute and store the modelled time series
mu_ts = pm.Deterministic(name="mu_ts", var=alpha + beta * t + seasons)

Check warning on line 624 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L624

Added line #L624 was not covered by tests
# Compute and store the modelled intervention effect
mu_in = pm.Deterministic(

Check warning on line 626 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L626

Added line #L626 was not covered by tests
name="mu_in", var=level + trend * (t - switchpoint) + impulse
)
# Compute and store the the sum of the intervention and the time series
mu = pm.Deterministic("mu", mu_ts + weight * mu_in)
sigma = pm.HalfNormal("sigma", 1)

Check warning on line 631 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L630-L631

Added lines #L630 - L631 were not covered by tests

# --- Likelihood ---
pm.Normal("y_hat", mu=mu, sigma=sigma, observed=y)

Check warning on line 634 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L634

Added line #L634 was not covered by tests

def fit(self, t, y, coords, time_range=None, grain_season=1, priors={}, n=1000):
"""
Draw samples from posterior distribution
"""
self.build_model(t, y, coords, time_range, grain_season, priors)
with self:
self.idata = pm.sample(n, progressbar=False, **self.sample_kwargs)
return self.idata

Check warning on line 643 in causalpy/pymc_models.py

View check run for this annotation

Codecov / codecov/patch

causalpy/pymc_models.py#L640-L643

Added lines #L640 - L643 were not covered by tests
6 changes: 3 additions & 3 deletions docs/source/_static/interrogate_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.