Skip to content

Commit 3070f4b

Browse files
authored
Support categoricals in from_pandas (#1832)
This adds support in from_pandas for converting and handling categorical datatypes. Example usage: ``` df = pd.DataFrame( { "int": [0, 1, 2, 3], "categorical_string": pd.Series(["A", "B", "A", "B"], dtype="category"), "categorical_int": pd.Series( np.array([1, 2, 3, 4], dtype=np.int64), dtype="category" ), } ) tiledb.from_pandas("my_array", df) ```
1 parent 3a21ee4 commit 3070f4b

File tree

2 files changed

+154
-10
lines changed

2 files changed

+154
-10
lines changed

tiledb/dataframe_.py

Lines changed: 126 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import warnings
55
from dataclasses import dataclass
6-
from typing import List, Optional, Union
6+
from typing import Any, List, Optional, Union
77

88
import numpy as np
99

@@ -75,15 +75,71 @@ def parse_tiledb_kwargs(kwargs):
7575
return parsed_args
7676

7777

78+
def _infer_dtype_from_pandas(values):
79+
from pandas.api import types as pd_types
80+
81+
inferred_dtype = pd_types.infer_dtype(values)
82+
if inferred_dtype == "bytes":
83+
return np.bytes_
84+
elif inferred_dtype == "string":
85+
return "<U0"
86+
elif inferred_dtype == "floating":
87+
return np.float64
88+
elif inferred_dtype == "integer":
89+
return np.int64
90+
elif inferred_dtype == "mixed-integer":
91+
raise NotImplementedError("Pandas type 'mixed-integer' is not supported")
92+
elif inferred_dtype == "mixed-integer-float":
93+
raise NotImplementedError("Pandas type 'mixed-integer-float' is not supported")
94+
elif inferred_dtype == "decimal":
95+
return np.float64
96+
elif inferred_dtype == "complex":
97+
return np.complex128
98+
elif inferred_dtype == "categorical":
99+
raise NotImplementedError(
100+
"Pandas type 'categorical of categorical' is not supported"
101+
)
102+
elif inferred_dtype == "boolean":
103+
return np.bool_
104+
elif inferred_dtype == "datetime64":
105+
return np.datetime64
106+
elif inferred_dtype == "datetime":
107+
return np.datetime64
108+
elif inferred_dtype == "date":
109+
return np.datetime64
110+
elif inferred_dtype == "timedelta64":
111+
return np.timedelta64
112+
elif inferred_dtype == "timedelta":
113+
return np.timedelta64
114+
elif inferred_dtype == "time":
115+
return np.timedelta64
116+
elif inferred_dtype == "period":
117+
raise NotImplementedError("Pandas type 'period' is not supported")
118+
elif inferred_dtype == "mixed":
119+
raise NotImplementedError("Pandas type 'mixed' is not supported")
120+
elif inferred_dtype == "unknown-array":
121+
raise NotImplementedError("Pandas type 'unknown-array' is not supported")
122+
123+
124+
@dataclass(frozen=True)
125+
class EnumerationInfo:
126+
dtype: np.dtype
127+
ordered: bool = False
128+
values: List[Any] = None
129+
130+
78131
@dataclass(frozen=True)
79132
class ColumnInfo:
80133
dtype: np.dtype
81134
repr: Optional[str] = None
82135
nullable: bool = False
83136
var: bool = False
137+
enumeration: bool = False
138+
enumeration_info: Optional[EnumerationInfo] = None
84139

85140
@classmethod
86141
def from_values(cls, array_like, varlen_types=()):
142+
from pandas import CategoricalDtype
87143
from pandas.api import types as pd_types
88144

89145
if pd_types.is_object_dtype(array_like):
@@ -100,11 +156,31 @@ def from_values(cls, array_like, varlen_types=()):
100156
raise NotImplementedError(
101157
f"{inferred_dtype} inferred dtype not supported"
102158
)
159+
elif hasattr(array_like, "dtype") and isinstance(
160+
array_like.dtype, CategoricalDtype
161+
):
162+
return cls.from_categorical(array_like.cat, array_like.dtype)
103163
else:
104164
if not hasattr(array_like, "dtype"):
105165
array_like = np.asanyarray(array_like)
106166
return cls.from_dtype(array_like.dtype, varlen_types)
107167

168+
@classmethod
169+
def from_categorical(cls, cat, dtype):
170+
values = cat.categories.values
171+
inferred_dtype = _infer_dtype_from_pandas(values)
172+
173+
return cls(
174+
np.int32,
175+
repr=dtype.name,
176+
nullable=False,
177+
var=False,
178+
enumeration=True,
179+
enumeration_info=EnumerationInfo(
180+
values=values, ordered=cat.ordered, dtype=inferred_dtype
181+
),
182+
)
183+
108184
@classmethod
109185
def from_dtype(cls, dtype, varlen_types=()):
110186
from pandas.api import types as pd_types
@@ -206,21 +282,54 @@ def _get_attr_dim_filters(name, filters):
206282
return _get_schema_filters(filters)
207283

208284

285+
def _get_enums(names, column_infos):
286+
enums = []
287+
for name in names:
288+
column_info = column_infos[name]
289+
if not column_info.enumeration:
290+
continue
291+
enums.append(
292+
tiledb.Enumeration(
293+
name=name,
294+
# Pandas categoricals are always ordered
295+
ordered=column_info.enumeration_info.ordered,
296+
values=np.array(
297+
column_info.enumeration_info.values,
298+
dtype=column_info.enumeration_info.dtype,
299+
),
300+
)
301+
)
302+
303+
return enums
304+
305+
209306
def _get_attrs(names, column_infos, attr_filters):
210307
attrs = []
211308
attr_reprs = {}
212309
for name in names:
213310
filters = _get_attr_dim_filters(name, attr_filters)
214311
column_info = column_infos[name]
215-
attrs.append(
216-
tiledb.Attr(
217-
name=name,
218-
filters=filters,
219-
dtype=column_info.dtype,
220-
nullable=column_info.nullable,
221-
var=column_info.var,
312+
if column_info.enumeration:
313+
attrs.append(
314+
tiledb.Attr(
315+
name=name,
316+
filters=filters,
317+
dtype=np.int32,
318+
enum_label=name,
319+
nullable=column_info.nullable,
320+
var=column_info.var,
321+
)
322+
)
323+
else:
324+
attrs.append(
325+
tiledb.Attr(
326+
name=name,
327+
filters=filters,
328+
dtype=column_info.dtype,
329+
nullable=column_info.nullable,
330+
var=column_info.var,
331+
)
222332
)
223-
)
224333

225334
if column_info.repr is not None:
226335
attr_reprs[name] = column_info.repr
@@ -375,6 +484,7 @@ def _df_to_np_arrays(df, column_infos, fillna):
375484
column = column.fillna(fillna[name])
376485

377486
to_numpy_kwargs = {}
487+
378488
if not column_info.var:
379489
to_numpy_kwargs.update(dtype=column_info.dtype)
380490

@@ -383,7 +493,11 @@ def _df_to_np_arrays(df, column_infos, fillna):
383493
to_numpy_kwargs.update(na_value=column_info.dtype.type())
384494
nullmaps[name] = (~column.isna()).to_numpy(dtype=np.uint8)
385495

386-
ret[name] = column.to_numpy(**to_numpy_kwargs)
496+
if column_info.enumeration:
497+
# Enumerations should get the numerical codes instead of converting enumeration values
498+
ret[name] = column.cat.codes.to_numpy(**to_numpy_kwargs)
499+
else:
500+
ret[name] = column.to_numpy(**to_numpy_kwargs)
387501

388502
return ret, nullmaps
389503

@@ -537,6 +651,7 @@ def _create_array(uri, df, sparse, full_domain, index_dims, column_infos, tiledb
537651
attrs, attr_metadata = _get_attrs(
538652
attr_names, column_infos, tiledb_args.get("attr_filters", True)
539653
)
654+
enums = _get_enums(attr_names, column_infos)
540655

541656
# create the ArraySchema
542657
with warnings.catch_warnings():
@@ -546,6 +661,7 @@ def _create_array(uri, df, sparse, full_domain, index_dims, column_infos, tiledb
546661
sparse=sparse,
547662
domain=tiledb.Domain(*dims),
548663
attrs=attrs,
664+
enums=enums,
549665
cell_order=tiledb_args["cell_order"],
550666
tile_order=tiledb_args["tile_order"],
551667
coords_filters=None

tiledb/tests/test_pandas_dataframe.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,21 @@ def make_dataframe_basic3(col_size=10, time_range=(None, None)):
106106
return df
107107

108108

109+
def make_dataframe_categorical():
110+
df = pd.DataFrame(
111+
{
112+
"int": [0, 1, 2, 3],
113+
"categorical_string": pd.Series(["A", "B", "A", "B"], dtype="category"),
114+
"categorical_int": pd.Series(
115+
np.array([1, 2, 3, 4], dtype=np.int64), dtype="category"
116+
),
117+
# 'categorical_bool': pd.Series([True, False, True, False], dtype="category"),
118+
}
119+
)
120+
121+
return df
122+
123+
109124
class TestColumnInfo:
110125
def assertColumnInfo(self, info, info_dtype, info_repr=None, info_nullable=False):
111126
assert isinstance(info.dtype, np.dtype)
@@ -352,6 +367,19 @@ def test_dataframe_basic2(self):
352367
with tiledb.open(uri) as B:
353368
tm.assert_frame_equal(df, B.df[:])
354369

370+
def test_dataframe_categorical(self):
371+
uri = self.path("dataframe_categorical_rt")
372+
373+
df = make_dataframe_categorical()
374+
375+
tiledb.from_pandas(uri, df, sparse=True)
376+
377+
df_readback = tiledb.open_dataframe(uri)
378+
tm.assert_frame_equal(df, df_readback)
379+
380+
with tiledb.open(uri) as B:
381+
tm.assert_frame_equal(df, B.df[:])
382+
355383
def test_dataframe_csv_rt1(self):
356384
def rand_dtype(dtype, size):
357385
nbytes = size * np.dtype(dtype).itemsize

0 commit comments

Comments
 (0)