19
19
'JaccardScoreCallback' ]
20
20
21
21
22
- def pixel_accuracy (outputs , targets ):
23
- """Compute the pixel accuracy
22
+ def pixel_accuracy (outputs : torch . Tensor ,
23
+ targets : torch . Tensor , ignore_index = None ):
24
24
"""
25
- outputs = (outputs .detach () > 0 ).float ()
25
+ Compute the pixel accuracy
26
+ """
27
+ outputs = outputs .detach ()
28
+ targets = targets .detach ()
29
+ if ignore_index is not None :
30
+ mask = targets != ignore_index
31
+ outputs = outputs [mask ]
32
+ targets = targets [mask ]
33
+
34
+ outputs = (outputs > 0 ).float ()
26
35
27
- correct = float (torch .sum (outputs == targets . detach () ))
36
+ correct = float (torch .sum (outputs == targets ))
28
37
total = targets .numel ()
29
38
return correct / total
30
39
@@ -38,16 +47,18 @@ def __init__(
38
47
input_key : str = "targets" ,
39
48
output_key : str = "logits" ,
40
49
prefix : str = "accuracy" ,
50
+ ignore_index = None
41
51
):
42
52
"""
43
53
:param input_key: input key to use for iou calculation;
44
54
specifies our `y_true`.
45
55
:param output_key: output key to use for iou calculation;
46
56
specifies our `y_pred`
57
+ :param ignore_index: same meaning as in nn.CrossEntropyLoss
47
58
"""
48
59
super ().__init__ (
49
60
prefix = prefix ,
50
- metric_fn = pixel_accuracy ,
61
+ metric_fn = partial ( pixel_accuracy , ignore_index = ignore_index ) ,
51
62
input_key = input_key ,
52
63
output_key = output_key ,
53
64
)
@@ -64,20 +75,23 @@ def __init__(
64
75
input_key : str = "targets" ,
65
76
output_key : str = "logits" ,
66
77
prefix : str = "confusion_matrix" ,
67
- class_names = None
78
+ class_names = None ,
79
+ ignore_index = None
68
80
):
69
81
"""
70
82
:param input_key: input key to use for precision calculation;
71
83
specifies our `y_true`.
72
84
:param output_key: output key to use for precision calculation;
73
85
specifies our `y_pred`.
86
+ :param ignore_index: same meaning as in nn.CrossEntropyLoss
74
87
"""
75
88
self .prefix = prefix
76
89
self .class_names = class_names
77
90
self .output_key = output_key
78
91
self .input_key = input_key
79
92
self .outputs = []
80
93
self .targets = []
94
+ self .ignore_index = ignore_index
81
95
82
96
def on_loader_start (self , state ):
83
97
self .outputs = []
@@ -89,6 +103,11 @@ def on_batch_end(self, state: RunnerState):
89
103
90
104
outputs = np .argmax (outputs , axis = 1 )
91
105
106
+ if self .ignore_index is not None :
107
+ mask = targets != self .ignore_index
108
+ outputs = outputs [mask ]
109
+ targets = targets [mask ]
110
+
92
111
self .outputs .extend (outputs )
93
112
self .targets .extend (targets )
94
113
@@ -124,7 +143,8 @@ def __init__(
124
143
self ,
125
144
input_key : str = "targets" ,
126
145
output_key : str = "logits" ,
127
- prefix : str = "macro_f1"
146
+ prefix : str = "macro_f1" ,
147
+ ignore_index = None
128
148
):
129
149
"""
130
150
:param input_key: input key to use for precision calculation;
@@ -138,13 +158,21 @@ def __init__(
138
158
self .input_key = input_key
139
159
self .outputs = []
140
160
self .targets = []
161
+ self .ignore_index = ignore_index
141
162
142
163
def on_batch_end (self , state : RunnerState ):
143
164
outputs = to_numpy (state .output [self .output_key ])
144
165
targets = to_numpy (state .input [self .input_key ])
166
+
145
167
num_classes = outputs .shape [1 ]
168
+ outputs = np .argmax (outputs , axis = 1 )
169
+
170
+ if self .ignore_index is not None :
171
+ mask = targets != self .ignore_index
172
+ outputs = outputs [mask ]
173
+ targets = targets [mask ]
146
174
147
- outputs = [np .eye (num_classes )[y ] for y in np . argmax ( outputs , axis = 1 ) ]
175
+ outputs = [np .eye (num_classes )[y ] for y in outputs ]
148
176
targets = [np .eye (num_classes )[y ] for y in targets ]
149
177
150
178
self .outputs .extend (outputs )
0 commit comments