@@ -541,48 +541,79 @@ class InterventionTimeEstimator(PyMCModel):
541
541
... t,
542
542
... y,
543
543
... coords,
544
- ... effect=[ "impulse"]
544
+ ... priors={ "impulse":[]}
545
545
... )
546
546
Inference data...
547
547
"""
548
548
549
- def build_model (self , t , y , coords , effect , span , grain_season ):
549
+ def build_model (self , t , y , coords , time_range , grain_season , priors ):
550
550
"""
551
551
Defines the PyMC model
552
552
553
553
:param t: An array of values representing the time over which y is spread
554
554
:param y: An array of values representing our outcome y
555
- :param coords: A dictionary with the coordinate names for our instruments
555
+ :param coords: An optional dictionary with the coordinate names for our instruments.
556
+ In particular, used to determine the number of seasons.
557
+ :param time_range: An optional tuple providing a specific time_range where the
558
+ intervention effect should have taken place.
559
+ :param priors: An optional dictionary of priors for the parameters of the
560
+ different distributions.
561
+ :code:`priors = {"alpha":[0, 5], "beta":[0,2], "level":[5, 5], "impulse":[1, 2 ,3]}`
556
562
"""
557
563
558
564
with self :
559
565
self .add_coords (coords )
560
566
561
- if span is None :
562
- span = (t .min (), t .max ())
567
+ if time_range is None :
568
+ time_range = (t .min (), t .max ())
563
569
564
570
# --- Priors ---
565
- switchpoint = pm .Uniform ("switchpoint" , lower = span [0 ], upper = span [1 ])
566
- alpha = pm .Normal (name = "alpha" , mu = 0 , sigma = 10 )
567
- beta = pm .Normal (name = "beta" , mu = 0 , sigma = 10 )
571
+ switchpoint = pm .Uniform (
572
+ "switchpoint" , lower = time_range [0 ], upper = time_range [1 ]
573
+ )
574
+ alpha = pm .Normal (name = "alpha" , mu = 0 , sigma = 50 )
575
+ beta = pm .Normal (name = "beta" , mu = 0 , sigma = 50 )
568
576
seasons = 0
569
577
if "seasons" in coords and len (coords ["seasons" ]) > 0 :
570
578
season_idx = np .arange (len (y )) // grain_season % len (coords ["seasons" ])
571
- season_effect = pm .Normal ("season" , mu = 0 , sigma = 1 , dims = "seasons" )
572
- seasons = season_effect [season_idx ]
579
+ seasons_effect = pm .Normal (
580
+ "seasons_effect" , mu = 0 , sigma = 50 , dims = "seasons"
581
+ )
582
+ seasons = seasons_effect [season_idx ]
573
583
574
584
# --- Intervention effect ---
575
585
level = trend = impulse = 0
576
586
577
- if "level" in effect :
578
- level = pm .Normal ("level" , mu = 0 , sigma = 10 )
579
-
580
- if "trend" in effect :
581
- trend = pm .Normal ("trend" , mu = 0 , sigma = 10 )
582
-
583
- if "impulse" in effect :
584
- impulse_amplitude = pm .Normal ("impulse_amplitude" , mu = 0 , sigma = 1 )
585
- decay_rate = pm .HalfNormal ("decay_rate" , sigma = 1 )
587
+ if "level" in priors :
588
+ mu , sigma = (
589
+ (0 , 50 )
590
+ if len (priors ["level" ]) != 2
591
+ else (priors ["level" ][0 ], priors ["level" ][1 ])
592
+ )
593
+ level = pm .Normal (
594
+ "level" ,
595
+ mu = mu ,
596
+ sigma = sigma ,
597
+ )
598
+ if "trend" in priors :
599
+ mu , sigma = (
600
+ (0 , 50 )
601
+ if len (priors ["trend" ]) != 2
602
+ else (priors ["trend" ][0 ], priors ["trend" ][1 ])
603
+ )
604
+ trend = pm .Normal ("trend" , mu = mu , sigma = sigma )
605
+ if "impulse" in priors :
606
+ mu , sigma1 , sigma2 = (
607
+ (0 , 50 , 50 )
608
+ if len (priors ["impulse" ]) != 3
609
+ else (
610
+ priors ["impulse" ][0 ],
611
+ priors ["impulse" ][1 ],
612
+ priors ["impulse" ][2 ],
613
+ )
614
+ )
615
+ impulse_amplitude = pm .Normal ("impulse_amplitude" , mu = mu , sigma = sigma1 )
616
+ decay_rate = pm .HalfNormal ("decay_rate" , sigma = sigma2 )
586
617
impulse = impulse_amplitude * pm .math .exp (
587
618
- decay_rate * abs (t - switchpoint )
588
619
)
@@ -597,16 +628,16 @@ def build_model(self, t, y, coords, effect, span, grain_season):
597
628
)
598
629
# Compute and store the the sum of the intervention and the time series
599
630
mu = pm .Deterministic ("mu" , mu_ts + weight * mu_in )
631
+ sigma = pm .HalfNormal ("sigma" , 1 )
600
632
601
633
# --- Likelihood ---
602
- pm .Normal ("y_hat" , mu = mu , sigma = 2 , observed = y )
634
+ pm .Normal ("y_hat" , mu = mu , sigma = sigma , observed = y )
603
635
604
- def fit (self , t , y , coords , effect = [], span = None , grain_season = 1 , n = 1000 ):
636
+ def fit (self , t , y , coords , time_range = None , grain_season = 1 , priors = {} , n = 1000 ):
605
637
"""
606
638
Draw samples from posterior distribution
607
639
"""
608
- self .sample_kwargs ["progressbar" ] = False
609
- self .build_model (t , y , coords , effect , span , grain_season )
640
+ self .build_model (t , y , coords , time_range , grain_season , priors )
610
641
with self :
611
- self .idata = pm .sample (n , ** self .sample_kwargs )
642
+ self .idata = pm .sample (n , progressbar = False , ** self .sample_kwargs )
612
643
return self .idata
0 commit comments