18
18
19
19
20
20
class FlatModel :
21
- def __init__ (self , weights : np .matrix ,
21
+ def __init__ (self , name : str ,
22
+ weights : np .matrix ,
22
23
bias : float ,
23
24
thresholds : float | np .ndarray ,
24
25
):
26
+ self .name = name
25
27
self .weights = weights
26
28
self .bias = bias
27
29
self .thresholds = thresholds
@@ -84,7 +86,8 @@ def train_1vsrest(y: sparse.csr_matrix,
84
86
yi = y [:, i ].toarray ().reshape (- 1 )
85
87
weights [:, i ] = _do_train (2 * yi - 1 , x , options ).ravel ()
86
88
87
- return FlatModel (weights = np .asmatrix (weights ),
89
+ return FlatModel (name = '1vsrest' ,
90
+ weights = np .asmatrix (weights ),
88
91
bias = bias ,
89
92
thresholds = 0 )
90
93
@@ -169,15 +172,16 @@ def train_thresholding(y: sparse.csr_matrix,
169
172
weights [:, i ] = w .ravel ()
170
173
thresholds [i ] = t
171
174
172
- return FlatModel (weights = np .asmatrix (weights ),
175
+ return FlatModel (name = 'thresholding' ,
176
+ weights = np .asmatrix (weights ),
173
177
bias = bias ,
174
178
thresholds = thresholds )
175
179
176
180
177
181
def _thresholding_one_label (y : np .ndarray ,
178
- x : sparse .csr_matrix ,
179
- options : str
180
- ) -> tuple [np .ndarray , float ]:
182
+ x : sparse .csr_matrix ,
183
+ options : str
184
+ ) -> tuple [np .ndarray , float ]:
181
185
"""Outer cross-validation for thresholding on a single label.
182
186
183
187
Args:
@@ -223,10 +227,10 @@ def _thresholding_one_label(y: np.ndarray,
223
227
224
228
225
229
def _scutfbr (y : np .ndarray ,
226
- x : sparse .csr_matrix ,
227
- fbr_list : list [float ],
228
- options : str
229
- ) -> tuple [np .matrix , np .ndarray ]:
230
+ x : sparse .csr_matrix ,
231
+ fbr_list : list [float ],
232
+ options : str
233
+ ) -> tuple [np .matrix , np .ndarray ]:
230
234
"""Inner cross-validation for SCutfbr heuristic.
231
235
232
236
Args:
@@ -414,15 +418,16 @@ def train_cost_sensitive(y: sparse.csr_matrix,
414
418
w = _cost_sensitive_one_label (2 * yi - 1 , x , options )
415
419
weights [:, i ] = w .ravel ()
416
420
417
- return FlatModel (weights = np .asmatrix (weights ),
421
+ return FlatModel (name = 'cost_sensitive' ,
422
+ weights = np .asmatrix (weights ),
418
423
bias = bias ,
419
424
thresholds = 0 )
420
425
421
426
422
427
def _cost_sensitive_one_label (y : np .ndarray ,
423
- x : sparse .csr_matrix ,
424
- options : str
425
- ) -> np .ndarray :
428
+ x : sparse .csr_matrix ,
429
+ options : str
430
+ ) -> np .ndarray :
426
431
"""Loop over parameter space for cost-sensitive on a single label.
427
432
428
433
Args:
@@ -453,10 +458,10 @@ def _cost_sensitive_one_label(y: np.ndarray,
453
458
454
459
455
460
def _cross_validate (y : np .ndarray ,
456
- x : sparse .csr_matrix ,
457
- options : str ,
458
- perm : np .ndarray
459
- ) -> np .ndarray :
461
+ x : sparse .csr_matrix ,
462
+ options : str ,
463
+ perm : np .ndarray
464
+ ) -> np .ndarray :
460
465
"""Cross-validation for cost-sensitive.
461
466
462
467
Args:
@@ -542,7 +547,8 @@ def train_cost_sensitive_micro(y: sparse.csr_matrix,
542
547
w = _do_train (2 * yi - 1 , x , final_options )
543
548
weights [:, i ] = w .ravel ()
544
549
545
- return FlatModel (weights = np .asmatrix (weights ),
550
+ return FlatModel (name = 'cost_sensitive_micro' ,
551
+ weights = np .asmatrix (weights ),
546
552
bias = bias ,
547
553
thresholds = 0 )
548
554
@@ -590,7 +596,8 @@ def train_binary_and_multiclass(y: sparse.csr_matrix,
590
596
# For labels not appeared in training, assign thresholds to -inf so they won't be predicted.
591
597
thresholds = np .full (num_labels , - np .inf )
592
598
thresholds [train_labels ] = 0
593
- return FlatModel (weights = np .asmatrix (weights ),
599
+ return FlatModel (name = 'binary_and_multiclass' ,
600
+ weights = np .asmatrix (weights ),
594
601
bias = bias ,
595
602
thresholds = thresholds )
596
603
@@ -615,7 +622,7 @@ def get_topk_labels(label_mapping: np.ndarray,
615
622
"""Get top k predictions from decision values.
616
623
617
624
Args:
618
- label_mapping (np.ndarray): A ndarray of class labels that maps each index (from 0 to ``num_class-1``) to its label.
625
+ label_mapping (np.ndarray): A ndarray of class labels that maps each index (from 0 to ``num_class-1``) to its label.
619
626
preds (np.ndarray): A matrix of decision values with dimension number of instances * number of classes.
620
627
top_k (int): Determine how many classes per instance should be predicted.
621
628
0 commit comments