Skip to content

Message Passing Module #516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ Blocks
Continuous Convolution Block <model/block/convolution.rst>
Orthogonal Block <model/block/orthogonal.rst>

Message Passing
-------------------

.. toctree::
:titlesonly:

Deep Tensor Network Block <model/block/message_passing/deep_tensor_network_block.rst>
E(n) Equivariant Network Block <model/block/message_passing/en_equivariant_network_block.rst>
Interaction Network Block <model/block/message_passing/interaction_network_block.rst>
Radial Field Network Block <model/block/message_passing/radial_field_network_block.rst>
Schnet Block <model/block/message_passing/schnet_block.rst>


Reduction and Embeddings
--------------------------
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Deep Tensor Network Block
==================================
.. currentmodule:: pina.model.block.message_passing.deep_tensor_network_block

.. autoclass:: DeepTensorNetworkBlock
:members:
:show-inheritance:
:noindex:
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
E(n) Equivariant Network Block
==================================
.. currentmodule:: pina.model.block.message_passing.en_equivariant_network_block

.. autoclass:: EnEquivariantNetworkBlock
:members:
:show-inheritance:
:noindex:
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Interaction Network Block
==================================
.. currentmodule:: pina.model.block.message_passing.interaction_network_block

.. autoclass:: InteractionNetworkBlock
:members:
:show-inheritance:
:noindex:
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Radial Field Network Block
==================================
.. currentmodule:: pina.model.block.message_passing.radial_field_network_block

.. autoclass:: RadialFieldNetworkBlock
:members:
:show-inheritance:
:noindex:
8 changes: 8 additions & 0 deletions docs/source/_rst/model/block/message_passing/schnet_block.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Schnet Block
==================================
.. currentmodule:: pina.model.block.message_passing.schnet_block

.. autoclass:: SchnetBlock
:members:
:show-inheritance:
:noindex:
15 changes: 15 additions & 0 deletions pina/model/block/message_passing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Module for the message passing blocks of the graph neural models."""

__all__ = [
"InteractionNetworkBlock",
"DeepTensorNetworkBlock",
"EnEquivariantNetworkBlock",
"RadialFieldNetworkBlock",
"SchnetBlock",
]

from .interaction_network_block import InteractionNetworkBlock
from .deep_tensor_network_block import DeepTensorNetworkBlock
from .en_equivariant_network_block import EnEquivariantNetworkBlock
from .radial_field_network_block import RadialFieldNetworkBlock
from .schnet_block import SchnetBlock
138 changes: 138 additions & 0 deletions pina/model/block/message_passing/deep_tensor_network_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Module for the Deep Tensor Network block."""

import torch
from torch_geometric.nn import MessagePassing
from ....utils import check_positive_integer


class DeepTensorNetworkBlock(MessagePassing):
"""
Implementation of the Deep Tensor Network block.

This block is used to perform message-passing between nodes and edges in a
graph neural network, following the scheme proposed by Schutt et al. in
2017. It serves as an inner block in a larger graph neural network
architecture.

The message between two nodes connected by an edge is computed by applying a
linear transformation to the sender node features and the edge features,
followed by a non-linear activation function. Messages are then aggregated
using an aggregation scheme (e.g., sum, mean, min, max, or product).

The update step is performed by a simple addition of the incoming messages
to the node features.

.. seealso::

**Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al.
(2017). *Quantum-Chemical Insights from Deep Tensor Neural Networks*.
Nature Communications 8, 13890 (2017).
DOI: `<https://doi.org/10.1038/ncomms13890>`_.
"""

def __init__(
self,
node_feature_dim,
edge_feature_dim,
activation=torch.nn.Tanh,
aggr="add",
node_dim=-2,
flow="source_to_target",
):
"""
Initialization of the :class:`DeepTensorNetworkBlock` class.

:param int node_feature_dim: The dimension of the node features.
:param int edge_feature_dim: The dimension of the edge features.
:param torch.nn.Module activation: The activation function.
Default is :class:`torch.nn.Tanh`.
:param str aggr: The aggregation scheme to use for message passing.
Available options are "add", "mean", "min", "max", "mul".
See :class:`torch_geometric.nn.MessagePassing` for more details.
Default is "add".
:param int node_dim: The axis along which to propagate. Default is -2.
:param str flow: The direction of message passing. Available options
are "source_to_target" and "target_to_source".
The "source_to_target" flow means that messages are sent from
the source node to the target node, while the "target_to_source"
flow means that messages are sent from the target node to the
source node. See :class:`torch_geometric.nn.MessagePassing` for more
details. Default is "source_to_target".
:raises AssertionError: If `node_feature_dim` is not a positive integer.
:raises AssertionError: If `edge_feature_dim` is not a positive integer.
"""
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)

# Check values
check_positive_integer(node_feature_dim, strict=True)
check_positive_integer(edge_feature_dim, strict=True)

# Activation function
self.activation = activation()

# Layer for processing node features
self.node_layer = torch.nn.Linear(
in_features=node_feature_dim,
out_features=node_feature_dim,
bias=True,
)

# Layer for processing edge features
self.edge_layer = torch.nn.Linear(
in_features=edge_feature_dim,
out_features=node_feature_dim,
bias=True,
)

# Layer for computing the message
self.message_layer = torch.nn.Linear(
in_features=node_feature_dim,
out_features=node_feature_dim,
bias=False,
)

def forward(self, x, edge_index, edge_attr):
"""
Forward pass of the block, triggering the message-passing routine.

:param x: The node features.
:type x: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: The edge indeces.
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor | LabelTensor
:return: The updated node features.
:rtype: torch.Tensor
"""
return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr)

def message(self, x_j, edge_attr):
"""
Compute the message to be passed between nodes and edges.

:param x_j: The node features of the sender nodes.
:type x_j: torch.Tensor | LabelTensor
:param edge_attr: The edge attributes.
:type edge_attr: torch.Tensor | LabelTensor
:return: The message to be passed.
:rtype: torch.Tensor
"""
# Process node and edge features
filter_node = self.node_layer(x_j)
filter_edge = self.edge_layer(edge_attr)

# Compute the message to be passed
message = self.message_layer(filter_node * filter_edge)

return self.activation(message)

def update(self, message, x):
"""
Update the node features with the received messages.

:param torch.Tensor message: The message to be passed.
:param x: The node features.
:type x: torch.Tensor | LabelTensor
:return: The updated node features.
:rtype: torch.Tensor
"""
return x + message
Loading