Skip to content

first implementation of run_udf #307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions openeo_processes_dask/process_implementations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .cubes import *
from .logic import *
from .math import *
from .udf import *

try:
from .ml import *
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .udf import run_udf
23 changes: 23 additions & 0 deletions openeo_processes_dask/process_implementations/udf/udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import xarray as xr
import dask.array as da

from typing import Optional

from openeo_processes_dask.process_implementations.data_model import RasterCube
from openeo.udf import UdfData
from openeo.udf.run_code import run_udf_code
from openeo.udf.xarraydatacube import XarrayDataCube

__all__ = ["run_udf"]


def run_udf(data: da.Array, udf: str, runtime: str, context: Optional[dict] = None
) -> RasterCube:
data = XarrayDataCube(xr.DataArray(data))
data = UdfData(proj={"EPSG": 900913}, datacube_list=[data], user_context=context)
result = run_udf_code(code=udf, data=data)
cubes = result.get_datacube_list()
if len(cubes) != 1:
raise ValueError(f"The provided UDF should return one datacube, but got: {result}")
result_array: xr.DataArray = cubes[0].array
return result_array
2 changes: 1 addition & 1 deletion openeo_processes_dask/specs/openeo-processes
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def polygon_geometry_small(

@pytest.fixture
def temporal_interval(interval=["2018-05-01", "2018-06-01"]) -> TemporalInterval:
return TemporalInterval.parse_obj(interval)
return TemporalInterval(interval)


@pytest.fixture
Expand Down
4 changes: 2 additions & 2 deletions tests/mockdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def create_fake_rastercube(
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
t_coords = pd.date_range(
start=np.datetime64(temporal_extent.__root__[0].__root__),
end=np.datetime64(temporal_extent.__root__[1].__root__),
start=temporal_extent.start.to_numpy(),
end=temporal_extent.end.to_numpy(),
periods=data.shape[2],
).values

Expand Down
39 changes: 39 additions & 0 deletions tests/test_udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np
import openeo
import pytest
import xarray as xr

from openeo_processes_dask.process_implementations.udf.udf import run_udf
from tests.general_checks import general_output_checks
from tests.mockdata import create_fake_rastercube


@pytest.mark.parametrize("size", [(6, 5, 4, 4)])
@pytest.mark.parametrize("dtype", [np.float32])
def test_run_udf(temporal_interval, bounding_box, random_raster_data):
input_cube = create_fake_rastercube(
data=random_raster_data,
spatial_extent=bounding_box,
temporal_extent=temporal_interval,
bands=["B02", "B03", "B04", "B08"],
backend="dask",
)

udf = """
from openeo.udf import XarrayDataCube

def apply_datacube(cube: XarrayDataCube, context: dict) -> XarrayDataCube:
return cube
"""

output_cube = run_udf(data=input_cube, udf=udf, runtime="Python")

general_output_checks(
input_cube=input_cube,
output_cube=output_cube,
verify_attrs=True,
verify_crs=True,
expected_results=input_cube,
)

xr.testing.assert_equal(output_cube, input_cube)