Skip to content

Commit ee701f2

Browse files
committed
Enhancement : Adding the possibility for the user to indicate priors over intervention's distributions
1 parent 5e9cde6 commit ee701f2

File tree

1 file changed

+55
-24
lines changed

1 file changed

+55
-24
lines changed

causalpy/pymc_models.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -541,48 +541,79 @@ class InterventionTimeEstimator(PyMCModel):
541541
... t,
542542
... y,
543543
... coords,
544-
... effect=["impulse"]
544+
... priors={"impulse":[]}
545545
... )
546546
Inference data...
547547
"""
548548

549-
def build_model(self, t, y, coords, effect, span, grain_season):
549+
def build_model(self, t, y, coords, time_range, grain_season, priors):
550550
"""
551551
Defines the PyMC model
552552
553553
:param t: An array of values representing the time over which y is spread
554554
:param y: An array of values representing our outcome y
555-
:param coords: A dictionary with the coordinate names for our instruments
555+
:param coords: An optional dictionary with the coordinate names for our instruments.
556+
In particular, used to determine the number of seasons.
557+
:param time_range: An optional tuple providing a specific time_range where the
558+
intervention effect should have taken place.
559+
:param priors: An optional dictionary of priors for the parameters of the
560+
different distributions.
561+
:code:`priors = {"alpha":[0, 5], "beta":[0,2], "level":[5, 5], "impulse":[1, 2 ,3]}`
556562
"""
557563

558564
with self:
559565
self.add_coords(coords)
560566

561-
if span is None:
562-
span = (t.min(), t.max())
567+
if time_range is None:
568+
time_range = (t.min(), t.max())
563569

564570
# --- Priors ---
565-
switchpoint = pm.Uniform("switchpoint", lower=span[0], upper=span[1])
566-
alpha = pm.Normal(name="alpha", mu=0, sigma=10)
567-
beta = pm.Normal(name="beta", mu=0, sigma=10)
571+
switchpoint = pm.Uniform(
572+
"switchpoint", lower=time_range[0], upper=time_range[1]
573+
)
574+
alpha = pm.Normal(name="alpha", mu=0, sigma=50)
575+
beta = pm.Normal(name="beta", mu=0, sigma=50)
568576
seasons = 0
569577
if "seasons" in coords and len(coords["seasons"]) > 0:
570578
season_idx = np.arange(len(y)) // grain_season % len(coords["seasons"])
571-
season_effect = pm.Normal("season", mu=0, sigma=1, dims="seasons")
572-
seasons = season_effect[season_idx]
579+
seasons_effect = pm.Normal(
580+
"seasons_effect", mu=0, sigma=50, dims="seasons"
581+
)
582+
seasons = seasons_effect[season_idx]
573583

574584
# --- Intervention effect ---
575585
level = trend = impulse = 0
576586

577-
if "level" in effect:
578-
level = pm.Normal("level", mu=0, sigma=10)
579-
580-
if "trend" in effect:
581-
trend = pm.Normal("trend", mu=0, sigma=10)
582-
583-
if "impulse" in effect:
584-
impulse_amplitude = pm.Normal("impulse_amplitude", mu=0, sigma=1)
585-
decay_rate = pm.HalfNormal("decay_rate", sigma=1)
587+
if "level" in priors:
588+
mu, sigma = (
589+
(0, 50)
590+
if len(priors["level"]) != 2
591+
else (priors["level"][0], priors["level"][1])
592+
)
593+
level = pm.Normal(
594+
"level",
595+
mu=mu,
596+
sigma=sigma,
597+
)
598+
if "trend" in priors:
599+
mu, sigma = (
600+
(0, 50)
601+
if len(priors["trend"]) != 2
602+
else (priors["trend"][0], priors["trend"][1])
603+
)
604+
trend = pm.Normal("trend", mu=mu, sigma=sigma)
605+
if "impulse" in priors:
606+
mu, sigma1, sigma2 = (
607+
(0, 50, 50)
608+
if len(priors["impulse"]) != 3
609+
else (
610+
priors["impulse"][0],
611+
priors["impulse"][1],
612+
priors["impulse"][2],
613+
)
614+
)
615+
impulse_amplitude = pm.Normal("impulse_amplitude", mu=mu, sigma=sigma1)
616+
decay_rate = pm.HalfNormal("decay_rate", sigma=sigma2)
586617
impulse = impulse_amplitude * pm.math.exp(
587618
-decay_rate * abs(t - switchpoint)
588619
)
@@ -597,16 +628,16 @@ def build_model(self, t, y, coords, effect, span, grain_season):
597628
)
598629
# Compute and store the the sum of the intervention and the time series
599630
mu = pm.Deterministic("mu", mu_ts + weight * mu_in)
631+
sigma = pm.HalfNormal("sigma", 1)
600632

601633
# --- Likelihood ---
602-
pm.Normal("y_hat", mu=mu, sigma=2, observed=y)
634+
pm.Normal("y_hat", mu=mu, sigma=sigma, observed=y)
603635

604-
def fit(self, t, y, coords, effect=[], span=None, grain_season=1, n=1000):
636+
def fit(self, t, y, coords, time_range=None, grain_season=1, priors={}, n=1000):
605637
"""
606638
Draw samples from posterior distribution
607639
"""
608-
self.sample_kwargs["progressbar"] = False
609-
self.build_model(t, y, coords, effect, span, grain_season)
640+
self.build_model(t, y, coords, time_range, grain_season, priors)
610641
with self:
611-
self.idata = pm.sample(n, **self.sample_kwargs)
642+
self.idata = pm.sample(n, progressbar=False, **self.sample_kwargs)
612643
return self.idata

0 commit comments

Comments
 (0)