Skip to content

Commit d9fe8b6

Browse files
authored
Merge pull request #18 from BloodAxe/develop
Release 0.1.3
2 parents 5bad076 + 38d870d commit d9fe8b6

File tree

6 files changed

+57
-25
lines changed

6 files changed

+57
-25
lines changed

.appveyor.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ environment:
1414
install:
1515
- '%PYTHON%\python.exe -m pip install -U pip wheel setuptools'
1616
- '%PYTHON%\python.exe -m pip install .[tests]'
17-
- '%PYTHON%\python.exe -m pip install flake8 flake8-docstrings'
17+
- '%PYTHON%\python.exe -m pip install "pydocstyle<4.0.0" flake8 flake8-docstrings'
1818

1919
test_script:
2020
- '%PYTHON%\python.exe -m pytest'

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ matrix:
1818
install:
1919
- if [[ "$TEST_MODE" == "DOCS" ]]; then pip install -q .; fi
2020
- if [[ "$TEST_MODE" != "DOCS" ]]; then pip install -q .[tests]; fi
21-
- if [[ "$TEST_MODE" != "DOCS" ]]; then pip install flake8 flake8-docstrings; fi
21+
- if [[ "$TEST_MODE" != "DOCS" ]]; then pip install "pydocstyle<4.0.0" flake8 flake8-docstrings; fi
2222
script:
2323
- if [[ "$TEST_MODE" != "DOCS" ]]; then pytest; fi
2424
- if [[ "$TEST_MODE" != "DOCS" ]]; then flake8; fi

pytorch_toolbelt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import absolute_import
22

3-
__version__ = '0.1.2'
3+
__version__ = '0.1.3'

pytorch_toolbelt/losses/focal.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@
88

99

1010
class BinaryFocalLoss(_Loss):
11-
def __init__(self, alpha=0.5, gamma=2, ignore=None, reduction='mean', reduced=False, threshold=0.5):
11+
def __init__(self, alpha=0.5, gamma=2, ignore_index=None, reduction='mean', reduced=False, threshold=0.5):
1212
"""
1313
1414
:param alpha:
1515
:param gamma:
16-
:param ignore:
16+
:param ignore_index:
1717
:param reduced:
1818
:param threshold:
1919
"""
2020
super().__init__()
2121
self.alpha = alpha
2222
self.gamma = gamma
23-
self.ignore = ignore
23+
self.ignore_index = ignore_index
2424
if reduced:
2525
self.focal_loss = partial(reduced_focal_loss, gamma=gamma, threshold=threshold, reduction=reduction)
2626
else:
@@ -32,9 +32,9 @@ def forward(self, label_input, label_target):
3232
label_target = label_target.view(-1)
3333
label_input = label_input.view(-1)
3434

35-
if self.ignore is not None:
35+
if self.ignore_index is not None:
3636
# Filter predictions with ignore label from loss computation
37-
not_ignored = label_target != self.ignore
37+
not_ignored = label_target != self.ignore_index
3838
label_input = label_input[not_ignored]
3939
label_target = label_target[not_ignored]
4040

@@ -43,28 +43,32 @@ def forward(self, label_input, label_target):
4343

4444

4545
class FocalLoss(_Loss):
46-
def __init__(self, alpha=0.5, gamma=2, ignore=None):
46+
def __init__(self, alpha=0.5, gamma=2, ignore_index=None):
47+
"""
48+
Focal loss for multi-class problem.
49+
50+
:param alpha:
51+
:param gamma:
52+
:param ignore_index: If not None, targets with given index are ignored
53+
"""
4754
super().__init__()
4855
self.alpha = alpha
4956
self.gamma = gamma
50-
self.ignore = ignore
57+
self.ignore_index = ignore_index
5158

5259
def forward(self, label_input, label_target):
53-
"""Compute focal loss for multi-class problem.
54-
Ignores anchors having -1 target label
55-
"""
5660
num_classes = label_input.size(1)
5761
loss = 0
5862

5963
# Filter anchors with -1 label from loss computation
60-
if self.ignore is not None:
61-
not_ignored = label_target != self.ignore
64+
if self.ignore_index is not None:
65+
not_ignored = label_target != self.ignore_index
6266

6367
for cls in range(num_classes):
6468
cls_label_target = (label_target == cls).long()
6569
cls_label_input = label_input[:, cls, ...]
6670

67-
if self.ignore is not None:
71+
if self.ignore_index is not None:
6872
cls_label_target = cls_label_target[not_ignored]
6973
cls_label_input = cls_label_input[not_ignored]
7074

pytorch_toolbelt/utils/catalyst/metrics.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,21 @@
1919
'JaccardScoreCallback']
2020

2121

22-
def pixel_accuracy(outputs, targets):
23-
"""Compute the pixel accuracy
22+
def pixel_accuracy(outputs: torch.Tensor,
23+
targets: torch.Tensor, ignore_index=None):
2424
"""
25-
outputs = (outputs.detach() > 0).float()
25+
Compute the pixel accuracy
26+
"""
27+
outputs = outputs.detach()
28+
targets = targets.detach()
29+
if ignore_index is not None:
30+
mask = targets != ignore_index
31+
outputs = outputs[mask]
32+
targets = targets[mask]
33+
34+
outputs = (outputs > 0).float()
2635

27-
correct = float(torch.sum(outputs == targets.detach()))
36+
correct = float(torch.sum(outputs == targets))
2837
total = targets.numel()
2938
return correct / total
3039

@@ -38,16 +47,18 @@ def __init__(
3847
input_key: str = "targets",
3948
output_key: str = "logits",
4049
prefix: str = "accuracy",
50+
ignore_index=None
4151
):
4252
"""
4353
:param input_key: input key to use for iou calculation;
4454
specifies our `y_true`.
4555
:param output_key: output key to use for iou calculation;
4656
specifies our `y_pred`
57+
:param ignore_index: same meaning as in nn.CrossEntropyLoss
4758
"""
4859
super().__init__(
4960
prefix=prefix,
50-
metric_fn=pixel_accuracy,
61+
metric_fn=partial(pixel_accuracy, ignore_index=ignore_index),
5162
input_key=input_key,
5263
output_key=output_key,
5364
)
@@ -64,20 +75,23 @@ def __init__(
6475
input_key: str = "targets",
6576
output_key: str = "logits",
6677
prefix: str = "confusion_matrix",
67-
class_names=None
78+
class_names=None,
79+
ignore_index=None
6880
):
6981
"""
7082
:param input_key: input key to use for precision calculation;
7183
specifies our `y_true`.
7284
:param output_key: output key to use for precision calculation;
7385
specifies our `y_pred`.
86+
:param ignore_index: same meaning as in nn.CrossEntropyLoss
7487
"""
7588
self.prefix = prefix
7689
self.class_names = class_names
7790
self.output_key = output_key
7891
self.input_key = input_key
7992
self.outputs = []
8093
self.targets = []
94+
self.ignore_index = ignore_index
8195

8296
def on_loader_start(self, state):
8397
self.outputs = []
@@ -89,6 +103,11 @@ def on_batch_end(self, state: RunnerState):
89103

90104
outputs = np.argmax(outputs, axis=1)
91105

106+
if self.ignore_index is not None:
107+
mask = targets != self.ignore_index
108+
outputs = outputs[mask]
109+
targets = targets[mask]
110+
92111
self.outputs.extend(outputs)
93112
self.targets.extend(targets)
94113

@@ -124,7 +143,8 @@ def __init__(
124143
self,
125144
input_key: str = "targets",
126145
output_key: str = "logits",
127-
prefix: str = "macro_f1"
146+
prefix: str = "macro_f1",
147+
ignore_index=None
128148
):
129149
"""
130150
:param input_key: input key to use for precision calculation;
@@ -138,13 +158,21 @@ def __init__(
138158
self.input_key = input_key
139159
self.outputs = []
140160
self.targets = []
161+
self.ignore_index = ignore_index
141162

142163
def on_batch_end(self, state: RunnerState):
143164
outputs = to_numpy(state.output[self.output_key])
144165
targets = to_numpy(state.input[self.input_key])
166+
145167
num_classes = outputs.shape[1]
168+
outputs = np.argmax(outputs, axis=1)
169+
170+
if self.ignore_index is not None:
171+
mask = targets != self.ignore_index
172+
outputs = outputs[mask]
173+
targets = targets[mask]
146174

147-
outputs = [np.eye(num_classes)[y] for y in np.argmax(outputs, axis=1)]
175+
outputs = [np.eye(num_classes)[y] for y in outputs]
148176
targets = [np.eye(num_classes)[y] for y in targets]
149177

150178
self.outputs.extend(outputs)

pytorch_toolbelt/utils/fs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
def has_image_ext(fname: str):
1212
name, ext = os.path.splitext(fname)
13-
return ext.lower() in {'.bmp', '.png', '.jpeg', '.jpg', '.tiff'}
13+
return ext.lower() in {'.bmp', '.png', '.jpeg', '.jpg', '.tiff', 'tif'}
1414

1515

1616
def find_in_dir(dirname: str):

0 commit comments

Comments
 (0)