Skip to content

Commit 0b638b0

Browse files
committed
estimate and print tree model size
1 parent eb711ba commit 0b638b0

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

libmultilabel/linear/tree.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sklearn.cluster
88
import sklearn.preprocessing
99
from tqdm import tqdm
10+
import psutil
1011

1112
from . import linear
1213

@@ -135,13 +136,25 @@ def train_tree(
135136
root = _build_tree(label_representation, np.arange(y.shape[1]), 0, K, dmax)
136137

137138
num_nodes = 0
139+
label_feature_used = (x != 0).T * y
138140

139141
def count(node):
140142
nonlocal num_nodes
141143
num_nodes += 1
144+
node.num_nnz_feat = np.count_nonzero(label_feature_used[:, node.label_map].sum(axis=0))
142145

143146
root.dfs(count)
144147

148+
# Calculate the total memory (excluding swap) on the local machine
149+
total_memory = psutil.virtual_memory().total
150+
print(f'Your system memory is: {total_memory / (1024**3):.3f} GB')
151+
152+
model_size = get_estimated_model_size(root, num_nodes)
153+
print(f'The estimated tree model size is: {model_size / (1024**3):.3f} GB')
154+
155+
if (total_memory <= model_size):
156+
raise MemoryError(f'Not enough memory to train the model. model_size: {model_size / (1024**3):.3f} GB')
157+
145158
pbar = tqdm(total=num_nodes, disable=not verbose)
146159

147160
def visit(node):
@@ -195,6 +208,26 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray,
195208
return Node(label_map=label_map, children=children)
196209

197210

211+
def get_estimated_model_size(root, num_nodes):
212+
num_nnz_feat, num_branches = np.zeros(num_nodes), np.zeros(num_nodes)
213+
num_nodes = 0
214+
def collect_stat(node: Node):
215+
nonlocal num_nodes
216+
num_nnz_feat[num_nodes] = node.num_nnz_feat
217+
218+
if node.isLeaf():
219+
num_branches[num_nodes] = len(node.label_map)
220+
else:
221+
num_branches[num_nodes] = len(node.children)
222+
223+
num_nodes += 1
224+
225+
root.dfs(collect_stat)
226+
227+
# 16 is because when storing sparse matrices, indices (int64) require 8 bytes and floats require 8 bytes
228+
return np.dot(num_nnz_feat, num_branches) * 16
229+
230+
198231
def _train_node(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str, node: Node):
199232
"""If node is internal, computes the metalabels representing each child and trains
200233
on the metalabels. Otherwise, train on y.

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ PyYAML
55
scikit-learn
66
scipy
77
tqdm
8+
psutil

0 commit comments

Comments
 (0)