1
+
2
+ import sympy .parsing .sympy_parser as parser
3
+ import sympy
4
+ from pyhf .parameters import ParamViewer
5
+ import jax .numpy as jnp
6
+ import jax
7
+
8
+ def create_modifiers (additional_parameters = None ):
9
+
10
+ class PureFunctionModifierBuilder :
11
+ is_shared = True
12
+ def __init__ (self , pdfconfig ):
13
+ self .config = pdfconfig
14
+ self .required_parsets = additional_parameters or {}
15
+ self .builder_data = {'local' : {},'global' : {'symbols' : set ()}}
16
+
17
+ def collect (self , thismod , nom ):
18
+ maskval = True if thismod else False
19
+ mask = [maskval ] * len (nom )
20
+ return {'mask' : mask }
21
+
22
+ def append (self , key , channel , sample , thismod , defined_samp ):
23
+ self .builder_data ['local' ].setdefault (key , {}).setdefault (sample , {}).setdefault ('data' , {'mask' : []})
24
+
25
+ nom = (
26
+ defined_samp ['data' ]
27
+ if defined_samp
28
+ else [0.0 ] * self .config .channel_nbins [channel ]
29
+ )
30
+ moddata = self .collect (thismod , nom )
31
+ self .builder_data ['local' ][key ][sample ]['data' ]['mask' ] += moddata ['mask' ]
32
+
33
+ if thismod is not None :
34
+ formula = thismod ['data' ]['formula' ]
35
+ parsed = parser .parse_expr (formula )
36
+ free_symbols = parsed .free_symbols
37
+ for x in free_symbols :
38
+ self .builder_data ['global' ].setdefault ('symbols' ,set ()).add (x )
39
+ else :
40
+ parsed = None
41
+ self .builder_data ['local' ].setdefault (key ,{}).setdefault (sample ,{}).setdefault ('channels' ,{}).setdefault (channel ,{})['parsed' ] = parsed
42
+
43
+ def finalize (self ):
44
+ list_of_symbols = [str (x ) for x in self .builder_data ['global' ]['symbols' ]]
45
+ self .builder_data ['global' ]['symbol_names' ] = list_of_symbols
46
+ for modname , modspec in self .builder_data ['local' ].items ():
47
+ for sample , samplespec in modspec .items ():
48
+ for channel , channelspec in samplespec ['channels' ].items ():
49
+ if channelspec ['parsed' ] is not None :
50
+ channelspec ['jaxfunc' ] = sympy .lambdify (list_of_symbols , channelspec ['parsed' ], 'jax' )
51
+ else :
52
+ channelspec ['jaxfunc' ] = lambda * args : 1.0
53
+ return self .builder_data
54
+
55
+ class PureFunctionModifierApplicator :
56
+ op_code = 'multiplication'
57
+ name = 'purefunc'
58
+
59
+ def __init__ (
60
+ self , modifiers = None , pdfconfig = None , builder_data = None , batch_size = None
61
+ ):
62
+ self .builder_data = builder_data
63
+ self .batch_size = batch_size
64
+ self .pdfconfig = pdfconfig
65
+ self .inputs = [str (x ) for x in builder_data ['global' ]['symbols' ]]
66
+
67
+ self .keys = [f'{ mtype } /{ m } ' for m , mtype in modifiers ]
68
+ self .modifiers = [m for m , _ in modifiers ]
69
+
70
+ parfield_shape = (
71
+ (self .batch_size , pdfconfig .npars )
72
+ if self .batch_size
73
+ else (pdfconfig .npars ,)
74
+ )
75
+
76
+ self .param_viewer = ParamViewer (parfield_shape , pdfconfig .par_map , self .inputs )
77
+ self .create_jax_eval ()
78
+
79
+ def create_jax_eval (self ):
80
+ def eval_func (pars ):
81
+ return jnp .array ([
82
+ [
83
+ jnp .concatenate ([
84
+ self .builder_data ['local' ][m ][s ]['channels' ][c ]['jaxfunc' ](* pars )* jnp .ones (self .pdfconfig .channel_nbins [c ])
85
+ for c in self .pdfconfig .channels
86
+ ])
87
+ for s in self .pdfconfig .samples
88
+ ]
89
+ for m in self .keys
90
+
91
+ ])
92
+ self .jaxeval = eval_func
93
+
94
+ def apply_nonbatched (self ,pars ):
95
+ return jnp .expand_dims (self .jaxeval (pars ),2 )
96
+
97
+ def apply_batched (self ,pars ):
98
+ return jax .vmap (self .jaxeval , in_axes = (1 ,), out_axes = 2 )(pars )
99
+
100
+ def apply (self , pars ):
101
+ if not self .param_viewer .index_selection :
102
+ return
103
+ if self .batch_size is None :
104
+ par_selection = self .param_viewer .get (pars )
105
+ results_purefunc = self .apply_nonbatched (par_selection )
106
+ else :
107
+ par_selection = self .param_viewer .get (pars )
108
+ results_purefunc = self .apply_batched (par_selection )
109
+ return results_purefunc
110
+
111
+ return PureFunctionModifierBuilder , PureFunctionModifierApplicator
112
+
113
+
114
+ from pyhf .modifiers import histfactory_set
115
+
116
+ def enable (new_params = None ):
117
+ modifier_set = {}
118
+ modifier_set .update (** histfactory_set )
119
+
120
+ builder , applicator = create_modifiers (new_params )
121
+
122
+ modifier_set .update (** {
123
+ applicator .name : (builder , applicator )}
124
+ )
125
+ return modifier_set
126
+
127
+ def new_unconstrained_scalars (new_params ):
128
+ param_spec = {
129
+ p ['name' ]:
130
+ [{
131
+ 'paramset_type' : 'unconstrained' ,
132
+ 'n_parameters' : 1 ,
133
+ 'is_shared' : True ,
134
+ 'inits' : (p ['init' ],),
135
+ 'bounds' : ((p ['min' ], p ['max' ]),),
136
+ 'is_scalar' : True ,
137
+ 'fixed' : False ,
138
+ }]
139
+ for p in new_params
140
+ }
141
+ return param_spec
0 commit comments