Skip to content

Integrate yprov4wfs into OpenEO (openeo_pg_parser_networkx) #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 163 additions & 6 deletions openeo_pg_parser_networkx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@
ProcessGraphUnflattener,
parse_nested_parameter,
)
import xarray as xr
import dask.array as da
from datetime import datetime
import os
from functools import wraps
## For yprov4wfs
import json
from yprov4wfs.datamodel.workflow import Workflow
from yprov4wfs.datamodel.task import Task
from yprov4wfs.datamodel.data import Data
import uuid


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,6 +78,10 @@ def __repr__(self):

class OpenEOProcessGraph:
def __init__(self, pg_data: dict):
# Make a workflow object
self.workflow = Workflow('wfs1', 'Workflow 1')
self.workflow._engineWMS = "Openeo-LocalProcessing"
self.workflow._level= "0"
self.G = nx.DiGraph()

nested_raw_graph = self._unflatten_raw_process_graph(pg_data)
Expand Down Expand Up @@ -295,7 +311,7 @@ def to_callable(
return self._map_node_to_callable(
self.result_node, process_registry, results_cache, parameters
)

def _map_node_to_callable(
self,
node: str,
Expand Down Expand Up @@ -352,8 +368,27 @@ def node_callable(*args, parent_callables, named_parameters=None, **kwargs):
for func in parent_callables:
func(*args, named_parameters=named_parameters, **kwargs)

cache_users = {}
try:
# If this node has already been computed once, just grab that result from the results_cache instead of recomputing it.
# This cannot be done for aggregated data as the wrapped function has to be called multiple times with different values.
# This also means the results_cache will be useless for these functions.
# TODO: track how often functions need to be called and check if they have been called that many times, if yes, we can
# use the cache for aggregate functions, but this is probably not super necessary
parent_node_id = [edge[0] for edge in self.edges if edge[1] == node]

if parent_node_id:
parent_node_process_id = [
n[1]["process_id"]
for n in self.nodes
if n[0] == parent_node_id[0]
]

if parent_node_process_id and parent_node_process_id[0] in [
"aggregate_temporal_period",
"aggregate_spatial",
]:
raise KeyError()
return results_cache.__getitem__(node)
except KeyError:
for _, source_node, data in self.G.out_edges(node, data=True):
Expand All @@ -366,13 +401,102 @@ def node_callable(*args, parent_callables, named_parameters=None, **kwargs):
kwargs[arg_sub.arg_name] = self.G.nodes(data=True)[node][
"resolved_kwargs"
].__getitem__(arg_sub.arg_name)

result = prebaked_process_impl(
*args, named_parameters=named_parameters, **kwargs
)

# Make a dictionary from the nodes that uses the outputs of the other nodes
if source_node not in cache_users:
cache_users[source_node] = []
cache_users[source_node].append(node)
# Make the tasks
task = Task(node, node_with_data['process_id'])
result, execution_data= self.profile_function(prebaked_process_impl)(
*args, named_parameters=named_parameters, **kwargs
)

if isinstance(result, xr.DataArray):
processed_result = {
"entity_type": "xarray.DataArray",
"info": {
"shape": result.shape,
"dimensions": list(result.dims),
# "attributes": result.attrs,
"dtype": str(result.dtype)
}
}

elif isinstance(result, da.Array):
processed_result = {
"entity_type": "dask.Array",
"info": {
"shape": result.shape,
"dtype": str(result.dtype),
"chunk_size": result.chunksize,
"chunk_type": type(result._meta).__name__
}
}
else:
processed_result = {}
processed_result['info'] = result
processed_result['entity_type'] = type(result).__name__
if result is not None:
results_cache_node = Data(str((uuid.uuid4())), processed_result['entity_type'])
results_cache_node._info = processed_result['info']
task.add_output(results_cache_node)
self.workflow.add_data(results_cache_node)
results_cache[node] = result

# Loading data info
process_id = node_with_data.get("process_id")
resolved_kwargs = node_with_data.get("resolved_kwargs", {})

if process_id in ("load_stac", "load_collection"):
key = "url" if process_id == "load_stac" else "id"
raw_source = resolved_kwargs.get(key, "")
data_source = raw_source.split("\\")[-1]

data_src = Data(str(uuid.uuid4()), data_source)
# Extract extra information
if process_id == "load_stac":
data_src._info = resolved_kwargs


task._start_time = execution_data['start_time']
task._end_time = execution_data['end_time']
task._status = execution_data['task_status']
task._level = "1"

# This is just for load stac ( for the temporary usage)
if node_with_data['process_id'] in ["load_stac", "load_collection"]:
task.add_input(data_src)

self.workflow.add_task(task)

if cache_users:
for source_node, target_node in cache_users.items():
output_data_from_source = self.workflow.get_task_by_id(source_node)._outputs[0]._id
for target in target_node:
self.workflow.get_task_by_id(target) .add_input(
self.workflow.get_data_by_id(output_data_from_source)
)

edges = [
{"source": source, "target": target, "type": data["reference_type"]}
for source, target, data in self.G.edges(node, data=True)]

for edge in edges:
self.workflow.get_task_by_id(edge['source']).set_next(self.workflow.get_task_by_id(edge['target']))

if node == self.result_node:
self.workflow._status= "Ok"

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = os.path.join(os.getcwd(), f"run_{timestamp}")
print(f"Provenance file saved to: {save_path}")

# Create the new directory
os.makedirs(save_path, exist_ok=True)


self.workflow.prov_to_json(directory_path=save_path)

return result

return partial(node_callable, parent_callables=parent_callables)
Expand Down Expand Up @@ -471,3 +595,36 @@ def plot(self, reverse=False):

if reverse:
self.G = self.G.reverse()

@staticmethod
def profile_function(func):
""" Decorator to track execution performance and return both result and profiling data.
In the case in the future there will be some more metrics of intrest (like cpu and memory
usage) to extract."""

@wraps(func)
def wrapper(*args, named_parameters, **kwargs):
start_dt = datetime.now()
start_timestamp = start_dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]

try:
result = func(*args, named_parameters, **kwargs)
status = "Ok"
except Exception as e:
result = str(e)
status = f"Error: {result[:70]}"

end_dt = datetime.now()
end_timestamp = end_dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
execution_time = (end_dt - start_dt).total_seconds()
execution_data = {
# "function": func.__name__,
"task_status": status,
"start_time": start_timestamp,
"end_time": end_timestamp,
"execution_time_sec": round(execution_time, 4),
}
# Return both the result and profiling data
return result, execution_data

return wrapper