Skip to content

Commit fd7c9a4

Browse files
authored
Round trip tests & various fixes (#42)
Round trip tests for pyarrow conversion.
1 parent 49f678f commit fd7c9a4

17 files changed

+24273
-49
lines changed

stac_geoparquet/from_arrow.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -121,20 +121,46 @@ def _convert_bbox_to_array(table: pa.Table) -> pa.Table:
121121
new_chunks = []
122122
for chunk in bbox_col.chunks:
123123
assert pa.types.is_struct(chunk.type)
124-
xmin = chunk.field(0).to_numpy()
125-
ymin = chunk.field(1).to_numpy()
126-
xmax = chunk.field(2).to_numpy()
127-
ymax = chunk.field(3).to_numpy()
128-
coords = np.column_stack(
129-
[
130-
xmin,
131-
ymin,
132-
xmax,
133-
ymax,
134-
]
135-
)
136124

137-
list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4)
125+
if bbox_col.type.num_fields == 4:
126+
xmin = chunk.field("xmin").to_numpy()
127+
ymin = chunk.field("ymin").to_numpy()
128+
xmax = chunk.field("xmax").to_numpy()
129+
ymax = chunk.field("ymax").to_numpy()
130+
coords = np.column_stack(
131+
[
132+
xmin,
133+
ymin,
134+
xmax,
135+
ymax,
136+
]
137+
)
138+
139+
list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4)
140+
141+
elif bbox_col.type.num_fields == 6:
142+
xmin = chunk.field("xmin").to_numpy()
143+
ymin = chunk.field("ymin").to_numpy()
144+
zmin = chunk.field("zmin").to_numpy()
145+
xmax = chunk.field("xmax").to_numpy()
146+
ymax = chunk.field("ymax").to_numpy()
147+
zmax = chunk.field("zmax").to_numpy()
148+
coords = np.column_stack(
149+
[
150+
xmin,
151+
ymin,
152+
zmin,
153+
xmax,
154+
ymax,
155+
zmax,
156+
]
157+
)
158+
159+
list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 6)
160+
161+
else:
162+
raise ValueError("Expected 4 or 6 fields in bbox struct.")
163+
138164
new_chunks.append(list_arr)
139165

140166
return table.set_column(bbox_col_idx, "bbox", new_chunks)

stac_geoparquet/to_arrow.py

Lines changed: 113 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def parse_stac_items_to_arrow(
2727
*,
2828
chunk_size: int = 8192,
2929
schema: Optional[pa.Schema] = None,
30+
downcast: bool = True,
3031
) -> pa.Table:
3132
"""Parse a collection of STAC Items to a :class:`pyarrow.Table`.
3233
@@ -41,6 +42,7 @@ def parse_stac_items_to_arrow(
4142
schema: The schema of the input data. If provided, can improve memory use;
4243
otherwise all items need to be parsed into a single array for schema
4344
inference. Defaults to None.
45+
downcast: if True, store bbox as float32 for memory and disk saving.
4446
4547
Returns:
4648
a pyarrow Table with the STAC-GeoParquet representation of items.
@@ -53,22 +55,23 @@ def parse_stac_items_to_arrow(
5355
for chunk in _chunks(items, chunk_size):
5456
batches.append(_stac_items_to_arrow(chunk, schema=schema))
5557

56-
stac_table = pa.Table.from_batches(batches, schema=schema)
58+
table = pa.Table.from_batches(batches, schema=schema)
5759
else:
5860
# If schema is _not_ provided, then we must convert to Arrow all at once, or
5961
# else it would be possible for a STAC item late in the collection (after the
6062
# first chunk) to have a different schema and not match the schema inferred for
6163
# the first chunk.
62-
stac_table = pa.Table.from_batches([_stac_items_to_arrow(items)])
64+
table = pa.Table.from_batches([_stac_items_to_arrow(items)])
6365

64-
return _process_arrow_table(stac_table)
66+
return _process_arrow_table(table, downcast=downcast)
6567

6668

6769
def parse_stac_ndjson_to_arrow(
6870
path: Union[str, Path],
6971
*,
7072
chunk_size: int = 8192,
7173
schema: Optional[pa.Schema] = None,
74+
downcast: bool = True,
7275
) -> pa.Table:
7376
# Define outside of if/else to make mypy happy
7477
items: List[dict] = []
@@ -98,14 +101,14 @@ def parse_stac_ndjson_to_arrow(
98101
if len(items) > 0:
99102
batches.append(_stac_items_to_arrow(items, schema=schema))
100103

101-
stac_table = pa.Table.from_batches(batches, schema=schema)
102-
return _process_arrow_table(stac_table)
104+
table = pa.Table.from_batches(batches, schema=schema)
105+
return _process_arrow_table(table, downcast=downcast)
103106

104107

105-
def _process_arrow_table(table: pa.Table) -> pa.Table:
108+
def _process_arrow_table(table: pa.Table, *, downcast: bool = True) -> pa.Table:
106109
table = _bring_properties_to_top_level(table)
107110
table = _convert_timestamp_columns(table)
108-
table = _convert_bbox_to_struct(table)
111+
table = _convert_bbox_to_struct(table, downcast=downcast)
109112
return table
110113

111114

@@ -192,11 +195,21 @@ def _convert_timestamp_columns(table: pa.Table) -> pa.Table:
192195
except KeyError:
193196
continue
194197

198+
field_index = table.schema.get_field_index(column_name)
199+
195200
if pa.types.is_timestamp(column.type):
196201
continue
202+
203+
# STAC allows datetimes to be null. If all rows are null, the column type may be
204+
# inferred as null. We cast this to a timestamp column.
205+
elif pa.types.is_null(column.type):
206+
table = table.set_column(
207+
field_index, column_name, column.cast(pa.timestamp("us"))
208+
)
209+
197210
elif pa.types.is_string(column.type):
198-
table = table.drop(column_name).append_column(
199-
column_name, _convert_timestamp_column(column)
211+
table = table.set_column(
212+
field_index, column_name, _convert_timestamp_column(column)
200213
)
201214
else:
202215
raise ValueError(
@@ -224,7 +237,26 @@ def _convert_timestamp_column(column: pa.ChunkedArray) -> pa.ChunkedArray:
224237
return pa.chunked_array(chunks)
225238

226239

227-
def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool = True) -> pa.Table:
240+
def is_bbox_3d(bbox_col: pa.ChunkedArray) -> bool:
241+
"""Infer whether the bounding box column represents 2d or 3d bounding boxes."""
242+
offsets_set = set()
243+
for chunk in bbox_col.chunks:
244+
offsets = chunk.offsets.to_numpy()
245+
offsets_set.update(np.unique(offsets[1:] - offsets[:-1]))
246+
247+
if len(offsets_set) > 1:
248+
raise ValueError("Mixed 2d-3d bounding boxes not yet supported")
249+
250+
offset = list(offsets_set)[0]
251+
if offset == 6:
252+
return True
253+
elif offset == 4:
254+
return False
255+
else:
256+
raise ValueError(f"Unexpected bbox offset: {offset=}")
257+
258+
259+
def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool) -> pa.Table:
228260
"""Convert bbox column to a struct representation
229261
230262
Since the bbox in JSON is stored as an array, pyarrow automatically converts the
@@ -244,6 +276,7 @@ def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool = True) -> pa.Tab
244276
"""
245277
bbox_col_idx = table.schema.get_field_index("bbox")
246278
bbox_col = table.column(bbox_col_idx)
279+
bbox_3d = is_bbox_3d(bbox_col)
247280

248281
new_chunks = []
249282
for chunk in bbox_col.chunks:
@@ -252,36 +285,80 @@ def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool = True) -> pa.Tab
252285
or pa.types.is_large_list(chunk.type)
253286
or pa.types.is_fixed_size_list(chunk.type)
254287
)
255-
coords = chunk.flatten().to_numpy().reshape(-1, 4)
256-
xmin = coords[:, 0]
257-
ymin = coords[:, 1]
258-
xmax = coords[:, 2]
259-
ymax = coords[:, 3]
288+
if bbox_3d:
289+
coords = chunk.flatten().to_numpy().reshape(-1, 6)
290+
else:
291+
coords = chunk.flatten().to_numpy().reshape(-1, 4)
260292

261293
if downcast:
262294
coords = coords.astype(np.float32)
263295

264-
# Round min values down to the next float32 value
265-
# Round max values up to the next float32 value
266-
xmin = np.nextafter(xmin, -np.Infinity)
267-
ymin = np.nextafter(ymin, -np.Infinity)
268-
xmax = np.nextafter(xmax, np.Infinity)
269-
ymax = np.nextafter(ymax, np.Infinity)
270-
271-
struct_arr = pa.StructArray.from_arrays(
272-
[
273-
xmin,
274-
ymin,
275-
xmax,
276-
ymax,
277-
],
278-
names=[
279-
"xmin",
280-
"ymin",
281-
"xmax",
282-
"ymax",
283-
],
284-
)
296+
if bbox_3d:
297+
xmin = coords[:, 0]
298+
ymin = coords[:, 1]
299+
zmin = coords[:, 2]
300+
xmax = coords[:, 3]
301+
ymax = coords[:, 4]
302+
zmax = coords[:, 5]
303+
304+
if downcast:
305+
# Round min values down to the next float32 value
306+
# Round max values up to the next float32 value
307+
xmin = np.nextafter(xmin, -np.Infinity)
308+
ymin = np.nextafter(ymin, -np.Infinity)
309+
zmin = np.nextafter(zmin, -np.Infinity)
310+
xmax = np.nextafter(xmax, np.Infinity)
311+
ymax = np.nextafter(ymax, np.Infinity)
312+
zmax = np.nextafter(zmax, np.Infinity)
313+
314+
struct_arr = pa.StructArray.from_arrays(
315+
[
316+
xmin,
317+
ymin,
318+
zmin,
319+
xmax,
320+
ymax,
321+
zmax,
322+
],
323+
names=[
324+
"xmin",
325+
"ymin",
326+
"zmin",
327+
"xmax",
328+
"ymax",
329+
"zmax",
330+
],
331+
)
332+
333+
else:
334+
xmin = coords[:, 0]
335+
ymin = coords[:, 1]
336+
xmax = coords[:, 2]
337+
ymax = coords[:, 3]
338+
339+
if downcast:
340+
# Round min values down to the next float32 value
341+
# Round max values up to the next float32 value
342+
xmin = np.nextafter(xmin, -np.Infinity)
343+
ymin = np.nextafter(ymin, -np.Infinity)
344+
xmax = np.nextafter(xmax, np.Infinity)
345+
ymax = np.nextafter(ymax, np.Infinity)
346+
347+
struct_arr = pa.StructArray.from_arrays(
348+
[
349+
xmin,
350+
ymin,
351+
xmax,
352+
ymax,
353+
],
354+
names=[
355+
"xmin",
356+
"ymin",
357+
"xmax",
358+
"ymax",
359+
],
360+
)
361+
285362
new_chunks.append(struct_arr)
286363

287364
return table.set_column(bbox_col_idx, "bbox", new_chunks)

0 commit comments

Comments
 (0)