Skip to content

Commit 656f54f

Browse files
authored
Additional Support For Nullable Attributes (#1836)
**Background** As detailed in sc-34754, this fixes a bug found by a customer using the TileDB-SOMA Python API where the `SOMADataFrame` containing an enumerated nullable attribute was not being readback correctly. This highlights a larger deficit in the TileDB-Py codebase in which we have [little support](https://docs.tiledb.com/main/how-to/arrays/writing-arrays/nullable-attributes) for writing nullable attributes outside of utilizing `tiledb.from_pandas` with Pandas's `ExtensionDtype`. **Changes** - This PR supports writing Pyarrow arrays and Pandas dataframes that contain nullable values (`pd.NA`, `pa.na`, `None`, etc.). - Nullable attributes are now represented in Numpy as [masked arrays](https://numpy.org/doc/stable/reference/maskedarray.html). - `PyQuery` results now also return the validity buffer. - Note that in Pyarrow, the validity values represent 0 = invalid, 1 = valid, whereas in Numpy, this is inverted and mask values represent 0 = valid, 1 = invalid. **Future Proposals** - Support writing `numpy.ma` for nullable attributes. ``` with tiledb.open(uri, "w') as A: A[:] = np.ma.array(data, mask) ``` - Support writing with built-in sequences (eg. `list`, `tuple`). Internally, we check if the attribute `.isnullable()` and then cast using `np.ma.masked_invalid()`. ``` with tiledb.open(uri, "w') as A: A[:] = [1, 2, None, 3] ```
1 parent f219f16 commit 656f54f

File tree

4 files changed

+153
-62
lines changed

4 files changed

+153
-62
lines changed

tiledb/core.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1183,7 +1183,8 @@ class PyQuery {
11831183
py::dict results;
11841184
for (auto &buffer_name : buffers_order_) {
11851185
auto bp = buffers_.at(buffer_name);
1186-
results[py::str(buffer_name)] = py::make_tuple(bp.data, bp.offsets);
1186+
results[py::str(buffer_name)] =
1187+
py::make_tuple(bp.data, bp.offsets, bp.validity);
11871188
}
11881189
return results;
11891190
}

tiledb/libtiledb.pyx

Lines changed: 90 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1925,15 +1925,26 @@ cdef class DenseArrayImpl(Array):
19251925
if self.view_attr:
19261926
result = self.subarray(selection, attrs=(self.view_attr,))
19271927
return result[self.view_attr]
1928-
else:
1929-
result = self.subarray(selection)
1930-
for i in range(self.schema.nattr):
1931-
attr = self.schema.attr(i)
1932-
enum_label = attr.enum_label
1933-
if enum_label is not None:
1934-
values = self.enum(enum_label).values()
1935-
result[attr.name] = np.array([values[idx] for idx in result[attr.name]])
1936-
return result
1928+
1929+
result = self.subarray(selection)
1930+
for i in range(self.schema.nattr):
1931+
attr = self.schema.attr(i)
1932+
enum_label = attr.enum_label
1933+
if enum_label is not None:
1934+
values = self.enum(enum_label).values()
1935+
if attr.isnullable:
1936+
data = np.array([values[idx] for idx in result[attr.name].data])
1937+
result[attr.name] = np.ma.array(
1938+
data, mask=~result[attr.name].mask)
1939+
else:
1940+
result[attr.name] = np.array(
1941+
[values[idx] for idx in result[attr.name]])
1942+
else:
1943+
if attr.isnullable:
1944+
result[attr.name] = np.ma.array(result[attr.name].data,
1945+
mask=~result[attr.name].mask)
1946+
1947+
return result
19371948

19381949
def __repr__(self):
19391950
if self.isopen:
@@ -2182,6 +2193,10 @@ cdef class DenseArrayImpl(Array):
21822193
arr.shape = np.prod(output_shape)
21832194

21842195
out[name] = arr
2196+
2197+
if self.schema.has_attr(name) and self.attr(name).isnullable:
2198+
out[name] = np.ma.array(out[name], mask=results[name][2].astype(bool))
2199+
21852200
return out
21862201

21872202
def __setitem__(self, object selection, object val):
@@ -2272,14 +2287,34 @@ cdef class DenseArrayImpl(Array):
22722287
# Create list of attribute names and values
22732288
for attr_idx in range(self.schema.nattr):
22742289
attr = self.schema.attr(attr_idx)
2275-
k = attr.name
2276-
v = val[k]
2277-
attr = self.schema.attr(k)
2290+
name = attr.name
2291+
attr_val = val[name]
2292+
22782293
attributes.append(attr._internal_name)
22792294
# object arrays are var-len and handled later
2280-
if type(v) is np.ndarray and v.dtype is not np.dtype('O'):
2281-
v = np.ascontiguousarray(v, dtype=attr.dtype)
2282-
values.append(v)
2295+
if type(attr_val) is np.ndarray and attr_val.dtype is not np.dtype('O'):
2296+
attr_val = np.ascontiguousarray(attr_val, dtype=attr.dtype)
2297+
2298+
try:
2299+
if attr.isvar:
2300+
# ensure that the value is array-convertible, for example: pandas.Series
2301+
attr_val = np.asarray(attr_val)
2302+
if attr.isnullable and name not in nullmaps:
2303+
nullmaps[name] = np.array([int(v is not None) for v in attr_val], dtype=np.uint8)
2304+
else:
2305+
if (np.issubdtype(attr.dtype, np.string_) and not
2306+
(np.issubdtype(attr_val.dtype, np.string_) or attr_val.dtype == np.dtype('O'))):
2307+
raise ValueError("Cannot write a string value to non-string "
2308+
"typed attribute '{}'!".format(name))
2309+
2310+
if attr.isnullable and name not in nullmaps:
2311+
nullmaps[name] = ~np.ma.masked_invalid(attr_val).mask
2312+
attr_val = np.nan_to_num(attr_val)
2313+
attr_val = np.ascontiguousarray(attr_val, dtype=attr.dtype)
2314+
except Exception as exc:
2315+
raise ValueError(f"NumPy array conversion check failed for attr '{name}'") from exc
2316+
2317+
values.append(attr_val)
22832318

22842319
elif np.isscalar(val):
22852320
for i in range(self.schema.nattr):
@@ -2290,10 +2325,29 @@ cdef class DenseArrayImpl(Array):
22902325
values.append(A)
22912326
elif self.schema.nattr == 1:
22922327
attr = self.schema.attr(0)
2328+
name = attr.name
22932329
attributes.append(attr._internal_name)
22942330
# object arrays are var-len and handled later
22952331
if type(val) is np.ndarray and val.dtype is not np.dtype('O'):
22962332
val = np.ascontiguousarray(val, dtype=attr.dtype)
2333+
try:
2334+
if attr.isvar:
2335+
# ensure that the value is array-convertible, for example: pandas.Series
2336+
val = np.asarray(val)
2337+
if attr.isnullable and name not in nullmaps:
2338+
nullmaps[name] = np.array([int(v is not None) for v in val], dtype=np.uint8)
2339+
else:
2340+
if (np.issubdtype(attr.dtype, np.string_) and not
2341+
(np.issubdtype(val.dtype, np.string_) or val.dtype == np.dtype('O'))):
2342+
raise ValueError("Cannot write a string value to non-string "
2343+
"typed attribute '{}'!".format(name))
2344+
2345+
if attr.isnullable and name not in nullmaps:
2346+
nullmaps[name] = ~np.ma.fix_invalid(val).mask
2347+
val = np.nan_to_num(val)
2348+
val = np.ascontiguousarray(val, dtype=attr.dtype)
2349+
except Exception as exc:
2350+
raise ValueError(f"NumPy array conversion check failed for attr '{name}'") from exc
22972351
values.append(val)
22982352
elif self.view_attr is not None:
22992353
# Support single-attribute assignment for multi-attr array
@@ -2329,9 +2383,6 @@ cdef class DenseArrayImpl(Array):
23292383
if not isinstance(val, np.ndarray):
23302384
raise TypeError(f"Expected NumPy array for attribute '{key}' "
23312385
f"validity bitmap, got {type(val)}")
2332-
if val.dtype != np.uint8:
2333-
raise TypeError(f"Expected NumPy uint8 array for attribute '{key}' "
2334-
f"validity bitmap, got {val.dtype}")
23352386

23362387
_write_array(
23372388
ctx_ptr,
@@ -2769,17 +2820,19 @@ def _setitem_impl_sparse(self: Array, selection, val, dict nullmaps):
27692820
if attr.isvar:
27702821
# ensure that the value is array-convertible, for example: pandas.Series
27712822
attr_val = np.asarray(attr_val)
2823+
if attr.isnullable and name not in nullmaps:
2824+
nullmaps[name] = np.array([int(v is not None) for v in attr_val], dtype=np.uint8)
27722825
else:
27732826
if (np.issubdtype(attr.dtype, np.string_) and not
27742827
(np.issubdtype(attr_val.dtype, np.string_) or attr_val.dtype == np.dtype('O'))):
27752828
raise ValueError("Cannot write a string value to non-string "
27762829
"typed attribute '{}'!".format(name))
2777-
2830+
2831+
if attr.isnullable and name not in nullmaps:
2832+
nullmaps[name] = ~np.ma.masked_invalid(attr_val).mask
2833+
attr_val = np.nan_to_num(attr_val)
27782834
attr_val = np.ascontiguousarray(attr_val, dtype=attr.dtype)
27792835

2780-
if attr.isnullable and attr.name not in nullmaps:
2781-
nullmaps[attr.name] = np.array([int(v is not None) for v in attr_val], dtype=np.uint8)
2782-
27832836
except Exception as exc:
27842837
raise ValueError(f"NumPy array conversion check failed for attr '{name}'") from exc
27852838

@@ -2919,7 +2972,18 @@ cdef class SparseArrayImpl(Array):
29192972
enum_label = attr.enum_label
29202973
if enum_label is not None:
29212974
values = self.enum(enum_label).values()
2922-
result[attr.name] = np.array([values[idx] for idx in result[attr.name]])
2975+
if attr.isnullable:
2976+
data = np.array([values[idx] for idx in result[attr.name].data])
2977+
result[attr.name] = np.ma.array(
2978+
data, mask=~result[attr.name].mask)
2979+
else:
2980+
result[attr.name] = np.array(
2981+
[values[idx] for idx in result[attr.name]])
2982+
else:
2983+
if attr.isnullable:
2984+
result[attr.name] = np.ma.array(result[attr.name].data,
2985+
mask=~result[attr.name].mask)
2986+
29232987
return result
29242988

29252989
def query(self, attrs=None, cond=None, attr_cond=None, dims=None,
@@ -3207,6 +3271,9 @@ cdef class SparseArrayImpl(Array):
32073271
else:
32083272
arr.dtype = el_dtype
32093273
out[final_name] = arr
3274+
3275+
if self.schema.has_attr(final_name) and self.attr(final_name).isnullable:
3276+
out[final_name] = np.ma.array(out[final_name], mask=results[name][2])
32103277

32113278
return out
32123279

tiledb/tests/test_enumeration.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import tiledb
66

7-
from .common import DiskTestCase, has_pandas
7+
from .common import DiskTestCase, has_pandas, has_pyarrow
88

99

1010
class EnumerationTest(DiskTestCase):
@@ -82,47 +82,37 @@ def test_array_schema_enumeration(self):
8282
assert_array_equal(A.df[:]["attr1"], A[:]["attr1"])
8383
assert_array_equal(A.df[:]["attr2"], A[:]["attr2"])
8484

85-
def test_array_schema_enumeration_nullable(self):
86-
uri = self.path("test_array_schema_enumeration")
87-
dom = tiledb.Domain(tiledb.Dim(domain=(1, 8), tile=1))
88-
enum1 = tiledb.Enumeration("enmr1", False, np.arange(3) * 10)
89-
enum2 = tiledb.Enumeration("enmr2", False, ["a", "bb", "ccc"])
90-
attr1 = tiledb.Attr("attr1", dtype=np.int32, enum_label="enmr1")
91-
attr2 = tiledb.Attr("attr2", dtype=np.int32, enum_label="enmr2")
92-
attr3 = tiledb.Attr("attr3", dtype=np.int32)
85+
@pytest.mark.skipif(
86+
not has_pyarrow() or not has_pandas(),
87+
reason="pyarrow and/or pandas not installed",
88+
)
89+
@pytest.mark.parametrize("sparse", [True, False])
90+
@pytest.mark.parametrize("pass_df", [True, False])
91+
def test_array_schema_enumeration_nullable(self, sparse, pass_df):
92+
import pyarrow as pa
93+
94+
uri = self.path("test_array_schema_enumeration_nullable")
95+
enmr = tiledb.Enumeration("e", False, ["alpha", "beta", "gamma"])
96+
dom = tiledb.Domain(tiledb.Dim("d", domain=(1, 5), dtype="int64"))
97+
att = tiledb.Attr("a", dtype="int8", nullable=True, enum_label="e")
9398
schema = tiledb.ArraySchema(
94-
domain=dom, attrs=(attr1, attr2, attr3), enums=(enum1, enum2)
99+
domain=dom, attrs=[att], enums=[enmr], sparse=sparse
95100
)
96101
tiledb.Array.create(uri, schema)
97102

98-
data1 = np.random.randint(0, 3, 8)
99-
data2 = np.random.randint(0, 3, 8)
100-
data3 = np.random.randint(0, 3, 8)
101-
102103
with tiledb.open(uri, "w") as A:
103-
A[:] = {"attr1": data1, "attr2": data2, "attr3": data3}
104-
105-
with tiledb.open(uri, "r") as A:
106-
assert A.enum("enmr1") == enum1
107-
assert attr1.enum_label == "enmr1"
108-
assert A.attr("attr1").enum_label == "enmr1"
104+
dims = pa.array([1, 2, 3, 4, 5])
105+
data = pa.array([1.0, 2.0, None, 0, 1.0])
106+
if pass_df:
107+
dims = dims.to_pandas()
108+
data = data.to_pandas()
109109

110-
assert A.enum("enmr2") == enum2
111-
assert attr2.enum_label == "enmr2"
112-
assert A.attr("attr2").enum_label == "enmr2"
113-
114-
with self.assertRaises(tiledb.TileDBError) as excinfo:
115-
assert A.enum("enmr3") == []
116-
assert " No enumeration named 'enmr3'" in str(excinfo.value)
117-
assert attr3.enum_label is None
118-
assert A.attr("attr3").enum_label is None
110+
if sparse:
111+
A[dims] = data
112+
else:
113+
A[:] = data
119114

120-
if has_pandas():
121-
assert_array_equal(A.df[:]["attr1"].cat.codes, data1)
122-
assert_array_equal(A.df[:]["attr2"].cat.codes, data2)
123-
124-
assert_array_equal(A.df[:]["attr1"], A.multi_index[:]["attr1"])
125-
assert_array_equal(A.df[:]["attr2"], A.multi_index[:]["attr2"])
126-
127-
assert_array_equal(A.df[:]["attr1"], A[:]["attr1"])
128-
assert_array_equal(A.df[:]["attr2"], A[:]["attr2"])
115+
with tiledb.open(uri, "r") as A:
116+
expected_validity = [False, False, True, False, False]
117+
assert_array_equal(A[:]["a"].mask, expected_validity)
118+
assert_array_equal(A.df[:]["a"].isna(), expected_validity)

tiledb/tests/test_libtiledb.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
assert_unordered_equal,
2828
fx_sparse_cell_order, # noqa: F401
2929
has_pandas,
30+
has_pyarrow,
3031
rand_ascii,
3132
rand_ascii_bytes,
3233
rand_utf8,
@@ -381,6 +382,38 @@ def test_array_delete(self):
381382

382383
assert tiledb.array_exists(uri) is False
383384

385+
@pytest.mark.skipif(
386+
not has_pyarrow() or not has_pandas(),
387+
reason="pyarrow and/or pandas not installed",
388+
)
389+
@pytest.mark.parametrize("sparse", [True, False])
390+
@pytest.mark.parametrize("pass_df", [True, False])
391+
def test_array_write_nullable(self, sparse, pass_df):
392+
import pyarrow as pa
393+
394+
uri = self.path("test_array_write_nullable")
395+
dom = tiledb.Domain(tiledb.Dim("d", domain=(1, 5), dtype="int64"))
396+
att = tiledb.Attr("a", dtype="int8", nullable=True)
397+
schema = tiledb.ArraySchema(domain=dom, attrs=[att], sparse=sparse)
398+
tiledb.Array.create(uri, schema)
399+
400+
with tiledb.open(uri, "w") as A:
401+
dims = pa.array([1, 2, 3, 4, 5])
402+
data = pa.array([1.0, 2.0, None, 0, 1.0])
403+
if pass_df:
404+
dims = dims.to_pandas()
405+
data = data.to_pandas()
406+
407+
if sparse:
408+
A[dims] = data
409+
else:
410+
A[:] = data
411+
412+
with tiledb.open(uri, "r") as A:
413+
expected_validity = [False, False, True, False, False]
414+
assert_array_equal(A[:]["a"].mask, expected_validity)
415+
assert_array_equal(A.df[:]["a"].isna(), expected_validity)
416+
384417

385418
class DenseArrayTest(DiskTestCase):
386419
def test_array_1d(self):

0 commit comments

Comments
 (0)