|
7 | 7 | import sklearn.cluster
|
8 | 8 | import sklearn.preprocessing
|
9 | 9 | from tqdm import tqdm
|
| 10 | +import psutil |
10 | 11 |
|
11 | 12 | from . import linear
|
12 | 13 |
|
@@ -135,13 +136,25 @@ def train_tree(
|
135 | 136 | root = _build_tree(label_representation, np.arange(y.shape[1]), 0, K, dmax)
|
136 | 137 |
|
137 | 138 | num_nodes = 0
|
| 139 | + label_feature_used = (x != 0).T * y |
138 | 140 |
|
139 | 141 | def count(node):
|
140 | 142 | nonlocal num_nodes
|
141 | 143 | num_nodes += 1
|
| 144 | + node.num_nnz_feat = np.count_nonzero(label_feature_used[:, node.label_map].sum(axis=0)) |
142 | 145 |
|
143 | 146 | root.dfs(count)
|
144 | 147 |
|
| 148 | + # Calculate the total memory (excluding swap) on the local machine |
| 149 | + total_memory = psutil.virtual_memory().total |
| 150 | + print(f'Your system memory is: {total_memory / (1024**3):.3f} GB') |
| 151 | + |
| 152 | + model_size = get_estimated_model_size(root, num_nodes) |
| 153 | + print(f'The estimated tree model size is: {model_size / (1024**3):.3f} GB') |
| 154 | + |
| 155 | + if (total_memory <= model_size): |
| 156 | + raise MemoryError(f'Not enough memory to train the model. model_size: {model_size / (1024**3):.3f} GB') |
| 157 | + |
145 | 158 | pbar = tqdm(total=num_nodes, disable=not verbose)
|
146 | 159 |
|
147 | 160 | def visit(node):
|
@@ -195,6 +208,26 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray,
|
195 | 208 | return Node(label_map=label_map, children=children)
|
196 | 209 |
|
197 | 210 |
|
| 211 | +def get_estimated_model_size(root, num_nodes): |
| 212 | + num_nnz_feat, num_branches = np.zeros(num_nodes), np.zeros(num_nodes) |
| 213 | + num_nodes = 0 |
| 214 | + def collect_stat(node: Node): |
| 215 | + nonlocal num_nodes |
| 216 | + num_nnz_feat[num_nodes] = node.num_nnz_feat |
| 217 | + |
| 218 | + if node.isLeaf(): |
| 219 | + num_branches[num_nodes] = len(node.label_map) |
| 220 | + else: |
| 221 | + num_branches[num_nodes] = len(node.children) |
| 222 | + |
| 223 | + num_nodes += 1 |
| 224 | + |
| 225 | + root.dfs(collect_stat) |
| 226 | + |
| 227 | + # 16 is because when storing sparse matrices, indices (int64) require 8 bytes and floats require 8 bytes |
| 228 | + return np.dot(num_nnz_feat, num_branches) * 16 |
| 229 | + |
| 230 | + |
198 | 231 | def _train_node(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str, node: Node):
|
199 | 232 | """If node is internal, computes the metalabels representing each child and trains
|
200 | 233 | on the metalabels. Otherwise, train on y.
|
|
0 commit comments