Skip to content

Commit a35dd54

Browse files
committed
Issue #346 Some more ProcessArgs porting
for less boilerplate code and better/earlier error messages
1 parent b26527c commit a35dd54

File tree

3 files changed

+116
-159
lines changed

3 files changed

+116
-159
lines changed

openeo_driver/ProcessGraphDeserializer.py

+112-156
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import time
1414
import warnings
1515
from pathlib import Path
16-
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, Sequence
16+
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, Sequence, Optional
1717

1818
import geopandas as gpd
1919
import numpy as np
@@ -935,22 +935,18 @@ def save_result(args: Dict, env: EvalEnv) -> SaveResult: # TODO: return type no
935935

936936
@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/save_ml_model.json"))
937937
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/save_ml_model.json"))
938-
def save_ml_model(args: dict, env: EvalEnv) -> MlModelResult:
939-
data: DriverMlModel = extract_arg(args, "data", process_id="save_ml_model")
940-
if not isinstance(data, DriverMlModel):
941-
raise ProcessParameterInvalidException(
942-
parameter="data", process="save_ml_model", reason=f"Invalid data type {type(data)!r} expected raster-cube."
943-
)
944-
options = args.get("options", {})
938+
def save_ml_model(args: ProcessArgs, env: EvalEnv) -> MlModelResult:
939+
data = args.get_required("data", expected_type=DriverMlModel)
940+
options = args.get_optional("options", default={}, expected_type=dict)
945941
return MlModelResult(ml_model=data, options=options)
946942

947943

948944
@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/load_ml_model.json"))
949945
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/load_ml_model.json"))
950-
def load_ml_model(args: dict, env: EvalEnv) -> DriverMlModel:
946+
def load_ml_model(args: ProcessArgs, env: EvalEnv) -> DriverMlModel:
951947
if env.get(ENV_DRY_RUN_TRACER):
952948
return DriverMlModel()
953-
job_id = extract_arg(args, "id")
949+
job_id = args.get_required("id", expected_type=str)
954950
return env.backend_implementation.load_ml_model(job_id)
955951

956952

@@ -1186,51 +1182,34 @@ def add_dimension(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
11861182

11871183

11881184
@process
1189-
def drop_dimension(args: dict, env: EvalEnv) -> DriverDataCube:
1190-
data_cube = extract_arg(args, 'data')
1191-
if not isinstance(data_cube, DriverDataCube):
1192-
raise ProcessParameterInvalidException(
1193-
parameter="data", process="drop_dimension",
1194-
reason=f"Invalid data type {type(data_cube)!r} expected raster-cube."
1195-
)
1196-
return data_cube.drop_dimension(name=extract_arg(args, 'name'))
1185+
def drop_dimension(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1186+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1187+
name: str = args.get_required("name", expected_type=str)
1188+
return cube.drop_dimension(name=name)
11971189

11981190

11991191
@process
1200-
def dimension_labels(args: dict, env: EvalEnv) -> DriverDataCube:
1201-
data_cube = extract_arg(args, 'data')
1202-
if not isinstance(data_cube, DriverDataCube):
1203-
raise ProcessParameterInvalidException(
1204-
parameter="data", process="dimension_labels",
1205-
reason=f"Invalid data type {type(data_cube)!r} expected raster-cube."
1206-
)
1207-
return data_cube.dimension_labels(dimension=extract_arg(args, 'dimension'))
1192+
def dimension_labels(args: ProcessArgs, env: EvalEnv) -> List[str]:
1193+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1194+
dimension: str = args.get_required("dimension", expected_type=str)
1195+
return cube.dimension_labels(dimension=dimension)
12081196

12091197

12101198
@process
1211-
def rename_dimension(args: dict, env: EvalEnv) -> DriverDataCube:
1212-
data_cube = extract_arg(args, 'data')
1213-
if not isinstance(data_cube, DriverDataCube):
1214-
raise ProcessParameterInvalidException(
1215-
parameter="data", process="rename_dimension",
1216-
reason=f"Invalid data type {type(data_cube)!r} expected raster-cube."
1217-
)
1218-
return data_cube.rename_dimension(source=extract_arg(args, 'source'),target=extract_arg(args, 'target'))
1199+
def rename_dimension(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1200+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1201+
source: str = args.get_required("source", expected_type=str)
1202+
target: str = args.get_required("target", expected_type=str)
1203+
return cube.rename_dimension(source=source, target=target)
12191204

12201205

12211206
@process
1222-
def rename_labels(args: dict, env: EvalEnv) -> DriverDataCube:
1223-
data_cube = extract_arg(args, 'data')
1224-
if not isinstance(data_cube, DriverDataCube):
1225-
raise ProcessParameterInvalidException(
1226-
parameter="data", process="rename_labels",
1227-
reason=f"Invalid data type {type(data_cube)!r} expected raster-cube."
1228-
)
1229-
return data_cube.rename_labels(
1230-
dimension=extract_arg(args, 'dimension'),
1231-
target=extract_arg(args, 'target'),
1232-
source=args.get('source',[])
1233-
)
1207+
def rename_labels(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1208+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1209+
dimension: str = args.get_required("dimension", expected_type=str)
1210+
target: List = args.get_required("target", expected_type=list)
1211+
source: Optional[list] = args.get_optional("source", default=None, expected_type=list)
1212+
return cube.rename_labels(dimension=dimension, target=target, source=source)
12341213

12351214

12361215
@process
@@ -1376,14 +1355,10 @@ def aggregate_spatial(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
13761355

13771356

13781357
@process
1379-
def mask(args: dict, env: EvalEnv) -> DriverDataCube:
1380-
cube = extract_arg(args, 'data')
1381-
if not isinstance(cube, DriverDataCube):
1382-
raise ProcessParameterInvalidException(
1383-
parameter="data", process="mask", reason=f"Invalid data type {type(cube)!r} expected raster-cube."
1384-
)
1385-
mask = extract_arg(args, 'mask')
1386-
replacement = args.get('replacement', None)
1358+
def mask(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1359+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1360+
mask: DriverDataCube = args.get_required("mask", expected_type=DriverDataCube)
1361+
replacement = args.get_optional("replacement", default=None)
13871362
return cube.mask(mask=mask, replacement=replacement)
13881363

13891364

@@ -1415,7 +1390,10 @@ def mask_polygon(args: dict, env: EvalEnv) -> DriverDataCube:
14151390
return image_collection
14161391

14171392

1418-
def _extract_temporal_extent(args: dict, field="extent", process_id="filter_temporal") -> Tuple[str, str]:
1393+
def _extract_temporal_extent(
1394+
args: Union[dict, ProcessArgs], field="extent", process_id="filter_temporal"
1395+
) -> Tuple[str, str]:
1396+
# TODO #346: make this a ProcessArgs method?
14191397
extent = extract_arg(args, name=field, process_id=process_id)
14201398
if len(extent) != 2:
14211399
raise ProcessParameterInvalidException(
@@ -1440,29 +1418,27 @@ def _extract_temporal_extent(args: dict, field="extent", process_id="filter_temp
14401418

14411419

14421420
@process
1443-
def filter_temporal(args: dict, env: EvalEnv) -> DriverDataCube:
1444-
cube = extract_arg(args, 'data')
1445-
if not isinstance(cube, DriverDataCube):
1446-
raise ProcessParameterInvalidException(
1447-
parameter="data", process="filter_temporal",
1448-
reason=f"Invalid data type {type(cube)!r} expected raster-cube."
1449-
)
1421+
def filter_temporal(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1422+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
14501423
extent = _extract_temporal_extent(args, field="extent", process_id="filter_temporal")
14511424
return cube.filter_temporal(start=extent[0], end=extent[1])
14521425

1426+
14531427
@process_registry_100.add_function(spec=read_spec("openeo-processes/1.x/proposals/filter_labels.json"))
14541428
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/2.x/proposals/filter_labels.json"))
1455-
def filter_labels(args: dict, env: EvalEnv) -> DriverDataCube:
1456-
cube = extract_arg(args, 'data')
1457-
if not isinstance(cube, DriverDataCube):
1458-
raise ProcessParameterInvalidException(
1459-
parameter="data", process="filter_labels",
1460-
reason=f"Invalid data type {type(cube)!r} expected cube."
1461-
)
1429+
def filter_labels(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1430+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1431+
# TODO: validation that condition is a process graph construct
1432+
condition = args.get_required("condition", expected_type=dict)
1433+
dimension = args.get_required("dimension", expected_type=str)
1434+
context = args.get_optional("context", default=None)
1435+
return cube.filter_labels(condition=condition, dimension=dimension, context=context, env=env)
14621436

1463-
return cube.filter_labels(condition=extract_arg(args,"condition"),dimension=extract_arg(args,"dimension"),context=args.get("context",None),env=env)
14641437

1465-
def _extract_bbox_extent(args: dict, field="extent", process_id="filter_bbox", handle_geojson=False) -> dict:
1438+
def _extract_bbox_extent(
1439+
args: Union[dict, ProcessArgs], field="extent", process_id="filter_bbox", handle_geojson=False
1440+
) -> dict:
1441+
# TODO #346: make this a ProcessArgs method?
14661442
extent = extract_arg(args, name=field, process_id=process_id)
14671443
if handle_geojson and extent.get("type") in [
14681444
"Polygon",
@@ -1487,24 +1463,16 @@ def _extract_bbox_extent(args: dict, field="extent", process_id="filter_bbox", h
14871463

14881464

14891465
@process
1490-
def filter_bbox(args: Dict, env: EvalEnv) -> DriverDataCube:
1491-
cube = extract_arg(args, 'data')
1492-
if not isinstance(cube, DriverDataCube):
1493-
raise ProcessParameterInvalidException(
1494-
parameter="data", process="filter_bbox", reason=f"Invalid data type {type(cube)!r} expected raster-cube."
1495-
)
1466+
def filter_bbox(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1467+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
14961468
spatial_extent = _extract_bbox_extent(args, "extent", process_id="filter_bbox")
14971469
return cube.filter_bbox(**spatial_extent)
14981470

14991471

15001472
@process
1501-
def filter_spatial(args: Dict, env: EvalEnv) -> DriverDataCube:
1502-
cube = extract_arg(args, 'data')
1503-
geometries = extract_arg(args, 'geometries')
1504-
if not isinstance(cube, DriverDataCube):
1505-
raise ProcessParameterInvalidException(
1506-
parameter="data", process="filter_spatial", reason=f"Invalid data type {type(cube)!r} expected raster-cube."
1507-
)
1473+
def filter_spatial(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1474+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1475+
geometries = args.get_required("geometries")
15081476

15091477
if isinstance(geometries, dict):
15101478
if "type" in geometries and geometries["type"] != "GeometryCollection":
@@ -1533,32 +1501,22 @@ def filter_spatial(args: Dict, env: EvalEnv) -> DriverDataCube:
15331501

15341502

15351503
@process
1536-
def filter_bands(args: Dict, env: EvalEnv) -> Union[DriverDataCube, DriverVectorCube]:
1537-
cube: Union[DriverDataCube, DriverVectorCube] = extract_arg(args, "data")
1538-
if not isinstance(cube, DriverDataCube) and not isinstance(cube, DriverVectorCube):
1539-
raise ProcessParameterInvalidException(
1540-
parameter="data", process="filter_bands", reason=f"Invalid data type {type(cube)!r} expected raster-cube."
1541-
)
1542-
bands = extract_arg(args, "bands", process_id="filter_bands")
1504+
def filter_bands(args: ProcessArgs, env: EvalEnv) -> Union[DriverDataCube, DriverVectorCube]:
1505+
cube: Union[DriverDataCube, DriverVectorCube] = args.get_required(
1506+
"data", expected_type=(DriverDataCube, DriverVectorCube)
1507+
)
1508+
bands = args.get_required("bands", expected_type=list)
15431509
return cube.filter_bands(bands=bands)
15441510

15451511

15461512
@process
1547-
def apply_kernel(args: Dict, env: EvalEnv) -> DriverDataCube:
1548-
image_collection = extract_arg(args, 'data')
1549-
kernel = np.asarray(extract_arg(args, 'kernel'))
1550-
factor = args.get('factor', 1.0)
1551-
border = args.get('border', 0)
1552-
if not isinstance(image_collection, DriverDataCube):
1553-
raise ProcessParameterInvalidException(
1554-
parameter="data", process="apply_kernel",
1555-
reason=f"Invalid data type {type(image_collection)!r} expected raster-cube."
1556-
)
1557-
if border == "0":
1558-
# R-client sends `0` border as a string
1559-
border = 0
1560-
replace_invalid = args.get('replace_invalid', 0)
1561-
return image_collection.apply_kernel(kernel=kernel, factor=factor, border=border, replace_invalid=replace_invalid)
1513+
def apply_kernel(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1514+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1515+
kernel = np.asarray(args.get_required("kernel", expected_type=list))
1516+
factor = args.get_optional("factor", default=1.0, expected_type=(int, float))
1517+
border = args.get_optional("border", default=0, expected_type=int)
1518+
replace_invalid = args.get_optional("replace_invalid", default=0, expected_type=(int, float))
1519+
return cube.apply_kernel(kernel=kernel, factor=factor, border=border, replace_invalid=replace_invalid)
15621520

15631521

15641522
@process
@@ -1590,16 +1548,30 @@ def resample_spatial(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
15901548

15911549

15921550
@process
1593-
def resample_cube_spatial(args: dict, env: EvalEnv) -> DriverDataCube:
1594-
image_collection = extract_arg(args, 'data')
1595-
target_image_collection = extract_arg(args, 'target')
1596-
method = args.get('method', 'near')
1597-
if not isinstance(image_collection, DriverDataCube):
1598-
raise ProcessParameterInvalidException(
1599-
parameter="data", process="resample_cube_spatial",
1600-
reason=f"Invalid data type {type(image_collection)!r} expected raster-cube."
1601-
)
1602-
return image_collection.resample_cube_spatial(target=target_image_collection, method=method)
1551+
def resample_cube_spatial(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1552+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1553+
target: DriverDataCube = args.get_required("target", expected_type=DriverDataCube)
1554+
method = args.get_enum(
1555+
"method",
1556+
options=[
1557+
"average",
1558+
"bilinear",
1559+
"cubic",
1560+
"cubicspline",
1561+
"lanczos",
1562+
"max",
1563+
"med",
1564+
"min",
1565+
"mode",
1566+
"near",
1567+
"q1",
1568+
"q3",
1569+
"rms",
1570+
"sum",
1571+
],
1572+
default="near",
1573+
)
1574+
return cube.resample_cube_spatial(target=target, method=method)
16031575

16041576

16051577
@process
@@ -1698,20 +1670,17 @@ def run_udf(args: dict, env: EvalEnv):
16981670

16991671

17001672
@process
1701-
def linear_scale_range(args: dict, env: EvalEnv) -> DriverDataCube:
1702-
image_collection = extract_arg(args, 'x')
1703-
1704-
inputMin = extract_arg(args, "inputMin")
1705-
inputMax = extract_arg(args, "inputMax")
1706-
outputMax = args.get("outputMax", 1.0)
1707-
outputMin = args.get("outputMin", 0.0)
1708-
if not isinstance(image_collection, DriverDataCube):
1709-
raise ProcessParameterInvalidException(
1710-
parameter="data", process="linear_scale_range",
1711-
reason=f"Invalid data type {type(image_collection)!r} expected raster-cube."
1712-
)
1713-
1714-
return image_collection.linear_scale_range(inputMin, inputMax, outputMin, outputMax)
1673+
def linear_scale_range(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1674+
# TODO: eliminate this top-level linear_scale_range process implementation (should be used as `apply` callback)
1675+
_log.warning("DEPRECATED: linear_scale_range usage directly on cube is deprecated/non-standard.")
1676+
cube: DriverDataCube = args.get_required("x", expected_type=DriverDataCube)
1677+
# Note: non-standard camelCase parameter names (https://github.com/Open-EO/openeo-processes/issues/302)
1678+
input_min = args.get_required("inputMin")
1679+
input_max = args.get_required("inputMax")
1680+
output_min = args.get_optional("outputMin", default=0.0)
1681+
output_max = args.get_optional("outputMax", default=1.0)
1682+
# TODO linear_scale_range is defined on GeopysparkDataCube, but not on DriverDataCube
1683+
return cube.linear_scale_range(input_min, input_max, output_min, output_max)
17151684

17161685

17171686
@process
@@ -1991,14 +1960,10 @@ def get_geometries(args: Dict, env: EvalEnv) -> Union[DelayedVector, dict]:
19911960
.param('data', description="A raster data cube.", schema={"type": "object", "subtype": "raster-cube"})
19921961
.returns("vector-cube", schema={"type": "object", "subtype": "vector-cube"})
19931962
)
1994-
def raster_to_vector(args: Dict, env: EvalEnv):
1995-
image_collection = extract_arg(args, 'data')
1996-
if not isinstance(image_collection, DriverDataCube):
1997-
raise ProcessParameterInvalidException(
1998-
parameter="data", process="raster_to_vector",
1999-
reason=f"Invalid data type {type(image_collection)!r} expected raster-cube."
2000-
)
2001-
return image_collection.raster_to_vector()
1963+
def raster_to_vector(args: ProcessArgs, env: EvalEnv):
1964+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1965+
# TODO: raster_to_vector is only defined on GeopysparkDataCube, not DriverDataCube
1966+
return cube.raster_to_vector()
20021967

20031968

20041969
@non_standard_process(
@@ -2238,13 +2203,8 @@ def discard_result(args: ProcessArgs, env: EvalEnv):
22382203

22392204
@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/mask_scl_dilation.json"))
22402205
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/mask_scl_dilation.json"))
2241-
def mask_scl_dilation(args: Dict, env: EvalEnv):
2242-
cube: DriverDataCube = extract_arg(args, 'data')
2243-
if not isinstance(cube, DriverDataCube):
2244-
raise ProcessParameterInvalidException(
2245-
parameter="data", process="mask_scl_dilation",
2246-
reason=f"Invalid data type {type(cube)!r} expected raster-cube."
2247-
)
2206+
def mask_scl_dilation(args: ProcessArgs, env: EvalEnv):
2207+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
22482208
if hasattr(cube, "mask_scl_dilation"):
22492209
the_args = args.copy()
22502210
del the_args["data"]
@@ -2275,13 +2235,8 @@ def to_scl_dilation_mask(args: ProcessArgs, env: EvalEnv):
22752235

22762236
@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/mask_l1c.json"))
22772237
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/mask_l1c.json"))
2278-
def mask_l1c(args: Dict, env: EvalEnv):
2279-
cube: DriverDataCube = extract_arg(args, 'data')
2280-
if not isinstance(cube, DriverDataCube):
2281-
raise ProcessParameterInvalidException(
2282-
parameter="data", process="mask_l1c",
2283-
reason=f"Invalid data type {type(cube)!r} expected raster-cube."
2284-
)
2238+
def mask_l1c(args: ProcessArgs, env: EvalEnv):
2239+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
22852240
if hasattr(cube, "mask_l1c"):
22862241
return cube.mask_l1c()
22872242
else:
@@ -2376,10 +2331,11 @@ def load_result(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
23762331

23772332
@process_registry_100.add_function(spec=read_spec("openeo-processes/1.x/proposals/inspect.json"))
23782333
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/2.x/proposals/inspect.json"))
2379-
def inspect(args: dict, env: EvalEnv):
2380-
data = extract_arg(args, "data")
2381-
message = args.get("message", "")
2382-
level = args.get("level", "info")
2334+
def inspect(args: ProcessArgs, env: EvalEnv):
2335+
data = args.get_required("data")
2336+
message = args.get_optional("message", default="")
2337+
code = args.get_optional("code", default="User")
2338+
level = args.get_optional("level", default="info")
23832339
if message:
23842340
_log.log(level=logging.getLevelName(level.upper()), msg=message)
23852341
data_message = str(data)

0 commit comments

Comments
 (0)