Skip to content

Commit 3900dbc

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d45fc58 commit 3900dbc

File tree

1 file changed

+53
-36
lines changed

1 file changed

+53
-36
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,29 @@
1-
21
import sympy.parsing.sympy_parser as parser
32
import sympy
43
from pyhf.parameters import ParamViewer
54
import jax.numpy as jnp
65
import jax
76

8-
def create_modifiers(additional_parameters = None):
7+
8+
def create_modifiers(additional_parameters=None):
99

1010
class PureFunctionModifierBuilder:
1111
is_shared = True
12+
1213
def __init__(self, pdfconfig):
1314
self.config = pdfconfig
1415
self.required_parsets = additional_parameters or {}
15-
self.builder_data = {'local': {},'global': {'symbols': set()}}
16+
self.builder_data = {'local': {}, 'global': {'symbols': set()}}
1617

1718
def collect(self, thismod, nom):
1819
maskval = True if thismod else False
1920
mask = [maskval] * len(nom)
2021
return {'mask': mask}
2122

2223
def append(self, key, channel, sample, thismod, defined_samp):
23-
self.builder_data['local'].setdefault(key, {}).setdefault(sample, {}).setdefault('data', {'mask': []})
24+
self.builder_data['local'].setdefault(key, {}).setdefault(
25+
sample, {}
26+
).setdefault('data', {'mask': []})
2427

2528
nom = (
2629
defined_samp['data']
@@ -35,10 +38,12 @@ def append(self, key, channel, sample, thismod, defined_samp):
3538
parsed = parser.parse_expr(formula)
3639
free_symbols = parsed.free_symbols
3740
for x in free_symbols:
38-
self.builder_data['global'].setdefault('symbols',set()).add(x)
41+
self.builder_data['global'].setdefault('symbols', set()).add(x)
3942
else:
4043
parsed = None
41-
self.builder_data['local'].setdefault(key,{}).setdefault(sample,{}).setdefault('channels',{}).setdefault(channel,{})['parsed'] = parsed
44+
self.builder_data['local'].setdefault(key, {}).setdefault(
45+
sample, {}
46+
).setdefault('channels', {}).setdefault(channel, {})['parsed'] = parsed
4247

4348
def finalize(self):
4449
list_of_symbols = [str(x) for x in self.builder_data['global']['symbols']]
@@ -47,7 +52,9 @@ def finalize(self):
4752
for sample, samplespec in modspec.items():
4853
for channel, channelspec in samplespec['channels'].items():
4954
if channelspec['parsed'] is not None:
50-
channelspec['jaxfunc'] = sympy.lambdify(list_of_symbols, channelspec['parsed'], 'jax')
55+
channelspec['jaxfunc'] = sympy.lambdify(
56+
list_of_symbols, channelspec['parsed'], 'jax'
57+
)
5158
else:
5259
channelspec['jaxfunc'] = lambda *args: 1.0
5360
return self.builder_data
@@ -73,28 +80,37 @@ def __init__(
7380
else (pdfconfig.npars,)
7481
)
7582

76-
self.param_viewer = ParamViewer(parfield_shape, pdfconfig.par_map, self.inputs)
83+
self.param_viewer = ParamViewer(
84+
parfield_shape, pdfconfig.par_map, self.inputs
85+
)
7786
self.create_jax_eval()
7887

7988
def create_jax_eval(self):
8089
def eval_func(pars):
81-
return jnp.array([
90+
return jnp.array(
8291
[
83-
jnp.concatenate([
84-
self.builder_data['local'][m][s]['channels'][c]['jaxfunc'](*pars)*jnp.ones(self.pdfconfig.channel_nbins[c])
85-
for c in self.pdfconfig.channels
86-
])
87-
for s in self.pdfconfig.samples
92+
[
93+
jnp.concatenate(
94+
[
95+
self.builder_data['local'][m][s]['channels'][c][
96+
'jaxfunc'
97+
](*pars)
98+
* jnp.ones(self.pdfconfig.channel_nbins[c])
99+
for c in self.pdfconfig.channels
100+
]
101+
)
102+
for s in self.pdfconfig.samples
103+
]
104+
for m in self.keys
88105
]
89-
for m in self.keys
106+
)
90107

91-
])
92108
self.jaxeval = eval_func
93-
94-
def apply_nonbatched(self,pars):
95-
return jnp.expand_dims(self.jaxeval(pars),2)
96109

97-
def apply_batched(self,pars):
110+
def apply_nonbatched(self, pars):
111+
return jnp.expand_dims(self.jaxeval(pars), 2)
112+
113+
def apply_batched(self, pars):
98114
return jax.vmap(self.jaxeval, in_axes=(1,), out_axes=2)(pars)
99115

100116
def apply(self, pars):
@@ -107,35 +123,36 @@ def apply(self, pars):
107123
par_selection = self.param_viewer.get(pars)
108124
results_purefunc = self.apply_batched(par_selection)
109125
return results_purefunc
110-
126+
111127
return PureFunctionModifierBuilder, PureFunctionModifierApplicator
112128

113129

114130
from pyhf.modifiers import histfactory_set
115131

116-
def enable(new_params = None):
132+
133+
def enable(new_params=None):
117134
modifier_set = {}
118135
modifier_set.update(**histfactory_set)
119136

120137
builder, applicator = create_modifiers(new_params)
121138

122-
modifier_set.update(**{
123-
applicator.name: (builder, applicator)}
124-
)
139+
modifier_set.update(**{applicator.name: (builder, applicator)})
125140
return modifier_set
126141

142+
127143
def new_unconstrained_scalars(new_params):
128144
param_spec = {
129-
p['name']:
130-
[{
131-
'paramset_type': 'unconstrained',
132-
'n_parameters': 1,
133-
'is_shared': True,
134-
'inits': (p['init'],),
135-
'bounds': ((p['min'], p['max']),),
136-
'is_scalar': True,
137-
'fixed': False,
138-
}]
145+
p['name']: [
146+
{
147+
'paramset_type': 'unconstrained',
148+
'n_parameters': 1,
149+
'is_shared': True,
150+
'inits': (p['init'],),
151+
'bounds': ((p['min'], p['max']),),
152+
'is_scalar': True,
153+
'fixed': False,
154+
}
155+
]
139156
for p in new_params
140157
}
141-
return param_spec
158+
return param_spec

0 commit comments

Comments
 (0)