8
8
9
9
10
10
class BinaryFocalLoss (_Loss ):
11
- def __init__ (self , alpha = 0.5 , gamma = 2 , ignore = None , reduction = 'mean' , reduced = False , threshold = 0.5 ):
11
+ def __init__ (self , alpha = 0.5 , gamma = 2 , ignore_index = None , reduction = 'mean' , reduced = False , threshold = 0.5 ):
12
12
"""
13
13
14
14
:param alpha:
15
15
:param gamma:
16
- :param ignore :
16
+ :param ignore_index :
17
17
:param reduced:
18
18
:param threshold:
19
19
"""
20
20
super ().__init__ ()
21
21
self .alpha = alpha
22
22
self .gamma = gamma
23
- self .ignore = ignore
23
+ self .ignore_index = ignore_index
24
24
if reduced :
25
25
self .focal_loss = partial (reduced_focal_loss , gamma = gamma , threshold = threshold , reduction = reduction )
26
26
else :
@@ -32,9 +32,9 @@ def forward(self, label_input, label_target):
32
32
label_target = label_target .view (- 1 )
33
33
label_input = label_input .view (- 1 )
34
34
35
- if self .ignore is not None :
35
+ if self .ignore_index is not None :
36
36
# Filter predictions with ignore label from loss computation
37
- not_ignored = label_target != self .ignore
37
+ not_ignored = label_target != self .ignore_index
38
38
label_input = label_input [not_ignored ]
39
39
label_target = label_target [not_ignored ]
40
40
@@ -43,28 +43,32 @@ def forward(self, label_input, label_target):
43
43
44
44
45
45
class FocalLoss (_Loss ):
46
- def __init__ (self , alpha = 0.5 , gamma = 2 , ignore = None ):
46
+ def __init__ (self , alpha = 0.5 , gamma = 2 , ignore_index = None ):
47
+ """
48
+ Focal loss for multi-class problem.
49
+
50
+ :param alpha:
51
+ :param gamma:
52
+ :param ignore_index: If not None, targets with given index are ignored
53
+ """
47
54
super ().__init__ ()
48
55
self .alpha = alpha
49
56
self .gamma = gamma
50
- self .ignore = ignore
57
+ self .ignore_index = ignore_index
51
58
52
59
def forward (self , label_input , label_target ):
53
- """Compute focal loss for multi-class problem.
54
- Ignores anchors having -1 target label
55
- """
56
60
num_classes = label_input .size (1 )
57
61
loss = 0
58
62
59
63
# Filter anchors with -1 label from loss computation
60
- if self .ignore is not None :
61
- not_ignored = label_target != self .ignore
64
+ if self .ignore_index is not None :
65
+ not_ignored = label_target != self .ignore_index
62
66
63
67
for cls in range (num_classes ):
64
68
cls_label_target = (label_target == cls ).long ()
65
69
cls_label_input = label_input [:, cls , ...]
66
70
67
- if self .ignore is not None :
71
+ if self .ignore_index is not None :
68
72
cls_label_target = cls_label_target [not_ignored ]
69
73
cls_label_input = cls_label_input [not_ignored ]
70
74
0 commit comments