@@ -95,8 +95,7 @@ async def __post_create__(self):
9595 AssignerActor .gen_uid (self ._session_id ), address = self .address
9696 )
9797
98- @alru_cache
99- async def _get_task_api (self ):
98+ async def _get_task_api (self ) -> TaskAPI :
10099 return await TaskAPI .create (self ._session_id , self .address )
101100
102101 def _put_subtask_with_priority (self , subtask : Subtask , priority : Tuple = None ):
@@ -272,21 +271,47 @@ async def update_subtask_priorities(
272271
273272 @alru_cache (maxsize = 10000 )
274273 async def _get_execution_ref (self , address : str ):
275- from ..worker .exec import SubtaskExecutionActor
274+ from ..worker .execution import SubtaskExecutionActor
276275
277276 return await mo .actor_ref (SubtaskExecutionActor .default_uid (), address = address )
278277
279- async def finish_subtasks (self , subtask_ids : List [str ], schedule_next : bool = True ):
280- band_tasks = defaultdict (lambda : 0 )
281- for subtask_id in subtask_ids :
282- subtask_info = self ._subtask_infos .pop (subtask_id , None )
278+ async def set_subtask_results (
279+ self , subtask_results : List [SubtaskResult ], source_bands : List [BandType ]
280+ ):
281+ delays = []
282+ task_api = await self ._get_task_api ()
283+ for result , band in zip (subtask_results , source_bands ):
284+ if result .status == SubtaskStatus .errored :
285+ subtask_info = self ._subtask_infos .get (result .subtask_id )
286+ if (
287+ subtask_info is not None
288+ and subtask_info .subtask .retryable
289+ and subtask_info .num_reschedules < subtask_info .max_reschedules
290+ and isinstance (result .error , (MarsError , OSError ))
291+ ):
292+ subtask_info .num_reschedules += 1
293+ logger .warning (
294+ "Resubmit subtask %s at attempt %d" ,
295+ subtask_info .subtask .subtask_id ,
296+ subtask_info .num_reschedules ,
297+ )
298+ execution_ref = await self ._get_execution_ref (band [0 ])
299+ await execution_ref .submit_subtasks .tell (
300+ [subtask_info .subtask ],
301+ [subtask_info .priority ],
302+ self .address ,
303+ band [1 ],
304+ )
305+ continue
306+
307+ subtask_info = self ._subtask_infos .pop (result .subtask_id , None )
283308 if subtask_info is not None :
284- self ._subtask_summaries [subtask_id ] = subtask_info .to_summary (
309+ self ._subtask_summaries [result . subtask_id ] = subtask_info .to_summary (
285310 is_finished = True
286311 )
287- if schedule_next :
288- for band in subtask_info . submitted_bands :
289- band_tasks [ band ] += 1
312+ delays . append ( task_api . set_subtask_result . delay ( result ))
313+
314+ await task_api . set_subtask_result . batch ( * delays )
290315
291316 def _get_subtasks_by_ids (self , subtask_ids : List [str ]) -> List [Optional [Subtask ]]:
292317 subtasks = []
0 commit comments