Skip to content

Commit 2255a70

Browse files
authored
Merge pull request #284 from ntumlgroup/linear_multiclass
Add model name to linear models
2 parents 3f33d9f + e772be0 commit 2255a70

File tree

3 files changed

+31
-25
lines changed

3 files changed

+31
-25
lines changed

libmultilabel/linear/linear.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818

1919

2020
class FlatModel:
21-
def __init__(self, weights: np.matrix,
21+
def __init__(self, name: str,
22+
weights: np.matrix,
2223
bias: float,
2324
thresholds: float | np.ndarray,
2425
):
26+
self.name = name
2527
self.weights = weights
2628
self.bias = bias
2729
self.thresholds = thresholds
@@ -84,7 +86,8 @@ def train_1vsrest(y: sparse.csr_matrix,
8486
yi = y[:, i].toarray().reshape(-1)
8587
weights[:, i] = _do_train(2*yi - 1, x, options).ravel()
8688

87-
return FlatModel(weights=np.asmatrix(weights),
89+
return FlatModel(name='1vsrest',
90+
weights=np.asmatrix(weights),
8891
bias=bias,
8992
thresholds=0)
9093

@@ -169,15 +172,16 @@ def train_thresholding(y: sparse.csr_matrix,
169172
weights[:, i] = w.ravel()
170173
thresholds[i] = t
171174

172-
return FlatModel(weights=np.asmatrix(weights),
175+
return FlatModel(name='thresholding',
176+
weights=np.asmatrix(weights),
173177
bias=bias,
174178
thresholds=thresholds)
175179

176180

177181
def _thresholding_one_label(y: np.ndarray,
178-
x: sparse.csr_matrix,
179-
options: str
180-
) -> tuple[np.ndarray, float]:
182+
x: sparse.csr_matrix,
183+
options: str
184+
) -> tuple[np.ndarray, float]:
181185
"""Outer cross-validation for thresholding on a single label.
182186
183187
Args:
@@ -223,10 +227,10 @@ def _thresholding_one_label(y: np.ndarray,
223227

224228

225229
def _scutfbr(y: np.ndarray,
226-
x: sparse.csr_matrix,
227-
fbr_list: list[float],
228-
options: str
229-
) -> tuple[np.matrix, np.ndarray]:
230+
x: sparse.csr_matrix,
231+
fbr_list: list[float],
232+
options: str
233+
) -> tuple[np.matrix, np.ndarray]:
230234
"""Inner cross-validation for SCutfbr heuristic.
231235
232236
Args:
@@ -414,15 +418,16 @@ def train_cost_sensitive(y: sparse.csr_matrix,
414418
w = _cost_sensitive_one_label(2*yi - 1, x, options)
415419
weights[:, i] = w.ravel()
416420

417-
return FlatModel(weights=np.asmatrix(weights),
421+
return FlatModel(name='cost_sensitive',
422+
weights=np.asmatrix(weights),
418423
bias=bias,
419424
thresholds=0)
420425

421426

422427
def _cost_sensitive_one_label(y: np.ndarray,
423-
x: sparse.csr_matrix,
424-
options: str
425-
) -> np.ndarray:
428+
x: sparse.csr_matrix,
429+
options: str
430+
) -> np.ndarray:
426431
"""Loop over parameter space for cost-sensitive on a single label.
427432
428433
Args:
@@ -453,10 +458,10 @@ def _cost_sensitive_one_label(y: np.ndarray,
453458

454459

455460
def _cross_validate(y: np.ndarray,
456-
x: sparse.csr_matrix,
457-
options: str,
458-
perm: np.ndarray
459-
) -> np.ndarray:
461+
x: sparse.csr_matrix,
462+
options: str,
463+
perm: np.ndarray
464+
) -> np.ndarray:
460465
"""Cross-validation for cost-sensitive.
461466
462467
Args:
@@ -542,7 +547,8 @@ def train_cost_sensitive_micro(y: sparse.csr_matrix,
542547
w = _do_train(2*yi - 1, x, final_options)
543548
weights[:, i] = w.ravel()
544549

545-
return FlatModel(weights=np.asmatrix(weights),
550+
return FlatModel(name='cost_sensitive_micro',
551+
weights=np.asmatrix(weights),
546552
bias=bias,
547553
thresholds=0)
548554

@@ -590,7 +596,8 @@ def train_binary_and_multiclass(y: sparse.csr_matrix,
590596
# For labels not appeared in training, assign thresholds to -inf so they won't be predicted.
591597
thresholds = np.full(num_labels, -np.inf)
592598
thresholds[train_labels] = 0
593-
return FlatModel(weights=np.asmatrix(weights),
599+
return FlatModel(name='binary_and_multiclass',
600+
weights=np.asmatrix(weights),
594601
bias=bias,
595602
thresholds=thresholds)
596603

@@ -615,7 +622,7 @@ def get_topk_labels(label_mapping: np.ndarray,
615622
"""Get top k predictions from decision values.
616623
617624
Args:
618-
label_mapping (np.ndarray): A ndarray of class labels that maps each index (from 0 to ``num_class-1``) to its label.
625+
label_mapping (np.ndarray): A ndarray of class labels that maps each index (from 0 to ``num_class-1``) to its label.
619626
preds (np.ndarray): A matrix of decision values with dimension number of instances * number of classes.
620627
top_k (int): Determine how many classes per instance should be predicted.
621628

libmultilabel/linear/tree.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self,
4242
flat_model: linear.FlatModel,
4343
weight_map: np.ndarray,
4444
):
45+
self.name = 'tree'
4546
self.root = root
4647
self.flat_model = flat_model
4748
self.weight_map = weight_map

linear_trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from tqdm import tqdm
66

77
import libmultilabel.linear as linear
8-
from libmultilabel.common_utils import (argsort_top_k, dump_log,
9-
is_multiclass_dataset)
8+
from libmultilabel.common_utils import argsort_top_k, dump_log
109
from libmultilabel.linear.utils import LINEAR_TECHNIQUES
1110

1211

@@ -15,7 +14,7 @@ def linear_test(config, model, datasets):
1514
config.metric_threshold,
1615
config.monitor_metrics,
1716
datasets['test']['y'].shape[1],
18-
multiclass=config.multiclass
17+
multiclass=model.name=='binary_and_multiclass'
1918
)
2019
num_instance = datasets['test']['x'].shape[0]
2120

@@ -73,7 +72,6 @@ def linear_run(config):
7372
config.label_file,
7473
config.include_test_labels,
7574
config.remove_no_label_data)
76-
config.multiclass = is_multiclass_dataset(datasets['train'], label='y')
7775
model = linear_train(datasets, config)
7876
linear.save_pipeline(config.checkpoint_dir, preprocessor, model)
7977

0 commit comments

Comments
 (0)