From ff4e4fd78b75b1a92877ed2dfe7a78ef01bcd777 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Wed, 30 Jul 2025 23:39:40 +0200 Subject: [PATCH 1/3] fix: Address slicing pyarrow array in data_color --- great_tables/_data_color/base.py | 11 +- great_tables/_tbl_data.py | 28 ++- .../__snapshots__/test_data_color.ambr | 182 ++++++++++++++++++ tests/data_color/test_data_color.py | 2 + 4 files changed, 220 insertions(+), 3 deletions(-) diff --git a/great_tables/_data_color/base.py b/great_tables/_data_color/base.py index 904a1d137..6003c54b4 100644 --- a/great_tables/_data_color/base.py +++ b/great_tables/_data_color/base.py @@ -6,7 +6,14 @@ from typing_extensions import TypeAlias from great_tables._locations import RowSelectExpr, resolve_cols_c, resolve_rows_i -from great_tables._tbl_data import DataFrameLike, SelectExpr, get_column_names, is_na +from great_tables._tbl_data import ( + DataFrameLike, + SelectExpr, + get_at_row_positions, + get_column_names, + is_na, + to_list, +) from great_tables.loc import body from great_tables.style import fill, text @@ -228,7 +235,7 @@ def data_color( # For each column targeted, get the data values as a new list object for col in columns_resolved: # This line handles both pandas and polars dataframes - column_vals = data_table[col][row_pos].to_list() + column_vals = to_list(get_at_row_positions(data_table[col], indexes=row_pos)) # Filter out NA values from `column_vals` filtered_column_vals = [x for x in column_vals if not is_na(data_table, x)] diff --git a/great_tables/_tbl_data.py b/great_tables/_tbl_data.py index 510522f23..bc552a4c6 100644 --- a/great_tables/_tbl_data.py +++ b/great_tables/_tbl_data.py @@ -759,7 +759,7 @@ def _(df: PyArrowTable, x: Any) -> bool: import pyarrow as pa arr = pa.array([x]) - return arr.is_null().to_pylist()[0] or arr.is_nan().to_pylist()[0] + return arr.is_null(nan_is_null=True).to_pylist()[0] @singledispatch @@ -870,3 +870,29 @@ def _(ser: PyArrowChunkedArray, name: Optional[str] = None) -> PyArrowTable: import pyarrow as pa return pa.table({name: ser}) + + +@singledispatch +def get_at_row_positions(ser: SeriesLike, indexes: list[int]) -> SeriesLike: + """Returns values of the series at `indexes` position.`""" + raise NotImplementedError(f"Unsupported type: {type(ser)}") + + +@get_at_row_positions.register +def _(ser: PdSeries, indexes: list[int]) -> PdSeries: + return ser.iloc[indexes] + + +@get_at_row_positions.register +def _(ser: PlSeries, indexes: list[int]) -> PlSeries: + return ser[indexes] + + +@get_at_row_positions.register +def _(ser: PyArrowArray, indexes: list[int]) -> PyArrowArray: + return ser.take(indexes) + + +@get_at_row_positions.register +def _(ser: PyArrowChunkedArray, indexes: list[int]) -> PyArrowChunkedArray: + return ser.take(indexes) diff --git a/tests/data_color/__snapshots__/test_data_color.ambr b/tests/data_color/__snapshots__/test_data_color.ambr index 022e471c5..f8abd32ab 100644 --- a/tests/data_color/__snapshots__/test_data_color.ambr +++ b/tests/data_color/__snapshots__/test_data_color.ambr @@ -123,6 +123,32 @@ ''' # --- +# name: test_data_color_autocolor_text_false[pyarrow] + ''' + + + 0.1111 + apricot + 49.95 + + + 2.222 + banana + 17.95 + + + 33.33 + coconut + 1.39 + + + 444.4 + durian + 65100 + + + ''' +# --- # name: test_data_color_colorbrewer_snap ''' @@ -206,6 +232,32 @@ ''' # --- +# name: test_data_color_domain_na_color_reverse_snap[pyarrow] + ''' + + + 0.1111 + apricot + 49.95 + + + 2.222 + banana + 17.95 + + + 33.33 + coconut + 1.39 + + + 444.4 + durian + 65100 + + + ''' +# --- # name: test_data_color_domain_na_color_snap[pandas] ''' @@ -258,6 +310,32 @@ ''' # --- +# name: test_data_color_domain_na_color_snap[pyarrow] + ''' + + + 0.1111 + apricot + 49.95 + + + 2.222 + banana + 17.95 + + + 33.33 + coconut + 1.39 + + + 444.4 + durian + 65100 + + + ''' +# --- # name: test_data_color_overlapping_domain[pandas] ''' @@ -310,6 +388,32 @@ ''' # --- +# name: test_data_color_overlapping_domain[pyarrow] + ''' + + + 0.1111 + apricot + 49.95 + + + 2.222 + banana + 17.95 + + + 33.33 + coconut + 1.39 + + + 444.4 + durian + 65100 + + + ''' +# --- # name: test_data_color_palette_snap[pandas] ''' @@ -362,6 +466,32 @@ ''' # --- +# name: test_data_color_palette_snap[pyarrow] + ''' + + + 0.1111 + apricot + 49.95 + + + 2.222 + banana + 17.95 + + + 33.33 + coconut + 1.39 + + + 444.4 + durian + 65100 + + + ''' +# --- # name: test_data_color_pd_cols_rows_snap ''' @@ -495,6 +625,32 @@ ''' # --- +# name: test_data_color_simple_exibble_snap[pyarrow] + ''' + + + 0.1111 + apricot + 49.95 + + + 2.222 + banana + 17.95 + + + 33.33 + coconut + 1.39 + + + 444.4 + durian + 65100 + + + ''' +# --- # name: test_data_color_subset_domain[pandas] ''' @@ -547,6 +703,32 @@ ''' # --- +# name: test_data_color_subset_domain[pyarrow] + ''' + + + 0.1111 + apricot + 49.95 + + + 2.222 + banana + 17.95 + + + 33.33 + coconut + 1.39 + + + 444.4 + durian + 65100 + + + ''' +# --- # name: test_data_color_viridis_snap ''' diff --git a/tests/data_color/test_data_color.py b/tests/data_color/test_data_color.py index f952f83ae..d2e85b3d4 100644 --- a/tests/data_color/test_data_color.py +++ b/tests/data_color/test_data_color.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import polars as pl +import pyarrow as pa import pytest from great_tables import GT, style @@ -16,6 +17,7 @@ params_frames = [ pytest.param(pd.DataFrame, id="pandas"), pytest.param(pl.DataFrame, id="polars"), + pytest.param(pa.table, id="pyarrow"), ] From c0af9e671083934c2474d36a4aceba783b4fc501 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Wed, 30 Jul 2025 23:45:09 +0200 Subject: [PATCH 2/3] rename function --- great_tables/_data_color/base.py | 4 ++-- great_tables/_tbl_data.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/great_tables/_data_color/base.py b/great_tables/_data_color/base.py index 6003c54b4..3f4471eb2 100644 --- a/great_tables/_data_color/base.py +++ b/great_tables/_data_color/base.py @@ -9,8 +9,8 @@ from great_tables._tbl_data import ( DataFrameLike, SelectExpr, - get_at_row_positions, get_column_names, + get_rows, is_na, to_list, ) @@ -235,7 +235,7 @@ def data_color( # For each column targeted, get the data values as a new list object for col in columns_resolved: # This line handles both pandas and polars dataframes - column_vals = to_list(get_at_row_positions(data_table[col], indexes=row_pos)) + column_vals = to_list(get_rows(data_table[col], indexes=row_pos)) # Filter out NA values from `column_vals` filtered_column_vals = [x for x in column_vals if not is_na(data_table, x)] diff --git a/great_tables/_tbl_data.py b/great_tables/_tbl_data.py index bc552a4c6..f8725c5c5 100644 --- a/great_tables/_tbl_data.py +++ b/great_tables/_tbl_data.py @@ -873,26 +873,26 @@ def _(ser: PyArrowChunkedArray, name: Optional[str] = None) -> PyArrowTable: @singledispatch -def get_at_row_positions(ser: SeriesLike, indexes: list[int]) -> SeriesLike: +def get_rows(ser: SeriesLike, indexes: list[int]) -> SeriesLike: """Returns values of the series at `indexes` position.`""" raise NotImplementedError(f"Unsupported type: {type(ser)}") -@get_at_row_positions.register +@get_rows.register def _(ser: PdSeries, indexes: list[int]) -> PdSeries: return ser.iloc[indexes] -@get_at_row_positions.register +@get_rows.register def _(ser: PlSeries, indexes: list[int]) -> PlSeries: return ser[indexes] -@get_at_row_positions.register +@get_rows.register def _(ser: PyArrowArray, indexes: list[int]) -> PyArrowArray: return ser.take(indexes) -@get_at_row_positions.register +@get_rows.register def _(ser: PyArrowChunkedArray, indexes: list[int]) -> PyArrowChunkedArray: return ser.take(indexes) From 1f09f61e8fa7636217e6e077dbe30da96a41e38e Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sat, 2 Aug 2025 17:42:03 +0200 Subject: [PATCH 3/3] collapse pyarrow into one func: --- great_tables/_tbl_data.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/great_tables/_tbl_data.py b/great_tables/_tbl_data.py index f8725c5c5..640935ff0 100644 --- a/great_tables/_tbl_data.py +++ b/great_tables/_tbl_data.py @@ -29,10 +29,10 @@ PlSelectExpr = _selector_proxy_ PlExpr = pl.Expr - PdSeries = pd.Series + PdSeries = pd.Series[Any] PlSeries = pl.Series - PyArrowArray = pa.Array - PyArrowChunkedArray = pa.ChunkedArray + PyArrowArray = pa.Array[Any] + PyArrowChunkedArray = pa.ChunkedArray[Any] PdNA = pd.NA PlNull = pl.Null @@ -888,11 +888,7 @@ def _(ser: PlSeries, indexes: list[int]) -> PlSeries: return ser[indexes] -@get_rows.register -def _(ser: PyArrowArray, indexes: list[int]) -> PyArrowArray: - return ser.take(indexes) - - -@get_rows.register -def _(ser: PyArrowChunkedArray, indexes: list[int]) -> PyArrowChunkedArray: +@get_rows.register(PyArrowArray) +@get_rows.register(PyArrowChunkedArray) +def _(ser: Any, indexes: list[int]) -> PyArrowArray | PyArrowChunkedArray: return ser.take(indexes)