We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents b2b2415 + ba9ebec commit 9df0688Copy full SHA for 9df0688
libmultilabel/linear/tree.py
@@ -224,6 +224,8 @@ def _train_node(y: sparse.csr_matrix,
224
meta_y, x, options, False
225
)
226
227
+ node.model.weights = sparse.csr_matrix(node.model.weights)
228
+
229
230
def _flatten_model(root: Node) -> tuple[linear.FlatModel, np.ndarray]:
231
"""Flattens tree weight matrices into a single weight matrix. The flattened weight
@@ -257,7 +259,7 @@ def visit(node):
257
259
258
260
model = linear.FlatModel(
261
name='flattened-tree',
- weights=np.hstack(weights),
262
+ weights=sparse.hstack(weights, 'csr'),
263
bias=bias,
264
thresholds=0,
265
0 commit comments