@@ -136,24 +136,27 @@ def train_tree(
136136 root = _build_tree (label_representation , np .arange (y .shape [1 ]), 0 , K , dmax )
137137
138138 num_nodes = 0
139- label_feature_used = (x != 0 ).T * y
139+ # Both type(x) and type(y) are sparse.csr_matrix
140+ # However, type((x != 0).T) becomes sparse.csc_matrix
141+ # So type((x != 0).T * y) results in sparse.csc_matrix
142+ features_used_perlabel = (x != 0 ).T * y
140143
141144 def count (node ):
142145 nonlocal num_nodes
143146 num_nodes += 1
144- node .num_nnz_feat = np .count_nonzero (label_feature_used [:, node .label_map ].sum (axis = 1 ))
147+ node .num_features_used = np .count_nonzero (features_used_perlabel [:, node .label_map ].sum (axis = 1 ))
145148
146149 root .dfs (count )
147150
151+ model_size = get_estimated_model_size (root )
152+ print (f'The estimated tree model size is: { model_size / (1024 ** 3 ):.3f} GB' )
153+
148154 # Calculate the total memory (excluding swap) on the local machine
149155 total_memory = psutil .virtual_memory ().total
150156 print (f'Your system memory is: { total_memory / (1024 ** 3 ):.3f} GB' )
151157
152- model_size = get_estimated_model_size (root )
153- print (f'The estimated tree model size is: { model_size / (1024 ** 3 ):.3f} GB' )
154-
155158 if (total_memory <= model_size ):
156- raise MemoryError (f'Not enough memory to train the model. model_size: { model_size / ( 1024 ** 3 ):.3f } GB ' )
159+ raise MemoryError (f'Not enough memory to train the model.' )
157160
158161 pbar = tqdm (total = num_nodes , disable = not verbose )
159162
@@ -215,9 +218,9 @@ def collect_stat(node: Node):
215218 nonlocal total_num_weights
216219
217220 if node .isLeaf ():
218- total_num_weights += len (node .label_map ) * node .num_nnz_feat
221+ total_num_weights += len (node .label_map ) * node .num_features_used
219222 else :
220- total_num_weights += len (node .children ) * node .num_nnz_feat
223+ total_num_weights += len (node .children ) * node .num_features_used
221224
222225 root .dfs (collect_stat )
223226
0 commit comments