Skip to content

Commit 52a418b

Browse files
committed
Rename ignore to ignore_index for consistency
1 parent 4787d2e commit 52a418b

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

pytorch_toolbelt/losses/focal.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@
88

99

1010
class BinaryFocalLoss(_Loss):
11-
def __init__(self, alpha=0.5, gamma=2, ignore=None, reduction='mean', reduced=False, threshold=0.5):
11+
def __init__(self, alpha=0.5, gamma=2, ignore_index=None, reduction='mean', reduced=False, threshold=0.5):
1212
"""
1313
1414
:param alpha:
1515
:param gamma:
16-
:param ignore:
16+
:param ignore_index:
1717
:param reduced:
1818
:param threshold:
1919
"""
2020
super().__init__()
2121
self.alpha = alpha
2222
self.gamma = gamma
23-
self.ignore = ignore
23+
self.ignore_index = ignore_index
2424
if reduced:
2525
self.focal_loss = partial(reduced_focal_loss, gamma=gamma, threshold=threshold, reduction=reduction)
2626
else:
@@ -32,9 +32,9 @@ def forward(self, label_input, label_target):
3232
label_target = label_target.view(-1)
3333
label_input = label_input.view(-1)
3434

35-
if self.ignore is not None:
35+
if self.ignore_index is not None:
3636
# Filter predictions with ignore label from loss computation
37-
not_ignored = label_target != self.ignore
37+
not_ignored = label_target != self.ignore_index
3838
label_input = label_input[not_ignored]
3939
label_target = label_target[not_ignored]
4040

@@ -43,28 +43,32 @@ def forward(self, label_input, label_target):
4343

4444

4545
class FocalLoss(_Loss):
46-
def __init__(self, alpha=0.5, gamma=2, ignore=None):
46+
def __init__(self, alpha=0.5, gamma=2, ignore_index=None):
47+
"""
48+
Focal loss for multi-class problem.
49+
50+
:param alpha:
51+
:param gamma:
52+
:param ignore_index: If not None, targets with given index are ignored
53+
"""
4754
super().__init__()
4855
self.alpha = alpha
4956
self.gamma = gamma
50-
self.ignore = ignore
57+
self.ignore_index = ignore_index
5158

5259
def forward(self, label_input, label_target):
53-
"""Compute focal loss for multi-class problem.
54-
Ignores anchors having -1 target label
55-
"""
5660
num_classes = label_input.size(1)
5761
loss = 0
5862

5963
# Filter anchors with -1 label from loss computation
60-
if self.ignore is not None:
61-
not_ignored = label_target != self.ignore
64+
if self.ignore_index is not None:
65+
not_ignored = label_target != self.ignore_index
6266

6367
for cls in range(num_classes):
6468
cls_label_target = (label_target == cls).long()
6569
cls_label_input = label_input[:, cls, ...]
6670

67-
if self.ignore is not None:
71+
if self.ignore_index is not None:
6872
cls_label_target = cls_label_target[not_ignored]
6973
cls_label_input = cls_label_input[not_ignored]
7074

0 commit comments

Comments
 (0)