Skip to content

Commit 1eedbe4

Browse files
committed
switch to on the fly creation of required parameters
1 parent d45fc58 commit 1eedbe4

File tree

1 file changed

+25
-21
lines changed

1 file changed

+25
-21
lines changed

src/pyhf/contrib/extended_modifiers/purefunc.py

+25-21
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,37 @@
55
import jax.numpy as jnp
66
import jax
77

8-
def create_modifiers(additional_parameters = None):
8+
def create_modifiers():
99

1010
class PureFunctionModifierBuilder:
1111
is_shared = True
1212
def __init__(self, pdfconfig):
1313
self.config = pdfconfig
14-
self.required_parsets = additional_parameters or {}
14+
self.required_parsets = {}
1515
self.builder_data = {'local': {},'global': {'symbols': set()}}
1616

1717
def collect(self, thismod, nom):
1818
maskval = True if thismod else False
1919
mask = [maskval] * len(nom)
2020
return {'mask': mask}
2121

22+
def require_synbols_as_scalars(self, symbols):
23+
param_spec = {
24+
p:
25+
[{
26+
'paramset_type': 'unconstrained',
27+
'n_parameters': 1,
28+
'is_shared': True,
29+
'inits': (1.0,),
30+
'bounds': ((0,10),),
31+
'is_scalar': True,
32+
'fixed': False,
33+
}]
34+
for p in symbols
35+
}
36+
return param_spec
37+
38+
2239
def append(self, key, channel, sample, thismod, defined_samp):
2340
self.builder_data['local'].setdefault(key, {}).setdefault(sample, {}).setdefault('data', {'mask': []})
2441

@@ -42,6 +59,9 @@ def append(self, key, channel, sample, thismod, defined_samp):
4259

4360
def finalize(self):
4461
list_of_symbols = [str(x) for x in self.builder_data['global']['symbols']]
62+
63+
self.required_parsets = self.require_synbols_as_scalars(list_of_symbols)
64+
4565
self.builder_data['global']['symbol_names'] = list_of_symbols
4666
for modname, modspec in self.builder_data['local'].items():
4767
for sample, samplespec in modspec.items():
@@ -113,29 +133,13 @@ def apply(self, pars):
113133

114134
from pyhf.modifiers import histfactory_set
115135

116-
def enable(new_params = None):
136+
def enable():
117137
modifier_set = {}
118138
modifier_set.update(**histfactory_set)
119139

120-
builder, applicator = create_modifiers(new_params)
140+
builder, applicator = create_modifiers()
121141

122142
modifier_set.update(**{
123143
applicator.name: (builder, applicator)}
124144
)
125-
return modifier_set
126-
127-
def new_unconstrained_scalars(new_params):
128-
param_spec = {
129-
p['name']:
130-
[{
131-
'paramset_type': 'unconstrained',
132-
'n_parameters': 1,
133-
'is_shared': True,
134-
'inits': (p['init'],),
135-
'bounds': ((p['min'], p['max']),),
136-
'is_scalar': True,
137-
'fixed': False,
138-
}]
139-
for p in new_params
140-
}
141-
return param_spec
145+
return modifier_set

0 commit comments

Comments
 (0)