@@ -84,18 +84,23 @@ def get_inv_ikk(self, f_s:Tensor, f_b:Tensor, f_s_asimov:Tensor, f_b_asimov:Tens
84
84
return torch .inverse (h )[self .poi_idx ,self .poi_idx ]
85
85
86
86
@staticmethod
87
- def to_shape (p :Tensor ) -> Tensor :
88
- f = p .sum (0 )+ 1e-7
87
+ def to_shape (p :Tensor , w : Optional [ Tensor ] = None ) -> Tensor :
88
+ f = ( p * w ). sum ( 0 ) + 1e-7 if w is not None else p .sum (0 )+ 1e-7
89
89
return f / f .sum ()
90
90
91
91
def on_forwards_end (self ) -> None :
92
92
r'''Compute loss and replace wrapper loss value'''
93
+
94
+ w_s = self .wrapper .w [~ self .b_mask ] if self .wrapper .w is not None else None
95
+ w_b = self .wrapper .w [self .b_mask ] if self .wrapper .w is not None else None
96
+
93
97
# Shapes with derivatives w.r.t. nuisances
94
- f_s = self .to_shape (self .wrapper .y_pred [~ self .b_mask ])
95
- f_b = self .to_shape (self .wrapper .y_pred [self .b_mask ])
98
+ f_s = self .to_shape (self .wrapper .y_pred [~ self .b_mask ], w_s )
99
+ f_b = self .to_shape (self .wrapper .y_pred [self .b_mask ], w_b )
100
+
96
101
# Shapes without derivatives w.r.t. nuisances
97
- f_s_asimov = self .to_shape (self .wrapper .model (self .wrapper .x [~ self .b_mask ].detach ())) if self .s_shape_alpha else f_s
98
- f_b_asimov = self .to_shape (self .wrapper .model (self .wrapper .x [self .b_mask ].detach ())) if self .b_shape_alpha else f_b
102
+ f_s_asimov = self .to_shape (self .wrapper .model (self .wrapper .x [~ self .b_mask ].detach ()), w_s ) if self .s_shape_alpha else f_s
103
+ f_b_asimov = self .to_shape (self .wrapper .model (self .wrapper .x [self .b_mask ].detach ()), w_b ) if self .b_shape_alpha else f_b
99
104
100
105
self .wrapper .loss_val = self .get_inv_ikk (f_s = f_s , f_b = f_b , f_s_asimov = f_s_asimov , f_b_asimov = f_b_asimov )
101
106
@@ -141,7 +146,7 @@ def on_train_begin(self) -> None:
141
146
if hasattr (c , 'loss_is_meaned' ): c .loss_is_meaned = False # Ensure that average losses are correct
142
147
143
148
@abstractmethod
144
- def _get_up_down (self , x_s :Tensor , x_b :Tensor ) -> Tuple [Tuple [Optional [Tensor ],Optional [Tensor ]],Tuple [Optional [Tensor ],Optional [Tensor ]]]:
149
+ def _get_up_down (self , x_s :Tensor , x_b :Tensor , w_s : Optional [ Tensor ] = None , w_b : Optional [ Tensor ] = None ) -> Tuple [Tuple [Optional [Tensor ],Optional [Tensor ]],Tuple [Optional [Tensor ],Optional [Tensor ]]]:
145
150
r'''Compute upd/down shapes for signal and background seperately. Overide this for specific problem.'''
146
151
pass
147
152
@@ -166,8 +171,10 @@ def get_ikk(self, f_s_nom:Tensor, f_b_nom:Tensor, f_s_up:Optional[Tensor], f_s_d
166
171
def on_forwards_end (self ) -> None :
167
172
r'''Compute loss and replace wrapper loss value'''
168
173
b = self .wrapper .y .squeeze () == 0
169
- f_s = self .to_shape (self .wrapper .y_pred [~ b ])
170
- f_b = self .to_shape (self .wrapper .y_pred [b ])
174
+ w_s = self .wrapper .w [~ b ] if self .wrapper .w is not None else None
175
+ w_b = self .wrapper .w [b ] if self .wrapper .w is not None else None
176
+ f_s = self .to_shape (self .wrapper .y_pred [~ b ], w_s )
177
+ f_b = self .to_shape (self .wrapper .y_pred [b ], w_b )
171
178
(f_s_up ,f_s_dw ),(f_b_up ,f_b_dw )= self ._get_up_down (self .wrapper .x [~ b ], self .wrapper .x [b ])
172
179
self .wrapper .loss_val = self .get_ikk (f_s_nom = f_s , f_b_nom = f_b , f_s_up = f_s_up , f_s_dw = f_s_dw , f_b_up = f_b_up , f_b_dw = f_b_dw )
173
180
@@ -191,8 +198,8 @@ def on_train_begin(self) -> None:
191
198
self .l_mod_t [0 ][0 ,2 ] = self .l_mods [0 ]/ self .l_init
192
199
self .l_mod_t [1 ][0 ,2 ] = self .l_mods [1 ]/ self .l_init
193
200
194
- def _get_up_down (self , x_s :Tensor , x_b :Tensor ) -> Tuple [Tuple [Optional [Tensor ],Optional [Tensor ]],Tuple [Optional [Tensor ],Optional [Tensor ]]]:
195
- if self .r_mods is None and self .l_mods is None : return None ,None
201
+ def _get_up_down (self , x_s :Tensor , x_b :Tensor , ** kwargs ) -> Tuple [Tuple [Optional [Tensor ],Optional [Tensor ]],Tuple [Optional [Tensor ],Optional [Tensor ]]]:
202
+ if self .r_mods is None and self .l_mods is None : return ( None ,None ),( None , None )
196
203
u ,d = [],[]
197
204
if self .r_mods is not None :
198
205
with torch .no_grad (): x_b = x_b + self .r_mod_t [0 ]
0 commit comments