3
3
import os
4
4
import warnings
5
5
from dataclasses import dataclass
6
- from typing import List , Optional , Union
6
+ from typing import Any , List , Optional , Union
7
7
8
8
import numpy as np
9
9
@@ -75,15 +75,71 @@ def parse_tiledb_kwargs(kwargs):
75
75
return parsed_args
76
76
77
77
78
+ def _infer_dtype_from_pandas (values ):
79
+ from pandas .api import types as pd_types
80
+
81
+ inferred_dtype = pd_types .infer_dtype (values )
82
+ if inferred_dtype == "bytes" :
83
+ return np .bytes_
84
+ elif inferred_dtype == "string" :
85
+ return "<U0"
86
+ elif inferred_dtype == "floating" :
87
+ return np .float64
88
+ elif inferred_dtype == "integer" :
89
+ return np .int64
90
+ elif inferred_dtype == "mixed-integer" :
91
+ raise NotImplementedError ("Pandas type 'mixed-integer' is not supported" )
92
+ elif inferred_dtype == "mixed-integer-float" :
93
+ raise NotImplementedError ("Pandas type 'mixed-integer-float' is not supported" )
94
+ elif inferred_dtype == "decimal" :
95
+ return np .float64
96
+ elif inferred_dtype == "complex" :
97
+ return np .complex128
98
+ elif inferred_dtype == "categorical" :
99
+ raise NotImplementedError (
100
+ "Pandas type 'categorical of categorical' is not supported"
101
+ )
102
+ elif inferred_dtype == "boolean" :
103
+ return np .bool_
104
+ elif inferred_dtype == "datetime64" :
105
+ return np .datetime64
106
+ elif inferred_dtype == "datetime" :
107
+ return np .datetime64
108
+ elif inferred_dtype == "date" :
109
+ return np .datetime64
110
+ elif inferred_dtype == "timedelta64" :
111
+ return np .timedelta64
112
+ elif inferred_dtype == "timedelta" :
113
+ return np .timedelta64
114
+ elif inferred_dtype == "time" :
115
+ return np .timedelta64
116
+ elif inferred_dtype == "period" :
117
+ raise NotImplementedError ("Pandas type 'period' is not supported" )
118
+ elif inferred_dtype == "mixed" :
119
+ raise NotImplementedError ("Pandas type 'mixed' is not supported" )
120
+ elif inferred_dtype == "unknown-array" :
121
+ raise NotImplementedError ("Pandas type 'unknown-array' is not supported" )
122
+
123
+
124
+ @dataclass (frozen = True )
125
+ class EnumerationInfo :
126
+ dtype : np .dtype
127
+ ordered : bool = False
128
+ values : List [Any ] = None
129
+
130
+
78
131
@dataclass (frozen = True )
79
132
class ColumnInfo :
80
133
dtype : np .dtype
81
134
repr : Optional [str ] = None
82
135
nullable : bool = False
83
136
var : bool = False
137
+ enumeration : bool = False
138
+ enumeration_info : Optional [EnumerationInfo ] = None
84
139
85
140
@classmethod
86
141
def from_values (cls , array_like , varlen_types = ()):
142
+ from pandas import CategoricalDtype
87
143
from pandas .api import types as pd_types
88
144
89
145
if pd_types .is_object_dtype (array_like ):
@@ -100,11 +156,31 @@ def from_values(cls, array_like, varlen_types=()):
100
156
raise NotImplementedError (
101
157
f"{ inferred_dtype } inferred dtype not supported"
102
158
)
159
+ elif hasattr (array_like , "dtype" ) and isinstance (
160
+ array_like .dtype , CategoricalDtype
161
+ ):
162
+ return cls .from_categorical (array_like .cat , array_like .dtype )
103
163
else :
104
164
if not hasattr (array_like , "dtype" ):
105
165
array_like = np .asanyarray (array_like )
106
166
return cls .from_dtype (array_like .dtype , varlen_types )
107
167
168
+ @classmethod
169
+ def from_categorical (cls , cat , dtype ):
170
+ values = cat .categories .values
171
+ inferred_dtype = _infer_dtype_from_pandas (values )
172
+
173
+ return cls (
174
+ np .int32 ,
175
+ repr = dtype .name ,
176
+ nullable = False ,
177
+ var = False ,
178
+ enumeration = True ,
179
+ enumeration_info = EnumerationInfo (
180
+ values = values , ordered = cat .ordered , dtype = inferred_dtype
181
+ ),
182
+ )
183
+
108
184
@classmethod
109
185
def from_dtype (cls , dtype , varlen_types = ()):
110
186
from pandas .api import types as pd_types
@@ -206,21 +282,54 @@ def _get_attr_dim_filters(name, filters):
206
282
return _get_schema_filters (filters )
207
283
208
284
285
+ def _get_enums (names , column_infos ):
286
+ enums = []
287
+ for name in names :
288
+ column_info = column_infos [name ]
289
+ if not column_info .enumeration :
290
+ continue
291
+ enums .append (
292
+ tiledb .Enumeration (
293
+ name = name ,
294
+ # Pandas categoricals are always ordered
295
+ ordered = column_info .enumeration_info .ordered ,
296
+ values = np .array (
297
+ column_info .enumeration_info .values ,
298
+ dtype = column_info .enumeration_info .dtype ,
299
+ ),
300
+ )
301
+ )
302
+
303
+ return enums
304
+
305
+
209
306
def _get_attrs (names , column_infos , attr_filters ):
210
307
attrs = []
211
308
attr_reprs = {}
212
309
for name in names :
213
310
filters = _get_attr_dim_filters (name , attr_filters )
214
311
column_info = column_infos [name ]
215
- attrs .append (
216
- tiledb .Attr (
217
- name = name ,
218
- filters = filters ,
219
- dtype = column_info .dtype ,
220
- nullable = column_info .nullable ,
221
- var = column_info .var ,
312
+ if column_info .enumeration :
313
+ attrs .append (
314
+ tiledb .Attr (
315
+ name = name ,
316
+ filters = filters ,
317
+ dtype = np .int32 ,
318
+ enum_label = name ,
319
+ nullable = column_info .nullable ,
320
+ var = column_info .var ,
321
+ )
322
+ )
323
+ else :
324
+ attrs .append (
325
+ tiledb .Attr (
326
+ name = name ,
327
+ filters = filters ,
328
+ dtype = column_info .dtype ,
329
+ nullable = column_info .nullable ,
330
+ var = column_info .var ,
331
+ )
222
332
)
223
- )
224
333
225
334
if column_info .repr is not None :
226
335
attr_reprs [name ] = column_info .repr
@@ -375,6 +484,7 @@ def _df_to_np_arrays(df, column_infos, fillna):
375
484
column = column .fillna (fillna [name ])
376
485
377
486
to_numpy_kwargs = {}
487
+
378
488
if not column_info .var :
379
489
to_numpy_kwargs .update (dtype = column_info .dtype )
380
490
@@ -383,7 +493,11 @@ def _df_to_np_arrays(df, column_infos, fillna):
383
493
to_numpy_kwargs .update (na_value = column_info .dtype .type ())
384
494
nullmaps [name ] = (~ column .isna ()).to_numpy (dtype = np .uint8 )
385
495
386
- ret [name ] = column .to_numpy (** to_numpy_kwargs )
496
+ if column_info .enumeration :
497
+ # Enumerations should get the numerical codes instead of converting enumeration values
498
+ ret [name ] = column .cat .codes .to_numpy (** to_numpy_kwargs )
499
+ else :
500
+ ret [name ] = column .to_numpy (** to_numpy_kwargs )
387
501
388
502
return ret , nullmaps
389
503
@@ -537,6 +651,7 @@ def _create_array(uri, df, sparse, full_domain, index_dims, column_infos, tiledb
537
651
attrs , attr_metadata = _get_attrs (
538
652
attr_names , column_infos , tiledb_args .get ("attr_filters" , True )
539
653
)
654
+ enums = _get_enums (attr_names , column_infos )
540
655
541
656
# create the ArraySchema
542
657
with warnings .catch_warnings ():
@@ -546,6 +661,7 @@ def _create_array(uri, df, sparse, full_domain, index_dims, column_infos, tiledb
546
661
sparse = sparse ,
547
662
domain = tiledb .Domain (* dims ),
548
663
attrs = attrs ,
664
+ enums = enums ,
549
665
cell_order = tiledb_args ["cell_order" ],
550
666
tile_order = tiledb_args ["tile_order" ],
551
667
coords_filters = None
0 commit comments