Skip to content

Commit 329469b

Browse files
committed
Issue #115 CrossBackendSplitter: add "streamed" split to allow injecting batch job ids on the fly
1 parent 40474c8 commit 329469b

File tree

2 files changed

+247
-36
lines changed

2 files changed

+247
-36
lines changed

src/openeo_aggregator/partitionedjobs/crossbackend.py

+85-33
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import time
77
from contextlib import nullcontext
8-
from typing import Callable, Dict, List, Sequence
8+
from typing import Callable, Dict, Iterator, List, Optional, Protocol, Sequence, Tuple
99

1010
import openeo
1111
from openeo import BatchJob
@@ -20,6 +20,42 @@
2020

2121
_LOAD_RESULT_PLACEHOLDER = "_placeholder:"
2222

23+
# Some type annotation aliases to make things more self-documenting
24+
SubGraphId = str
25+
26+
27+
class GetReplacementCallable(Protocol):
28+
"""
29+
Type annotation for callback functions that produce a node replacement
30+
for a node that is split off from the main process graph
31+
32+
Also see `_default_get_replacement`
33+
"""
34+
35+
def __call__(self, node_id: str, node: dict, subgraph_id: SubGraphId) -> dict:
36+
"""
37+
:param node_id: original id of the node in the process graph (e.g. `loadcollection2`)
38+
:param node: original node in the process graph (e.g. `{"process_id": "load_collection", "arguments": {...}}`)
39+
:param subgraph_id: id of the corresponding dependency subgraph
40+
(to be handled as opaque id, but possibly something like `backend1:loadcollection2`)
41+
42+
:return: new process graph nodes. Should contain at least a node keyed under `node_id`
43+
"""
44+
...
45+
46+
47+
def _default_get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId) -> dict:
48+
"""
49+
Default `get_replacement` function to replace a node that has been split off.
50+
"""
51+
return {
52+
node_id: {
53+
# TODO: use `load_stac` iso `load_result`
54+
"process_id": "load_result",
55+
"arguments": {"id": f"{_LOAD_RESULT_PLACEHOLDER}{subgraph_id}"},
56+
}
57+
}
58+
2359

2460
class CrossBackendSplitter(AbstractJobSplitter):
2561
"""
@@ -42,10 +78,25 @@ def __init__(
4278
self.backend_for_collection = backend_for_collection
4379
self._always_split = always_split
4480

45-
def split(
46-
self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None
47-
) -> PartitionedJob:
48-
process_graph = process["process_graph"]
81+
def split_streaming(
82+
self,
83+
process_graph: FlatPG,
84+
get_replacement: GetReplacementCallable = _default_get_replacement,
85+
) -> Iterator[Tuple[SubGraphId, SubJob, List[SubGraphId]]]:
86+
"""
87+
Split given process graph in sub-process graphs and return these as an iterator
88+
in an order so that a subgraph comes after all subgraphs it depends on
89+
(e.g. main "primary" graph comes last).
90+
91+
The iterator approach allows working with a dynamic `get_replacement` implementation
92+
that adapting to on previously produced subgraphs
93+
(e.g. creating openEO batch jobs on the fly and injecting the corresponding batch job ids appropriately).
94+
95+
:return: tuple containing:
96+
- subgraph id
97+
- SubJob
98+
- dependencies as list of subgraph ids
99+
"""
49100

50101
# Extract necessary back-ends from `load_collection` usage
51102
backend_per_collection: Dict[str, str] = {
@@ -57,55 +108,60 @@ def split(
57108
backend_usage = collections.Counter(backend_per_collection.values())
58109
_log.info(f"Extracted backend usage from `load_collection` nodes: {backend_usage=} {backend_per_collection=}")
59110

111+
# TODO: more options to determine primary backend?
60112
primary_backend = backend_usage.most_common(1)[0][0] if backend_usage else None
61113
secondary_backends = {b for b in backend_usage if b != primary_backend}
62114
_log.info(f"Backend split: {primary_backend=} {secondary_backends=}")
63115

64116
primary_id = "main"
65-
primary_pg = SubJob(process_graph={}, backend_id=primary_backend)
117+
primary_pg = {}
66118
primary_has_load_collection = False
67-
68-
subjobs: Dict[str, SubJob] = {primary_id: primary_pg}
69-
dependencies: Dict[str, List[str]] = {primary_id: []}
119+
primary_dependencies = []
70120

71121
for node_id, node in process_graph.items():
72122
if node["process_id"] == "load_collection":
73123
bid = backend_per_collection[node["arguments"]["id"]]
74-
if bid == primary_backend and not (
75-
self._always_split and primary_has_load_collection
76-
):
124+
if bid == primary_backend and (not self._always_split or not primary_has_load_collection):
77125
# Add to primary pg
78-
primary_pg.process_graph[node_id] = node
126+
primary_pg[node_id] = node
79127
primary_has_load_collection = True
80128
else:
81129
# New secondary pg
82-
pg = {
130+
sub_id = f"{bid}:{node_id}"
131+
sub_pg = {
83132
node_id: node,
84133
"sr1": {
85134
# TODO: other/better choices for save_result format (e.g. based on backend support)?
86-
# TODO: particular format options?
87135
"process_id": "save_result",
88136
"arguments": {
89137
"data": {"from_node": node_id},
138+
# TODO: particular format options?
90139
# "format": "NetCDF",
91140
"format": "GTiff",
92141
},
93142
"result": True,
94143
},
95144
}
96-
dependency_id = f"{bid}:{node_id}"
97-
subjobs[dependency_id] = SubJob(process_graph=pg, backend_id=bid)
98-
dependencies[primary_id].append(dependency_id)
99-
# Link to primary pg with load_result
100-
primary_pg.process_graph[node_id] = {
101-
# TODO: encapsulate this placeholder process/id better?
102-
"process_id": "load_result",
103-
"arguments": {
104-
"id": f"{_LOAD_RESULT_PLACEHOLDER}{dependency_id}"
105-
},
106-
}
145+
146+
yield (sub_id, SubJob(process_graph=sub_pg, backend_id=bid), [])
147+
148+
# Link secondary pg into primary pg
149+
primary_pg.update(get_replacement(node_id=node_id, node=node, subgraph_id=sub_id))
150+
primary_dependencies.append(sub_id)
107151
else:
108-
primary_pg.process_graph[node_id] = node
152+
primary_pg[node_id] = node
153+
154+
yield (primary_id, SubJob(process_graph=primary_pg, backend_id=primary_backend), primary_dependencies)
155+
156+
def split(self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None) -> PartitionedJob:
157+
"""Split given process graph into a `PartitionedJob`"""
158+
159+
subjobs: Dict[SubGraphId, SubJob] = {}
160+
dependencies: Dict[SubGraphId, List[SubGraphId]] = {}
161+
for sub_id, subjob, sub_dependencies in self.split_streaming(process_graph=process["process_graph"]):
162+
subjobs[sub_id] = subjob
163+
if sub_dependencies:
164+
dependencies[sub_id] = sub_dependencies
109165

110166
return PartitionedJob(
111167
process=process,
@@ -116,9 +172,7 @@ def split(
116172
)
117173

118174

119-
def resolve_dependencies(
120-
process_graph: FlatPG, batch_jobs: Dict[str, BatchJob]
121-
) -> FlatPG:
175+
def _resolve_dependencies(process_graph: FlatPG, batch_jobs: Dict[str, BatchJob]) -> FlatPG:
122176
"""
123177
Replace placeholders in given process graph
124178
based on given subjob_id to batch_job_id mapping.
@@ -235,9 +289,7 @@ def run_partitioned_job(
235289
# Handle job (start, poll status, ...)
236290
if states[subjob_id] == SUBJOB_STATES.READY:
237291
try:
238-
process_graph = resolve_dependencies(
239-
subjob.process_graph, batch_jobs=batch_jobs
240-
)
292+
process_graph = _resolve_dependencies(subjob.process_graph, batch_jobs=batch_jobs)
241293

242294
_log.info(
243295
f"Starting new batch job for subjob {subjob_id!r} on backend {subjob.backend_id!r}"

tests/partitionedjobs/test_crossbackend.py

+162-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import dataclasses
22
import re
3+
import types
34
from typing import Dict, List, Optional
45
from unittest import mock
56

@@ -13,22 +14,30 @@
1314
from openeo_aggregator.partitionedjobs import PartitionedJob, SubJob
1415
from openeo_aggregator.partitionedjobs.crossbackend import (
1516
CrossBackendSplitter,
17+
SubGraphId,
1618
run_partitioned_job,
1719
)
1820

1921

2022
class TestCrossBackendSplitter:
21-
def test_simple(self):
23+
def test_split_simple(self):
2224
process_graph = {
2325
"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}
2426
}
2527
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: "foo")
2628
res = splitter.split({"process_graph": process_graph})
2729

2830
assert res.subjobs == {"main": SubJob(process_graph, backend_id=None)}
29-
assert res.dependencies == {"main": []}
31+
assert res.dependencies == {}
3032

31-
def test_basic(self):
33+
def test_split_streaming_simple(self):
34+
process_graph = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}}
35+
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: "foo")
36+
res = splitter.split_streaming(process_graph)
37+
assert isinstance(res, types.GeneratorType)
38+
assert list(res) == [("main", SubJob(process_graph, backend_id=None), [])]
39+
40+
def test_split_basic(self):
3241
process_graph = {
3342
"lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}},
3443
"lc2": {"process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}},
@@ -93,6 +102,156 @@ def test_basic(self):
93102
}
94103
assert res.dependencies == {"main": ["B2:lc2"]}
95104

105+
def test_split_streaming_basic(self):
106+
process_graph = {
107+
"lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}},
108+
"lc2": {"process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}},
109+
"mc1": {
110+
"process_id": "merge_cubes",
111+
"arguments": {
112+
"cube1": {"from_node": "lc1"},
113+
"cube2": {"from_node": "lc2"},
114+
},
115+
},
116+
"sr1": {
117+
"process_id": "save_result",
118+
"arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"},
119+
"result": True,
120+
},
121+
}
122+
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
123+
result = splitter.split_streaming(process_graph)
124+
assert isinstance(result, types.GeneratorType)
125+
126+
assert list(result) == [
127+
(
128+
"B2:lc2",
129+
SubJob(
130+
process_graph={
131+
"lc2": {
132+
"process_id": "load_collection",
133+
"arguments": {"id": "B2_FAPAR"},
134+
},
135+
"sr1": {
136+
"process_id": "save_result",
137+
"arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"},
138+
"result": True,
139+
},
140+
},
141+
backend_id="B2",
142+
),
143+
[],
144+
),
145+
(
146+
"main",
147+
SubJob(
148+
process_graph={
149+
"lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}},
150+
"lc2": {"process_id": "load_result", "arguments": {"id": "_placeholder:B2:lc2"}},
151+
"mc1": {
152+
"process_id": "merge_cubes",
153+
"arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}},
154+
},
155+
"sr1": {
156+
"process_id": "save_result",
157+
"arguments": {"data": {"from_node": "mc1"}, "format": "NetCDF"},
158+
"result": True,
159+
},
160+
},
161+
backend_id="B1",
162+
),
163+
["B2:lc2"],
164+
),
165+
]
166+
167+
def test_split_streaming_get_replacement(self):
168+
process_graph = {
169+
"lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}},
170+
"lc2": {"process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}},
171+
"lc3": {"process_id": "load_collection", "arguments": {"id": "B3_SCL"}},
172+
"merge": {
173+
"process_id": "merge",
174+
"arguments": {
175+
"cube1": {"from_node": "lc1"},
176+
"cube2": {"from_node": "lc2"},
177+
"cube3": {"from_node": "lc3"},
178+
},
179+
"result": True,
180+
},
181+
}
182+
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
183+
184+
batch_jobs = {}
185+
186+
def get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId) -> dict:
187+
return {
188+
node_id: {
189+
"process_id": "load_batch_job",
190+
"arguments": {"batch_job": batch_jobs[subgraph_id]},
191+
}
192+
}
193+
194+
substream = splitter.split_streaming(process_graph, get_replacement=get_replacement)
195+
196+
result = []
197+
for subgraph_id, subjob, dependencies in substream:
198+
batch_jobs[subgraph_id] = f"job-{111 * (len(batch_jobs) + 1)}"
199+
result.append((subgraph_id, subjob, dependencies))
200+
201+
assert list(result) == [
202+
(
203+
"B2:lc2",
204+
SubJob(
205+
process_graph={
206+
"lc2": {"process_id": "load_collection", "arguments": {"id": "B2_FAPAR"}},
207+
"sr1": {
208+
"process_id": "save_result",
209+
"arguments": {"data": {"from_node": "lc2"}, "format": "GTiff"},
210+
"result": True,
211+
},
212+
},
213+
backend_id="B2",
214+
),
215+
[],
216+
),
217+
(
218+
"B3:lc3",
219+
SubJob(
220+
process_graph={
221+
"lc3": {"process_id": "load_collection", "arguments": {"id": "B3_SCL"}},
222+
"sr1": {
223+
"process_id": "save_result",
224+
"arguments": {"data": {"from_node": "lc3"}, "format": "GTiff"},
225+
"result": True,
226+
},
227+
},
228+
backend_id="B3",
229+
),
230+
[],
231+
),
232+
(
233+
"main",
234+
SubJob(
235+
process_graph={
236+
"lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}},
237+
"lc2": {"process_id": "load_batch_job", "arguments": {"batch_job": "job-111"}},
238+
"lc3": {"process_id": "load_batch_job", "arguments": {"batch_job": "job-222"}},
239+
"merge": {
240+
"process_id": "merge",
241+
"arguments": {
242+
"cube1": {"from_node": "lc1"},
243+
"cube2": {"from_node": "lc2"},
244+
"cube3": {"from_node": "lc3"},
245+
},
246+
"result": True,
247+
},
248+
},
249+
backend_id="B1",
250+
),
251+
["B2:lc2", "B3:lc3"],
252+
),
253+
]
254+
96255

97256
@dataclasses.dataclass
98257
class _FakeJob:

0 commit comments

Comments
 (0)