Skip to content

Commit c09a42d

Browse files
committed
Fix issues based on code review comments
1 parent 73b8ddd commit c09a42d

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

libmultilabel/linear/tree.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,24 +136,27 @@ def train_tree(
136136
root = _build_tree(label_representation, np.arange(y.shape[1]), 0, K, dmax)
137137

138138
num_nodes = 0
139-
label_feature_used = (x != 0).T * y
139+
# Both type(x) and type(y) are sparse.csr_matrix
140+
# However, type((x != 0).T) becomes sparse.csc_matrix
141+
# So type((x != 0).T * y) results in sparse.csc_matrix
142+
features_used_perlabel = (x != 0).T * y
140143

141144
def count(node):
142145
nonlocal num_nodes
143146
num_nodes += 1
144-
node.num_nnz_feat = np.count_nonzero(label_feature_used[:, node.label_map].sum(axis=1))
147+
node.num_features_used = np.count_nonzero(features_used_perlabel[:, node.label_map].sum(axis=1))
145148

146149
root.dfs(count)
147150

151+
model_size = get_estimated_model_size(root)
152+
print(f'The estimated tree model size is: {model_size / (1024**3):.3f} GB')
153+
148154
# Calculate the total memory (excluding swap) on the local machine
149155
total_memory = psutil.virtual_memory().total
150156
print(f'Your system memory is: {total_memory / (1024**3):.3f} GB')
151157

152-
model_size = get_estimated_model_size(root)
153-
print(f'The estimated tree model size is: {model_size / (1024**3):.3f} GB')
154-
155158
if (total_memory <= model_size):
156-
raise MemoryError(f'Not enough memory to train the model. model_size: {model_size / (1024**3):.3f} GB')
159+
raise MemoryError(f'Not enough memory to train the model.')
157160

158161
pbar = tqdm(total=num_nodes, disable=not verbose)
159162

@@ -215,9 +218,9 @@ def collect_stat(node: Node):
215218
nonlocal total_num_weights
216219

217220
if node.isLeaf():
218-
total_num_weights += len(node.label_map) * node.num_nnz_feat
221+
total_num_weights += len(node.label_map) * node.num_features_used
219222
else:
220-
total_num_weights += len(node.children) * node.num_nnz_feat
223+
total_num_weights += len(node.children) * node.num_features_used
221224

222225
root.dfs(collect_stat)
223226

0 commit comments

Comments
 (0)