Skip to content

Commit dfd384a

Browse files
author
Tom Augspurger
authored
Merge pull request #31 from TomAugspurger/feature/arrow-types
Optionally use pyarrow types in to_geodataframe
2 parents 3901d33 + 9c60219 commit dfd384a

File tree

5 files changed

+348
-159
lines changed

5 files changed

+348
-159
lines changed

stac_geoparquet/stac_geoparquet.py

Lines changed: 120 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
"""
22
Generate geoparquet from a sequence of STAC items.
33
"""
4+
45
from __future__ import annotations
6+
import collections
57

6-
from typing import Sequence, Any
8+
from typing import Sequence, Any, Literal
9+
import warnings
710

811
import pystac
912
import geopandas
1013
import pandas as pd
14+
import pyarrow as pa
1115
import numpy as np
1216
import shapely.geometry
1317

@@ -16,7 +20,7 @@
1620
from stac_geoparquet.utils import fix_empty_multipolygon
1721

1822
STAC_ITEM_TYPES = ["application/json", "application/geo+json"]
19-
23+
DTYPE_BACKEND = Literal["numpy_nullable", "pyarrow"]
2024
SELF_LINK_COLUMN = "self_link"
2125

2226

@@ -31,7 +35,10 @@ def _fix_array(v):
3135

3236

3337
def to_geodataframe(
34-
items: Sequence[dict[str, Any]], add_self_link: bool = False
38+
items: Sequence[dict[str, Any]],
39+
add_self_link: bool = False,
40+
dtype_backend: DTYPE_BACKEND | None = None,
41+
datetime_precision: str = "ns",
3542
) -> geopandas.GeoDataFrame:
3643
"""
3744
Convert a sequence of STAC items to a :class:`geopandas.GeoDataFrame`.
@@ -42,19 +49,72 @@ def to_geodataframe(
4249
Parameters
4350
----------
4451
items: A sequence of STAC items.
45-
add_self_link: Add the absolute link (if available) to the source STAC Item as a separate column named "self_link"
52+
add_self_link: bool, default False
53+
Add the absolute link (if available) to the source STAC Item
54+
as a separate column named "self_link"
55+
dtype_backend: {'pyarrow', 'numpy_nullable'}, optional
56+
The dtype backend to use for storing arrays.
57+
58+
By default, this will use 'numpy_nullable' and emit a
59+
FutureWarning that the default will change to 'pyarrow' in
60+
the next release.
61+
62+
Set to 'numpy_nullable' to silence the warning and accept the
63+
old behavior.
64+
65+
Set to 'pyarrow' to silence the warning and accept the new behavior.
66+
67+
There are some difference in the output as well: with
68+
``dtype_backend="pyarrow"``, struct-like fields will explicitly
69+
contain null values for fields that appear in only some of the
70+
records. For example, given an ``assets`` like::
71+
72+
{
73+
"a": {
74+
"href": "a.tif",
75+
},
76+
"b": {
77+
"href": "b.tif",
78+
"title": "B",
79+
}
80+
}
81+
82+
The ``assets`` field of the output for the first row with
83+
``dtype_backend="numpy_nullable"`` will be a Python dictionary with
84+
just ``{"href": "a.tiff"}``.
85+
86+
With ``dtype_backend="pyarrow"``, this will be a pyarrow struct
87+
with fields ``{"href": "a.tif", "title", None}``. pyarrow will
88+
infer that the struct field ``asset.title`` is nullable.
89+
90+
datetime_precision: str, default "ns"
91+
The precision to use for the datetime columns. For example,
92+
"us" is microsecond and "ns" is nanosecond.
4693
4794
Returns
4895
-------
4996
The converted GeoDataFrame.
5097
"""
51-
items2 = []
98+
items2 = collections.defaultdict(list)
99+
52100
for item in items:
53-
item2 = {k: v for k, v in item.items() if k != "properties"}
101+
keys = set(item) - {"properties", "geometry"}
102+
103+
for k in keys:
104+
items2[k].append(item[k])
105+
106+
item_geometry = item["geometry"]
107+
if item_geometry:
108+
item_geometry = fix_empty_multipolygon(item_geometry)
109+
110+
items2["geometry"].append(item_geometry)
111+
54112
for k, v in item["properties"].items():
55-
if k in item2:
56-
raise ValueError("k", k)
57-
item2[k] = v
113+
if k in item:
114+
msg = f"Key '{k}' appears in both 'properties' and the top level."
115+
raise ValueError(msg)
116+
items2[k].append(v)
117+
58118
if add_self_link:
59119
self_href = None
60120
for link in item["links"]:
@@ -65,23 +125,11 @@ def to_geodataframe(
65125
):
66126
self_href = link["href"]
67127
break
68-
item2[SELF_LINK_COLUMN] = self_href
69-
items2.append(item2)
70-
71-
# Filter out missing geoms in MultiPolygons
72-
# https://github.com/shapely/shapely/issues/1407
73-
# geometry = [shapely.geometry.shape(x["geometry"]) for x in items2]
74-
75-
geometry = []
76-
for item2 in items2:
77-
item_geometry = item2["geometry"]
78-
if item_geometry:
79-
item_geometry = fix_empty_multipolygon(item_geometry) # type: ignore
80-
geometry.append(item_geometry)
81-
82-
gdf = geopandas.GeoDataFrame(items2, geometry=geometry, crs="WGS84")
128+
items2[SELF_LINK_COLUMN].append(self_href)
83129

84-
for column in [
130+
# TODO: Ideally we wouldn't have to hard-code this list.
131+
# Could we get it from the JSON schema.
132+
DATETIME_COLUMNS = {
85133
"datetime", # common metadata
86134
"start_datetime",
87135
"end_datetime",
@@ -90,9 +138,43 @@ def to_geodataframe(
90138
"expires", # timestamps extension
91139
"published",
92140
"unpublished",
93-
]:
94-
if column in gdf.columns:
95-
gdf[column] = pd.to_datetime(gdf[column], format="ISO8601")
141+
}
142+
143+
items2["geometry"] = geopandas.array.from_shapely(items2["geometry"])
144+
145+
if dtype_backend is None:
146+
msg = (
147+
"The default argument for 'dtype_backend' will change from "
148+
"'numpy_nullable' to 'pyarrow'. To keep the previous default "
149+
"specify ``dtype_backend='numpy_nullable'``. To accept the future "
150+
"behavior specify ``dtype_backend='pyarrow'."
151+
)
152+
warnings.warn(FutureWarning(msg))
153+
dtype_backend = "numpy_nullable"
154+
155+
if dtype_backend == "pyarrow":
156+
for k, v in items2.items():
157+
if k in DATETIME_COLUMNS:
158+
dt = pd.to_datetime(v, format="ISO8601").as_unit(datetime_precision)
159+
items2[k] = pd.arrays.ArrowExtensionArray(pa.array(dt))
160+
161+
elif k != "geometry":
162+
items2[k] = pd.arrays.ArrowExtensionArray(pa.array(v))
163+
164+
elif dtype_backend == "numpy_nullable":
165+
for k, v in items2.items():
166+
if k in DATETIME_COLUMNS:
167+
items2[k] = pd.to_datetime(v, format="ISO8601").as_unit(
168+
datetime_precision
169+
)
170+
171+
if k in {"type", "stac_version", "id", "collection", SELF_LINK_COLUMN}:
172+
items2[k] = pd.array(v, dtype="string")
173+
else:
174+
msg = f"Invalid 'dtype_backend={dtype_backend}'."
175+
raise TypeError(msg)
176+
177+
gdf = geopandas.GeoDataFrame(items2, geometry="geometry", crs="WGS84")
96178

97179
columns = [
98180
"type",
@@ -111,10 +193,6 @@ def to_geodataframe(
111193
columns.remove(col)
112194

113195
gdf = pd.concat([gdf[columns], gdf.drop(columns=columns)], axis="columns")
114-
for k in ["type", "stac_version", "id", "collection", SELF_LINK_COLUMN]:
115-
if k in gdf:
116-
gdf[k] = gdf[k].astype("string")
117-
118196
return gdf
119197

120198

@@ -144,12 +222,16 @@ def to_dict(record: dict) -> dict:
144222

145223
if k == SELF_LINK_COLUMN:
146224
continue
225+
elif k == "assets":
226+
item[k] = {k2: v2 for k2, v2 in v.items() if v2 is not None}
147227
elif k in top_level_keys:
148228
item[k] = v
149229
else:
150230
properties[k] = v
151231

152-
item["geometry"] = shapely.geometry.mapping(item["geometry"])
232+
if item["geometry"]:
233+
item["geometry"] = shapely.geometry.mapping(item["geometry"])
234+
153235
item["properties"] = properties
154236

155237
return item
@@ -175,6 +257,11 @@ def to_item_collection(df: geopandas.GeoDataFrame) -> pystac.ItemCollection:
175257
include=["datetime64[ns, UTC]", "datetime64[ns]"]
176258
).columns
177259
for k in datelike:
260+
# %f isn't implemented in pyarrow
261+
# https://github.com/apache/arrow/issues/20146
262+
if isinstance(df2[k].dtype, pd.ArrowDtype):
263+
df2[k] = df2[k].astype("datetime64[ns, utc]")
264+
178265
df2[k] = (
179266
df2[k].dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ").fillna("").replace({"": None})
180267
)

stac_geoparquet/utils.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,27 @@
88

99

1010
@functools.singledispatch
11-
def assert_equal(result: Any, expected: Any) -> bool:
11+
def assert_equal(result: Any, expected: Any, ignore_none: bool = False) -> bool:
1212
raise TypeError(f"Invalid type {type(result)}")
1313

1414

1515
@assert_equal.register(pystac.ItemCollection)
1616
def assert_equal_ic(
17-
result: pystac.ItemCollection, expected: pystac.ItemCollection
17+
result: pystac.ItemCollection,
18+
expected: pystac.ItemCollection,
19+
ignore_none: bool = False,
1820
) -> None:
1921
assert type(result) == type(expected)
2022
assert len(result) == len(expected)
2123
assert result.extra_fields == expected.extra_fields
2224
for a, b in zip(result.items, expected.items):
23-
assert_equal(a, b)
25+
assert_equal(a, b, ignore_none=ignore_none)
2426

2527

2628
@assert_equal.register(pystac.Item)
27-
def assert_equal_item(result: pystac.Item, expected: pystac.Item) -> None:
29+
def assert_equal_item(
30+
result: pystac.Item, expected: pystac.Item, ignore_none: bool = False
31+
) -> None:
2832
assert type(result) == type(expected)
2933
assert result.id == expected.id
3034
assert shapely.geometry.shape(result.geometry) == shapely.geometry.shape(
@@ -41,20 +45,44 @@ def assert_equal_item(result: pystac.Item, expected: pystac.Item) -> None:
4145
expected_links = sorted(expected.links, key=lambda x: x.href)
4246
assert len(result_links) == len(expected_links)
4347
for a, b in zip(result_links, expected_links):
44-
assert_equal(a, b)
48+
assert_equal(a, b, ignore_none=ignore_none)
4549

4650
assert set(result.assets) == set(expected.assets)
4751
for k in result.assets:
48-
assert_equal(result.assets[k], expected.assets[k])
52+
assert_equal(result.assets[k], expected.assets[k], ignore_none=ignore_none)
4953

5054

5155
@assert_equal.register(pystac.Link)
5256
@assert_equal.register(pystac.Asset)
5357
def assert_link_equal(
54-
result: pystac.Link | pystac.Asset, expected: pystac.Link | pystac.Asset
58+
result: pystac.Link | pystac.Asset,
59+
expected: pystac.Link | pystac.Asset,
60+
ignore_none: bool = False,
5561
) -> None:
5662
assert type(result) == type(expected)
57-
assert result.to_dict() == expected.to_dict()
63+
resultd = result.to_dict()
64+
expectedd = expected.to_dict()
65+
66+
left = {}
67+
68+
if ignore_none:
69+
for k, v in resultd.items():
70+
if v is None and k not in expectedd:
71+
pass
72+
elif isinstance(v, list) and k in expectedd:
73+
out = []
74+
for val in v:
75+
if isinstance(val, dict):
76+
out.append({k: v2 for k, v2 in val.items() if v2 is not None})
77+
else:
78+
out.append(val)
79+
left[k] = out
80+
else:
81+
left[k] = v
82+
else:
83+
left = resultd
84+
85+
assert left == expectedd
5886

5987

6088
def fix_empty_multipolygon(

tests/test_pgstac_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_naip_item():
113113
expected.remove_links(rel=pystac.RelType.SELF)
114114
result.remove_links(rel=pystac.RelType.SELF)
115115

116-
assert_equal(result, expected)
116+
assert_equal(result, expected, ignore_none=True)
117117

118118

119119
def test_sentinel2_l2a():
@@ -139,7 +139,7 @@ def test_sentinel2_l2a():
139139
result.remove_links(rel=pystac.RelType.SELF)
140140

141141
expected.remove_links(rel=pystac.RelType.LICENSE)
142-
assert_equal(result, expected)
142+
assert_equal(result, expected, ignore_none=True)
143143

144144

145145
def test_generate_endpoints():

0 commit comments

Comments
 (0)