|
7 | 7 | import sklearn.cluster
|
8 | 8 | import sklearn.preprocessing
|
9 | 9 | from tqdm import tqdm
|
| 10 | +import psutil |
10 | 11 |
|
11 | 12 | from . import linear
|
12 | 13 |
|
@@ -135,13 +136,28 @@ def train_tree(
|
135 | 136 | root = _build_tree(label_representation, np.arange(y.shape[1]), 0, K, dmax)
|
136 | 137 |
|
137 | 138 | num_nodes = 0
|
| 139 | + # Both type(x) and type(y) are sparse.csr_matrix |
| 140 | + # However, type((x != 0).T) becomes sparse.csc_matrix |
| 141 | + # So type((x != 0).T * y) results in sparse.csc_matrix |
| 142 | + features_used_perlabel = (x != 0).T * y |
138 | 143 |
|
139 | 144 | def count(node):
|
140 | 145 | nonlocal num_nodes
|
141 | 146 | num_nodes += 1
|
| 147 | + node.num_features_used = np.count_nonzero(features_used_perlabel[:, node.label_map].sum(axis=1)) |
142 | 148 |
|
143 | 149 | root.dfs(count)
|
144 | 150 |
|
| 151 | + model_size = get_estimated_model_size(root) |
| 152 | + print(f'The estimated tree model size is: {model_size / (1024**3):.3f} GB') |
| 153 | + |
| 154 | + # Calculate the total memory (excluding swap) on the local machine |
| 155 | + total_memory = psutil.virtual_memory().total |
| 156 | + print(f'Your system memory is: {total_memory / (1024**3):.3f} GB') |
| 157 | + |
| 158 | + if (total_memory <= model_size): |
| 159 | + raise MemoryError(f'Not enough memory to train the model.') |
| 160 | + |
145 | 161 | pbar = tqdm(total=num_nodes, disable=not verbose)
|
146 | 162 |
|
147 | 163 | def visit(node):
|
@@ -195,6 +211,23 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray,
|
195 | 211 | return Node(label_map=label_map, children=children)
|
196 | 212 |
|
197 | 213 |
|
| 214 | +def get_estimated_model_size(root): |
| 215 | + total_num_weights = 0 |
| 216 | + |
| 217 | + def collect_stat(node: Node): |
| 218 | + nonlocal total_num_weights |
| 219 | + |
| 220 | + if node.isLeaf(): |
| 221 | + total_num_weights += len(node.label_map) * node.num_features_used |
| 222 | + else: |
| 223 | + total_num_weights += len(node.children) * node.num_features_used |
| 224 | + |
| 225 | + root.dfs(collect_stat) |
| 226 | + |
| 227 | + # 16 is because when storing sparse matrices, indices (int64) require 8 bytes and floats require 8 bytes |
| 228 | + return total_num_weights * 16 |
| 229 | + |
| 230 | + |
198 | 231 | def _train_node(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str, node: Node):
|
199 | 232 | """If node is internal, computes the metalabels representing each child and trains
|
200 | 233 | on the metalabels. Otherwise, train on y.
|
|
0 commit comments