Skip to content

Commit 6544547

Browse files
authored
Merge pull request #385 from ntumlgroup/print-estimate-tree-model-size-with-x-transpose
estimate and print tree model size
2 parents 91e6a4a + d71ba2a commit 6544547

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

libmultilabel/linear/tree.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sklearn.cluster
88
import sklearn.preprocessing
99
from tqdm import tqdm
10+
import psutil
1011

1112
from . import linear
1213

@@ -135,13 +136,28 @@ def train_tree(
135136
root = _build_tree(label_representation, np.arange(y.shape[1]), 0, K, dmax)
136137

137138
num_nodes = 0
139+
# Both type(x) and type(y) are sparse.csr_matrix
140+
# However, type((x != 0).T) becomes sparse.csc_matrix
141+
# So type((x != 0).T * y) results in sparse.csc_matrix
142+
features_used_perlabel = (x != 0).T * y
138143

139144
def count(node):
140145
nonlocal num_nodes
141146
num_nodes += 1
147+
node.num_features_used = np.count_nonzero(features_used_perlabel[:, node.label_map].sum(axis=1))
142148

143149
root.dfs(count)
144150

151+
model_size = get_estimated_model_size(root)
152+
print(f'The estimated tree model size is: {model_size / (1024**3):.3f} GB')
153+
154+
# Calculate the total memory (excluding swap) on the local machine
155+
total_memory = psutil.virtual_memory().total
156+
print(f'Your system memory is: {total_memory / (1024**3):.3f} GB')
157+
158+
if (total_memory <= model_size):
159+
raise MemoryError(f'Not enough memory to train the model.')
160+
145161
pbar = tqdm(total=num_nodes, disable=not verbose)
146162

147163
def visit(node):
@@ -195,6 +211,23 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray,
195211
return Node(label_map=label_map, children=children)
196212

197213

214+
def get_estimated_model_size(root):
215+
total_num_weights = 0
216+
217+
def collect_stat(node: Node):
218+
nonlocal total_num_weights
219+
220+
if node.isLeaf():
221+
total_num_weights += len(node.label_map) * node.num_features_used
222+
else:
223+
total_num_weights += len(node.children) * node.num_features_used
224+
225+
root.dfs(collect_stat)
226+
227+
# 16 is because when storing sparse matrices, indices (int64) require 8 bytes and floats require 8 bytes
228+
return total_num_weights * 16
229+
230+
198231
def _train_node(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str, node: Node):
199232
"""If node is internal, computes the metalabels representing each child and trains
200233
on the metalabels. Otherwise, train on y.

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ PyYAML
55
scikit-learn
66
scipy
77
tqdm
8+
psutil

0 commit comments

Comments
 (0)