|
2 | 2 |
|
3 | 3 | import importlib
|
4 | 4 | import sys
|
| 5 | +from collections import defaultdict |
| 6 | + |
| 7 | +import polars as pl |
5 | 8 |
|
6 | 9 |
|
7 | 10 | def show_versions(file=sys.stdout): # pragma: no cover
|
@@ -119,3 +122,87 @@ def _update(self, kwargs):
|
119 | 122 | def __exit__(self, type, value, traceback):
|
120 | 123 | """Context management."""
|
121 | 124 | self._update(self.old)
|
| 125 | + |
| 126 | + |
| 127 | +class MinimalExploder: |
| 128 | + """ |
| 129 | + A comprehensive class for analyzing and performing minimal explosions |
| 130 | + of DataFrames with multiple list columns. |
| 131 | + """ |
| 132 | + |
| 133 | + def __init__(self, df: pl.DataFrame): |
| 134 | + self.df = df |
| 135 | + self._list_cols: list[str] | None = None |
| 136 | + self._length_patterns: dict[str, tuple[int, ...]] | None = None |
| 137 | + self._explodable_groups: list[list[str]] | None = None |
| 138 | + |
| 139 | + @property |
| 140 | + def list_columns(self) -> list[str]: |
| 141 | + """Get all list-type columns in the DataFrame.""" |
| 142 | + if self._list_cols is None: |
| 143 | + self._list_cols = [col for col in self.df.columns if self.df[col].dtype == pl.List] |
| 144 | + return self._list_cols |
| 145 | + |
| 146 | + @property |
| 147 | + def length_patterns(self) -> dict[str, tuple[int, ...]]: |
| 148 | + """Get length patterns for all list columns. |
| 149 | +
|
| 150 | + This is stored as a dictionary containing tuples of all list lengths, ie |
| 151 | + 'a' : (1,3,2), |
| 152 | + 'b' : (2,2,2), |
| 153 | +
|
| 154 | + """ |
| 155 | + if self._length_patterns is None: |
| 156 | + self._length_patterns = self._analyze_patterns() |
| 157 | + return self._length_patterns |
| 158 | + |
| 159 | + @property |
| 160 | + def explodable_groups(self) -> list[list[str]]: |
| 161 | + """Get groups of columns that can be exploded together.""" |
| 162 | + if self._explodable_groups is None: |
| 163 | + self._explodable_groups = self._compute_groups() |
| 164 | + return self._explodable_groups |
| 165 | + |
| 166 | + def _analyze_patterns(self) -> dict[str, tuple[int, ...]]: |
| 167 | + """Analyze length patterns of all list columns. Returns a value |
| 168 | + rather than setting self._length_patterns to shut up mypy.""" |
| 169 | + _length_patterns = {} |
| 170 | + |
| 171 | + for col in self.list_columns: |
| 172 | + lengths = self.df.select(pl.col(col).list.len()).to_series().to_list() |
| 173 | + _length_patterns[col] = tuple(lengths) |
| 174 | + |
| 175 | + return _length_patterns |
| 176 | + |
| 177 | + def _compute_groups(self): |
| 178 | + """Compute explodable groups based on length patterns. Returns a value |
| 179 | + rather than setting self._explodable_groups to shut up mypy.""" |
| 180 | + pattern_groups = defaultdict(list) |
| 181 | + |
| 182 | + for col, pattern in self.length_patterns.items(): |
| 183 | + pattern_groups[pattern].append(col) |
| 184 | + |
| 185 | + return list(pattern_groups.values()) |
| 186 | + |
| 187 | + @property |
| 188 | + def summary(self) -> dict: |
| 189 | + """Get a summary of the explosion analysis.""" |
| 190 | + return { |
| 191 | + 'total_columns': len(self.df.columns), |
| 192 | + 'list_columns': len(self.list_columns), |
| 193 | + 'unique_patterns': len(set(self.length_patterns.values())), |
| 194 | + 'explodable_groups': len(self.explodable_groups), |
| 195 | + 'explosion_operations_needed': len(self.explodable_groups), |
| 196 | + 'groups': self.explodable_groups, |
| 197 | + } |
| 198 | + |
| 199 | + def __call__(self) -> pl.DataFrame: |
| 200 | + """Perform the minimal explosion.""" |
| 201 | + if not self.list_columns: |
| 202 | + return self.df |
| 203 | + |
| 204 | + result_df = self.df |
| 205 | + for group in self.explodable_groups: |
| 206 | + result_df = result_df.explode(*group) |
| 207 | + |
| 208 | + return result_df |
0 commit comments