diff --git a/pyproject.toml b/pyproject.toml index c622766a9b..b9ce087c20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -189,6 +189,8 @@ requires-python = '>=3.9' 'core.pcod' = 'aiida.tools.dbimporters.plugins.pcod:PcodDbImporter' 'core.tcod' = 'aiida.tools.dbimporters.plugins.tcod:TcodDbImporter' +[project.entry-points.'aiida.tools.workflows'] + [project.entry-points.'aiida.transports'] 'core.local' = 'aiida.transports.plugins.local:LocalTransport' 'core.ssh' = 'aiida.transports.plugins.ssh:SshTransport' diff --git a/src/aiida/orm/nodes/process/calculation/calcjob.py b/src/aiida/orm/nodes/process/calculation/calcjob.py index a526fc3b9c..c38a4967a7 100644 --- a/src/aiida/orm/nodes/process/calculation/calcjob.py +++ b/src/aiida/orm/nodes/process/calculation/calcjob.py @@ -106,7 +106,7 @@ class Model(CalculationNode.Model): _tools = None @property - def tools(self) -> 'CalculationTools': + def tools(self) -> Optional['CalculationTools']: """Return the calculation tools that are registered for the process type associated with this calculation. If the entry point name stored in the `process_type` of the CalcJobNode has an accompanying entry point in the diff --git a/src/aiida/orm/nodes/process/workflow/workchain.py b/src/aiida/orm/nodes/process/workflow/workchain.py index 8f52dff926..6d2cc285d4 100644 --- a/src/aiida/orm/nodes/process/workflow/workchain.py +++ b/src/aiida/orm/nodes/process/workflow/workchain.py @@ -8,12 +8,16 @@ ########################################################################### """Module with `Node` sub class for workchain processes.""" -from typing import Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple +from aiida.common import exceptions from aiida.common.lang import classproperty from .workflow import WorkflowNode +if TYPE_CHECKING: + from aiida.tools.workflows import WorkflowTools + __all__ = ('WorkChainNode',) @@ -22,6 +26,40 @@ class WorkChainNode(WorkflowNode): STEPPER_STATE_INFO_KEY = 'stepper_state_info' + # An optional entry point for a WorkflowTools instance + _tools = None + + @property + def tools(self) -> Optional['WorkflowTools']: + """Return the workflow tools that are registered for the process type associated with this workflow. + + If the entry point name stored in the `process_type` of the WorkChainNode has an accompanying entry point in the + `aiida.tools.workflows` entry point category, it will attempt to load the entry point and instantiate it + passing the node to the constructor. If the entry point does not exist, cannot be resolved or loaded, a warning + will be logged and the base WorkflowTools class will be instantiated and returned. + + :return: WorkflowsTools instance + """ + from aiida.plugins.entry_point import get_entry_point_from_string, is_valid_entry_point_string, load_entry_point + from aiida.tools.workflows import WorkflowTools + + if self._tools is None: + entry_point_string = self.process_type + + if entry_point_string and is_valid_entry_point_string(entry_point_string): + entry_point = get_entry_point_from_string(entry_point_string) + + try: + tools_class = load_entry_point('aiida.tools.workflows', entry_point.name) + self._tools = tools_class(self) + except exceptions.EntryPointError as exception: + self._tools = WorkflowTools(self) + self.logger.warning( + f'could not load the workflow tools entry point {entry_point.name}: {exception}' + ) + + return self._tools + @classproperty def _updatable_attributes(cls) -> Tuple[str, ...]: # type: ignore[override] # noqa: N805 return super()._updatable_attributes + (cls.STEPPER_STATE_INFO_KEY,) diff --git a/src/aiida/plugins/entry_point.py b/src/aiida/plugins/entry_point.py index 7ec26a4c34..72761c0e46 100644 --- a/src/aiida/plugins/entry_point.py +++ b/src/aiida/plugins/entry_point.py @@ -96,6 +96,7 @@ class EntryPointFormat(enum.Enum): 'aiida.storage': 'aiida.storage', 'aiida.transports': 'aiida.transports.plugins', 'aiida.tools.calculations': 'aiida.tools.calculations', + 'aiida.tools.workflows': 'aiida.tools.workflows', 'aiida.tools.data.orbitals': 'aiida.tools.data.orbitals', 'aiida.tools.dbexporters': 'aiida.tools.dbexporters', 'aiida.tools.dbimporters': 'aiida.tools.dbimporters.plugins', diff --git a/src/aiida/tools/workflows/__init__.py b/src/aiida/tools/workflows/__init__.py new file mode 100644 index 0000000000..a91e94a057 --- /dev/null +++ b/src/aiida/tools/workflows/__init__.py @@ -0,0 +1,21 @@ +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Workflow tool plugins for Workflow classes.""" + +# AUTO-GENERATED + +# fmt: off + +from .base import * + +__all__ = ( + 'WorkflowTools', +) + +# fmt: on diff --git a/src/aiida/tools/workflows/base.py b/src/aiida/tools/workflows/base.py new file mode 100644 index 0000000000..a5e6986a91 --- /dev/null +++ b/src/aiida/tools/workflows/base.py @@ -0,0 +1,22 @@ +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Base class for WorkflowTools + +Sub-classes can be registered in the `aiida.tools.workflows` category to enable the `CalcJobNode` class from being +able to find the tools plugin, load it and expose it through the `tools` property of the `CalcJobNode`. +""" + +__all__ = ('WorkflowTools',) + + +class WorkflowTools: + """Base class for WorkflowTools.""" + + def __init__(self, node): + self._node = node diff --git a/tests/restapi/test_statistics.py b/tests/restapi/test_statistics.py index cfc025b608..612bf36a9b 100644 --- a/tests/restapi/test_statistics.py +++ b/tests/restapi/test_statistics.py @@ -53,4 +53,4 @@ def test_count_consistency(restapi_server, server_url): for full_type, count in statistics_dict.items(): if full_type in type_count_dict: - assert count == type_count_dict[full_type] + assert count == type_count_dict[full_type], f'Found inconsistency for full_type {full_type!r}' diff --git a/tests/tools/calculations/__init__.py b/tests/tools/calculations/__init__.py new file mode 100644 index 0000000000..c56ff0a1f8 --- /dev/null +++ b/tests/tools/calculations/__init__.py @@ -0,0 +1,8 @@ +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### diff --git a/tests/tools/calculations/test_base.py b/tests/tools/calculations/test_base.py new file mode 100644 index 0000000000..f521e84a64 --- /dev/null +++ b/tests/tools/calculations/test_base.py @@ -0,0 +1,36 @@ +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for CalculationTools""" + + +class MockCalculation: ... + + +class MockCalculationTools: + def __init__(self, node): + self._node = node + + +def test_mock_calculation_tools(entry_points, generate_calcjob_node): + """Test if the calculation tools is correctly loaded from the entry point.""" + entry_points.add(MockCalculation, 'aiida.calculations:MockCalculation') + entry_points.add(MockCalculationTools, 'aiida.tools.calculations:MockCalculation') + node = generate_calcjob_node(entry_point='aiida.calculations:MockCalculation') + assert isinstance(node.tools, MockCalculationTools) + assert node.tools._node == node + + +def test_fallback_calculation_tools(entry_points, generate_calcjob_node): + """Test if the calculation tools is falling back to `CalculationTools` if it cannot be loaded from entry point.""" + from aiida.tools.calculations import CalculationTools + + entry_points.add(MockCalculation, 'aiida.calculations:MockCalculation') + node = generate_calcjob_node(entry_point='aiida.calculations:MockCalculation') + assert isinstance(node.tools, CalculationTools) + assert node.tools._node == node diff --git a/tests/tools/workflows/__init__.py b/tests/tools/workflows/__init__.py new file mode 100644 index 0000000000..c56ff0a1f8 --- /dev/null +++ b/tests/tools/workflows/__init__.py @@ -0,0 +1,8 @@ +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### diff --git a/tests/tools/workflows/test_base.py b/tests/tools/workflows/test_base.py new file mode 100644 index 0000000000..1f1ffd0479 --- /dev/null +++ b/tests/tools/workflows/test_base.py @@ -0,0 +1,38 @@ +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Tests for WorkflowTools""" + +from aiida.orm import WorkChainNode + + +class MockWorkflow: ... + + +class MockWorkflowTools: + def __init__(self, node): + self._node = node + + +def test_mock_workflow_tools(entry_points): + """Test if the workflow tools is correctly loaded from the entry point.""" + entry_points.add(MockWorkflow, 'aiida.workflows:MockWorkflow') + entry_points.add(MockWorkflowTools, 'aiida.tools.workflows:MockWorkflow') + node = WorkChainNode(process_type='aiida.workflows:MockWorkflow') + assert isinstance(node.tools, MockWorkflowTools) + assert node.tools._node == node + + +def test_fallback_workflow_tools(entry_points, generate_work_chain): + """Test if the workflow tools is falling back to `WorkflowTools` if it cannot be loaded from entry point.""" + from aiida.tools.workflows import WorkflowTools + + entry_points.add(MockWorkflow, 'aiida.workflows:MockWorkflow') + node = WorkChainNode(process_type='aiida.workflows:MockWorkflow') + assert isinstance(node.tools, WorkflowTools) + assert node.tools._node == node