1
-
2
1
import sympy .parsing .sympy_parser as parser
3
2
import sympy
4
3
from pyhf .parameters import ParamViewer
5
4
import jax .numpy as jnp
6
5
import jax
7
6
8
- def create_modifiers (additional_parameters = None ):
7
+
8
+ def create_modifiers (additional_parameters = None ):
9
9
10
10
class PureFunctionModifierBuilder :
11
11
is_shared = True
12
+
12
13
def __init__ (self , pdfconfig ):
13
14
self .config = pdfconfig
14
15
self .required_parsets = additional_parameters or {}
15
- self .builder_data = {'local' : {},'global' : {'symbols' : set ()}}
16
+ self .builder_data = {'local' : {}, 'global' : {'symbols' : set ()}}
16
17
17
18
def collect (self , thismod , nom ):
18
19
maskval = True if thismod else False
19
20
mask = [maskval ] * len (nom )
20
21
return {'mask' : mask }
21
22
22
23
def append (self , key , channel , sample , thismod , defined_samp ):
23
- self .builder_data ['local' ].setdefault (key , {}).setdefault (sample , {}).setdefault ('data' , {'mask' : []})
24
+ self .builder_data ['local' ].setdefault (key , {}).setdefault (
25
+ sample , {}
26
+ ).setdefault ('data' , {'mask' : []})
24
27
25
28
nom = (
26
29
defined_samp ['data' ]
@@ -35,10 +38,12 @@ def append(self, key, channel, sample, thismod, defined_samp):
35
38
parsed = parser .parse_expr (formula )
36
39
free_symbols = parsed .free_symbols
37
40
for x in free_symbols :
38
- self .builder_data ['global' ].setdefault ('symbols' ,set ()).add (x )
41
+ self .builder_data ['global' ].setdefault ('symbols' , set ()).add (x )
39
42
else :
40
43
parsed = None
41
- self .builder_data ['local' ].setdefault (key ,{}).setdefault (sample ,{}).setdefault ('channels' ,{}).setdefault (channel ,{})['parsed' ] = parsed
44
+ self .builder_data ['local' ].setdefault (key , {}).setdefault (
45
+ sample , {}
46
+ ).setdefault ('channels' , {}).setdefault (channel , {})['parsed' ] = parsed
42
47
43
48
def finalize (self ):
44
49
list_of_symbols = [str (x ) for x in self .builder_data ['global' ]['symbols' ]]
@@ -47,7 +52,9 @@ def finalize(self):
47
52
for sample , samplespec in modspec .items ():
48
53
for channel , channelspec in samplespec ['channels' ].items ():
49
54
if channelspec ['parsed' ] is not None :
50
- channelspec ['jaxfunc' ] = sympy .lambdify (list_of_symbols , channelspec ['parsed' ], 'jax' )
55
+ channelspec ['jaxfunc' ] = sympy .lambdify (
56
+ list_of_symbols , channelspec ['parsed' ], 'jax'
57
+ )
51
58
else :
52
59
channelspec ['jaxfunc' ] = lambda * args : 1.0
53
60
return self .builder_data
@@ -73,28 +80,37 @@ def __init__(
73
80
else (pdfconfig .npars ,)
74
81
)
75
82
76
- self .param_viewer = ParamViewer (parfield_shape , pdfconfig .par_map , self .inputs )
83
+ self .param_viewer = ParamViewer (
84
+ parfield_shape , pdfconfig .par_map , self .inputs
85
+ )
77
86
self .create_jax_eval ()
78
87
79
88
def create_jax_eval (self ):
80
89
def eval_func (pars ):
81
- return jnp .array ([
90
+ return jnp .array (
82
91
[
83
- jnp .concatenate ([
84
- self .builder_data ['local' ][m ][s ]['channels' ][c ]['jaxfunc' ](* pars )* jnp .ones (self .pdfconfig .channel_nbins [c ])
85
- for c in self .pdfconfig .channels
86
- ])
87
- for s in self .pdfconfig .samples
92
+ [
93
+ jnp .concatenate (
94
+ [
95
+ self .builder_data ['local' ][m ][s ]['channels' ][c ][
96
+ 'jaxfunc'
97
+ ](* pars )
98
+ * jnp .ones (self .pdfconfig .channel_nbins [c ])
99
+ for c in self .pdfconfig .channels
100
+ ]
101
+ )
102
+ for s in self .pdfconfig .samples
103
+ ]
104
+ for m in self .keys
88
105
]
89
- for m in self . keys
106
+ )
90
107
91
- ])
92
108
self .jaxeval = eval_func
93
-
94
- def apply_nonbatched (self ,pars ):
95
- return jnp .expand_dims (self .jaxeval (pars ),2 )
96
109
97
- def apply_batched (self ,pars ):
110
+ def apply_nonbatched (self , pars ):
111
+ return jnp .expand_dims (self .jaxeval (pars ), 2 )
112
+
113
+ def apply_batched (self , pars ):
98
114
return jax .vmap (self .jaxeval , in_axes = (1 ,), out_axes = 2 )(pars )
99
115
100
116
def apply (self , pars ):
@@ -107,35 +123,36 @@ def apply(self, pars):
107
123
par_selection = self .param_viewer .get (pars )
108
124
results_purefunc = self .apply_batched (par_selection )
109
125
return results_purefunc
110
-
126
+
111
127
return PureFunctionModifierBuilder , PureFunctionModifierApplicator
112
128
113
129
114
130
from pyhf .modifiers import histfactory_set
115
131
116
- def enable (new_params = None ):
132
+
133
+ def enable (new_params = None ):
117
134
modifier_set = {}
118
135
modifier_set .update (** histfactory_set )
119
136
120
137
builder , applicator = create_modifiers (new_params )
121
138
122
- modifier_set .update (** {
123
- applicator .name : (builder , applicator )}
124
- )
139
+ modifier_set .update (** {applicator .name : (builder , applicator )})
125
140
return modifier_set
126
141
142
+
127
143
def new_unconstrained_scalars (new_params ):
128
144
param_spec = {
129
- p ['name' ]:
130
- [{
131
- 'paramset_type' : 'unconstrained' ,
132
- 'n_parameters' : 1 ,
133
- 'is_shared' : True ,
134
- 'inits' : (p ['init' ],),
135
- 'bounds' : ((p ['min' ], p ['max' ]),),
136
- 'is_scalar' : True ,
137
- 'fixed' : False ,
138
- }]
145
+ p ['name' ]: [
146
+ {
147
+ 'paramset_type' : 'unconstrained' ,
148
+ 'n_parameters' : 1 ,
149
+ 'is_shared' : True ,
150
+ 'inits' : (p ['init' ],),
151
+ 'bounds' : ((p ['min' ], p ['max' ]),),
152
+ 'is_scalar' : True ,
153
+ 'fixed' : False ,
154
+ }
155
+ ]
139
156
for p in new_params
140
157
}
141
- return param_spec
158
+ return param_spec
0 commit comments