8
8
9
9
10
10
def focal_loss_with_logits (
11
- input : torch .Tensor ,
11
+ output : torch .Tensor ,
12
12
target : torch .Tensor ,
13
- gamma = 2.0 ,
13
+ gamma : float = 2.0 ,
14
14
alpha : Optional [float ] = 0.25 ,
15
- reduction = "mean" ,
16
- normalized = False ,
15
+ reduction : str = "mean" ,
16
+ normalized : bool = False ,
17
17
reduced_threshold : Optional [float ] = None ,
18
+ eps : float = 1e-6 ,
18
19
) -> torch .Tensor :
19
20
"""Compute binary focal loss between target and output logits.
20
21
21
22
See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.
22
23
23
24
Args:
24
- input : Tensor of arbitrary shape
25
+ output : Tensor of arbitrary shape (predictions of the model)
25
26
target: Tensor of the same shape as input
27
+ gamma: Focal loss power factor
28
+ alpha: Weight factor to balance positive and negative samples. Alpha must be in [0...1] range,
29
+ high values will give more weight to positive class.
26
30
reduction (string, optional): Specifies the reduction to apply to the output:
27
31
'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied,
28
32
'mean': the sum of the output will be divided by the number of
@@ -32,18 +36,18 @@ def focal_loss_with_logits(
32
36
'batchwise_mean' computes mean loss per sample in batch. Default: 'mean'
33
37
normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
34
38
reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347).
35
- References::
36
39
40
+ References:
37
41
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py
38
42
"""
39
- target = target .type (input .type ())
43
+ target = target .type (output .type ())
40
44
41
- logpt = F .binary_cross_entropy_with_logits (input , target , reduction = "none" )
45
+ logpt = F .binary_cross_entropy_with_logits (output , target , reduction = "none" )
42
46
pt = torch .exp (- logpt )
43
47
44
48
# compute the loss
45
49
if reduced_threshold is None :
46
- focal_term = (1 - pt ).pow (gamma )
50
+ focal_term = (1.0 - pt ).pow (gamma )
47
51
else :
48
52
focal_term = ((1.0 - pt ) / reduced_threshold ).pow (gamma )
49
53
focal_term [pt < reduced_threshold ] = 1
@@ -54,7 +58,7 @@ def focal_loss_with_logits(
54
58
loss *= alpha * target + (1 - alpha ) * (1 - target )
55
59
56
60
if normalized :
57
- norm_factor = focal_term .sum () + 1e-5
61
+ norm_factor = focal_term .sum (). clamp_min ( eps )
58
62
loss /= norm_factor
59
63
60
64
if reduction == "mean" :
@@ -72,19 +76,22 @@ def focal_loss_with_logits(
72
76
73
77
74
78
# TODO: Mark as deprecated and emit warning
75
- def reduced_focal_loss (input : torch .Tensor , target : torch .Tensor , threshold = 0.5 , gamma = 2.0 , reduction = "mean" ):
79
+ def reduced_focal_loss (output : torch .Tensor , target : torch .Tensor , threshold = 0.5 , gamma = 2.0 , reduction = "mean" ):
76
80
return focal_loss_with_logits (
77
- input , target , alpha = None , gamma = gamma , reduction = reduction , reduced_threshold = threshold
81
+ output , target , alpha = None , gamma = gamma , reduction = reduction , reduced_threshold = threshold
78
82
)
79
83
80
84
81
- def soft_jaccard_score (y_pred : torch .Tensor , y_true : torch .Tensor , smooth = 0.0 , eps = 1e-7 , dims = None ) -> torch .Tensor :
85
+ def soft_jaccard_score (
86
+ output : torch .Tensor , target : torch .Tensor , smooth : float = 0.0 , eps : float = 1e-7 , dims = None
87
+ ) -> torch .Tensor :
82
88
"""
83
89
84
- :param y_pred :
85
- :param y_true :
90
+ :param output :
91
+ :param target :
86
92
:param smooth:
87
93
:param eps:
94
+ :param dims:
88
95
:return:
89
96
90
97
Shape:
@@ -94,25 +101,27 @@ def soft_jaccard_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0.0, e
94
101
- Output: scalar.
95
102
96
103
"""
97
- assert y_pred .size () == y_true .size ()
104
+ assert output .size () == target .size ()
98
105
99
106
if dims is not None :
100
- intersection = torch .sum (y_pred * y_true , dim = dims )
101
- cardinality = torch .sum (y_pred + y_true , dim = dims )
107
+ intersection = torch .sum (output * target , dim = dims )
108
+ cardinality = torch .sum (output + target , dim = dims )
102
109
else :
103
- intersection = torch .sum (y_pred * y_true )
104
- cardinality = torch .sum (y_pred + y_true )
110
+ intersection = torch .sum (output * target )
111
+ cardinality = torch .sum (output + target )
105
112
106
113
union = cardinality - intersection
107
- jaccard_score = (intersection + smooth ) / (union . clamp_min ( eps ) + smooth )
114
+ jaccard_score = (intersection + smooth ) / (union + smooth ). clamp_min ( eps )
108
115
return jaccard_score
109
116
110
117
111
- def soft_dice_score (y_pred : torch .Tensor , y_true : torch .Tensor , smooth = 0 , eps = 1e-7 , dims = None ) -> torch .Tensor :
118
+ def soft_dice_score (
119
+ output : torch .Tensor , target : torch .Tensor , smooth : float = 0.0 , eps : float = 1e-7 , dims = None
120
+ ) -> torch .Tensor :
112
121
"""
113
122
114
- :param y_pred :
115
- :param y_true :
123
+ :param output :
124
+ :param target :
116
125
:param smooth:
117
126
:param eps:
118
127
:return:
@@ -124,28 +133,28 @@ def soft_dice_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0, eps=1e
124
133
- Output: scalar.
125
134
126
135
"""
127
- assert y_pred .size () == y_true .size ()
136
+ assert output .size () == target .size ()
128
137
if dims is not None :
129
- intersection = torch .sum (y_pred * y_true , dim = dims )
130
- cardinality = torch .sum (y_pred + y_true , dim = dims )
138
+ intersection = torch .sum (output * target , dim = dims )
139
+ cardinality = torch .sum (output + target , dim = dims )
131
140
else :
132
- intersection = torch .sum (y_pred * y_true )
133
- cardinality = torch .sum (y_pred + y_true )
134
- dice_score = (2.0 * intersection + smooth ) / (cardinality . clamp_min ( eps ) + smooth )
141
+ intersection = torch .sum (output * target )
142
+ cardinality = torch .sum (output + target )
143
+ dice_score = (2.0 * intersection + smooth ) / (cardinality + smooth ). clamp_min ( eps )
135
144
return dice_score
136
145
137
146
138
- def wing_loss (prediction : torch .Tensor , target : torch .Tensor , width = 5 , curvature = 0.5 , reduction = "mean" ):
147
+ def wing_loss (output : torch .Tensor , target : torch .Tensor , width = 5 , curvature = 0.5 , reduction = "mean" ):
139
148
"""
140
149
https://arxiv.org/pdf/1711.06753.pdf
141
- :param prediction :
150
+ :param output :
142
151
:param target:
143
152
:param width:
144
153
:param curvature:
145
154
:param reduction:
146
155
:return:
147
156
"""
148
- diff_abs = (target - prediction ).abs ()
157
+ diff_abs = (target - output ).abs ()
149
158
loss = diff_abs .clone ()
150
159
151
160
idx_smaller = diff_abs < width
@@ -163,3 +172,49 @@ def wing_loss(prediction: torch.Tensor, target: torch.Tensor, width=5, curvature
163
172
loss = loss .mean ()
164
173
165
174
return loss
175
+
176
+
177
+ def label_smoothed_nll_loss (
178
+ lprobs : torch .Tensor , target : torch .Tensor , epsilon : float , ignore_index = None , reduction = "mean" , dim = - 1
179
+ ) -> torch .Tensor :
180
+ """
181
+
182
+ Source: https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/label_smoothed_cross_entropy.py
183
+
184
+ :param lprobs: Log-probabilities of predictions (e.g after log_softmax)
185
+ :param target:
186
+ :param epsilon:
187
+ :param ignore_index:
188
+ :param reduction:
189
+ :return:
190
+ """
191
+ if target .dim () == lprobs .dim () - 1 :
192
+ target = target .unsqueeze (dim )
193
+
194
+ if ignore_index is not None :
195
+ pad_mask = target .eq (ignore_index )
196
+ target = target .masked_fill (pad_mask , 0 )
197
+ nll_loss = - lprobs .gather (dim = dim , index = target )
198
+ smooth_loss = - lprobs .sum (dim = dim , keepdim = True )
199
+
200
+ # nll_loss.masked_fill_(pad_mask, 0.0)
201
+ # smooth_loss.masked_fill_(pad_mask, 0.0)
202
+ nll_loss = nll_loss .masked_fill (pad_mask , 0.0 )
203
+ smooth_loss = smooth_loss .masked_fill (pad_mask , 0.0 )
204
+ else :
205
+ nll_loss = - lprobs .gather (dim = dim , index = target )
206
+ smooth_loss = - lprobs .sum (dim = dim , keepdim = True )
207
+
208
+ nll_loss = nll_loss .squeeze (dim )
209
+ smooth_loss = smooth_loss .squeeze (dim )
210
+
211
+ if reduction == "sum" :
212
+ nll_loss = nll_loss .sum ()
213
+ smooth_loss = smooth_loss .sum ()
214
+ if reduction == "mean" :
215
+ nll_loss = nll_loss .mean ()
216
+ smooth_loss = smooth_loss .mean ()
217
+
218
+ eps_i = epsilon / lprobs .size (dim )
219
+ loss = (1.0 - epsilon ) * nll_loss + eps_i * smooth_loss
220
+ return loss
0 commit comments