Skip to content

Add tools property and entrypoint for workflows #6884

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ requires-python = '>=3.9'
'core.pcod' = 'aiida.tools.dbimporters.plugins.pcod:PcodDbImporter'
'core.tcod' = 'aiida.tools.dbimporters.plugins.tcod:TcodDbImporter'

[project.entry-points.'aiida.tools.workflows']

[project.entry-points.'aiida.transports']
'core.local' = 'aiida.transports.plugins.local:LocalTransport'
'core.ssh' = 'aiida.transports.plugins.ssh:SshTransport'
Expand Down
2 changes: 1 addition & 1 deletion src/aiida/orm/nodes/process/calculation/calcjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class Model(CalculationNode.Model):
_tools = None

@property
def tools(self) -> 'CalculationTools':
def tools(self) -> Optional['CalculationTools']:
"""Return the calculation tools that are registered for the process type associated with this calculation.

If the entry point name stored in the `process_type` of the CalcJobNode has an accompanying entry point in the
Expand Down
40 changes: 39 additions & 1 deletion src/aiida/orm/nodes/process/workflow/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
###########################################################################
"""Module with `Node` sub class for workchain processes."""

from typing import Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple

from aiida.common import exceptions
from aiida.common.lang import classproperty

from .workflow import WorkflowNode

if TYPE_CHECKING:
from aiida.tools.workflows import WorkflowTools

Check warning on line 19 in src/aiida/orm/nodes/process/workflow/workchain.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/orm/nodes/process/workflow/workchain.py#L19

Added line #L19 was not covered by tests

__all__ = ('WorkChainNode',)


Expand All @@ -22,6 +26,40 @@

STEPPER_STATE_INFO_KEY = 'stepper_state_info'

# An optional entry point for a WorkflowTools instance
_tools = None

@property
def tools(self) -> Optional['WorkflowTools']:
"""Return the workflow tools that are registered for the process type associated with this workflow.

If the entry point name stored in the `process_type` of the WorkChainNode has an accompanying entry point in the
`aiida.tools.workflows` entry point category, it will attempt to load the entry point and instantiate it
passing the node to the constructor. If the entry point does not exist, cannot be resolved or loaded, a warning
will be logged and the base WorkflowTools class will be instantiated and returned.

:return: WorkflowsTools instance
"""
from aiida.plugins.entry_point import get_entry_point_from_string, is_valid_entry_point_string, load_entry_point
from aiida.tools.workflows import WorkflowTools

if self._tools is None:
entry_point_string = self.process_type

if entry_point_string and is_valid_entry_point_string(entry_point_string):
entry_point = get_entry_point_from_string(entry_point_string)

try:
tools_class = load_entry_point('aiida.tools.workflows', entry_point.name)
self._tools = tools_class(self)
except exceptions.EntryPointError as exception:
self._tools = WorkflowTools(self)
self.logger.warning(
f'could not load the workflow tools entry point {entry_point.name}: {exception}'
)

return self._tools

@classproperty
def _updatable_attributes(cls) -> Tuple[str, ...]: # type: ignore[override] # noqa: N805
return super()._updatable_attributes + (cls.STEPPER_STATE_INFO_KEY,)
Expand Down
1 change: 1 addition & 0 deletions src/aiida/plugins/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class EntryPointFormat(enum.Enum):
'aiida.storage': 'aiida.storage',
'aiida.transports': 'aiida.transports.plugins',
'aiida.tools.calculations': 'aiida.tools.calculations',
'aiida.tools.workflows': 'aiida.tools.workflows',
'aiida.tools.data.orbitals': 'aiida.tools.data.orbitals',
'aiida.tools.dbexporters': 'aiida.tools.dbexporters',
'aiida.tools.dbimporters': 'aiida.tools.dbimporters.plugins',
Expand Down
21 changes: 21 additions & 0 deletions src/aiida/tools/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Workflow tool plugins for Workflow classes."""

# AUTO-GENERATED

# fmt: off

from .base import *

__all__ = (
'WorkflowTools',
)

# fmt: on
22 changes: 22 additions & 0 deletions src/aiida/tools/workflows/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Base class for WorkflowTools

Sub-classes can be registered in the `aiida.tools.workflows` category to enable the `CalcJobNode` class from being
able to find the tools plugin, load it and expose it through the `tools` property of the `CalcJobNode`.
"""

__all__ = ('WorkflowTools',)


class WorkflowTools:
"""Base class for WorkflowTools."""

def __init__(self, node):
self._node = node
2 changes: 1 addition & 1 deletion tests/restapi/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ def test_count_consistency(restapi_server, server_url):

for full_type, count in statistics_dict.items():
if full_type in type_count_dict:
assert count == type_count_dict[full_type]
assert count == type_count_dict[full_type], f'Found inconsistency for full_type {full_type!r}'
8 changes: 8 additions & 0 deletions tests/tools/calculations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
36 changes: 36 additions & 0 deletions tests/tools/calculations/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Tests for CalculationTools"""


class MockCalculation: ...


class MockCalculationTools:
def __init__(self, node):
self._node = node


def test_mock_calculation_tools(entry_points, generate_calcjob_node):
"""Test if the calculation tools is correctly loaded from the entry point."""
entry_points.add(MockCalculation, 'aiida.calculations:MockCalculation')
entry_points.add(MockCalculationTools, 'aiida.tools.calculations:MockCalculation')
node = generate_calcjob_node(entry_point='aiida.calculations:MockCalculation')
assert isinstance(node.tools, MockCalculationTools)
assert node.tools._node == node


def test_fallback_calculation_tools(entry_points, generate_calcjob_node):
"""Test if the calculation tools is falling back to `CalculationTools` if it cannot be loaded from entry point."""
from aiida.tools.calculations import CalculationTools

entry_points.add(MockCalculation, 'aiida.calculations:MockCalculation')
node = generate_calcjob_node(entry_point='aiida.calculations:MockCalculation')
assert isinstance(node.tools, CalculationTools)
assert node.tools._node == node
8 changes: 8 additions & 0 deletions tests/tools/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
38 changes: 38 additions & 0 deletions tests/tools/workflows/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Tests for WorkflowTools"""

from aiida.orm import WorkChainNode


class MockWorkflow: ...


class MockWorkflowTools:
def __init__(self, node):
self._node = node


def test_mock_workflow_tools(entry_points):
"""Test if the workflow tools is correctly loaded from the entry point."""
entry_points.add(MockWorkflow, 'aiida.workflows:MockWorkflow')
entry_points.add(MockWorkflowTools, 'aiida.tools.workflows:MockWorkflow')
node = WorkChainNode(process_type='aiida.workflows:MockWorkflow')
assert isinstance(node.tools, MockWorkflowTools)
assert node.tools._node == node


def test_fallback_workflow_tools(entry_points, generate_work_chain):
"""Test if the workflow tools is falling back to `WorkflowTools` if it cannot be loaded from entry point."""
from aiida.tools.workflows import WorkflowTools

entry_points.add(MockWorkflow, 'aiida.workflows:MockWorkflow')
node = WorkChainNode(process_type='aiida.workflows:MockWorkflow')
assert isinstance(node.tools, WorkflowTools)
assert node.tools._node == node
Loading