Skip to content

Commit 4787d2e

Browse files
committed
Added ignore_index param to metrics
1 parent 4f736ee commit 4787d2e

File tree

1 file changed

+36
-8
lines changed

1 file changed

+36
-8
lines changed

pytorch_toolbelt/utils/catalyst/metrics.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,21 @@
1919
'JaccardScoreCallback']
2020

2121

22-
def pixel_accuracy(outputs, targets):
23-
"""Compute the pixel accuracy
22+
def pixel_accuracy(outputs: torch.Tensor,
23+
targets: torch.Tensor, ignore_index=None):
2424
"""
25-
outputs = (outputs.detach() > 0).float()
25+
Compute the pixel accuracy
26+
"""
27+
outputs = outputs.detach()
28+
targets = targets.detach()
29+
if ignore_index is not None:
30+
mask = targets != ignore_index
31+
outputs = outputs[mask]
32+
targets = targets[mask]
33+
34+
outputs = (outputs > 0).float()
2635

27-
correct = float(torch.sum(outputs == targets.detach()))
36+
correct = float(torch.sum(outputs == targets))
2837
total = targets.numel()
2938
return correct / total
3039

@@ -38,16 +47,18 @@ def __init__(
3847
input_key: str = "targets",
3948
output_key: str = "logits",
4049
prefix: str = "accuracy",
50+
ignore_index=None
4151
):
4252
"""
4353
:param input_key: input key to use for iou calculation;
4454
specifies our `y_true`.
4555
:param output_key: output key to use for iou calculation;
4656
specifies our `y_pred`
57+
:param ignore_index: same meaning as in nn.CrossEntropyLoss
4758
"""
4859
super().__init__(
4960
prefix=prefix,
50-
metric_fn=pixel_accuracy,
61+
metric_fn=partial(pixel_accuracy, ignore_index=ignore_index),
5162
input_key=input_key,
5263
output_key=output_key,
5364
)
@@ -64,20 +75,23 @@ def __init__(
6475
input_key: str = "targets",
6576
output_key: str = "logits",
6677
prefix: str = "confusion_matrix",
67-
class_names=None
78+
class_names=None,
79+
ignore_index=None
6880
):
6981
"""
7082
:param input_key: input key to use for precision calculation;
7183
specifies our `y_true`.
7284
:param output_key: output key to use for precision calculation;
7385
specifies our `y_pred`.
86+
:param ignore_index: same meaning as in nn.CrossEntropyLoss
7487
"""
7588
self.prefix = prefix
7689
self.class_names = class_names
7790
self.output_key = output_key
7891
self.input_key = input_key
7992
self.outputs = []
8093
self.targets = []
94+
self.ignore_index = ignore_index
8195

8296
def on_loader_start(self, state):
8397
self.outputs = []
@@ -89,6 +103,11 @@ def on_batch_end(self, state: RunnerState):
89103

90104
outputs = np.argmax(outputs, axis=1)
91105

106+
if self.ignore_index is not None:
107+
mask = targets != self.ignore_index
108+
outputs = outputs[mask]
109+
targets = targets[mask]
110+
92111
self.outputs.extend(outputs)
93112
self.targets.extend(targets)
94113

@@ -124,7 +143,8 @@ def __init__(
124143
self,
125144
input_key: str = "targets",
126145
output_key: str = "logits",
127-
prefix: str = "macro_f1"
146+
prefix: str = "macro_f1",
147+
ignore_index=None
128148
):
129149
"""
130150
:param input_key: input key to use for precision calculation;
@@ -138,13 +158,21 @@ def __init__(
138158
self.input_key = input_key
139159
self.outputs = []
140160
self.targets = []
161+
self.ignore_index = ignore_index
141162

142163
def on_batch_end(self, state: RunnerState):
143164
outputs = to_numpy(state.output[self.output_key])
144165
targets = to_numpy(state.input[self.input_key])
166+
145167
num_classes = outputs.shape[1]
168+
outputs = np.argmax(outputs, axis=1)
169+
170+
if self.ignore_index is not None:
171+
mask = targets != self.ignore_index
172+
outputs = outputs[mask]
173+
targets = targets[mask]
146174

147-
outputs = [np.eye(num_classes)[y] for y in np.argmax(outputs, axis=1)]
175+
outputs = [np.eye(num_classes)[y] for y in outputs]
148176
targets = [np.eye(num_classes)[y] for y in targets]
149177

150178
self.outputs.extend(outputs)

0 commit comments

Comments
 (0)