Skip to content

Commit 73b8ddd

Browse files
committed
not use momory to store nnz statistics
1 parent 56bf1a5 commit 73b8ddd

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

libmultilabel/linear/tree.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def count(node):
149149
total_memory = psutil.virtual_memory().total
150150
print(f'Your system memory is: {total_memory / (1024**3):.3f} GB')
151151

152-
model_size = get_estimated_model_size(root, num_nodes)
152+
model_size = get_estimated_model_size(root)
153153
print(f'The estimated tree model size is: {model_size / (1024**3):.3f} GB')
154154

155155
if (total_memory <= model_size):
@@ -208,24 +208,21 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray,
208208
return Node(label_map=label_map, children=children)
209209

210210

211-
def get_estimated_model_size(root, num_nodes):
212-
num_nnz_feat, num_branches = np.zeros(num_nodes), np.zeros(num_nodes)
213-
num_nodes = 0
211+
def get_estimated_model_size(root):
212+
total_num_weights = 0
213+
214214
def collect_stat(node: Node):
215-
nonlocal num_nodes
216-
num_nnz_feat[num_nodes] = node.num_nnz_feat
215+
nonlocal total_num_weights
217216

218217
if node.isLeaf():
219-
num_branches[num_nodes] = len(node.label_map)
218+
total_num_weights += len(node.label_map) * node.num_nnz_feat
220219
else:
221-
num_branches[num_nodes] = len(node.children)
222-
223-
num_nodes += 1
220+
total_num_weights += len(node.children) * node.num_nnz_feat
224221

225222
root.dfs(collect_stat)
226223

227224
# 16 is because when storing sparse matrices, indices (int64) require 8 bytes and floats require 8 bytes
228-
return np.dot(num_nnz_feat, num_branches) * 16
225+
return total_num_weights * 16
229226

230227

231228
def _train_node(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str, node: Node):

0 commit comments

Comments
 (0)