Skip to content

Commit 9df0688

Browse files
authored
Merge pull request #291 from ntumlgroup/sparse-tree
Use sparse weights for flattened tree
2 parents b2b2415 + ba9ebec commit 9df0688

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

libmultilabel/linear/tree.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ def _train_node(y: sparse.csr_matrix,
224224
meta_y, x, options, False
225225
)
226226

227+
node.model.weights = sparse.csr_matrix(node.model.weights)
228+
227229

228230
def _flatten_model(root: Node) -> tuple[linear.FlatModel, np.ndarray]:
229231
"""Flattens tree weight matrices into a single weight matrix. The flattened weight
@@ -257,7 +259,7 @@ def visit(node):
257259

258260
model = linear.FlatModel(
259261
name='flattened-tree',
260-
weights=np.hstack(weights),
262+
weights=sparse.hstack(weights, 'csr'),
261263
bias=bias,
262264
thresholds=0,
263265
)

0 commit comments

Comments
 (0)