1-
21import sympy .parsing .sympy_parser as parser
32import sympy
43from pyhf .parameters import ParamViewer
54import jax .numpy as jnp
65import jax
76
8- def create_modifiers (additional_parameters = None ):
7+
8+ def create_modifiers (additional_parameters = None ):
99
1010 class PureFunctionModifierBuilder :
1111 is_shared = True
12+
1213 def __init__ (self , pdfconfig ):
1314 self .config = pdfconfig
1415 self .required_parsets = additional_parameters or {}
15- self .builder_data = {'local' : {},'global' : {'symbols' : set ()}}
16+ self .builder_data = {'local' : {}, 'global' : {'symbols' : set ()}}
1617
1718 def collect (self , thismod , nom ):
1819 maskval = True if thismod else False
1920 mask = [maskval ] * len (nom )
2021 return {'mask' : mask }
2122
2223 def append (self , key , channel , sample , thismod , defined_samp ):
23- self .builder_data ['local' ].setdefault (key , {}).setdefault (sample , {}).setdefault ('data' , {'mask' : []})
24+ self .builder_data ['local' ].setdefault (key , {}).setdefault (
25+ sample , {}
26+ ).setdefault ('data' , {'mask' : []})
2427
2528 nom = (
2629 defined_samp ['data' ]
@@ -35,10 +38,12 @@ def append(self, key, channel, sample, thismod, defined_samp):
3538 parsed = parser .parse_expr (formula )
3639 free_symbols = parsed .free_symbols
3740 for x in free_symbols :
38- self .builder_data ['global' ].setdefault ('symbols' ,set ()).add (x )
41+ self .builder_data ['global' ].setdefault ('symbols' , set ()).add (x )
3942 else :
4043 parsed = None
41- self .builder_data ['local' ].setdefault (key ,{}).setdefault (sample ,{}).setdefault ('channels' ,{}).setdefault (channel ,{})['parsed' ] = parsed
44+ self .builder_data ['local' ].setdefault (key , {}).setdefault (
45+ sample , {}
46+ ).setdefault ('channels' , {}).setdefault (channel , {})['parsed' ] = parsed
4247
4348 def finalize (self ):
4449 list_of_symbols = [str (x ) for x in self .builder_data ['global' ]['symbols' ]]
@@ -47,7 +52,9 @@ def finalize(self):
4752 for sample , samplespec in modspec .items ():
4853 for channel , channelspec in samplespec ['channels' ].items ():
4954 if channelspec ['parsed' ] is not None :
50- channelspec ['jaxfunc' ] = sympy .lambdify (list_of_symbols , channelspec ['parsed' ], 'jax' )
55+ channelspec ['jaxfunc' ] = sympy .lambdify (
56+ list_of_symbols , channelspec ['parsed' ], 'jax'
57+ )
5158 else :
5259 channelspec ['jaxfunc' ] = lambda * args : 1.0
5360 return self .builder_data
@@ -73,28 +80,37 @@ def __init__(
7380 else (pdfconfig .npars ,)
7481 )
7582
76- self .param_viewer = ParamViewer (parfield_shape , pdfconfig .par_map , self .inputs )
83+ self .param_viewer = ParamViewer (
84+ parfield_shape , pdfconfig .par_map , self .inputs
85+ )
7786 self .create_jax_eval ()
7887
7988 def create_jax_eval (self ):
8089 def eval_func (pars ):
81- return jnp .array ([
90+ return jnp .array (
8291 [
83- jnp .concatenate ([
84- self .builder_data ['local' ][m ][s ]['channels' ][c ]['jaxfunc' ](* pars )* jnp .ones (self .pdfconfig .channel_nbins [c ])
85- for c in self .pdfconfig .channels
86- ])
87- for s in self .pdfconfig .samples
92+ [
93+ jnp .concatenate (
94+ [
95+ self .builder_data ['local' ][m ][s ]['channels' ][c ][
96+ 'jaxfunc'
97+ ](* pars )
98+ * jnp .ones (self .pdfconfig .channel_nbins [c ])
99+ for c in self .pdfconfig .channels
100+ ]
101+ )
102+ for s in self .pdfconfig .samples
103+ ]
104+ for m in self .keys
88105 ]
89- for m in self . keys
106+ )
90107
91- ])
92108 self .jaxeval = eval_func
93-
94- def apply_nonbatched (self ,pars ):
95- return jnp .expand_dims (self .jaxeval (pars ),2 )
96109
97- def apply_batched (self ,pars ):
110+ def apply_nonbatched (self , pars ):
111+ return jnp .expand_dims (self .jaxeval (pars ), 2 )
112+
113+ def apply_batched (self , pars ):
98114 return jax .vmap (self .jaxeval , in_axes = (1 ,), out_axes = 2 )(pars )
99115
100116 def apply (self , pars ):
@@ -107,35 +123,36 @@ def apply(self, pars):
107123 par_selection = self .param_viewer .get (pars )
108124 results_purefunc = self .apply_batched (par_selection )
109125 return results_purefunc
110-
126+
111127 return PureFunctionModifierBuilder , PureFunctionModifierApplicator
112128
113129
114130from pyhf .modifiers import histfactory_set
115131
116- def enable (new_params = None ):
132+
133+ def enable (new_params = None ):
117134 modifier_set = {}
118135 modifier_set .update (** histfactory_set )
119136
120137 builder , applicator = create_modifiers (new_params )
121138
122- modifier_set .update (** {
123- applicator .name : (builder , applicator )}
124- )
139+ modifier_set .update (** {applicator .name : (builder , applicator )})
125140 return modifier_set
126141
142+
127143def new_unconstrained_scalars (new_params ):
128144 param_spec = {
129- p ['name' ]:
130- [{
131- 'paramset_type' : 'unconstrained' ,
132- 'n_parameters' : 1 ,
133- 'is_shared' : True ,
134- 'inits' : (p ['init' ],),
135- 'bounds' : ((p ['min' ], p ['max' ]),),
136- 'is_scalar' : True ,
137- 'fixed' : False ,
138- }]
145+ p ['name' ]: [
146+ {
147+ 'paramset_type' : 'unconstrained' ,
148+ 'n_parameters' : 1 ,
149+ 'is_shared' : True ,
150+ 'inits' : (p ['init' ],),
151+ 'bounds' : ((p ['min' ], p ['max' ]),),
152+ 'is_scalar' : True ,
153+ 'fixed' : False ,
154+ }
155+ ]
139156 for p in new_params
140157 }
141- return param_spec
158+ return param_spec
0 commit comments