Skip to content

Commit cb9d61e

Browse files
bastoneroagoscinski
andcommitted
Add tools property and entrypoint for workflows
This is an implementation of feature request in issue #6865. Further changes: * Extend the typing hints from `tools` output as it can return `None` * Add tests for `CalcJobNode.tools` and `WorkChainNode.tools` --------- Co-authored-by: Alexander Goscinski <alex.goscinski@posteo.de>
1 parent 43176cb commit cb9d61e

File tree

10 files changed

+176
-2
lines changed

10 files changed

+176
-2
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ requires-python = '>=3.9'
189189
'core.pcod' = 'aiida.tools.dbimporters.plugins.pcod:PcodDbImporter'
190190
'core.tcod' = 'aiida.tools.dbimporters.plugins.tcod:TcodDbImporter'
191191

192+
[project.entry-points.'aiida.tools.workflows']
193+
192194
[project.entry-points.'aiida.transports']
193195
'core.local' = 'aiida.transports.plugins.local:LocalTransport'
194196
'core.ssh' = 'aiida.transports.plugins.ssh:SshTransport'

src/aiida/orm/nodes/process/calculation/calcjob.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class Model(CalculationNode.Model):
106106
_tools = None
107107

108108
@property
109-
def tools(self) -> 'CalculationTools':
109+
def tools(self) -> Optional['CalculationTools']:
110110
"""Return the calculation tools that are registered for the process type associated with this calculation.
111111
112112
If the entry point name stored in the `process_type` of the CalcJobNode has an accompanying entry point in the

src/aiida/orm/nodes/process/workflow/workchain.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88
###########################################################################
99
"""Module with `Node` sub class for workchain processes."""
1010

11-
from typing import Optional, Tuple
11+
from typing import TYPE_CHECKING, Optional, Tuple
1212

13+
from aiida.common import exceptions
1314
from aiida.common.lang import classproperty
1415

1516
from .workflow import WorkflowNode
1617

18+
if TYPE_CHECKING:
19+
from aiida.tools.workflows import WorkflowTools
20+
1721
__all__ = ('WorkChainNode',)
1822

1923

@@ -22,6 +26,40 @@ class WorkChainNode(WorkflowNode):
2226

2327
STEPPER_STATE_INFO_KEY = 'stepper_state_info'
2428

29+
# An optional entry point for a WorkflowTools instance
30+
_tools = None
31+
32+
@property
33+
def tools(self) -> Optional['WorkflowTools']:
34+
"""Return the workflow tools that are registered for the process type associated with this workflow.
35+
36+
If the entry point name stored in the `process_type` of the WorkChainNode has an accompanying entry point in the
37+
`aiida.tools.workflows` entry point category, it will attempt to load the entry point and instantiate it
38+
passing the node to the constructor. If the entry point does not exist, cannot be resolved or loaded, a warning
39+
will be logged and the base WorkflowTools class will be instantiated and returned.
40+
41+
:return: WorkflowsTools instance
42+
"""
43+
from aiida.plugins.entry_point import get_entry_point_from_string, is_valid_entry_point_string, load_entry_point
44+
from aiida.tools.workflows import WorkflowTools
45+
46+
if self._tools is None:
47+
entry_point_string = self.process_type
48+
49+
if entry_point_string and is_valid_entry_point_string(entry_point_string):
50+
entry_point = get_entry_point_from_string(entry_point_string)
51+
52+
try:
53+
tools_class = load_entry_point('aiida.tools.workflows', entry_point.name)
54+
self._tools = tools_class(self)
55+
except exceptions.EntryPointError as exception:
56+
self._tools = WorkflowTools(self)
57+
self.logger.warning(
58+
f'could not load the workflow tools entry point {entry_point.name}: {exception}'
59+
)
60+
61+
return self._tools
62+
2563
@classproperty
2664
def _updatable_attributes(cls) -> Tuple[str, ...]: # type: ignore[override] # noqa: N805
2765
return super()._updatable_attributes + (cls.STEPPER_STATE_INFO_KEY,)

src/aiida/plugins/entry_point.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class EntryPointFormat(enum.Enum):
9696
'aiida.storage': 'aiida.storage',
9797
'aiida.transports': 'aiida.transports.plugins',
9898
'aiida.tools.calculations': 'aiida.tools.calculations',
99+
'aiida.tools.workflows': 'aiida.tools.workflows',
99100
'aiida.tools.data.orbitals': 'aiida.tools.data.orbitals',
100101
'aiida.tools.dbexporters': 'aiida.tools.dbexporters',
101102
'aiida.tools.dbimporters': 'aiida.tools.dbimporters.plugins',

src/aiida/tools/workflows/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
###########################################################################
2+
# Copyright (c), The AiiDA team. All rights reserved. #
3+
# This file is part of the AiiDA code. #
4+
# #
5+
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
6+
# For further information on the license, see the LICENSE.txt file #
7+
# For further information please visit http://www.aiida.net #
8+
###########################################################################
9+
"""Workflow tool plugins for Workflow classes."""
10+
11+
# AUTO-GENERATED
12+
13+
# fmt: off
14+
15+
from .base import *
16+
17+
__all__ = (
18+
'WorkflowTools',
19+
)
20+
21+
# fmt: on

src/aiida/tools/workflows/base.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
###########################################################################
2+
# Copyright (c), The AiiDA team. All rights reserved. #
3+
# This file is part of the AiiDA code. #
4+
# #
5+
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
6+
# For further information on the license, see the LICENSE.txt file #
7+
# For further information please visit http://www.aiida.net #
8+
###########################################################################
9+
"""Base class for WorkflowTools
10+
11+
Sub-classes can be registered in the `aiida.tools.workflows` category to enable the `CalcJobNode` class from being
12+
able to find the tools plugin, load it and expose it through the `tools` property of the `CalcJobNode`.
13+
"""
14+
15+
__all__ = ('WorkflowTools',)
16+
17+
18+
class WorkflowTools:
19+
"""Base class for WorkflowTools."""
20+
21+
def __init__(self, node):
22+
self._node = node

tests/tools/calculations/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
###########################################################################
2+
# Copyright (c), The AiiDA team. All rights reserved. #
3+
# This file is part of the AiiDA code. #
4+
# #
5+
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
6+
# For further information on the license, see the LICENSE.txt file #
7+
# For further information please visit http://www.aiida.net #
8+
###########################################################################

tests/tools/calculations/test_base.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
###########################################################################
2+
# Copyright (c), The AiiDA team. All rights reserved. #
3+
# This file is part of the AiiDA code. #
4+
# #
5+
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
6+
# For further information on the license, see the LICENSE.txt file #
7+
# For further information please visit http://www.aiida.net #
8+
###########################################################################
9+
"""Tests for CalculationTools"""
10+
11+
12+
class MockCalculation: ...
13+
14+
15+
class MockCalculationTools:
16+
def __init__(self, node):
17+
self._node = node
18+
19+
20+
def test_mock_calculation_tools(entry_points, generate_calcjob_node):
21+
"""Test if the calculation tools is correctly loaded from the entry point."""
22+
entry_points.add(MockCalculation, 'aiida.calculations:MockCalculation')
23+
entry_points.add(MockCalculationTools, 'aiida.tools.calculations:MockCalculation')
24+
node = generate_calcjob_node(entry_point='aiida.calculations:MockCalculation')
25+
assert isinstance(node.tools, MockCalculationTools)
26+
assert node.tools._node == node
27+
28+
29+
def test_fallback_calculation_tools(entry_points, generate_calcjob_node):
30+
"""Test if the calculation tools is falling back to `CalculationTools` if it cannot be loaded from entry point."""
31+
from aiida.tools.calculations import CalculationTools
32+
33+
entry_points.add(MockCalculation, 'aiida.calculations:MockCalculation')
34+
node = generate_calcjob_node(entry_point='aiida.calculations:MockCalculation')
35+
assert isinstance(node.tools, CalculationTools)
36+
assert node.tools._node == node

tests/tools/workflows/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
###########################################################################
2+
# Copyright (c), The AiiDA team. All rights reserved. #
3+
# This file is part of the AiiDA code. #
4+
# #
5+
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
6+
# For further information on the license, see the LICENSE.txt file #
7+
# For further information please visit http://www.aiida.net #
8+
###########################################################################

tests/tools/workflows/test_base.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
###########################################################################
2+
# Copyright (c), The AiiDA team. All rights reserved. #
3+
# This file is part of the AiiDA code. #
4+
# #
5+
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
6+
# For further information on the license, see the LICENSE.txt file #
7+
# For further information please visit http://www.aiida.net #
8+
###########################################################################
9+
"""Tests for WorkflowTools"""
10+
11+
from aiida.orm import WorkChainNode
12+
13+
14+
class MockWorkflow: ...
15+
16+
17+
class MockWorkflowTools:
18+
def __init__(self, node):
19+
self._node = node
20+
21+
22+
def test_mock_workflow_tools(entry_points):
23+
"""Test if the workflow tools is correctly loaded from the entry point."""
24+
entry_points.add(MockWorkflow, 'aiida.workflows:MockWorkflow')
25+
entry_points.add(MockWorkflowTools, 'aiida.tools.workflows:MockWorkflow')
26+
node = WorkChainNode(process_type='aiida.workflows:MockWorkflow')
27+
assert isinstance(node.tools, MockWorkflowTools)
28+
assert node.tools._node == node
29+
30+
31+
def test_fallback_workflow_tools(entry_points, generate_work_chain):
32+
"""Test if the workflow tools is falling back to `WorkflowTools` if it cannot be loaded from entry point."""
33+
from aiida.tools.workflows import WorkflowTools
34+
35+
entry_points.add(MockWorkflow, 'aiida.workflows:MockWorkflow')
36+
node = WorkChainNode(process_type='aiida.workflows:MockWorkflow')
37+
assert isinstance(node.tools, WorkflowTools)
38+
assert node.tools._node == node

0 commit comments

Comments
 (0)