diff --git a/cads_broker/dispatcher.py b/cads_broker/dispatcher.py index f0754c9a..51235b60 100644 --- a/cads_broker/dispatcher.py +++ b/cads_broker/dispatcher.py @@ -13,14 +13,15 @@ import distributed import sqlalchemy as sa import structlog +from dask.typing import Key from typing_extensions import Iterable try: - from cads_worker import worker + import cads_worker.worker except ModuleNotFoundError: pass -from cads_broker import Environment, config, factory +from cads_broker import Environment, config, factory, utils from cads_broker import database as db from cads_broker.qos import QoS @@ -197,6 +198,34 @@ def __init__(self, number_of_workers) -> None: parser.parse_rules(self.rules, self.environment) +class TempDirNannyPlugin(distributed.NannyPlugin): + def setup(self, nanny: distributed.Nanny) -> None: + path = utils.rm_task_path(nanny, None) + path.mkdir() + + def teardown(self, nanny: distributed.Nanny) -> None: + utils.rm_task_path(nanny, None) + + +class TempDirsWorkerPlugin(distributed.WorkerPlugin): + def setup(self, worker) -> None: + self.worker = worker + + def teardown(self, worker: distributed.Worker) -> None: + for key in worker.state.tasks: + utils.rm_task_path(worker, key) + + def transition( + self, + key: Key, + start: distributed.worker_state_machine.TaskStateState, + finish: distributed.worker_state_machine.TaskStateState, + **kwargs: Any, + ) -> None: + if finish in ("memory", "error"): + utils.rm_task_path(self.worker, key) + + @attrs.define class Broker: client: distributed.Client @@ -218,6 +247,10 @@ class Broker: internal_scheduler: Scheduler = Scheduler() queue: Queue = Queue() + def __attrs_post_init__(self): + self.client.register_plugin(TempDirNannyPlugin()) + self.client.register_plugin(TempDirsWorkerPlugin()) + @classmethod def from_address( cls, @@ -563,7 +596,7 @@ def submit_request( ) self.queue.pop(request.request_uid) future = self.client.submit( - worker.submit_workflow, + cads_worker.worker.submit_workflow, key=request.request_uid, setup_code=request.request_body.get("setup_code", ""), entry_point=request.entry_point, diff --git a/cads_broker/utils.py b/cads_broker/utils.py new file mode 100644 index 00000000..82c80dbb --- /dev/null +++ b/cads_broker/utils.py @@ -0,0 +1,35 @@ +import pathlib +import shutil +from typing import Any + +import distributed +from dask.typing import Key + + +def get_task_path( + worker_or_nanny: distributed.Worker | distributed.Nanny, key: Key | None +) -> pathlib.Path: + if isinstance(worker_or_nanny, distributed.Worker): + root = worker_or_nanny.local_directory + elif isinstance(worker_or_nanny, distributed.Nanny): + root = worker_or_nanny.worker_dir + else: + raise TypeError( + f"`worker_or_nanny` is of the wrong type: {type(worker_or_nanny)}" + ) + path = pathlib.Path(root) / "tasks_working_dir" + if key is not None: + path /= str(key) + return path + + +def rm_task_path( + worker_or_nanny: distributed.Worker | distributed.Nanny, + key: Key | None, + **kwargs: Any, +) -> pathlib.Path: + # This function is used by cads-worker as well. + path = get_task_path(worker_or_nanny, key) + if path.exists(): + shutil.rmtree(path, **kwargs) + return path diff --git a/tests/test_20_dispatcher.py b/tests/test_20_dispatcher.py index 9aeb38fa..624fc735 100644 --- a/tests/test_20_dispatcher.py +++ b/tests/test_20_dispatcher.py @@ -1,4 +1,5 @@ import datetime +import pathlib import uuid from typing import Any @@ -120,3 +121,35 @@ def mock_get_tasks() -> dict[str, str]: # with pytest.raises(db.NoResultFound): # with session_obj() as session: # db.get_request(dismissed_request_uid, session=session) + + +def test_plugins( + mocker: pytest_mock.plugin.MockerFixture, session_obj: sa.orm.sessionmaker +) -> None: + environment = Environment.Environment() + qos = QoS.QoS(rules=Rule.RuleSet(), environment=environment, rules_hash="") + broker = dispatcher.Broker( + client=CLIENT, + environment=environment, + qos=qos, + address="scheduler-address", + session_maker_read=session_obj, + session_maker_write=session_obj, + ) + + def func() -> pathlib.Path: + worker = distributed.get_worker() + key = worker.get_current_task() + task_path = ( + pathlib.Path(worker.local_directory) / "tasks_working_dir" / str(key) + ) + task_path.mkdir() + return task_path + + future = broker.client.submit(func) + task_path = future.result() + assert not task_path.exists() + + assert task_path.parent.exists() + broker.client.shutdown() + assert not task_path.parent.exists()