@@ -149,7 +149,7 @@ def count(node):
149
149
total_memory = psutil .virtual_memory ().total
150
150
print (f'Your system memory is: { total_memory / (1024 ** 3 ):.3f} GB' )
151
151
152
- model_size = get_estimated_model_size (root , num_nodes )
152
+ model_size = get_estimated_model_size (root )
153
153
print (f'The estimated tree model size is: { model_size / (1024 ** 3 ):.3f} GB' )
154
154
155
155
if (total_memory <= model_size ):
@@ -208,24 +208,21 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray,
208
208
return Node (label_map = label_map , children = children )
209
209
210
210
211
- def get_estimated_model_size (root , num_nodes ):
212
- num_nnz_feat , num_branches = np . zeros ( num_nodes ), np . zeros ( num_nodes )
213
- num_nodes = 0
211
+ def get_estimated_model_size (root ):
212
+ total_num_weights = 0
213
+
214
214
def collect_stat (node : Node ):
215
- nonlocal num_nodes
216
- num_nnz_feat [num_nodes ] = node .num_nnz_feat
215
+ nonlocal total_num_weights
217
216
218
217
if node .isLeaf ():
219
- num_branches [ num_nodes ] = len (node .label_map )
218
+ total_num_weights + = len (node .label_map ) * node . num_nnz_feat
220
219
else :
221
- num_branches [num_nodes ] = len (node .children )
222
-
223
- num_nodes += 1
220
+ total_num_weights += len (node .children ) * node .num_nnz_feat
224
221
225
222
root .dfs (collect_stat )
226
223
227
224
# 16 is because when storing sparse matrices, indices (int64) require 8 bytes and floats require 8 bytes
228
- return np . dot ( num_nnz_feat , num_branches ) * 16
225
+ return total_num_weights * 16
229
226
230
227
231
228
def _train_node (y : sparse .csr_matrix , x : sparse .csr_matrix , options : str , node : Node ):
0 commit comments