5
5
import numpy as np
6
6
7
7
__all__ = ['get_metrics' ,
8
- 'tabulate_metrics' ]
8
+ 'compute_metrics' ,
9
+ 'tabulate_metrics' ,
10
+ 'MetricCollection' ]
9
11
10
12
11
13
class RPrecision :
12
- def __init__ (self , top_k : int ) -> None :
14
+ def __init__ (self , top_k : int ):
13
15
self .top_k = top_k
14
16
self .score = 0
15
17
self .num_sample = 0
16
18
17
- def update (self , preds : np .ndarray , target : np .ndarray ) -> None :
19
+ def update (self , preds : np .ndarray , target : np .ndarray ):
18
20
assert preds .shape == target .shape # (batch_size, num_classes)
19
21
top_k_ind = np .argpartition (preds , - self .top_k )[:, - self .top_k :]
20
22
num_relevant = np .take_along_axis (
@@ -28,14 +30,21 @@ def update(self, preds: np.ndarray, target: np.ndarray) -> None:
28
30
def compute (self ) -> float :
29
31
return self .score / self .num_sample
30
32
33
+ def reset (self ):
34
+ self .score = 0
35
+ self .num_sample = 0
36
+
31
37
32
38
class Precision :
33
- def __init__ (self , num_classes : int , average : str , top_k : int ) -> None :
39
+ def __init__ (self , num_classes : int , average : str , top_k : int ):
40
+ if average != 'samples' :
41
+ raise ValueError ('unsupported average' )
42
+
34
43
self .top_k = top_k
35
44
self .score = 0
36
45
self .num_sample = 0
37
46
38
- def update (self , preds : np .ndarray , target : np .ndarray ) -> None :
47
+ def update (self , preds : np .ndarray , target : np .ndarray ):
39
48
assert preds .shape == target .shape # (batch_size, num_classes)
40
49
top_k_ind = np .argpartition (preds , - self .top_k )[:, - self .top_k :]
41
50
num_relevant = np .take_along_axis (target , top_k_ind , - 1 ).sum ()
@@ -45,25 +54,28 @@ def update(self, preds: np.ndarray, target: np.ndarray) -> None:
45
54
def compute (self ) -> float :
46
55
return self .score / self .num_sample
47
56
57
+ def reset (self ):
58
+ self .score = 0
59
+ self .num_sample = 0
60
+
48
61
49
62
class F1 :
50
- def __init__ (self , num_classes : int , metric_threshold : float , average : str , multiclass = False ) -> None :
63
+ def __init__ (self , num_classes : int , average : str , multiclass = False ):
51
64
self .num_classes = num_classes
52
- self .metric_threshold = metric_threshold
53
65
if average not in {'macro' , 'micro' , 'another-macro' }:
54
66
raise ValueError ('unsupported average' )
55
67
self .average = average
56
68
self .multiclass = multiclass
57
69
self .tp = self .fp = self .fn = 0
58
70
59
- def update (self , preds : np .ndarray , target : np .ndarray ) -> None :
71
+ def update (self , preds : np .ndarray , target : np .ndarray ):
60
72
assert preds .shape == target .shape # (batch_size, num_classes)
61
73
if self .multiclass :
62
74
max_idx = np .argmax (preds , axis = 1 ).reshape (- 1 , 1 )
63
75
preds = np .zeros (preds .shape )
64
76
np .put_along_axis (preds , max_idx , 1 , axis = 1 )
65
77
else :
66
- preds = preds > self . metric_threshold
78
+ preds = preds > 0
67
79
self .tp += np .logical_and (target == 1 , preds == 1 ).sum (axis = 0 )
68
80
self .fn += np .logical_and (target == 1 , preds == 0 ).sum (axis = 0 )
69
81
self .fp += np .logical_and (target == 0 , preds == 1 ).sum (axis = 0 )
@@ -88,34 +100,58 @@ def compute(self) -> float:
88
100
np .seterr (** prev_settings )
89
101
return score
90
102
103
+ def reset (self ):
104
+ self .tp = self .fp = self .fn = 0
105
+
91
106
92
107
class MetricCollection (dict ):
93
- def __init__ (self , metrics ) -> None :
108
+ """A collection of metrics created by get_metrics.
109
+ MetricCollection computes metric values in two steps. First, batches of
110
+ decision values and labels are added with update(). After all instances have been
111
+ added, compute() computes the metric values from the accumulated batches.
112
+ """
113
+
114
+ def __init__ (self , metrics ):
94
115
self .metrics = metrics
95
116
96
- def update (self , preds : np .ndarray , target : np .ndarray ) -> None :
117
+ def update (self , preds : np .ndarray , target : np .ndarray ):
118
+ """Adds a batch of decision values and labels.
119
+
120
+ Args:
121
+ preds (np.ndarray): A matrix of decision values with dimensions number of instances * number of classes.
122
+ target (np.ndarray): A 0/1 matrix of labels with dimensions number of instances * number of classes.
123
+ """
97
124
assert preds .shape == target .shape # (batch_size, num_classes)
98
125
for metric in self .metrics .values ():
99
126
metric .update (preds , target )
100
127
101
128
def compute (self ) -> dict [str , float ]:
129
+ """Computes the metrics from the accumulated batches of decision values and labels.
130
+
131
+ Returns:
132
+ dict[str, float]: A dictionary of metric values.
133
+ """
102
134
ret = {}
103
135
for name , metric in self .metrics .items ():
104
136
ret [name ] = metric .compute ()
105
137
return ret
106
138
139
+ def reset (self ):
140
+ """Clears the accumulated batches of decision values and labels.
141
+ """
142
+ for metric in self .metrics .values ():
143
+ metric .reset ()
144
+
107
145
108
- def get_metrics (metric_threshold : float ,
109
- monitor_metrics : list [str ],
146
+ def get_metrics (monitor_metrics : list [str ],
110
147
num_classes : int ,
111
148
multiclass : bool = False
112
149
) -> MetricCollection :
113
150
"""Get a collection of metrics by their names.
151
+ See MetricCollection for more details.
114
152
115
153
Args:
116
- metric_threshold (float): The decision value threshold over which a
117
- label is predicted as positive.
118
- monitor_metrics (list[str]): A list metric names.
154
+ monitor_metrics (list[str]): A list of metric names.
119
155
num_classes (int): The number of classes.
120
156
multiclass (bool, optional): Enable multiclass mode. Defaults to False.
121
157
@@ -132,19 +168,54 @@ def get_metrics(metric_threshold: float,
132
168
elif re .match ('RP@\d+' , metric ):
133
169
metrics [metric ] = RPrecision (top_k = int (metric [3 :]))
134
170
elif metric in {'Another-Macro-F1' , 'Macro-F1' , 'Micro-F1' }:
135
- metrics [metric ] = F1 (num_classes , metric_threshold ,
171
+ metrics [metric ] = F1 (num_classes ,
136
172
average = metric [:- 3 ].lower (),
137
173
multiclass = multiclass )
138
174
else :
139
- raise ValueError (f'Invalid metric: { metric } ' )
175
+ raise ValueError (f'invalid metric: { metric } ' )
140
176
141
177
return MetricCollection (metrics )
142
178
143
179
180
+ def compute_metrics (preds : np .ndarray ,
181
+ target : np .ndarray ,
182
+ monitor_metrics : list [str ],
183
+ multiclass : bool = False
184
+ ) -> dict [str , float ]:
185
+ """Compute metrics with decision values and labels.
186
+ See get_metrics and MetricCollection if decision values and labels are too
187
+ large to hold in memory.
188
+
189
+
190
+ Args:
191
+ preds (np.ndarray): A matrix of decision values with dimensions number of instances * number of classes.
192
+ target (np.ndarray): A 0/1 matrix of labels with dimensions number of instances * number of classes.
193
+ monitor_metrics (list[str]): A list of metric names.
194
+ multiclass (bool, optional): Enable multiclass mode. Defaults to False.
195
+
196
+ Returns:
197
+ dict[str, float]: A dictionary of metric values.
198
+ """
199
+ assert preds .shape == target .shape
200
+
201
+ metric = get_metrics (monitor_metrics , preds .shape [1 ], multiclass )
202
+ metric .update (preds , target )
203
+ return metric .compute ()
204
+
205
+
144
206
def tabulate_metrics (metric_dict : dict [str , float ], split : str ) -> str :
207
+ """Convert a dictionary of metric values into a pretty formatted string for printing.
208
+
209
+ Args:
210
+ metric_dict (dict[str, float]): A dictionary of metric values.
211
+ split (str): Name of the data split.
212
+
213
+ Returns:
214
+ str: Pretty formatted string.
215
+ """
145
216
msg = f'====== { split } dataset evaluation result =======\n '
146
217
header = '|' .join ([f'{ k :^18} ' for k in metric_dict .keys ()])
147
- values = '|' .join ([f'{ x * 100 :^18.4f} ' if isinstance (x , (np .floating ,
218
+ values = '|' .join ([f'{ x :^18.4f} ' if isinstance (x , (np .floating ,
148
219
float )) else f'{ x :^18} ' for x in metric_dict .values ()])
149
220
msg += f"|{ header } |\n |{ '-----------------:|' * len (metric_dict )} \n |{ values } |\n "
150
221
return msg
0 commit comments