@@ -111,6 +111,7 @@ def set_backends_for_collection(self, cid: str, backends: Iterable[str]):
111
111
self ._data [cid ]["backends" ] = list (backends )
112
112
113
113
def get_backends_for_collection (self , cid : str ) -> List [str ]:
114
+ """Get backend ids that provide given collection id."""
114
115
if cid not in self ._data :
115
116
raise CollectionNotFoundException (collection_id = cid )
116
117
return self ._data [cid ]["backends" ]
@@ -205,6 +206,11 @@ def evaluate(backend_id, pg):
205
206
206
207
return [functools .partial (evaluate , pg = pg ) for pg in process_graphs ]
207
208
209
+ def get_backends_for_collection (self , cid : str ) -> List [str ]:
210
+ """Get backend ids that provide given collection id."""
211
+ metadata , internal = self ._get_all_metadata_cached ()
212
+ return internal .get_backends_for_collection (cid = cid )
213
+
208
214
def get_backend_candidates_for_collections (self , collections : Iterable [str ]) -> List [str ]:
209
215
"""
210
216
Get backend ids providing all given collections
@@ -568,13 +574,16 @@ def _process_load_ml_model(
568
574
class AggregatorBatchJobs (BatchJobs ):
569
575
570
576
def __init__ (
571
- self ,
572
- backends : MultiBackendConnection ,
573
- processing : AggregatorProcessing ,
574
- partitioned_job_tracker : Optional [PartitionedJobTracker ] = None ,
577
+ self ,
578
+ * ,
579
+ backends : MultiBackendConnection ,
580
+ catalog : AggregatorCollectionCatalog ,
581
+ processing : AggregatorProcessing ,
582
+ partitioned_job_tracker : Optional [PartitionedJobTracker ] = None ,
575
583
):
576
584
super (AggregatorBatchJobs , self ).__init__ ()
577
585
self .backends = backends
586
+ self ._catalog = catalog
578
587
self .processing = processing
579
588
self .partitioned_job_tracker = partitioned_job_tracker
580
589
@@ -1127,8 +1136,9 @@ def __init__(self, backends: MultiBackendConnection, config: AggregatorConfig):
1127
1136
1128
1137
batch_jobs = AggregatorBatchJobs (
1129
1138
backends = backends ,
1139
+ catalog = catalog ,
1130
1140
processing = processing ,
1131
- partitioned_job_tracker = partitioned_job_tracker
1141
+ partitioned_job_tracker = partitioned_job_tracker ,
1132
1142
)
1133
1143
1134
1144
secondary_services = AggregatorSecondaryServices (backends = backends , processing = processing , config = config )
0 commit comments