55import logging
66import time
77from contextlib import nullcontext
8- from typing import Callable , Dict , List , Sequence
8+ from typing import Callable , Dict , Iterator , List , Optional , Protocol , Sequence , Tuple
99
1010import openeo
1111from openeo import BatchJob
2020
2121_LOAD_RESULT_PLACEHOLDER = "_placeholder:"
2222
23+ # Some type annotation aliases to make things more self-documenting
24+ SubGraphId = str
25+
26+
27+ class GetReplacementCallable (Protocol ):
28+ """
29+ Type annotation for callback functions that produce a node replacement
30+ for a node that is split off from the main process graph
31+
32+ Also see `_default_get_replacement`
33+ """
34+
35+ def __call__ (self , node_id : str , node : dict , subgraph_id : SubGraphId ) -> dict :
36+ """
37+ :param node_id: original id of the node in the process graph (e.g. `loadcollection2`)
38+ :param node: original node in the process graph (e.g. `{"process_id": "load_collection", "arguments": {...}}`)
39+ :param subgraph_id: id of the corresponding dependency subgraph
40+ (to be handled as opaque id, but possibly something like `backend1:loadcollection2`)
41+
42+ :return: new process graph nodes. Should contain at least a node keyed under `node_id`
43+ """
44+ ...
45+
46+
47+ def _default_get_replacement (node_id : str , node : dict , subgraph_id : SubGraphId ) -> dict :
48+ """
49+ Default `get_replacement` function to replace a node that has been split off.
50+ """
51+ return {
52+ node_id : {
53+ # TODO: use `load_stac` iso `load_result`
54+ "process_id" : "load_result" ,
55+ "arguments" : {"id" : f"{ _LOAD_RESULT_PLACEHOLDER } { subgraph_id } " },
56+ }
57+ }
58+
2359
2460class CrossBackendSplitter (AbstractJobSplitter ):
2561 """
@@ -42,10 +78,25 @@ def __init__(
4278 self .backend_for_collection = backend_for_collection
4379 self ._always_split = always_split
4480
45- def split (
46- self , process : PGWithMetadata , metadata : dict = None , job_options : dict = None
47- ) -> PartitionedJob :
48- process_graph = process ["process_graph" ]
81+ def split_streaming (
82+ self ,
83+ process_graph : FlatPG ,
84+ get_replacement : GetReplacementCallable = _default_get_replacement ,
85+ ) -> Iterator [Tuple [SubGraphId , SubJob , List [SubGraphId ]]]:
86+ """
87+ Split given process graph in sub-process graphs and return these as an iterator
88+ in an order so that a subgraph comes after all subgraphs it depends on
89+ (e.g. main "primary" graph comes last).
90+
91+ The iterator approach allows working with a dynamic `get_replacement` implementation
92+ that adapting to on previously produced subgraphs
93+ (e.g. creating openEO batch jobs on the fly and injecting the corresponding batch job ids appropriately).
94+
95+ :return: tuple containing:
96+ - subgraph id
97+ - SubJob
98+ - dependencies as list of subgraph ids
99+ """
49100
50101 # Extract necessary back-ends from `load_collection` usage
51102 backend_per_collection : Dict [str , str ] = {
@@ -57,55 +108,60 @@ def split(
57108 backend_usage = collections .Counter (backend_per_collection .values ())
58109 _log .info (f"Extracted backend usage from `load_collection` nodes: { backend_usage = } { backend_per_collection = } " )
59110
111+ # TODO: more options to determine primary backend?
60112 primary_backend = backend_usage .most_common (1 )[0 ][0 ] if backend_usage else None
61113 secondary_backends = {b for b in backend_usage if b != primary_backend }
62114 _log .info (f"Backend split: { primary_backend = } { secondary_backends = } " )
63115
64116 primary_id = "main"
65- primary_pg = SubJob ( process_graph = {}, backend_id = primary_backend )
117+ primary_pg = {}
66118 primary_has_load_collection = False
67-
68- subjobs : Dict [str , SubJob ] = {primary_id : primary_pg }
69- dependencies : Dict [str , List [str ]] = {primary_id : []}
119+ primary_dependencies = []
70120
71121 for node_id , node in process_graph .items ():
72122 if node ["process_id" ] == "load_collection" :
73123 bid = backend_per_collection [node ["arguments" ]["id" ]]
74- if bid == primary_backend and not (
75- self ._always_split and primary_has_load_collection
76- ):
124+ if bid == primary_backend and (not self ._always_split or not primary_has_load_collection ):
77125 # Add to primary pg
78- primary_pg . process_graph [node_id ] = node
126+ primary_pg [node_id ] = node
79127 primary_has_load_collection = True
80128 else :
81129 # New secondary pg
82- pg = {
130+ sub_id = f"{ bid } :{ node_id } "
131+ sub_pg = {
83132 node_id : node ,
84133 "sr1" : {
85134 # TODO: other/better choices for save_result format (e.g. based on backend support)?
86- # TODO: particular format options?
87135 "process_id" : "save_result" ,
88136 "arguments" : {
89137 "data" : {"from_node" : node_id },
138+ # TODO: particular format options?
90139 # "format": "NetCDF",
91140 "format" : "GTiff" ,
92141 },
93142 "result" : True ,
94143 },
95144 }
96- dependency_id = f"{ bid } :{ node_id } "
97- subjobs [dependency_id ] = SubJob (process_graph = pg , backend_id = bid )
98- dependencies [primary_id ].append (dependency_id )
99- # Link to primary pg with load_result
100- primary_pg .process_graph [node_id ] = {
101- # TODO: encapsulate this placeholder process/id better?
102- "process_id" : "load_result" ,
103- "arguments" : {
104- "id" : f"{ _LOAD_RESULT_PLACEHOLDER } { dependency_id } "
105- },
106- }
145+
146+ yield (sub_id , SubJob (process_graph = sub_pg , backend_id = bid ), [])
147+
148+ # Link secondary pg into primary pg
149+ primary_pg .update (get_replacement (node_id = node_id , node = node , subgraph_id = sub_id ))
150+ primary_dependencies .append (sub_id )
107151 else :
108- primary_pg .process_graph [node_id ] = node
152+ primary_pg [node_id ] = node
153+
154+ yield (primary_id , SubJob (process_graph = primary_pg , backend_id = primary_backend ), primary_dependencies )
155+
156+ def split (self , process : PGWithMetadata , metadata : dict = None , job_options : dict = None ) -> PartitionedJob :
157+ """Split given process graph into a `PartitionedJob`"""
158+
159+ subjobs : Dict [SubGraphId , SubJob ] = {}
160+ dependencies : Dict [SubGraphId , List [SubGraphId ]] = {}
161+ for sub_id , subjob , sub_dependencies in self .split_streaming (process_graph = process ["process_graph" ]):
162+ subjobs [sub_id ] = subjob
163+ if sub_dependencies :
164+ dependencies [sub_id ] = sub_dependencies
109165
110166 return PartitionedJob (
111167 process = process ,
@@ -116,9 +172,7 @@ def split(
116172 )
117173
118174
119- def resolve_dependencies (
120- process_graph : FlatPG , batch_jobs : Dict [str , BatchJob ]
121- ) -> FlatPG :
175+ def _resolve_dependencies (process_graph : FlatPG , batch_jobs : Dict [str , BatchJob ]) -> FlatPG :
122176 """
123177 Replace placeholders in given process graph
124178 based on given subjob_id to batch_job_id mapping.
@@ -235,9 +289,7 @@ def run_partitioned_job(
235289 # Handle job (start, poll status, ...)
236290 if states [subjob_id ] == SUBJOB_STATES .READY :
237291 try :
238- process_graph = resolve_dependencies (
239- subjob .process_graph , batch_jobs = batch_jobs
240- )
292+ process_graph = _resolve_dependencies (subjob .process_graph , batch_jobs = batch_jobs )
241293
242294 _log .info (
243295 f"Starting new batch job for subjob { subjob_id !r} on backend { subjob .backend_id !r} "
0 commit comments