Skip to content

Commit d45fc58

Browse files
committed
first version of custom modifiers with sympy and jax
1 parent 40ebf6d commit d45fc58

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

src/pyhf/contrib/extended_modifiers/__init__.py

Whitespace-only changes.
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
2+
import sympy.parsing.sympy_parser as parser
3+
import sympy
4+
from pyhf.parameters import ParamViewer
5+
import jax.numpy as jnp
6+
import jax
7+
8+
def create_modifiers(additional_parameters = None):
9+
10+
class PureFunctionModifierBuilder:
11+
is_shared = True
12+
def __init__(self, pdfconfig):
13+
self.config = pdfconfig
14+
self.required_parsets = additional_parameters or {}
15+
self.builder_data = {'local': {},'global': {'symbols': set()}}
16+
17+
def collect(self, thismod, nom):
18+
maskval = True if thismod else False
19+
mask = [maskval] * len(nom)
20+
return {'mask': mask}
21+
22+
def append(self, key, channel, sample, thismod, defined_samp):
23+
self.builder_data['local'].setdefault(key, {}).setdefault(sample, {}).setdefault('data', {'mask': []})
24+
25+
nom = (
26+
defined_samp['data']
27+
if defined_samp
28+
else [0.0] * self.config.channel_nbins[channel]
29+
)
30+
moddata = self.collect(thismod, nom)
31+
self.builder_data['local'][key][sample]['data']['mask'] += moddata['mask']
32+
33+
if thismod is not None:
34+
formula = thismod['data']['formula']
35+
parsed = parser.parse_expr(formula)
36+
free_symbols = parsed.free_symbols
37+
for x in free_symbols:
38+
self.builder_data['global'].setdefault('symbols',set()).add(x)
39+
else:
40+
parsed = None
41+
self.builder_data['local'].setdefault(key,{}).setdefault(sample,{}).setdefault('channels',{}).setdefault(channel,{})['parsed'] = parsed
42+
43+
def finalize(self):
44+
list_of_symbols = [str(x) for x in self.builder_data['global']['symbols']]
45+
self.builder_data['global']['symbol_names'] = list_of_symbols
46+
for modname, modspec in self.builder_data['local'].items():
47+
for sample, samplespec in modspec.items():
48+
for channel, channelspec in samplespec['channels'].items():
49+
if channelspec['parsed'] is not None:
50+
channelspec['jaxfunc'] = sympy.lambdify(list_of_symbols, channelspec['parsed'], 'jax')
51+
else:
52+
channelspec['jaxfunc'] = lambda *args: 1.0
53+
return self.builder_data
54+
55+
class PureFunctionModifierApplicator:
56+
op_code = 'multiplication'
57+
name = 'purefunc'
58+
59+
def __init__(
60+
self, modifiers=None, pdfconfig=None, builder_data=None, batch_size=None
61+
):
62+
self.builder_data = builder_data
63+
self.batch_size = batch_size
64+
self.pdfconfig = pdfconfig
65+
self.inputs = [str(x) for x in builder_data['global']['symbols']]
66+
67+
self.keys = [f'{mtype}/{m}' for m, mtype in modifiers]
68+
self.modifiers = [m for m, _ in modifiers]
69+
70+
parfield_shape = (
71+
(self.batch_size, pdfconfig.npars)
72+
if self.batch_size
73+
else (pdfconfig.npars,)
74+
)
75+
76+
self.param_viewer = ParamViewer(parfield_shape, pdfconfig.par_map, self.inputs)
77+
self.create_jax_eval()
78+
79+
def create_jax_eval(self):
80+
def eval_func(pars):
81+
return jnp.array([
82+
[
83+
jnp.concatenate([
84+
self.builder_data['local'][m][s]['channels'][c]['jaxfunc'](*pars)*jnp.ones(self.pdfconfig.channel_nbins[c])
85+
for c in self.pdfconfig.channels
86+
])
87+
for s in self.pdfconfig.samples
88+
]
89+
for m in self.keys
90+
91+
])
92+
self.jaxeval = eval_func
93+
94+
def apply_nonbatched(self,pars):
95+
return jnp.expand_dims(self.jaxeval(pars),2)
96+
97+
def apply_batched(self,pars):
98+
return jax.vmap(self.jaxeval, in_axes=(1,), out_axes=2)(pars)
99+
100+
def apply(self, pars):
101+
if not self.param_viewer.index_selection:
102+
return
103+
if self.batch_size is None:
104+
par_selection = self.param_viewer.get(pars)
105+
results_purefunc = self.apply_nonbatched(par_selection)
106+
else:
107+
par_selection = self.param_viewer.get(pars)
108+
results_purefunc = self.apply_batched(par_selection)
109+
return results_purefunc
110+
111+
return PureFunctionModifierBuilder, PureFunctionModifierApplicator
112+
113+
114+
from pyhf.modifiers import histfactory_set
115+
116+
def enable(new_params = None):
117+
modifier_set = {}
118+
modifier_set.update(**histfactory_set)
119+
120+
builder, applicator = create_modifiers(new_params)
121+
122+
modifier_set.update(**{
123+
applicator.name: (builder, applicator)}
124+
)
125+
return modifier_set
126+
127+
def new_unconstrained_scalars(new_params):
128+
param_spec = {
129+
p['name']:
130+
[{
131+
'paramset_type': 'unconstrained',
132+
'n_parameters': 1,
133+
'is_shared': True,
134+
'inits': (p['init'],),
135+
'bounds': ((p['min'], p['max']),),
136+
'is_scalar': True,
137+
'fixed': False,
138+
}]
139+
for p in new_params
140+
}
141+
return param_spec

0 commit comments

Comments
 (0)