Skip to content

Commit e3640f6

Browse files
committed
Merge branch 'master' of github.com:GilesStrong/pytorch_inferno
2 parents fdbd3d7 + d2068bf commit e3640f6

File tree

6 files changed

+59
-44
lines changed

6 files changed

+59
-44
lines changed

README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,9 @@ If you have used this implementation of INFERNO in your analysis work and wish t
5353
```
5454
@misc{giles_chatham_strong_2021_4597140,
5555
  author = {Giles Chatham Strong},
56-
  title = {LUMIN},
57-
  month = mar,
58-
  year = 2021,
59-
  note = {{Please check https://github.com/GilesStrong/pytorch_inferno/graphs/contributors for the full list of contributors}},
56+
  title = {PyTorch INFERNO},
57+
  month = Mar,
58+
  year = 2021,
6059
  doi = {10.5281/zenodo.4597140},
6160
  url = {https://doi.org/10.5281/zenodo.4597140}
6261
}

docs/Gemfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ source "https://rubygems.org"
33
gem 'github-pages', group: :jekyll_plugins
44

55
# Added at 2019-11-25 10:11:40 -0800 by jhoward:
6-
gem "nokogiri", "< 1.11.1"
6+
gem "nokogiri", "< 1.11.5"
77
gem "jekyll", ">= 3.7"
88
gem "kramdown", ">= 2.3.0"

docs/Gemfile.lock

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,14 @@ GEM
201201
rb-fsevent (~> 0.10, >= 0.10.3)
202202
rb-inotify (~> 0.9, >= 0.9.10)
203203
mercenary (0.3.6)
204-
mini_portile2 (2.5.0)
204+
mini_portile2 (2.5.1)
205205
minima (2.5.1)
206206
jekyll (>= 3.5, < 5.0)
207207
jekyll-feed (~> 0.9)
208208
jekyll-seo-tag (~> 2.1)
209209
minitest (5.14.1)
210210
multipart-post (2.1.1)
211-
nokogiri (1.11.0)
211+
nokogiri (1.11.4)
212212
mini_portile2 (~> 2.5.0)
213213
racc (~> 1.4)
214214
octokit (4.18.0)
@@ -221,7 +221,7 @@ GEM
221221
rb-fsevent (0.10.4)
222222
rb-inotify (0.10.1)
223223
ffi (~> 1.0)
224-
rexml (3.2.4)
224+
rexml (3.2.5)
225225
rouge (3.19.0)
226226
ruby-enum (0.8.0)
227227
i18n
@@ -257,7 +257,7 @@ DEPENDENCIES
257257
github-pages
258258
jekyll (>= 3.7)
259259
kramdown (>= 2.3.0)
260-
nokogiri (< 1.11.1)
260+
nokogiri (< 1.11.5)
261261

262262
BUNDLED WITH
263263
2.0.2

nbs/07_inferno_exact.ipynb

Lines changed: 27 additions & 20 deletions
Large diffs are not rendered by default.

nbs/08_inferno_interp.ipynb

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@
138138
" if hasattr(c, 'loss_is_meaned'): c.loss_is_meaned = False # Ensure that average losses are correct\n",
139139
" \n",
140140
" @abstractmethod\n",
141-
" def _get_up_down(self, x_s:Tensor, x_b:Tensor) -> Tuple[Tuple[Optional[Tensor],Optional[Tensor]],Tuple[Optional[Tensor],Optional[Tensor]]]:\n",
141+
" def _get_up_down(self, x_s:Tensor, x_b:Tensor, w_s:Optional[Tensor]=None, w_b:Optional[Tensor]=None) -> Tuple[Tuple[Optional[Tensor],Optional[Tensor]],Tuple[Optional[Tensor],Optional[Tensor]]]:\n",
142142
" r'''Compute upd/down shapes for signal and background seperately. Overide this for specific problem.'''\n",
143143
" pass\n",
144144
" \n",
@@ -163,8 +163,10 @@
163163
" def on_forwards_end(self) -> None:\n",
164164
" r'''Compute loss and replace wrapper loss value'''\n",
165165
" b = self.wrapper.y.squeeze() == 0\n",
166-
" f_s = self.to_shape(self.wrapper.y_pred[~b])\n",
167-
" f_b = self.to_shape(self.wrapper.y_pred[b])\n",
166+
" w_s = self.wrapper.w[~b] if self.wrapper.w is not None else None\n",
167+
" w_b = self.wrapper.w[b] if self.wrapper.w is not None else None\n",
168+
" f_s = self.to_shape(self.wrapper.y_pred[~b], w_s)\n",
169+
" f_b = self.to_shape(self.wrapper.y_pred[b], w_b)\n",
168170
" (f_s_up,f_s_dw),(f_b_up,f_b_dw)= self._get_up_down(self.wrapper.x[~b], self.wrapper.x[b])\n",
169171
" self.wrapper.loss_val = self.get_ikk(f_s_nom=f_s, f_b_nom=f_b, f_s_up=f_s_up, f_s_dw=f_s_dw, f_b_up=f_b_up, f_b_dw=f_b_dw)"
170172
]
@@ -195,7 +197,7 @@
195197
" self.l_mod_t[0][0,2] = self.l_mods[0]/self.l_init\n",
196198
" self.l_mod_t[1][0,2] = self.l_mods[1]/self.l_init\n",
197199
" \n",
198-
" def _get_up_down(self, x_s:Tensor, x_b:Tensor) -> Tuple[Tuple[Optional[Tensor],Optional[Tensor]],Tuple[Optional[Tensor],Optional[Tensor]]]:\n",
200+
" def _get_up_down(self, x_s:Tensor, x_b:Tensor, **kwargs) -> Tuple[Tuple[Optional[Tensor],Optional[Tensor]],Tuple[Optional[Tensor],Optional[Tensor]]]:\n",
199201
" if self.r_mods is None and self.l_mods is None: return (None,None),(None,None)\n",
200202
" u,d = [],[]\n",
201203
" if self.r_mods is not None:\n",

pytorch_inferno/inferno.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,18 +84,23 @@ def get_inv_ikk(self, f_s:Tensor, f_b:Tensor, f_s_asimov:Tensor, f_b_asimov:Tens
8484
return torch.inverse(h)[self.poi_idx,self.poi_idx]
8585

8686
@staticmethod
87-
def to_shape(p:Tensor) -> Tensor:
88-
f = p.sum(0)+1e-7
87+
def to_shape(p:Tensor, w:Optional[Tensor]=None) -> Tensor:
88+
f = (p*w).sum(0)+1e-7 if w is not None else p.sum(0)+1e-7
8989
return f/f.sum()
9090

9191
def on_forwards_end(self) -> None:
9292
r'''Compute loss and replace wrapper loss value'''
93+
94+
w_s = self.wrapper.w[~self.b_mask] if self.wrapper.w is not None else None
95+
w_b = self.wrapper.w[self.b_mask] if self.wrapper.w is not None else None
96+
9397
# Shapes with derivatives w.r.t. nuisances
94-
f_s = self.to_shape(self.wrapper.y_pred[~self.b_mask])
95-
f_b = self.to_shape(self.wrapper.y_pred[self.b_mask])
98+
f_s = self.to_shape(self.wrapper.y_pred[~self.b_mask], w_s)
99+
f_b = self.to_shape(self.wrapper.y_pred[self.b_mask], w_b)
100+
96101
# Shapes without derivatives w.r.t. nuisances
97-
f_s_asimov = self.to_shape(self.wrapper.model(self.wrapper.x[~self.b_mask].detach())) if self.s_shape_alpha else f_s
98-
f_b_asimov = self.to_shape(self.wrapper.model(self.wrapper.x[self.b_mask].detach())) if self.b_shape_alpha else f_b
102+
f_s_asimov = self.to_shape(self.wrapper.model(self.wrapper.x[~self.b_mask].detach()), w_s) if self.s_shape_alpha else f_s
103+
f_b_asimov = self.to_shape(self.wrapper.model(self.wrapper.x[self.b_mask].detach()), w_b) if self.b_shape_alpha else f_b
99104

100105
self.wrapper.loss_val = self.get_inv_ikk(f_s=f_s, f_b=f_b, f_s_asimov=f_s_asimov, f_b_asimov=f_b_asimov)
101106

@@ -141,7 +146,7 @@ def on_train_begin(self) -> None:
141146
if hasattr(c, 'loss_is_meaned'): c.loss_is_meaned = False # Ensure that average losses are correct
142147

143148
@abstractmethod
144-
def _get_up_down(self, x_s:Tensor, x_b:Tensor) -> Tuple[Tuple[Optional[Tensor],Optional[Tensor]],Tuple[Optional[Tensor],Optional[Tensor]]]:
149+
def _get_up_down(self, x_s:Tensor, x_b:Tensor, w_s:Optional[Tensor]=None, w_b:Optional[Tensor]=None) -> Tuple[Tuple[Optional[Tensor],Optional[Tensor]],Tuple[Optional[Tensor],Optional[Tensor]]]:
145150
r'''Compute upd/down shapes for signal and background seperately. Overide this for specific problem.'''
146151
pass
147152

@@ -166,8 +171,10 @@ def get_ikk(self, f_s_nom:Tensor, f_b_nom:Tensor, f_s_up:Optional[Tensor], f_s_d
166171
def on_forwards_end(self) -> None:
167172
r'''Compute loss and replace wrapper loss value'''
168173
b = self.wrapper.y.squeeze() == 0
169-
f_s = self.to_shape(self.wrapper.y_pred[~b])
170-
f_b = self.to_shape(self.wrapper.y_pred[b])
174+
w_s = self.wrapper.w[~b] if self.wrapper.w is not None else None
175+
w_b = self.wrapper.w[b] if self.wrapper.w is not None else None
176+
f_s = self.to_shape(self.wrapper.y_pred[~b], w_s)
177+
f_b = self.to_shape(self.wrapper.y_pred[b], w_b)
171178
(f_s_up,f_s_dw),(f_b_up,f_b_dw)= self._get_up_down(self.wrapper.x[~b], self.wrapper.x[b])
172179
self.wrapper.loss_val = self.get_ikk(f_s_nom=f_s, f_b_nom=f_b, f_s_up=f_s_up, f_s_dw=f_s_dw, f_b_up=f_b_up, f_b_dw=f_b_dw)
173180

@@ -191,8 +198,8 @@ def on_train_begin(self) -> None:
191198
self.l_mod_t[0][0,2] = self.l_mods[0]/self.l_init
192199
self.l_mod_t[1][0,2] = self.l_mods[1]/self.l_init
193200

194-
def _get_up_down(self, x_s:Tensor, x_b:Tensor) -> Tuple[Tuple[Optional[Tensor],Optional[Tensor]],Tuple[Optional[Tensor],Optional[Tensor]]]:
195-
if self.r_mods is None and self.l_mods is None: return None,None
201+
def _get_up_down(self, x_s:Tensor, x_b:Tensor, **kwargs) -> Tuple[Tuple[Optional[Tensor],Optional[Tensor]],Tuple[Optional[Tensor],Optional[Tensor]]]:
202+
if self.r_mods is None and self.l_mods is None: return (None,None),(None,None)
196203
u,d = [],[]
197204
if self.r_mods is not None:
198205
with torch.no_grad(): x_b = x_b+self.r_mod_t[0]

0 commit comments

Comments
 (0)