Skip to content

Commit b2b2415

Browse files
authored
Merge pull request #287 from ntumlgroup/linear-metric
Linear metric
2 parents 79ba02a + 32f1376 commit b2b2415

File tree

10 files changed

+153
-57
lines changed

10 files changed

+153
-57
lines changed

docs/api/linear.rst

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ The simplest usage is::
1111
model = linear.train_1vsrest(train_y, train_x, options)
1212
predict = linear.predict_values(model, test_x)
1313

14-
.. See `the user guide <../guides/linear_guides.html>`_ for more details.
1514

1615
.. currentmodule:: libmultilabel.linear
1716

@@ -51,6 +50,27 @@ Load and Save Pipeline
5150
.. autofunction:: load_pipeline
5251

5352

53+
Metrics
54+
^^^^^^^
55+
Metrics are specified by their names in ``compute_metrics`` and ``get_metrics``.
56+
The possible metric names are:
57+
58+
* ``'P@K'``, where ``K`` is a positive integer
59+
* ``'RP@K'``, where ``K`` is a positive integer
60+
* ``'Macro-F1'``
61+
* ``'Micro-F1'``
62+
63+
.. Their definitions are given in the `user guide <https://www.csie.ntu.edu.tw/~cjlin/papers/libmultilabel/userguide.pdf>`_.
64+
65+
.. autofunction:: compute_metrics
66+
67+
.. autofunction:: get_metrics
68+
69+
.. autoclass:: MetricCollection
70+
:members:
71+
72+
.. autofunction:: tabulate_metrics
73+
5474
Grid Search with Sklearn Estimators
5575
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
5676

docs/examples/plot_linear_quickstart.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,24 +73,21 @@
7373
#
7474
# To see how well we performed, we may want to check various
7575
# metrics with the test set.
76-
# For that we may use:
77-
78-
metrics = linear.get_metrics(metric_threshold=0,
79-
monitor_metrics=['Macro-F1', 'Micro-F1', 'P@1', 'P@3', 'P@5'],
80-
num_classes=datasets['test']['y'].shape[1])
81-
82-
######################################################################
83-
# This creates the set of metrics we wish to see.
8476
# Since the dataset we loaded are stored as ``scipy.sparse.csr_matrix``,
85-
# we need to transform them to ``np.array`` before we can compute the metrics:
77+
# we will first transform the dataset to ``np.array``.
8678

8779
target = datasets['test']['y'].toarray()
8880

89-
######################################################################
90-
# Finally, we compute and print the metrics:
81+
##############################################################################
82+
# Then we will compute the metrics with ``compute_metrics``.
83+
84+
metrics = linear.compute_metrics(
85+
preds,
86+
target,
87+
monitor_metrics=['Macro-F1', 'Micro-F1', 'P@1', 'P@3', 'P@5'],
88+
)
9189

92-
metrics.update(preds, target)
93-
print(metrics.compute())
90+
print(metrics)
9491

9592
######################################################################
9693
# The results will look similar to::

libmultilabel/linear/linear.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ def predict_values(self, x: sparse.csr_matrix) -> np.ndarray:
5858

5959
def train_1vsrest(y: sparse.csr_matrix,
6060
x: sparse.csr_matrix,
61-
options: str,
61+
options: str = '',
6262
verbose: bool = True
6363
) -> FlatModel:
6464
"""Trains a linear model for multiabel data using a one-vs-rest strategy.
6565
6666
Args:
6767
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
6868
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
69-
options (str): The option string passed to liblinear.
69+
options (str, optional): The option string passed to liblinear. Defaults to ''.
7070
verbose (bool, optional): Output extra progress information. Defaults to True.
7171
7272
Returns:
@@ -116,6 +116,9 @@ def _prepare_options(x: sparse.csr_matrix, options: str) -> tuple[sparse.csr_mat
116116
if solver_type < 0 or solver_type > 7:
117117
raise ValueError(
118118
"Invalid LIBLINEAR solver type. Only classification solvers are allowed.")
119+
else:
120+
# workaround for liblinear warning about unspecified solver
121+
options_split.extend(['-s', '2'])
119122

120123
bias = -1.
121124
if '-B' in options_split:
@@ -137,7 +140,7 @@ def _prepare_options(x: sparse.csr_matrix, options: str) -> tuple[sparse.csr_mat
137140

138141
def train_thresholding(y: sparse.csr_matrix,
139142
x: sparse.csr_matrix,
140-
options: str,
143+
options: str = '',
141144
verbose: bool = True
142145
) -> FlatModel:
143146
"""Trains a linear model for multilabel data using a one-vs-rest strategy
@@ -149,7 +152,7 @@ def train_thresholding(y: sparse.csr_matrix,
149152
Args:
150153
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
151154
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
152-
options (str): The option string passed to liblinear.
155+
options (str, optional): The option string passed to liblinear. Defaults to ''.
153156
verbose (bool, optional): Output extra progress information. Defaults to True.
154157
155158
Returns:
@@ -383,7 +386,7 @@ def _fmeasure(y_true: np.ndarray, y_pred: np.ndarray) -> float:
383386

384387
def train_cost_sensitive(y: sparse.csr_matrix,
385388
x: sparse.csr_matrix,
386-
options: str,
389+
options: str = '',
387390
verbose: bool = True
388391
) -> FlatModel:
389392
"""Trains a linear model for multilabel data using a one-vs-rest strategy
@@ -396,7 +399,7 @@ def train_cost_sensitive(y: sparse.csr_matrix,
396399
Args:
397400
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
398401
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
399-
options (str): The option string passed to liblinear.
402+
options (str, optional): The option string passed to liblinear. Defaults to ''.
400403
verbose (bool, optional): Output extra progress information. Defaults to True.
401404
402405
Returns:
@@ -489,7 +492,7 @@ def _cross_validate(y: np.ndarray,
489492

490493
def train_cost_sensitive_micro(y: sparse.csr_matrix,
491494
x: sparse.csr_matrix,
492-
options: str,
495+
options: str = '',
493496
verbose: bool = True
494497
) -> FlatModel:
495498
"""Trains a linear model for multilabel data using a one-vs-rest strategy
@@ -502,7 +505,7 @@ def train_cost_sensitive_micro(y: sparse.csr_matrix,
502505
Args:
503506
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
504507
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
505-
options (str): The option string passed to liblinear.
508+
options (str, optional): The option string passed to liblinear. Defaults to ''.
506509
verbose (bool, optional): Output extra progress information. Defaults to True.
507510
508511
Returns:
@@ -555,15 +558,15 @@ def train_cost_sensitive_micro(y: sparse.csr_matrix,
555558

556559
def train_binary_and_multiclass(y: sparse.csr_matrix,
557560
x: sparse.csr_matrix,
558-
options: str,
561+
options: str = '',
559562
verbose: bool = True
560563
) -> FlatModel:
561564
"""Trains a linear model for binary and multi-class data.
562565
563566
Args:
564567
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
565568
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
566-
options (str): The option string passed to liblinear.
569+
options (str, optional): The option string passed to liblinear. Defaults to ''.
567570
verbose (bool, optional): Output extra progress information. Defaults to True.
568571
569572
Returns:

libmultilabel/linear/metrics.py

Lines changed: 90 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,18 @@
55
import numpy as np
66

77
__all__ = ['get_metrics',
8-
'tabulate_metrics']
8+
'compute_metrics',
9+
'tabulate_metrics',
10+
'MetricCollection']
911

1012

1113
class RPrecision:
12-
def __init__(self, top_k: int) -> None:
14+
def __init__(self, top_k: int):
1315
self.top_k = top_k
1416
self.score = 0
1517
self.num_sample = 0
1618

17-
def update(self, preds: np.ndarray, target: np.ndarray) -> None:
19+
def update(self, preds: np.ndarray, target: np.ndarray):
1820
assert preds.shape == target.shape # (batch_size, num_classes)
1921
top_k_ind = np.argpartition(preds, -self.top_k)[:, -self.top_k:]
2022
num_relevant = np.take_along_axis(
@@ -28,14 +30,21 @@ def update(self, preds: np.ndarray, target: np.ndarray) -> None:
2830
def compute(self) -> float:
2931
return self.score / self.num_sample
3032

33+
def reset(self):
34+
self.score = 0
35+
self.num_sample = 0
36+
3137

3238
class Precision:
33-
def __init__(self, num_classes: int, average: str, top_k: int) -> None:
39+
def __init__(self, num_classes: int, average: str, top_k: int):
40+
if average != 'samples':
41+
raise ValueError('unsupported average')
42+
3443
self.top_k = top_k
3544
self.score = 0
3645
self.num_sample = 0
3746

38-
def update(self, preds: np.ndarray, target: np.ndarray) -> None:
47+
def update(self, preds: np.ndarray, target: np.ndarray):
3948
assert preds.shape == target.shape # (batch_size, num_classes)
4049
top_k_ind = np.argpartition(preds, -self.top_k)[:, -self.top_k:]
4150
num_relevant = np.take_along_axis(target, top_k_ind, -1).sum()
@@ -45,25 +54,28 @@ def update(self, preds: np.ndarray, target: np.ndarray) -> None:
4554
def compute(self) -> float:
4655
return self.score / self.num_sample
4756

57+
def reset(self):
58+
self.score = 0
59+
self.num_sample = 0
60+
4861

4962
class F1:
50-
def __init__(self, num_classes: int, metric_threshold: float, average: str, multiclass=False) -> None:
63+
def __init__(self, num_classes: int, average: str, multiclass=False):
5164
self.num_classes = num_classes
52-
self.metric_threshold = metric_threshold
5365
if average not in {'macro', 'micro', 'another-macro'}:
5466
raise ValueError('unsupported average')
5567
self.average = average
5668
self.multiclass = multiclass
5769
self.tp = self.fp = self.fn = 0
5870

59-
def update(self, preds: np.ndarray, target: np.ndarray) -> None:
71+
def update(self, preds: np.ndarray, target: np.ndarray):
6072
assert preds.shape == target.shape # (batch_size, num_classes)
6173
if self.multiclass:
6274
max_idx = np.argmax(preds, axis=1).reshape(-1, 1)
6375
preds = np.zeros(preds.shape)
6476
np.put_along_axis(preds, max_idx, 1, axis=1)
6577
else:
66-
preds = preds > self.metric_threshold
78+
preds = preds > 0
6779
self.tp += np.logical_and(target == 1, preds == 1).sum(axis=0)
6880
self.fn += np.logical_and(target == 1, preds == 0).sum(axis=0)
6981
self.fp += np.logical_and(target == 0, preds == 1).sum(axis=0)
@@ -88,34 +100,58 @@ def compute(self) -> float:
88100
np.seterr(**prev_settings)
89101
return score
90102

103+
def reset(self):
104+
self.tp = self.fp = self.fn = 0
105+
91106

92107
class MetricCollection(dict):
93-
def __init__(self, metrics) -> None:
108+
"""A collection of metrics created by get_metrics.
109+
MetricCollection computes metric values in two steps. First, batches of
110+
decision values and labels are added with update(). After all instances have been
111+
added, compute() computes the metric values from the accumulated batches.
112+
"""
113+
114+
def __init__(self, metrics):
94115
self.metrics = metrics
95116

96-
def update(self, preds: np.ndarray, target: np.ndarray) -> None:
117+
def update(self, preds: np.ndarray, target: np.ndarray):
118+
"""Adds a batch of decision values and labels.
119+
120+
Args:
121+
preds (np.ndarray): A matrix of decision values with dimensions number of instances * number of classes.
122+
target (np.ndarray): A 0/1 matrix of labels with dimensions number of instances * number of classes.
123+
"""
97124
assert preds.shape == target.shape # (batch_size, num_classes)
98125
for metric in self.metrics.values():
99126
metric.update(preds, target)
100127

101128
def compute(self) -> dict[str, float]:
129+
"""Computes the metrics from the accumulated batches of decision values and labels.
130+
131+
Returns:
132+
dict[str, float]: A dictionary of metric values.
133+
"""
102134
ret = {}
103135
for name, metric in self.metrics.items():
104136
ret[name] = metric.compute()
105137
return ret
106138

139+
def reset(self):
140+
"""Clears the accumulated batches of decision values and labels.
141+
"""
142+
for metric in self.metrics.values():
143+
metric.reset()
144+
107145

108-
def get_metrics(metric_threshold: float,
109-
monitor_metrics: list[str],
146+
def get_metrics(monitor_metrics: list[str],
110147
num_classes: int,
111148
multiclass: bool = False
112149
) -> MetricCollection:
113150
"""Get a collection of metrics by their names.
151+
See MetricCollection for more details.
114152
115153
Args:
116-
metric_threshold (float): The decision value threshold over which a
117-
label is predicted as positive.
118-
monitor_metrics (list[str]): A list metric names.
154+
monitor_metrics (list[str]): A list of metric names.
119155
num_classes (int): The number of classes.
120156
multiclass (bool, optional): Enable multiclass mode. Defaults to False.
121157
@@ -132,19 +168,54 @@ def get_metrics(metric_threshold: float,
132168
elif re.match('RP@\d+', metric):
133169
metrics[metric] = RPrecision(top_k=int(metric[3:]))
134170
elif metric in {'Another-Macro-F1', 'Macro-F1', 'Micro-F1'}:
135-
metrics[metric] = F1(num_classes, metric_threshold,
171+
metrics[metric] = F1(num_classes,
136172
average=metric[:-3].lower(),
137173
multiclass=multiclass)
138174
else:
139-
raise ValueError(f'Invalid metric: {metric}')
175+
raise ValueError(f'invalid metric: {metric}')
140176

141177
return MetricCollection(metrics)
142178

143179

180+
def compute_metrics(preds: np.ndarray,
181+
target: np.ndarray,
182+
monitor_metrics: list[str],
183+
multiclass: bool = False
184+
) -> dict[str, float]:
185+
"""Compute metrics with decision values and labels.
186+
See get_metrics and MetricCollection if decision values and labels are too
187+
large to hold in memory.
188+
189+
190+
Args:
191+
preds (np.ndarray): A matrix of decision values with dimensions number of instances * number of classes.
192+
target (np.ndarray): A 0/1 matrix of labels with dimensions number of instances * number of classes.
193+
monitor_metrics (list[str]): A list of metric names.
194+
multiclass (bool, optional): Enable multiclass mode. Defaults to False.
195+
196+
Returns:
197+
dict[str, float]: A dictionary of metric values.
198+
"""
199+
assert preds.shape == target.shape
200+
201+
metric = get_metrics(monitor_metrics, preds.shape[1], multiclass)
202+
metric.update(preds, target)
203+
return metric.compute()
204+
205+
144206
def tabulate_metrics(metric_dict: dict[str, float], split: str) -> str:
207+
"""Convert a dictionary of metric values into a pretty formatted string for printing.
208+
209+
Args:
210+
metric_dict (dict[str, float]): A dictionary of metric values.
211+
split (str): Name of the data split.
212+
213+
Returns:
214+
str: Pretty formatted string.
215+
"""
145216
msg = f'====== {split} dataset evaluation result =======\n'
146217
header = '|'.join([f'{k:^18}' for k in metric_dict.keys()])
147-
values = '|'.join([f'{x * 100:^18.4f}' if isinstance(x, (np.floating,
218+
values = '|'.join([f'{x:^18.4f}' if isinstance(x, (np.floating,
148219
float)) else f'{x:^18}' for x in metric_dict.values()])
149220
msg += f"|{header}|\n|{'-----------------:|' * len(metric_dict)}\n|{values}|\n"
150221
return msg

0 commit comments

Comments
 (0)