diff --git a/src/hiopbbpy/utils/evaluation_manager.py b/src/hiopbbpy/utils/evaluation_manager.py index 444a57795..465ddb327 100644 --- a/src/hiopbbpy/utils/evaluation_manager.py +++ b/src/hiopbbpy/utils/evaluation_manager.py @@ -6,6 +6,7 @@ Weslley S Pereira """ +import threading import logging import copy import os @@ -71,6 +72,7 @@ def __init__( cpu_executor=None, mpi_executor=None) -> None: self._queue = deque([]) + self._queue_lock = threading.Lock() self.logger = logging.getLogger(self.__class__.__name__) self.executors = { @@ -100,7 +102,7 @@ def sync(self) -> None: future_objs = [queue_obj[1] for queue_obj in self._queue] wait(future_objs) - def submit_tasks(self, fn, X, execute_at="cpu", *args, **kwargs) -> None: + def submit_tasks(self, fn, X, execute_at="cpu", **kwargs) -> None: """Submits tasks to the specified executor. :param fn: The function to be executed. @@ -108,16 +110,15 @@ def submit_tasks(self, fn, X, execute_at="cpu", *args, **kwargs) -> None: :param execute_at: The executor to use for task submission. It can be "cpu" for intra-node parallelism or "mpi" for inter-node parallelism. - :param args: Additional positional arguments to be passed to the - function. :param kwargs: Additional keyword arguments to be passed to the function. """ key = execute_at.lower() - for x in X: - future_obj = self.executors[key].submit(fn, x, *args, **kwargs) - self._queue.append([copy.deepcopy(x), future_obj]) - self.logger.info(f"Submitted f({x})") + with self._queue_lock: + for x in X: + future_obj = self.executors[key].submit(fn, x, **kwargs) + self._queue.append([copy.deepcopy(x), future_obj]) + self.logger.info(f"Submitted f({x})") def retrieve_results(self) -> tuple[list, list]: """Retrieves the results of completed tasks. @@ -127,33 +128,31 @@ def retrieve_results(self) -> tuple[list, list]: """ X = deque([]) F = deque([]) - Idxs = deque([]) - new_queue = deque([]) - for i, item in enumerate(self._queue): - x = item[0] - future = item[1] - if future.done(): - # Try to get result - try: - fx = future.result() - Idxs.append(i) - except CancelledError: - self.logger.warning(f"The execution of x={x} was cancelled.") - continue - - # Add result to the output - X.append(x) - F.append(fx) - self.logger.info(f"Completed: f({x}) = {fx}") - else: - # Keep the future in the queue - new_queue.append(item) - - # Remove completed tasks from the queue - self._queue = new_queue - - X = [X[Idxs[i]] for i in range(len(Idxs))] - F = [F[Idxs[i]] for i in range(len(Idxs))] + + with self._queue_lock: + new_queue = deque([]) + for item in self._queue: + x = item[0] + future = item[1] + if future.done(): + # Try to get result + try: + fx = future.result() + except CancelledError: + self.logger.warning(f"The execution of x={x} was cancelled.") + continue + + # Add result to the output + X.append(x) + F.append(fx) + self.logger.info(f"Completed: f({x}) = {fx}") + else: + # Keep the future in the queue + new_queue.append(item) + + # Remove completed tasks from the queue + self._queue = new_queue + return list(X), list(F)