Skip to content

Commit f5e17b8

Browse files
committed
Issue #346 Some more ProcessArgs porting
for less boilerplate code and better/earlier error messages
1 parent cff6869 commit f5e17b8

File tree

3 files changed

+80
-116
lines changed

3 files changed

+80
-116
lines changed

openeo_driver/ProcessGraphDeserializer.py

+73-114
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import time
1414
import warnings
1515
from pathlib import Path
16-
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, Sequence
16+
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, Sequence, Optional
1717

1818
import geopandas as gpd
1919
import numpy as np
@@ -140,7 +140,7 @@ def wrapped(args: dict, env: EvalEnv):
140140

141141
# Type hint alias for a "process function":
142142
# a Python function that implements some openEO process (as used in `apply_process`)
143-
ProcessFunction = Callable[[dict, EvalEnv], Any]
143+
ProcessFunction = Callable[[Union[dict, ProcessArgs], EvalEnv], Any]
144144

145145

146146
def process(f: ProcessFunction) -> ProcessFunction:
@@ -750,14 +750,15 @@ def load_collection(args: dict, env: EvalEnv) -> DriverDataCube:
750750
.param(name='options', description="options specific to the file format", schema={"type": "object"})
751751
.returns(description="the data as a data cube", schema={})
752752
)
753-
def load_disk_data(args: Dict, env: EvalEnv) -> DriverDataCube:
753+
def load_disk_data(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
754754
"""
755755
Deprecated, use load_uploaded_files or load_stac
756756
"""
757+
_log.warning("Deprecated: usage of load_disk_data")
757758
kwargs = dict(
758-
glob_pattern=extract_arg(args, 'glob_pattern'),
759-
format=extract_arg(args, 'format'),
760-
options=args.get('options', {}),
759+
glob_pattern=args.get_required("glob_pattern", expected_type=str),
760+
format=args.get_required("format", expected_type=str),
761+
options=args.get_optional("options", default={}, expected_type=dict),
761762
)
762763
dry_run_tracer: DryRunDataTracer = env.get(ENV_DRY_RUN_TRACER)
763764
if dry_run_tracer:
@@ -916,22 +917,18 @@ def save_result(args: Dict, env: EvalEnv) -> SaveResult: # TODO: return type no
916917

917918
@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/save_ml_model.json"))
918919
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/save_ml_model.json"))
919-
def save_ml_model(args: dict, env: EvalEnv) -> MlModelResult:
920-
data: DriverMlModel = extract_arg(args, "data", process_id="save_ml_model")
921-
if not isinstance(data, DriverMlModel):
922-
raise ProcessParameterInvalidException(
923-
parameter="data", process="save_ml_model", reason=f"Invalid data type {type(data)!r} expected raster-cube."
924-
)
925-
options = args.get("options", {})
920+
def save_ml_model(args: ProcessArgs, env: EvalEnv) -> MlModelResult:
921+
data = args.get_required("data", expected_type=DriverMlModel)
922+
options = args.get_optional("options", default={}, expected_type=dict)
926923
return MlModelResult(ml_model=data, options=options)
927924

928925

929926
@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/load_ml_model.json"))
930927
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/load_ml_model.json"))
931-
def load_ml_model(args: dict, env: EvalEnv) -> DriverMlModel:
928+
def load_ml_model(args: ProcessArgs, env: EvalEnv) -> DriverMlModel:
932929
if env.get(ENV_DRY_RUN_TRACER):
933930
return DriverMlModel()
934-
job_id = extract_arg(args, "id")
931+
job_id = args.get_required("id", expected_type=str)
935932
return env.backend_implementation.load_ml_model(job_id)
936933

937934

@@ -1138,19 +1135,19 @@ def get_validated_parameter(args, param_name, default_value, expected_type, min_
11381135

11391136
@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/predict_random_forest.json"))
11401137
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/predict_random_forest.json"))
1141-
def predict_random_forest(args: dict, env: EvalEnv):
1138+
def predict_random_forest(args: ProcessArgs, env: EvalEnv):
11421139
raise NotImplementedError
11431140

11441141

11451142
@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/predict_catboost.json"))
11461143
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/predict_catboost.json"))
1147-
def predict_catboost(args: dict, env: EvalEnv):
1144+
def predict_catboost(args: ProcessArgs, env: EvalEnv):
11481145
raise NotImplementedError
11491146

11501147

11511148
@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/predict_probabilities.json"))
11521149
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/predict_probabilities.json"))
1153-
def predict_probabilities(args: dict, env: EvalEnv):
1150+
def predict_probabilities(args: ProcessArgs, env: EvalEnv):
11541151
raise NotImplementedError
11551152

11561153

@@ -1165,51 +1162,34 @@ def add_dimension(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
11651162

11661163

11671164
@process
1168-
def drop_dimension(args: dict, env: EvalEnv) -> DriverDataCube:
1169-
data_cube = extract_arg(args, 'data')
1170-
if not isinstance(data_cube, DriverDataCube):
1171-
raise ProcessParameterInvalidException(
1172-
parameter="data", process="drop_dimension",
1173-
reason=f"Invalid data type {type(data_cube)!r} expected raster-cube."
1174-
)
1175-
return data_cube.drop_dimension(name=extract_arg(args, 'name'))
1165+
def drop_dimension(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1166+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1167+
name: str = args.get_required("name", expected_type=str)
1168+
return cube.drop_dimension(name=name)
11761169

11771170

11781171
@process
1179-
def dimension_labels(args: dict, env: EvalEnv) -> DriverDataCube:
1180-
data_cube = extract_arg(args, 'data')
1181-
if not isinstance(data_cube, DriverDataCube):
1182-
raise ProcessParameterInvalidException(
1183-
parameter="data", process="dimension_labels",
1184-
reason=f"Invalid data type {type(data_cube)!r} expected raster-cube."
1185-
)
1186-
return data_cube.dimension_labels(dimension=extract_arg(args, 'dimension'))
1172+
def dimension_labels(args: ProcessArgs, env: EvalEnv) -> List[str]:
1173+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1174+
dimension: str = args.get_required("dimension", expected_type=str)
1175+
return cube.dimension_labels(dimension=dimension)
11871176

11881177

11891178
@process
1190-
def rename_dimension(args: dict, env: EvalEnv) -> DriverDataCube:
1191-
data_cube = extract_arg(args, 'data')
1192-
if not isinstance(data_cube, DriverDataCube):
1193-
raise ProcessParameterInvalidException(
1194-
parameter="data", process="rename_dimension",
1195-
reason=f"Invalid data type {type(data_cube)!r} expected raster-cube."
1196-
)
1197-
return data_cube.rename_dimension(source=extract_arg(args, 'source'),target=extract_arg(args, 'target'))
1179+
def rename_dimension(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1180+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1181+
source: str = args.get_required("source", expected_type=str)
1182+
target: str = args.get_required("target", expected_type=str)
1183+
return cube.rename_dimension(source=source, target=target)
11981184

11991185

12001186
@process
1201-
def rename_labels(args: dict, env: EvalEnv) -> DriverDataCube:
1202-
data_cube = extract_arg(args, 'data')
1203-
if not isinstance(data_cube, DriverDataCube):
1204-
raise ProcessParameterInvalidException(
1205-
parameter="data", process="rename_labels",
1206-
reason=f"Invalid data type {type(data_cube)!r} expected raster-cube."
1207-
)
1208-
return data_cube.rename_labels(
1209-
dimension=extract_arg(args, 'dimension'),
1210-
target=extract_arg(args, 'target'),
1211-
source=args.get('source',[])
1212-
)
1187+
def rename_labels(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1188+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1189+
dimension: str = args.get_required("dimension", expected_type=str)
1190+
target: List[str] = args.get_required("target", expected_type=list)
1191+
source: Optional[str] = args.get_optional("source", default=None)
1192+
return cube.rename_labels(dimension=dimension, target=target, source=source)
12131193

12141194

12151195
@process
@@ -1355,14 +1335,10 @@ def aggregate_spatial(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
13551335

13561336

13571337
@process
1358-
def mask(args: dict, env: EvalEnv) -> DriverDataCube:
1359-
cube = extract_arg(args, 'data')
1360-
if not isinstance(cube, DriverDataCube):
1361-
raise ProcessParameterInvalidException(
1362-
parameter="data", process="mask", reason=f"Invalid data type {type(cube)!r} expected raster-cube."
1363-
)
1364-
mask = extract_arg(args, 'mask')
1365-
replacement = args.get('replacement', None)
1338+
def mask(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1339+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1340+
mask = args.get_required("mask")
1341+
replacement = args.get_optional("replacement", default=None)
13661342
return cube.mask(mask=mask, replacement=replacement)
13671343

13681344

@@ -1394,7 +1370,10 @@ def mask_polygon(args: dict, env: EvalEnv) -> DriverDataCube:
13941370
return image_collection
13951371

13961372

1397-
def _extract_temporal_extent(args: dict, field="extent", process_id="filter_temporal") -> Tuple[str, str]:
1373+
def _extract_temporal_extent(
1374+
args: Union[dict, ProcessArgs], field="extent", process_id="filter_temporal"
1375+
) -> Tuple[str, str]:
1376+
# TODO: make this a ProcessArgs method?
13981377
extent = extract_arg(args, name=field, process_id=process_id)
13991378
if len(extent) != 2:
14001379
raise ProcessParameterInvalidException(
@@ -1419,29 +1398,27 @@ def _extract_temporal_extent(args: dict, field="extent", process_id="filter_temp
14191398

14201399

14211400
@process
1422-
def filter_temporal(args: dict, env: EvalEnv) -> DriverDataCube:
1423-
cube = extract_arg(args, 'data')
1424-
if not isinstance(cube, DriverDataCube):
1425-
raise ProcessParameterInvalidException(
1426-
parameter="data", process="filter_temporal",
1427-
reason=f"Invalid data type {type(cube)!r} expected raster-cube."
1428-
)
1401+
def filter_temporal(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1402+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
14291403
extent = _extract_temporal_extent(args, field="extent", process_id="filter_temporal")
14301404
return cube.filter_temporal(start=extent[0], end=extent[1])
14311405

1406+
14321407
@process_registry_100.add_function(spec=read_spec("openeo-processes/1.x/proposals/filter_labels.json"))
14331408
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/2.x/proposals/filter_labels.json"))
1434-
def filter_labels(args: dict, env: EvalEnv) -> DriverDataCube:
1435-
cube = extract_arg(args, 'data')
1436-
if not isinstance(cube, DriverDataCube):
1437-
raise ProcessParameterInvalidException(
1438-
parameter="data", process="filter_labels",
1439-
reason=f"Invalid data type {type(cube)!r} expected cube."
1440-
)
1409+
def filter_labels(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1410+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1411+
# TODO: validation that condition is a process graph construct
1412+
condition = args.get_required("condition", expected_type=dict)
1413+
dimension = args.get_required("dimension", expected_type=str)
1414+
context = args.get_optional("context", default=None)
1415+
return cube.filter_labels(condition=condition, dimension=dimension, context=context, env=env)
14411416

1442-
return cube.filter_labels(condition=extract_arg(args,"condition"),dimension=extract_arg(args,"dimension"),context=args.get("context",None),env=env)
14431417

1444-
def _extract_bbox_extent(args: dict, field="extent", process_id="filter_bbox", handle_geojson=False) -> dict:
1418+
def _extract_bbox_extent(
1419+
args: Union[dict, ProcessArgs], field="extent", process_id="filter_bbox", handle_geojson=False
1420+
) -> dict:
1421+
# TODO: make this a ProcessArgs method?
14451422
extent = extract_arg(args, name=field, process_id=process_id)
14461423
if handle_geojson and extent.get("type") in [
14471424
"Polygon",
@@ -1466,24 +1443,16 @@ def _extract_bbox_extent(args: dict, field="extent", process_id="filter_bbox", h
14661443

14671444

14681445
@process
1469-
def filter_bbox(args: Dict, env: EvalEnv) -> DriverDataCube:
1470-
cube = extract_arg(args, 'data')
1471-
if not isinstance(cube, DriverDataCube):
1472-
raise ProcessParameterInvalidException(
1473-
parameter="data", process="filter_bbox", reason=f"Invalid data type {type(cube)!r} expected raster-cube."
1474-
)
1446+
def filter_bbox(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1447+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
14751448
spatial_extent = _extract_bbox_extent(args, "extent", process_id="filter_bbox")
14761449
return cube.filter_bbox(**spatial_extent)
14771450

14781451

14791452
@process
1480-
def filter_spatial(args: Dict, env: EvalEnv) -> DriverDataCube:
1481-
cube = extract_arg(args, 'data')
1482-
geometries = extract_arg(args, 'geometries')
1483-
if not isinstance(cube, DriverDataCube):
1484-
raise ProcessParameterInvalidException(
1485-
parameter="data", process="filter_spatial", reason=f"Invalid data type {type(cube)!r} expected raster-cube."
1486-
)
1453+
def filter_spatial(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1454+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1455+
geometries = args.get_required("geometries")
14871456

14881457
if isinstance(geometries, dict):
14891458
if "type" in geometries and geometries["type"] != "GeometryCollection":
@@ -1512,32 +1481,22 @@ def filter_spatial(args: Dict, env: EvalEnv) -> DriverDataCube:
15121481

15131482

15141483
@process
1515-
def filter_bands(args: Dict, env: EvalEnv) -> Union[DriverDataCube, DriverVectorCube]:
1516-
cube: Union[DriverDataCube, DriverVectorCube] = extract_arg(args, "data")
1517-
if not isinstance(cube, DriverDataCube) and not isinstance(cube, DriverVectorCube):
1518-
raise ProcessParameterInvalidException(
1519-
parameter="data", process="filter_bands", reason=f"Invalid data type {type(cube)!r} expected raster-cube."
1520-
)
1521-
bands = extract_arg(args, "bands", process_id="filter_bands")
1484+
def filter_bands(args: ProcessArgs, env: EvalEnv) -> Union[DriverDataCube, DriverVectorCube]:
1485+
cube: Union[DriverDataCube, DriverVectorCube] = args.get_required(
1486+
"data", expected_type=(DriverDataCube, DriverVectorCube)
1487+
)
1488+
bands = args.get_required("bands", expected_type=list)
15221489
return cube.filter_bands(bands=bands)
15231490

15241491

15251492
@process
1526-
def apply_kernel(args: Dict, env: EvalEnv) -> DriverDataCube:
1527-
image_collection = extract_arg(args, 'data')
1528-
kernel = np.asarray(extract_arg(args, 'kernel'))
1529-
factor = args.get('factor', 1.0)
1530-
border = args.get('border', 0)
1531-
if not isinstance(image_collection, DriverDataCube):
1532-
raise ProcessParameterInvalidException(
1533-
parameter="data", process="apply_kernel",
1534-
reason=f"Invalid data type {type(image_collection)!r} expected raster-cube."
1535-
)
1536-
if border == "0":
1537-
# R-client sends `0` border as a string
1538-
border = 0
1539-
replace_invalid = args.get('replace_invalid', 0)
1540-
return image_collection.apply_kernel(kernel=kernel, factor=factor, border=border, replace_invalid=replace_invalid)
1493+
def apply_kernel(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
1494+
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
1495+
kernel = np.asarray(args.get_required("kernel", expected_type=list))
1496+
factor = args.get_optional("factor", default=1.0, expected_type=(int, float))
1497+
border = args.get_optional("border", default=0, expected_type=int)
1498+
replace_invalid = args.get_optional("replace_invalid", default=0, expected_type=(int, float))
1499+
return cube.apply_kernel(kernel=kernel, factor=factor, border=border, replace_invalid=replace_invalid)
15411500

15421501

15431502
@process

openeo_driver/datacube.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,15 @@ def filter_spatial(self, geometries) -> 'DriverDataCube':
8585
def filter_bands(self, bands) -> 'DriverDataCube':
8686
self._not_implemented()
8787

88-
def filter_labels(self, condition: dict,dimensin: str, context: Optional[dict] = None, env: EvalEnv = None ) -> 'DriverDataCube':
88+
def filter_labels(
89+
self, condition: dict, dimension: str, context: Optional[dict] = None, env: EvalEnv = None
90+
) -> "DriverDataCube":
8991
self._not_implemented()
9092

9193
def apply(self, process: dict, *, context: Optional[dict] = None, env: EvalEnv) -> "DriverDataCube":
9294
self._not_implemented()
9395

94-
def apply_kernel(self, kernel: list, factor=1, border=0, replace_invalid=0) -> 'DriverDataCube':
96+
def apply_kernel(self, kernel: numpy.ndarray, factor=1, border=0, replace_invalid=0) -> "DriverDataCube":
9597
self._not_implemented()
9698

9799
def apply_neighborhood(

openeo_driver/processes.py

+3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66
from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Union
77

8+
from openeo_driver.datacube import DriverDataCube
89
from openeo_driver.errors import (
910
OpenEOApiException,
1011
ProcessParameterInvalidException,
@@ -325,6 +326,8 @@ def _check_value(
325326
):
326327
if expected_type:
327328
if not isinstance(value, expected_type):
329+
if expected_type is DriverDataCube:
330+
expected_type = "raster cube"
328331
raise ProcessParameterInvalidException(
329332
parameter=name, process=self.process_id, reason=f"Expected {expected_type} but got {type(value)}."
330333
)

0 commit comments

Comments
 (0)