Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions great_tables/_gt_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class GTData:
_locale: Locale | None
_formats: Formats
_substitutions: Formats
_col_merge: ColMerges
_options: Options
_google_font_imports: GoogleFontImports = field(default_factory=GoogleFontImports)
_has_built: bool = False
Expand Down Expand Up @@ -132,6 +133,7 @@ def from_data(
_locale=Locale(locale),
_formats=[],
_substitutions=[],
_col_merge=[],
_options=options,
_google_font_imports=GoogleFontImports(),
)
Expand Down Expand Up @@ -985,6 +987,20 @@ def __init__(self, func: FormatFns, cols: list[str], rows: list[int]):
Formats = list


# Column Merge ----


@dataclass(frozen=True)
class ColMergeInfo:
vars: list[str]
rows: list[int]
type: str # type of merge operation (only 'merge' used currently)
pattern: str | None = None


ColMerges = list[ColMergeInfo]


# Summary Rows ---

# This can't conflict with actual group ids since we have a
Expand Down
187 changes: 185 additions & 2 deletions great_tables/_spanners.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from typing_extensions import TypeAlias, TypedDict

from ._boxhead import cols_label
from ._gt_data import SpannerInfo, Spanners
from ._locations import resolve_cols_c
from ._gt_data import ColMergeInfo, SpannerInfo, Spanners
from ._locations import resolve_cols_c, resolve_rows_i
from ._tbl_data import SelectExpr
from ._text import BaseText, Text
from ._utils import OrderedSet, _assert_list_is_subset
Expand Down Expand Up @@ -1026,3 +1026,186 @@ def cols_width(self: GTSelf, cases: dict[str, str] | None = None, **kwargs: str)
curr_boxhead = curr_boxhead._set_column_width(col, width)

return self._replace(_boxhead=curr_boxhead)


def cols_merge(
self: GTSelf,
columns: SelectExpr,
hide_columns: SelectExpr | Literal[False] = None,
rows: int | list[int] | None = None,
pattern: str | None = None,
) -> GTSelf:
"""Merge data from two or more columns into a single column.

This method takes input from two or more columns and allows the contents to be merged into a
single column by using a pattern that specifies the arrangement. The first column in the
`columns=` parameter operates as the target column (i.e., the column that will undergo mutation)
whereas all following columns will be untouched. There is the option to hide the non-target
columns. The formatting of values in different columns will be preserved upon merging.

Parameters
----------
columns
The columns for which the merging operations should be applied. The first column name
resolved will be the target column (i.e., undergo mutation) and the other columns will serve
to provide input. Can be a list of column names or a selection expression, though a list is
preferred here to ensure the order of columns is exactly as intended (since order matters
for the `pattern=` parameter).
hide_columns
Any column names provided here will have their state changed to hidden (via internal use
of `.cols_hide()`) if they aren't already hidden. This is convenient if the shared purpose
of these specified columns is only to provide string input to the target column. To
suppress any hiding of columns, `False` can be used here. By default, all columns other
than the first one specified in `columns=` will be hidden.
rows
In conjunction with `columns=`, we can specify which of their rows should participate in
the merging process. The default is all rows, resulting in all rows in `columns=` being
formatted. Alternatively, we can supply a list of row indices.
pattern
A formatting pattern that specifies the arrangement of the column values and any string
literals. The pattern uses numbers (within `{}`) that correspond to the indices of columns
provided in `columns=`. If two columns are provided in `columns=` and we would like to
combine the cell data onto the first column, `"{1} {2}"` could be used. If a pattern isn't
provided then a space-separated pattern that includes all columns will be generated
automatically. The pattern can also use `<<`/`>>` to surround spans of text that will be
removed if any of the contained `{}` yields a missing value. Further details are provided in
the *How the pattern works* section.

Returns
-------
GT
The GT object is returned. This is the same object that the method is called on so that we
can facilitate method chaining.

Details
-------
### How the pattern works

There are two types of templating for the `pattern` string:

- `{` `}` for arranging single column values in a row-wise fashion
- `<<` `>>` to surround spans of text that will be removed if any of the contained `{` `}`
yields a missing value

Integer values are placed in `{}` and those values correspond to the columns involved in the
merge, in the order they are provided in the `columns=` argument. So the pattern
`"{1} ({2}-{3})"` corresponds to the target column value listed first in `columns` and the
second and third columns cited (formatted as a range in parentheses). With hypothetical values,
this might result as the merged string `"38.2 (3-8)"`.

Because some values involved in merging may be missing, it is likely that something like
`"38.2 (3-None)"` would be undesirable. For such cases, placing sections of text in `<<>>`
results in the entire span being eliminated if there were to be an `None` value (arising from
`{}` values). We could instead opt for a pattern like `"{1}<< ({2}-{3})>>"`, which results in
`"38.2"` if either columns `{2}` or `{3}` have a `None` value. We can even use a more complex
nesting pattern like `"{1}<< ({2}-<<{3}>>)>>"` to retain a lower limit in parentheses (where
`{3}` is `None`) but remove the range altogether if `{2}` is `None`.

One more thing to note here is that if `.sub_missing()` is used on values in a column, those
specific values affected won't be considered truly missing by `.cols_merge()` (since they have
been explicitly handled with substitute text).

Examples
--------
Let's use a subset of the `sp500` dataset to create a table. We'll merge the `open` & `close`
columns together, and the `low` & `high` columns (putting an em dash between both).

```{python}
from great_tables import GT
from great_tables.data import sp500
import polars as pl

sp500_mini = (
pl.from_pandas(sp500)
.slice(49, 6)
.select("open", "close", "low", "high")
)

(
GT(sp500_mini)
.fmt_number(
columns=["open", "close", "low", "high"],
decimals=2,
use_seps=False
)
.cols_merge(columns=["open", "close"], pattern="{1}&mdash;{2}")
.cols_merge(columns=["low", "high"], pattern="{1}&mdash;{2}")
.cols_label(open="open/close", low="low/high")
)
```

Now we'll use a portion of the `gtcars` for the next example that accounts for missing values in
the `pattern=` parameter. Use the `.cols_merge()` method twice to merge together the: (1) `trq`
and `trq_rpm` columns, and (2) `mpg_c` & `mpg_h` columns. Given the presence of missing values,
we can use patterns with `<<`/`>>` to create conditional text spans, avoiding results where
any of the merged columns have missing values.

```{python}
from great_tables.data import gtcars
import polars.selectors as cs

gtcars_pl = (
pl.from_pandas(gtcars)
.filter(pl.col("year") == 2017)
.select(["mfr", "model", "trq", "trq_rpm", "mpg_c", "mpg_h"])
)

(
GT(gtcars_pl)
.fmt_integer(columns=[cs.starts_with("trq"), cs.starts_with("mpg")])
.cols_merge(columns=["trq", "trq_rpm"], pattern="{1}<< ({2} rpm)>>")
.cols_merge(columns=["mpg_c", "mpg_h"], pattern="<<{1} city<</{2} hwy>>>>")
.cols_label(mfr="Manufacturer", model="Car Model", trq="Torque", mpg_c="MPG")
)
```
"""
# Get the columns supplied in `columns` as a list of column names
columns_resolved = resolve_cols_c(data=self, expr=columns)

if len(columns_resolved) < 2:
raise ValueError("At least two columns must be specified for merging.")

# Generate default pattern if not provided
if pattern is None:
pattern = " ".join(f"{{{i+1}}}" for i in range(len(columns_resolved)))

# Resolve the rows supplied in the `rows` argument
row_res = resolve_rows_i(self, rows)
row_pos = [name_pos[1] for name_pos in row_res]

# Determine which columns to hide
# Default behavior: hide all columns except the first (target) column
if hide_columns is None:
hide_columns = columns_resolved[1:]
elif hide_columns is False:
hide_columns = []
else:
# Resolve hide_columns expression
hide_columns = resolve_cols_c(data=self, expr=hide_columns)

# Filter hide_columns to only include those in columns_resolved
hide_columns_filtered = [col for col in hide_columns if col in columns_resolved]

# Warn if some hide_columns are not in columns_resolved
if len(hide_columns_filtered) < len(hide_columns):
warnings.warn(
"Only columns supplied in `columns` will be hidden. "
"Use an additional `cols_hide()` call to hide any out-of-scope columns.",
UserWarning,
)

# Hide the specified columns
result = self
if hide_columns_filtered:
result = cols_hide(result, columns=hide_columns_filtered)

# Create column merge entry
col_merge_entry = ColMergeInfo(
vars=columns_resolved,
rows=row_pos,
type="merge",
pattern=pattern,
)

# Add to _col_merge list
return result._replace(_col_merge=[*result._col_merge, col_merge_entry])
95 changes: 95 additions & 0 deletions great_tables/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,98 @@ def _get_visible_cells(data: TblData) -> list[tuple[str, int]]:

def is_valid_http_schema(url: str) -> bool:
return url.startswith(("http://", "https://"))


# Column merge pattern processing utilities ----

# Token used to represent missing values during pattern processing
_MISSING_VAL_TOKEN = "::NA::"


def _process_col_merge_pattern(
pattern: str,
col_values: dict[str, str],
col_is_missing: dict[str, bool],
) -> str:
"""Process a column merge pattern by substituting values and handling missing data."""

# Replace values with tokens if they are truly missing
processed_values = {}
for key, value in col_values.items():
# Consider a value missing if it matches a known NA representation
# and the original data was also NA
# - "NA": generic representation
# - "<NA>": pandas representation
# - "None": Polars representation (str(None))
if col_is_missing.get(key, False) and value in ("NA", "<NA>", "None"):
processed_values[key] = _MISSING_VAL_TOKEN
else:
processed_values[key] = value

# Substitute `{n}`` placeholders with values
result = pattern
for key, value in processed_values.items():
result = result.replace(f"{{{key}}}", value)

# Process conditional sections (`<<...>>`)
if "<<" in result and ">>" in result:
result = _resolve_conditional_sections(result)

# Clean up any remaining missing value tokens
result = result.replace(_MISSING_VAL_TOKEN, "NA")

return result


def _resolve_conditional_sections(text: str) -> str:
"""Resolve conditional sections marked with <<...>> in text."""
# Process from innermost to outermost sections
max_iterations = 100 # Prevent infinite looping
iteration = 0

while "<<" in text and ">>" in text and iteration < max_iterations:
iteration += 1

# Find the last occurrence of `<<` (innermost section start)
last_open = text.rfind("<<")
if last_open == -1:
break

# Find the first `>>` after that `<<`
first_close = text.find(">>", last_open)
if first_close == -1:
break

# Extract the content between << and >>
section_content = text[last_open + 2 : first_close]

# Check if the section contains a missing value token
if _MISSING_VAL_TOKEN in section_content:
# Remove the entire section (including the markers)
replacement = ""
else:
# Keep the content without the markers
replacement = section_content

# Replace this section in the text
text = text[:last_open] + replacement + text[first_close + 2 :]

return text


def _extract_pattern_columns(pattern: str) -> list[str]:
"""Extract column references from a pattern string."""

# Find all `{n}` patterns
matches = re.findall(r"\{(\d+)\}", pattern)

# Return unique matches in order they appear
seen = set()
result = []

for match in matches:
if match not in seen:
result.append(match)
seen.add(match)

return result
Loading
Loading