Skip to content

Commit 5c0a682

Browse files
kylebarronTom Augspurger
and
Tom Augspurger
authored
Move arrow-based code into arrow module (#47)
* Move arrow-based code into arrow module * fix tests import * deprecation --------- Co-authored-by: Tom Augspurger <taugspurger@microsoft.com>
1 parent fd7c9a4 commit 5c0a682

File tree

9 files changed

+620
-574
lines changed

9 files changed

+620
-574
lines changed

stac_geoparquet/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""stac-geoparquet"""
22

3-
from .stac_geoparquet import to_geodataframe, to_dict, to_item_collection
3+
from . import arrow
44
from ._version import __version__
5-
5+
from .stac_geoparquet import to_dict, to_geodataframe, to_item_collection
66

77
__all__ = [
88
"__version__",

stac_geoparquet/arrow/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ._from_arrow import stac_table_to_items, stac_table_to_ndjson
2+
from ._to_arrow import parse_stac_items_to_arrow, parse_stac_ndjson_to_arrow
3+
from ._to_parquet import to_parquet

stac_geoparquet/arrow/_from_arrow.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
"""Convert STAC Items in Arrow Table format to JSON Lines or Python dicts."""
2+
3+
import os
4+
import json
5+
from typing import Iterable, List, Union
6+
7+
import numpy as np
8+
import pyarrow as pa
9+
import pyarrow.compute as pc
10+
import shapely
11+
12+
13+
def stac_table_to_ndjson(table: pa.Table, dest: Union[str, os.PathLike[str]]) -> None:
14+
"""Write a STAC Table to a newline-delimited JSON file."""
15+
with open(dest, "w") as f:
16+
for item_dict in stac_table_to_items(table):
17+
json.dump(item_dict, f, separators=(",", ":"))
18+
f.write("\n")
19+
20+
21+
def stac_table_to_items(table: pa.Table) -> Iterable[dict]:
22+
"""Convert a STAC Table to a generator of STAC Item `dict`s"""
23+
table = _undo_stac_table_transformations(table)
24+
25+
# Convert WKB geometry column to GeoJSON, and then assign the geojson geometry when
26+
# converting each row to a dictionary.
27+
for batch in table.to_batches():
28+
geoms = shapely.from_wkb(batch["geometry"])
29+
geojson_strings = shapely.to_geojson(geoms)
30+
31+
# RecordBatch is missing a `drop()` method, so we keep all columns other than
32+
# geometry instead
33+
keep_column_names = [name for name in batch.column_names if name != "geometry"]
34+
struct_batch = batch.select(keep_column_names).to_struct_array()
35+
36+
for row_idx in range(len(struct_batch)):
37+
row_dict = struct_batch[row_idx].as_py()
38+
row_dict["geometry"] = json.loads(geojson_strings[row_idx])
39+
yield row_dict
40+
41+
42+
def _undo_stac_table_transformations(table: pa.Table) -> pa.Table:
43+
"""Undo the transformations done to convert STAC Json into an Arrow Table
44+
45+
Note that this function does _not_ undo the GeoJSON -> WKB geometry transformation,
46+
as that is easier to do when converting each item in the table to a dict.
47+
"""
48+
table = _convert_timestamp_columns_to_string(table)
49+
table = _lower_properties_from_top_level(table)
50+
table = _convert_bbox_to_array(table)
51+
return table
52+
53+
54+
def _convert_timestamp_columns_to_string(table: pa.Table) -> pa.Table:
55+
"""Convert any datetime columns in the table to a string representation"""
56+
allowed_column_names = {
57+
"datetime", # common metadata
58+
"start_datetime",
59+
"end_datetime",
60+
"created",
61+
"updated",
62+
"expires", # timestamps extension
63+
"published",
64+
"unpublished",
65+
}
66+
for column_name in allowed_column_names:
67+
try:
68+
column = table[column_name]
69+
except KeyError:
70+
continue
71+
72+
table = table.drop(column_name).append_column(
73+
column_name, pc.strftime(column, format="%Y-%m-%dT%H:%M:%SZ")
74+
)
75+
76+
return table
77+
78+
79+
def _lower_properties_from_top_level(table: pa.Table) -> pa.Table:
80+
"""Take properties columns from the top level and wrap them in a struct column"""
81+
stac_top_level_keys = {
82+
"stac_version",
83+
"stac_extensions",
84+
"type",
85+
"id",
86+
"bbox",
87+
"geometry",
88+
"collection",
89+
"links",
90+
"assets",
91+
}
92+
93+
properties_column_names: List[str] = []
94+
properties_column_fields: List[pa.Field] = []
95+
for column_idx in range(table.num_columns):
96+
column_name = table.column_names[column_idx]
97+
if column_name in stac_top_level_keys:
98+
continue
99+
100+
properties_column_names.append(column_name)
101+
properties_column_fields.append(table.schema.field(column_idx))
102+
103+
properties_array_chunks = []
104+
for batch in table.select(properties_column_names).to_batches():
105+
struct_arr = pa.StructArray.from_arrays(
106+
batch.columns, fields=properties_column_fields
107+
)
108+
properties_array_chunks.append(struct_arr)
109+
110+
return table.drop_columns(properties_column_names).append_column(
111+
"properties", pa.chunked_array(properties_array_chunks)
112+
)
113+
114+
115+
def _convert_bbox_to_array(table: pa.Table) -> pa.Table:
116+
"""Convert the struct bbox column back to a list column for writing to JSON"""
117+
118+
bbox_col_idx = table.schema.get_field_index("bbox")
119+
bbox_col = table.column(bbox_col_idx)
120+
121+
new_chunks = []
122+
for chunk in bbox_col.chunks:
123+
assert pa.types.is_struct(chunk.type)
124+
125+
if bbox_col.type.num_fields == 4:
126+
xmin = chunk.field("xmin").to_numpy()
127+
ymin = chunk.field("ymin").to_numpy()
128+
xmax = chunk.field("xmax").to_numpy()
129+
ymax = chunk.field("ymax").to_numpy()
130+
coords = np.column_stack(
131+
[
132+
xmin,
133+
ymin,
134+
xmax,
135+
ymax,
136+
]
137+
)
138+
139+
list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4)
140+
141+
elif bbox_col.type.num_fields == 6:
142+
xmin = chunk.field("xmin").to_numpy()
143+
ymin = chunk.field("ymin").to_numpy()
144+
zmin = chunk.field("zmin").to_numpy()
145+
xmax = chunk.field("xmax").to_numpy()
146+
ymax = chunk.field("ymax").to_numpy()
147+
zmax = chunk.field("zmax").to_numpy()
148+
coords = np.column_stack(
149+
[
150+
xmin,
151+
ymin,
152+
zmin,
153+
xmax,
154+
ymax,
155+
zmax,
156+
]
157+
)
158+
159+
list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 6)
160+
161+
else:
162+
raise ValueError("Expected 4 or 6 fields in bbox struct.")
163+
164+
new_chunks.append(list_arr)
165+
166+
return table.set_column(bbox_col_idx, "bbox", new_chunks)

0 commit comments

Comments
 (0)