Skip to content

Commit b6a294c

Browse files
committed
use sparse weights instead
1 parent 79ba02a commit b6a294c

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

libmultilabel/linear/tree.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def _train_node(y: sparse.csr_matrix,
223223
node.model = linear.train_1vsrest(
224224
meta_y, x, options, False
225225
)
226+
node.model.weights = sparse.csr_matrix(node.model.weights)
226227

227228

228229
def _flatten_model(root: Node) -> tuple[linear.FlatModel, np.ndarray]:
@@ -257,7 +258,7 @@ def visit(node):
257258

258259
model = linear.FlatModel(
259260
name='flattened-tree',
260-
weights=np.hstack(weights),
261+
weights=sparse.hstack(weights, 'csr'),
261262
bias=bias,
262263
thresholds=0,
263264
)

0 commit comments

Comments
 (0)