Skip to content

Commit 799a434

Browse files
authored
Add Support For SchemaEvolution on Enumerations (#1834)
This PR adds `SchemaEvolution.add_enumeration` and `SchemaEvolution.drop_enumeration`. The bindings were done in a previous PR but never added to `SchemaEvolution`.
1 parent 3e13a83 commit 799a434

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

tiledb/schema_evolution.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import tiledb
44

5+
from .enumeration import Enumeration
56
from .main import ArraySchemaEvolution as ASE
67

78

@@ -29,6 +30,20 @@ def drop_attribute(self, attr_name: str):
2930

3031
self.ase.drop_attribute(attr_name)
3132

33+
def add_enumeration(self, enmr: Enumeration):
34+
"""Add the given enumeration to the schema evolution plan.
35+
Note: this function does not apply any changes; the changes are
36+
only applied when `ArraySchemaEvolution.array_evolve` is called."""
37+
38+
self.ase.add_enumeration(enmr)
39+
40+
def drop_enumeration(self, enmr_name: str):
41+
"""Drop the given enumeration (by name) in the schema evolution.
42+
Note: this function does not apply any changes; the changes are
43+
only applied when `ArraySchemaEvolution.array_evolve` is called."""
44+
45+
self.ase.drop_enumeration(enmr_name)
46+
3247
def array_evolve(self, uri: str):
3348
"""Apply ArraySchemaEvolution actions to Array at given URI."""
3449

tiledb/tests/test_schema_evolution.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,56 @@ def get_schema_timestamps(schema_uri):
115115
se.array_evolve(uri)
116116

117117
assert 123456789 in get_schema_timestamps(schema_uri)
118+
119+
120+
def test_schema_evolution_with_enmr(tmp_path):
121+
ctx = tiledb.default_ctx()
122+
se = tiledb.ArraySchemaEvolution(ctx)
123+
124+
uri = str(tmp_path)
125+
126+
attrs = [
127+
tiledb.Attr(name="a1", dtype=np.float64),
128+
tiledb.Attr(name="a2", dtype=np.int32),
129+
]
130+
dims = [tiledb.Dim(domain=(0, 3), dtype=np.uint64)]
131+
domain = tiledb.Domain(*dims)
132+
schema = tiledb.ArraySchema(domain=domain, attrs=attrs, sparse=False)
133+
tiledb.Array.create(uri, schema)
134+
135+
data1 = {
136+
"a1": np.arange(5, 9),
137+
"a2": np.random.randint(0, 1e7, size=4).astype(np.int32),
138+
}
139+
140+
with tiledb.open(uri, "w") as A:
141+
A[:] = data1
142+
143+
with tiledb.open(uri) as A:
144+
assert not A.schema.has_attr("a3")
145+
146+
newattr = tiledb.Attr("a3", dtype=np.int8, enum_label="e3")
147+
se.add_attribute(newattr)
148+
149+
with pytest.raises(tiledb.TileDBError) as excinfo:
150+
se.array_evolve(uri)
151+
assert " Attribute refers to an unknown enumeration" in str(excinfo.value)
152+
153+
se.add_enumeration(tiledb.Enumeration("e3", True, np.arange(0, 8)))
154+
se.array_evolve(uri)
155+
156+
with tiledb.open(uri) as A:
157+
assert A.schema.has_attr("a3")
158+
assert A.attr("a3").enum_label == "e3"
159+
160+
se.drop_enumeration("e3")
161+
162+
with pytest.raises(tiledb.TileDBError) as excinfo:
163+
se.array_evolve(uri)
164+
assert "the enumeration has not been loaded" in str(excinfo.value)
165+
166+
se.drop_attribute("a3")
167+
se.array_evolve(uri)
168+
169+
with tiledb.open(uri) as A:
170+
assert not A.schema.has_attr("a3")

0 commit comments

Comments
 (0)