From 45a7164ecb06bf8acae433c5f35a01e7f027e1b2 Mon Sep 17 00:00:00 2001 From: Richard Iannone Date: Wed, 8 Oct 2025 14:28:04 -0400 Subject: [PATCH 1/3] Add column merging through cols_merge() method --- great_tables/_gt_data.py | 16 ++++ great_tables/_spanners.py | 187 +++++++++++++++++++++++++++++++++++++- great_tables/_utils.py | 130 ++++++++++++++++++++++++++ great_tables/gt.py | 104 ++++++++++++++++++++- 4 files changed, 432 insertions(+), 5 deletions(-) diff --git a/great_tables/_gt_data.py b/great_tables/_gt_data.py index 1456d9454..193cc1e5e 100644 --- a/great_tables/_gt_data.py +++ b/great_tables/_gt_data.py @@ -81,6 +81,7 @@ class GTData: _locale: Locale | None _formats: Formats _substitutions: Formats + _col_merge: ColMerges _options: Options _google_font_imports: GoogleFontImports = field(default_factory=GoogleFontImports) _has_built: bool = False @@ -128,6 +129,7 @@ def from_data( _locale=Locale(locale), _formats=[], _substitutions=[], + _col_merge=[], _options=options, _google_font_imports=GoogleFontImports(), ) @@ -972,6 +974,20 @@ def __init__(self, func: FormatFns, cols: list[str], rows: list[int]): Formats = list +# Column Merge ---- + + +@dataclass(frozen=True) +class ColMergeInfo: + vars: list[str] + rows: list[int] + type: str # type of merge operation (only 'merge' used currently) + pattern: str | None = None + + +ColMerges = list[ColMergeInfo] + + # Options ---- default_fonts_list = [ diff --git a/great_tables/_spanners.py b/great_tables/_spanners.py index 1062c2923..2488ac1ba 100644 --- a/great_tables/_spanners.py +++ b/great_tables/_spanners.py @@ -10,8 +10,8 @@ from typing_extensions import TypeAlias, TypedDict from ._boxhead import cols_label -from ._gt_data import SpannerInfo, Spanners -from ._locations import resolve_cols_c +from ._gt_data import ColMergeInfo, SpannerInfo, Spanners +from ._locations import resolve_cols_c, resolve_rows_i from ._tbl_data import SelectExpr from ._text import BaseText, Text from ._utils import OrderedSet, _assert_list_is_subset @@ -1026,3 +1026,186 @@ def cols_width(self: GTSelf, cases: dict[str, str] | None = None, **kwargs: str) curr_boxhead = curr_boxhead._set_column_width(col, width) return self._replace(_boxhead=curr_boxhead) + + +def cols_merge( + self: GTSelf, + columns: SelectExpr, + hide_columns: SelectExpr | Literal[False] = None, + rows: int | list[int] | None = None, + pattern: str | None = None, +) -> GTSelf: + """Merge data from two or more columns into a single column. + + This method takes input from two or more columns and allows the contents to be merged into a + single column by using a pattern that specifies the arrangement. The first column in the + `columns=` parameter operates as the target column (i.e., the column that will undergo mutation) + whereas all following columns will be untouched. There is the option to hide the non-target + columns. The formatting of values in different columns will be preserved upon merging. + + Parameters + ---------- + columns + The columns for which the merging operations should be applied. The first column name + resolved will be the target column (i.e., undergo mutation) and the other columns will serve + to provide input. Can be a list of column names or a selection expression, though a list is + preferred here to ensure the order of columns is exactly as intended (since order matters + for the `pattern=` parameter). + hide_columns + Any column names provided here will have their state changed to hidden (via internal use + of `.cols_hide()`) if they aren't already hidden. This is convenient if the shared purpose + of these specified columns is only to provide string input to the target column. To + suppress any hiding of columns, `False` can be used here. By default, all columns other + than the first one specified in `columns=` will be hidden. + rows + In conjunction with `columns=`, we can specify which of their rows should participate in + the merging process. The default is all rows, resulting in all rows in `columns=` being + formatted. Alternatively, we can supply a list of row indices. + pattern + A formatting pattern that specifies the arrangement of the column values and any string + literals. The pattern uses numbers (within `{}`) that correspond to the indices of columns + provided in `columns=`. If two columns are provided in `columns=` and we would like to + combine the cell data onto the first column, `"{1} {2}"` could be used. If a pattern isn't + provided then a space-separated pattern that includes all columns will be generated + automatically. The pattern can also use `<<`/`>>` to surround spans of text that will be + removed if any of the contained `{}` yields a missing value. Further details are provided in + the *How the pattern works* section. + + Returns + ------- + GT + The GT object is returned. This is the same object that the method is called on so that we + can facilitate method chaining. + + Details + ------- + ### How the pattern works + + There are two types of templating for the `pattern` string: + + - `{` `}` for arranging single column values in a row-wise fashion + - `<<` `>>` to surround spans of text that will be removed if any of the contained `{` `}` + yields a missing value + + Integer values are placed in `{}` and those values correspond to the columns involved in the + merge, in the order they are provided in the `columns=` argument. So the pattern + `"{1} ({2}-{3})"` corresponds to the target column value listed first in `columns` and the + second and third columns cited (formatted as a range in parentheses). With hypothetical values, + this might result as the merged string `"38.2 (3-8)"`. + + Because some values involved in merging may be missing, it is likely that something like + `"38.2 (3-None)"` would be undesirable. For such cases, placing sections of text in `<<>>` + results in the entire span being eliminated if there were to be an `None` value (arising from + `{}` values). We could instead opt for a pattern like `"{1}<< ({2}-{3})>>"`, which results in + `"38.2"` if either columns `{2}` or `{3}` have a `None` value. We can even use a more complex + nesting pattern like `"{1}<< ({2}-<<{3}>>)>>"` to retain a lower limit in parentheses (where + `{3}` is `None`) but remove the range altogether if `{2}` is `None`. + + One more thing to note here is that if `.sub_missing()` is used on values in a column, those + specific values affected won't be considered truly missing by `.cols_merge()` (since they have + been explicitly handled with substitute text). + + Examples + -------- + Let's use a subset of the `sp500` dataset to create a table. We'll merge the `open` & `close` + columns together, and the `low` & `high` columns (putting an em dash between both). + + ```{python} + from great_tables import GT + from great_tables.data import sp500 + import polars as pl + + sp500_mini = ( + pl.from_pandas(sp500) + .slice(49, 6) + .select("open", "close", "low", "high") + ) + + ( + GT(sp500_mini) + .fmt_number( + columns=["open", "close", "low", "high"], + decimals=2, + use_seps=False + ) + .cols_merge(columns=["open", "close"], pattern="{1}—{2}") + .cols_merge(columns=["low", "high"], pattern="{1}—{2}") + .cols_label(open="open/close", low="low/high") + ) + ``` + + Now we'll use a portion of the `gtcars` for the next example that accounts for missing values in + the `pattern=` parameter. Use the `.cols_merge()` method twice to merge together the: (1) `trq` + and `trq_rpm` columns, and (2) `mpg_c` & `mpg_h` columns. Given the presence of missing values, + we can use patterns with `<<`/`>>` to create conditional text spans, avoiding results where + any of the merged columns have missing values. + + ```{python} + from great_tables.data import gtcars + import polars.selectors as cs + + gtcars_pl = ( + pl.from_pandas(gtcars) + .filter(pl.col("year") == 2017) + .select(["mfr", "model", "trq", "trq_rpm", "mpg_c", "mpg_h"]) + ) + + ( + GT(gtcars_pl) + .fmt_integer(columns=[cs.starts_with("trq"), cs.starts_with("mpg")]) + .cols_merge(columns=["trq", "trq_rpm"], pattern="{1}<< ({2} rpm)>>") + .cols_merge(columns=["mpg_c", "mpg_h"], pattern="<<{1} city<>>>") + .cols_label(mfr="Manufacturer", model="Car Model", trq="Torque", mpg_c="MPG") + ) + ``` + """ + # Get the columns supplied in `columns` as a list of column names + columns_resolved = resolve_cols_c(data=self, expr=columns) + + if len(columns_resolved) < 2: + raise ValueError("At least two columns must be specified for merging.") + + # Generate default pattern if not provided + if pattern is None: + pattern = " ".join(f"{{{i+1}}}" for i in range(len(columns_resolved))) + + # Resolve the rows supplied in the `rows` argument + row_res = resolve_rows_i(self, rows) + row_pos = [name_pos[1] for name_pos in row_res] + + # Determine which columns to hide + # Default behavior: hide all columns except the first (target) column + if hide_columns is None: + hide_columns = columns_resolved[1:] + elif hide_columns is False: + hide_columns = [] + else: + # Resolve hide_columns expression + hide_columns = resolve_cols_c(data=self, expr=hide_columns) + + # Filter hide_columns to only include those in columns_resolved + hide_columns_filtered = [col for col in hide_columns if col in columns_resolved] + + # Warn if some hide_columns are not in columns_resolved + if len(hide_columns_filtered) < len(hide_columns): + warnings.warn( + "Only columns supplied in `columns` will be hidden. " + "Use an additional `cols_hide()` call to hide any out-of-scope columns.", + UserWarning, + ) + + # Hide the specified columns + result = self + if hide_columns_filtered: + result = cols_hide(result, columns=hide_columns_filtered) + + # Create column merge entry + col_merge_entry = ColMergeInfo( + vars=columns_resolved, + rows=row_pos, + type="merge", + pattern=pattern, + ) + + # Add to _col_merge list + return result._replace(_col_merge=[*result._col_merge, col_merge_entry]) diff --git a/great_tables/_utils.py b/great_tables/_utils.py index 6025ab881..666f90670 100644 --- a/great_tables/_utils.py +++ b/great_tables/_utils.py @@ -285,3 +285,133 @@ def _get_visible_cells(data: TblData) -> list[tuple[str, int]]: def is_valid_http_schema(url: str) -> bool: return url.startswith(("http://", "https://")) + + +# Column merge pattern processing utilities ---- + +# Token used to represent missing values during pattern processing +_MISSING_VAL_TOKEN = "::NA::" + + +def _process_col_merge_pattern( + pattern: str, + col_values: dict[str, str], + col_is_missing: dict[str, bool], +) -> str: + """Process a column merge pattern by substituting values and handling missing data. + + Parameters + ---------- + pattern : str + The pattern string with {n} placeholders and optional <<>> conditional sections. + col_values : dict[str, str] + Dictionary mapping column indices (as strings, e.g., "1", "2") to their formatted values. + col_is_missing : dict[str, bool] + Dictionary mapping column indices to whether the original value was missing/NA. + + Returns + ------- + str + The processed pattern with values substituted and conditional sections resolved. + """ + # Replace values with tokens if they are truly missing + processed_values = {} + for key, value in col_values.items(): + # Consider a value missing if it's "NA" or "" (pandas representation) + # and the original data was also NA + if col_is_missing.get(key, False) and value in ("NA", ""): + processed_values[key] = _MISSING_VAL_TOKEN + else: + processed_values[key] = value + + # Substitute {n} placeholders with values + result = pattern + for key, value in processed_values.items(): + result = result.replace(f"{{{key}}}", value) + + # Process conditional sections (<<...>>) + if "<<" in result and ">>" in result: + result = _resolve_conditional_sections(result) + + # Clean up any remaining missing value tokens + result = result.replace(_MISSING_VAL_TOKEN, "NA") + + return result + + +def _resolve_conditional_sections(text: str) -> str: + """Resolve conditional sections marked with <<...>> in text. + + Removes any section enclosed in <<...>> if it contains the missing value token. + Processes innermost sections first to handle nesting. + + Parameters + ---------- + text : str + The text containing conditional sections. + + Returns + ------- + str + The text with conditional sections resolved. + """ + # Process from innermost to outermost sections + max_iterations = 100 # Prevent infinite loops + iteration = 0 + + while "<<" in text and ">>" in text and iteration < max_iterations: + iteration += 1 + + # Find the last occurrence of << (innermost section start) + last_open = text.rfind("<<") + if last_open == -1: + break + + # Find the first >> after that << + first_close = text.find(">>", last_open) + if first_close == -1: + break + + # Extract the content between << and >> + section_content = text[last_open + 2 : first_close] + + # Check if the section contains a missing value token + if _MISSING_VAL_TOKEN in section_content: + # Remove the entire section (including markers) + replacement = "" + else: + # Keep the content without the markers + replacement = section_content + + # Replace this section in the text + text = text[:last_open] + replacement + text[first_close + 2 :] + + # Clean up any remaining markers (shouldn't normally happen) + text = text.replace("<<", "").replace(">>", "") + + return text + + +def _extract_pattern_columns(pattern: str) -> list[str]: + """Extract column references from a pattern string. + + Parameters + ---------- + pattern : str + The pattern string with {n} placeholders. + + Returns + ------- + list[str] + List of unique column indices referenced in the pattern (as strings). + """ + # Find all {n} patterns + matches = re.findall(r"\{(\d+)\}", pattern) + # Return unique matches in order they appear + seen = set() + result = [] + for match in matches: + if match not in seen: + result.append(match) + seen.add(match) + return result diff --git a/great_tables/gt.py b/great_tables/gt.py index e38875950..948409ffe 100644 --- a/great_tables/gt.py +++ b/great_tables/gt.py @@ -51,6 +51,7 @@ from ._source_notes import tab_source_note from ._spanners import ( cols_hide, + cols_merge, cols_move, cols_move_to_end, cols_move_to_start, @@ -63,8 +64,12 @@ from ._stubhead import tab_stubhead from ._substitution import sub_missing, sub_zero from ._tab_create_modify import tab_style -from ._tbl_data import _get_cell, n_rows -from ._utils import _migrate_unformatted_to_output +from ._tbl_data import _get_cell, _set_cell, is_na, n_rows +from ._utils import ( + _extract_pattern_columns, + _migrate_unformatted_to_output, + _process_col_merge_pattern, +) from ._utils_render_html import ( _get_table_defs, create_body_component_h, @@ -258,6 +263,7 @@ def __init__( cols_align = cols_align cols_width = cols_width cols_label = cols_label + cols_merge = cols_merge cols_move = cols_move cols_move_to_start = cols_move_to_start cols_move_to_end = cols_move_to_end @@ -313,6 +319,96 @@ def _render_formats(self, context: str) -> Self: new_body.render_formats(self._tbl_data, self._substitutions, context) return self._replace(_body=new_body) + def _perform_col_merge(self) -> Self: + # If no column merging defined, return unchanged + if not self._col_merge: + return self # pragma: no cover + + # Create a copy of the body for modification + new_body = self._body.copy() + + # Process each column merge operation in order + for col_merge in self._col_merge: + if col_merge.type != "merge": + # TODO: incorporate other specialized merging operations (e.g., "merge_range") but + # only handle the basic 'merge' type for now + continue + + # Get the target column (column that receives the merged values) + target_column = col_merge.vars[0] + + # Get all columns, rows, and the pattern for this merge operation + columns = col_merge.vars + rows = col_merge.rows + pattern = col_merge.pattern + + # Pattern should always be set by `.cols_merge()`, but check here just in case + if pattern is None: # pragma: no cover + raise ValueError("Pattern must be provided for column merge operations.") + + # Validate that pattern references are valid + pattern_cols = _extract_pattern_columns(pattern) + for col_ref in pattern_cols: + # The pattern syntax uses 1-based indexing so adjust here + col_idx = int(col_ref) - 1 + + # Check that the referenced column exists in the provided columns + if col_idx < 0: + raise ValueError( + f"Pattern references column {{{col_ref}}} but column indexing starts " + f"at {{1}}, not {{0}}. Please use 1-based indexing in patterns." + ) + if col_idx >= len(columns): + raise ValueError( + f"Pattern references column {{{col_ref}}} but only {len(columns)} " + f"columns were provided to cols_merge()." + ) + + # Process each row (according to the `rows=` parameter in `cols_merge()`) + for row_idx in rows: + # Collect values and missing status for all columns + col_values = {} + col_is_missing = {} + + for i, col_name in enumerate(columns): + # Get the formatted value from the body + formatted_value = _get_cell(new_body.body, row_idx, col_name) + + # Get the original value from the data table + original_value = _get_cell(self._tbl_data, row_idx, col_name) + + # If the body cell is missing (unformatted) and the original has a value, + # use the original value; otherwise use the formatted value + if is_na(new_body.body, formatted_value) and not is_na( + self._tbl_data, original_value + ): + # Cell is unformatted but has a value in the original data + display_value = str(original_value) + else: + # If the cell is formatted OR the original is missing then use the + # formatted value (which has the proper NA representation like "") + display_value = str(formatted_value) + + # Store with 1-based index (as used in the pattern) + col_key = str(i + 1) + col_values[col_key] = display_value + col_is_missing[col_key] = is_na(self._tbl_data, original_value) + + # Process the pattern with the collected values + merged_value = _process_col_merge_pattern( + pattern=pattern, col_values=col_values, col_is_missing=col_is_missing + ) + + # Set the merged value in the target column + result = _set_cell(new_body.body, row_idx, target_column, merged_value) + + # For Pandas and Polars, _set_cell() modifies in place and returns None but + # for PyArrow, _set_cell() returns a new table + if result is not None: + new_body.body = result + + return self._replace(_body=new_body) + def _build_data(self, context: str) -> Self: # Build the body of the table by generating a dictionary # of lists with cells initially set to nan values @@ -323,7 +419,9 @@ def _build_data(self, context: str) -> Self: data=built, data_tbl=self._tbl_data, formats=self._formats, context=context ) - # built._perform_col_merge() + # Perform column merging + built = built._perform_col_merge() + final_body = body_reassemble(built._body) # Reordering of the metadata elements of the table From 9148a4e8468fd3883484fd3a2dc855b5b9e45ad5 Mon Sep 17 00:00:00 2001 From: Richard Iannone Date: Wed, 8 Oct 2025 14:28:22 -0400 Subject: [PATCH 2/3] Add several tests for cols_merge() --- tests/test_cols_merge.py | 251 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 251 insertions(+) create mode 100644 tests/test_cols_merge.py diff --git a/tests/test_cols_merge.py b/tests/test_cols_merge.py new file mode 100644 index 000000000..157ee6b95 --- /dev/null +++ b/tests/test_cols_merge.py @@ -0,0 +1,251 @@ +import pandas as pd +import pytest +from great_tables import GT +from great_tables._gt_data import ColMergeInfo + + +@pytest.fixture +def simple_df(): + return pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + + +@pytest.fixture +def missing_df(): + return pd.DataFrame( + { + "number": ["Three", "Four", "Five"], + "val1": [10, 20, 30], + "val2": [15.0, 25.0, None], + "val3": [5.0, None, None], + } + ) + + +def test_cols_merge_basic(simple_df: pd.DataFrame): + gt = GT(simple_df).cols_merge(columns=["a", "b"]) + + # Check that merge was registered + assert len(gt._col_merge) == 1 + assert gt._col_merge[0].vars == ["a", "b"] + assert gt._col_merge[0].pattern == "{1} {2}" + assert gt._col_merge[0].type == "merge" + + # Check rendered output + html = gt.as_raw_html() + assert "1 4" in html + assert "2 5" in html + assert "3 6" in html + + +def test_cols_merge_custom_pattern(simple_df: pd.DataFrame): + gt = GT(simple_df).cols_merge(columns=["a", "b"], pattern="{1}—{2}") + + assert gt._col_merge[0].pattern == "{1}—{2}" + + # Check rendered output with em dash + html = gt.as_raw_html() + assert "1—4" in html + assert "2—5" in html + assert "3—6" in html + + +def test_cols_merge_three_columns(simple_df: pd.DataFrame): + gt = GT(simple_df).cols_merge(columns=["a", "b", "c"]) + + assert gt._col_merge[0].vars == ["a", "b", "c"] + assert gt._col_merge[0].pattern == "{1} {2} {3}" + + # Check rendered output with three columns + html = gt.as_raw_html() + assert "1 4 7" in html + assert "2 5 8" in html + assert "3 6 9" in html + + +def test_cols_merge_subset_of_columns(simple_df: pd.DataFrame): + # Provide three columns but only use first two in the pattern + gt1 = GT(simple_df).cols_merge(columns=["a", "b", "c"], pattern="{1} {2}") + + html1 = gt1.as_raw_html() + assert "1 4" in html1 + assert "2 5" in html1 + assert "3 6" in html1 + + # Provide three columns but only use first and third in the pattern + gt2 = GT(simple_df).cols_merge(columns=["a", "b", "c"], pattern="{1} {3}") + + html2 = gt2.as_raw_html() + assert "1 7" in html2 + assert "2 8" in html2 + assert "3 9" in html2 + + # Provide three columns but only use the third in the pattern + gt3 = GT(simple_df).cols_merge(columns=["a", "b", "c"], pattern="Value: {3}") + + html3 = gt3.as_raw_html() + assert "Value: 7" in html3 + assert "Value: 8" in html3 + assert "Value: 9" in html3 + + +def test_cols_merge_hiding_default(simple_df: pd.DataFrame): + gt = GT(simple_df).cols_merge(columns=["a", "b", "c"]) + + built = gt._build_data(context="html") + + # Check visibility in boxhead + col_a = [col for col in built._boxhead if col.var == "a"][0] + col_b = [col for col in built._boxhead if col.var == "b"][0] + col_c = [col for col in built._boxhead if col.var == "c"][0] + + assert col_a.visible + assert not col_b.visible + assert not col_c.visible + + +def test_cols_merge_hiding_false(simple_df: pd.DataFrame): + gt = GT(simple_df).cols_merge(columns=["a", "b"], hide_columns=False) + + built = gt._build_data(context="html") + + col_a = [col for col in built._boxhead if col.var == "a"][0] + col_b = [col for col in built._boxhead if col.var == "b"][0] + + assert col_a.visible + assert col_b.visible + + +def test_cols_merge_specific_rows(simple_df: pd.DataFrame): + gt = GT(simple_df).cols_merge(columns=["a", "b"], rows=[0, 2], pattern="{1}-{2}") + + assert gt._col_merge[0].rows == [0, 2] + + html = gt.as_raw_html() + + # Check that only rows 0 and 2 are merged + assert "1-4" in html # Row 0 + assert "2-5" not in html # Row 1 (not merged) + assert "3-6" in html # Row 2 + + +def test_cols_merge_multiple_operations(simple_df: pd.DataFrame): + gt = ( + GT(simple_df) + .cols_merge(columns=["a", "b"], pattern="{1}+{2}") + .cols_merge(columns=["a", "c"], pattern="{1}*{2}") + ) + + assert len(gt._col_merge) == 2 + assert gt._col_merge[0].pattern == "{1}+{2}" + assert gt._col_merge[1].pattern == "{1}*{2}" + + html = gt.as_raw_html() + + # Check that the second merge uses the result of the first merge + assert "1+4*7" in html + assert "2+5*8" in html + assert "3+6*9" in html + + +def test_cols_merge_with_missing_values(missing_df: pd.DataFrame): + gt = GT(missing_df).cols_merge(columns=["val1", "val2"], pattern="{1}<< ({2})>>") + + html = gt.as_raw_html() + assert "10 (15.0)" in html or "10.0 (15.0)" in html + assert "20 (25.0)" in html or "20.0 (25.0)" in html + assert "30<" in html or "30.0<" in html # The '<' is part of the closing tag (missing val2) + + +def test_cols_merge_nested_conditionals(missing_df: pd.DataFrame): + gt = GT(missing_df).cols_merge( + columns=["val1", "val2", "val3"], pattern="{1}<< ({2}-<<{3}>>)>>" + ) + + html = gt.as_raw_html() + + assert "10 (15.0-5.0)" in html or "10.0 (15.0-5.0)" in html + assert "20 (25.0-)" in html or "20.0 (25.0-)" in html + assert "30<" in html or "30.0<" in html + + +def test_cols_merge_minimum_columns_error(simple_df: pd.DataFrame): + with pytest.raises(ValueError, match="At least two columns"): + GT(simple_df).cols_merge(columns=["a"]) + + +def test_cols_merge_invalid_pattern_reference(simple_df: pd.DataFrame): + gt = GT(simple_df).cols_merge(columns=["a", "b"], pattern="{1} {2} {3}") + + with pytest.raises(ValueError, match="Pattern references column"): + gt._repr_html_() + + +def test_cols_merge_zero_based_index_error(simple_df: pd.DataFrame): + gt = GT(simple_df).cols_merge(columns=["a", "b"], pattern="{0}-{1}") + + # Should raise error because pattern uses 1-based indexing, not 0-based + with pytest.raises(ValueError, match="column indexing starts at"): + gt._repr_html_() + + +def test_cols_merge_preserves_formatting(simple_df: pd.DataFrame): + gt = ( + GT(simple_df) + .fmt_number(columns="a", decimals=3) + .cols_merge(columns=["a", "b"], pattern="{1}+{2}") + ) + + html = gt.as_raw_html() + + assert "1.000+4" in html + assert "2.000+5" in html + assert "3.000+6" in html + + +def test_col_merge_info_creation(): + info = ColMergeInfo(vars=["a", "b"], rows=[0, 1, 2], type="merge", pattern="{1} {2}") + + assert info.vars == ["a", "b"] + assert info.rows == [0, 1, 2] + assert info.type == "merge" + assert info.pattern == "{1} {2}" + + +def test_col_merge_info_pattern_optional(): + info = ColMergeInfo(vars=["a", "b"], rows=[0], type="merge", pattern=None) + + assert info.pattern is None + + +def test_cols_merge_with_sub_missing(missing_df: pd.DataFrame): + # calling sub_missing() before cols_merge() + gt_1 = ( + GT(missing_df) + .sub_missing(columns=["val2", "val3"], missing_text="--") + .cols_merge(columns=["val1", "val2", "val3"], pattern="{1}<< to {2}<< to {3}>>>>") + ) + + # calling sub_missing() after cols_merge() + gt_2 = ( + GT(missing_df) + .cols_merge(columns=["val1", "val2", "val3"], pattern="{1}<< to {2}<< to {3}>>>>") + .sub_missing(columns=["val2", "val3"], missing_text="--") + ) + + # From both approaches, we should get the same output + html_1 = gt_1.as_raw_html() + html_2 = gt_2.as_raw_html() + + # Row 0: All values are present (none are missing values) + assert "10 to 15.0 to 5.0" in html_1 + assert "10 to 15.0 to 5.0" in html_2 + + # Row 1: val3 was None, but sub_missing() replaced it with "--" so that "--" + # should be included (it's treated as a non-missing value because of sub_missing()) + assert "20 to 25.0 to --" in html_1 + assert "20 to 25.0 to --" in html_2 + + # Row 2: val2 and val3 were None, both replaced with "--" by sub_missing() so both of + # the "--" values should be included in the output + assert "30 to -- to --" in html_1 + assert "30 to -- to --" in html_2 From 577e6f806e59eab20f6e2ffdd67b53e6bacd1de3 Mon Sep 17 00:00:00 2001 From: Richard Iannone Date: Wed, 8 Oct 2025 16:23:39 -0400 Subject: [PATCH 3/3] Improve cols_merge() pattern handling --- great_tables/_utils.py | 73 +++++++++++----------------------------- great_tables/gt.py | 2 ++ tests/test_cols_merge.py | 16 +++++++++ 3 files changed, 37 insertions(+), 54 deletions(-) diff --git a/great_tables/_utils.py b/great_tables/_utils.py index 666f90670..43cd9a67d 100644 --- a/great_tables/_utils.py +++ b/great_tables/_utils.py @@ -298,38 +298,27 @@ def _process_col_merge_pattern( col_values: dict[str, str], col_is_missing: dict[str, bool], ) -> str: - """Process a column merge pattern by substituting values and handling missing data. - - Parameters - ---------- - pattern : str - The pattern string with {n} placeholders and optional <<>> conditional sections. - col_values : dict[str, str] - Dictionary mapping column indices (as strings, e.g., "1", "2") to their formatted values. - col_is_missing : dict[str, bool] - Dictionary mapping column indices to whether the original value was missing/NA. - - Returns - ------- - str - The processed pattern with values substituted and conditional sections resolved. - """ + """Process a column merge pattern by substituting values and handling missing data.""" + # Replace values with tokens if they are truly missing processed_values = {} for key, value in col_values.items(): - # Consider a value missing if it's "NA" or "" (pandas representation) + # Consider a value missing if it matches a known NA representation # and the original data was also NA - if col_is_missing.get(key, False) and value in ("NA", ""): + # - "NA": generic representation + # - "": pandas representation + # - "None": Polars representation (str(None)) + if col_is_missing.get(key, False) and value in ("NA", "", "None"): processed_values[key] = _MISSING_VAL_TOKEN else: processed_values[key] = value - # Substitute {n} placeholders with values + # Substitute `{n}`` placeholders with values result = pattern for key, value in processed_values.items(): result = result.replace(f"{{{key}}}", value) - # Process conditional sections (<<...>>) + # Process conditional sections (`<<...>>`) if "<<" in result and ">>" in result: result = _resolve_conditional_sections(result) @@ -340,34 +329,20 @@ def _process_col_merge_pattern( def _resolve_conditional_sections(text: str) -> str: - """Resolve conditional sections marked with <<...>> in text. - - Removes any section enclosed in <<...>> if it contains the missing value token. - Processes innermost sections first to handle nesting. - - Parameters - ---------- - text : str - The text containing conditional sections. - - Returns - ------- - str - The text with conditional sections resolved. - """ + """Resolve conditional sections marked with <<...>> in text.""" # Process from innermost to outermost sections - max_iterations = 100 # Prevent infinite loops + max_iterations = 100 # Prevent infinite looping iteration = 0 while "<<" in text and ">>" in text and iteration < max_iterations: iteration += 1 - # Find the last occurrence of << (innermost section start) + # Find the last occurrence of `<<` (innermost section start) last_open = text.rfind("<<") if last_open == -1: break - # Find the first >> after that << + # Find the first `>>` after that `<<` first_close = text.find(">>", last_open) if first_close == -1: break @@ -377,7 +352,7 @@ def _resolve_conditional_sections(text: str) -> str: # Check if the section contains a missing value token if _MISSING_VAL_TOKEN in section_content: - # Remove the entire section (including markers) + # Remove the entire section (including the markers) replacement = "" else: # Keep the content without the markers @@ -386,32 +361,22 @@ def _resolve_conditional_sections(text: str) -> str: # Replace this section in the text text = text[:last_open] + replacement + text[first_close + 2 :] - # Clean up any remaining markers (shouldn't normally happen) - text = text.replace("<<", "").replace(">>", "") - return text def _extract_pattern_columns(pattern: str) -> list[str]: - """Extract column references from a pattern string. - - Parameters - ---------- - pattern : str - The pattern string with {n} placeholders. + """Extract column references from a pattern string.""" - Returns - ------- - list[str] - List of unique column indices referenced in the pattern (as strings). - """ - # Find all {n} patterns + # Find all `{n}` patterns matches = re.findall(r"\{(\d+)\}", pattern) + # Return unique matches in order they appear seen = set() result = [] + for match in matches: if match not in seen: result.append(match) seen.add(match) + return result diff --git a/great_tables/gt.py b/great_tables/gt.py index 948409ffe..0415693de 100644 --- a/great_tables/gt.py +++ b/great_tables/gt.py @@ -348,6 +348,8 @@ def _perform_col_merge(self) -> Self: # Validate that pattern references are valid pattern_cols = _extract_pattern_columns(pattern) + + # With each column reference in the pattern, check that it is valid for col_ref in pattern_cols: # The pattern syntax uses 1-based indexing so adjust here col_idx = int(col_ref) - 1 diff --git a/tests/test_cols_merge.py b/tests/test_cols_merge.py index 157ee6b95..ab96c5e75 100644 --- a/tests/test_cols_merge.py +++ b/tests/test_cols_merge.py @@ -1,4 +1,5 @@ import pandas as pd +import polars as pl import pytest from great_tables import GT from great_tables._gt_data import ColMergeInfo @@ -249,3 +250,18 @@ def test_cols_merge_with_sub_missing(missing_df: pd.DataFrame): # the "--" values should be included in the output assert "30 to -- to --" in html_1 assert "30 to -- to --" in html_2 + + +def test_cols_merge_with_formatted_values(): + df = pl.DataFrame({"a": [1, 2, None], "b": [10, None, 30], "c": [100, 200, 300]}) + + gt = ( + GT(df) + .cols_merge(columns=["a", "b"], pattern="{1}<< ({2})>>") + .fmt_integer(columns=["a", "b", "c"]) + ) + + html = gt.as_raw_html() + + assert "1 (10)" in html or "1(10)" in html + assert "2<" in html # The < is from the closing tag, not from the pattern