-
Notifications
You must be signed in to change notification settings - Fork 76
Message Passing Module #516
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
dario-coscia
wants to merge
6
commits into
dev
Choose a base branch
from
messagepassing
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
5a4fc44
add buggy egnn block
dario-coscia a7c8c35
add deep tensor network block
GiovanniCanali 9a09821
add interaction network block
GiovanniCanali 9269702
radial field
AleDinve b6f7c17
fix radial field
AleDinve 1c6bef4
radial_field fix + schnet block
AleDinve File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
"""Module for the message passing blocks of the graph neural models.""" | ||
|
||
__all__ = [ | ||
"InteractionNetworkBlock", | ||
"DeepTensorNetworkBlock", | ||
] | ||
|
||
from .interaction_network_block import InteractionNetworkBlock | ||
from .deep_tensor_network_block import DeepTensorNetworkBlock |
152 changes: 152 additions & 0 deletions
152
pina/model/block/message_passing/deep_tensor_network_block.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
"""Module for the Deep Tensor Network block.""" | ||
|
||
import torch | ||
from torch_geometric.nn import MessagePassing | ||
from ....utils import check_consistency | ||
|
||
|
||
class DeepTensorNetworkBlock(MessagePassing): | ||
""" | ||
Implementation of the Deep Tensor Network block. | ||
|
||
This block is used to perform message-passing between nodes and edges in a | ||
graph neural network, following the scheme proposed by Schutt et al. (2017). | ||
It serves as an inner block in a larger graph neural network architecture. | ||
|
||
The message between two nodes connected by an edge is computed by applying a | ||
linear transformation to the sender node features and the edge features, | ||
followed by a non-linear activation function. Messages are then aggregated | ||
using an aggregation scheme (e.g., sum, mean, min, max, or product). | ||
|
||
The update step is performed by a simple addition of the incoming messages | ||
to the node features. | ||
|
||
.. seealso:: | ||
|
||
**Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al. | ||
*Quantum-Chemical Insights from Deep Tensor Neural Networks*. | ||
Nature Communications 8, 13890 (2017). | ||
DOI: `<https://doi.org/10.1038/ncomms13890>_`. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
node_feature_dim, | ||
edge_feature_dim, | ||
activation=torch.nn.Tanh, | ||
aggr="add", | ||
node_dim=-2, | ||
flow="source_to_target", | ||
): | ||
""" | ||
Initialization of the :class:`DeepTensorNetworkBlocklock` class. | ||
|
||
:param int node_feature_dim: The dimension of the node features. | ||
:param int edge_feature_dim: The dimension of the edge features. | ||
:param torch.nn.Module activation: The activation function. | ||
Default is :class:`torch.nn.Tanh`. | ||
:param str aggr: The aggregation scheme to use for message passing. | ||
Available options are "add", "mean", "min", "max", "mul". | ||
See :class:`torch_geometric.nn.MessagePassing` for more details. | ||
Default is "add". | ||
:param int node_dim: The axis along which to propagate. Default is -2. | ||
:param str flow: The direction of message passing. Available options | ||
are "source_to_target" and "target_to_source". | ||
The "source_to_target" flow means that messages are sent from | ||
the source node to the target node, while the "target_to_source" | ||
flow means that messages are sent from the target node to the | ||
source node. See :class:`torch_geometric.nn.MessagePassing` for more | ||
details. Default is "source_to_target". | ||
:raises ValueError: If `node_feature_dim` is not a positive integer. | ||
:raises ValueError: If `edge_feature_dim` is not a positive integer. | ||
""" | ||
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow) | ||
|
||
# Check consistency | ||
check_consistency(node_feature_dim, int) | ||
check_consistency(edge_feature_dim, int) | ||
|
||
# Check values | ||
if node_feature_dim <= 0: | ||
raise ValueError( | ||
"`node_feature_dim` must be a positive integer," | ||
f" got {node_feature_dim}." | ||
) | ||
|
||
if edge_feature_dim <= 0: | ||
raise ValueError( | ||
"`edge_feature_dim` must be a positive integer," | ||
f" got {edge_feature_dim}." | ||
) | ||
|
||
# Initialize parameters | ||
self.node_feature_dim = node_feature_dim | ||
self.edge_feature_dim = edge_feature_dim | ||
self.activation = activation | ||
|
||
# Layer for processing node features | ||
self.node_layer = torch.nn.Linear( | ||
in_features=self.node_feature_dim, | ||
out_features=self.node_feature_dim, | ||
bias=True, | ||
) | ||
|
||
# Layer for processing edge features | ||
self.edge_layer = torch.nn.Linear( | ||
in_features=self.edge_feature_dim, | ||
out_features=self.node_feature_dim, | ||
bias=True, | ||
) | ||
|
||
# Layer for computing the message | ||
self.message_layer = torch.nn.Linear( | ||
in_features=self.node_feature_dim, | ||
out_features=self.node_feature_dim, | ||
bias=False, | ||
) | ||
|
||
def forward(self, x, edge_index, edge_attr): | ||
""" | ||
Forward pass of the block, triggering the message-passing routine. | ||
|
||
:param x: The node features. | ||
:type x: torch.Tensor | LabelTensor | ||
:param torch.Tensor edge_index: The edge indeces. | ||
:param edge_attr: The edge attributes. | ||
:type edge_attr: torch.Tensor | LabelTensor | ||
:return: The updated node features. | ||
:rtype: torch.Tensor | ||
""" | ||
return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr) | ||
|
||
def message(self, x_j, edge_attr): | ||
""" | ||
Compute the message to be passed between nodes and edges. | ||
|
||
:param x_j: The node features of the sender nodes. | ||
:type x_j: torch.Tensor | LabelTensor | ||
:param edge_attr: The edge attributes. | ||
:type edge_attr: torch.Tensor | LabelTensor | ||
:return: The message to be passed. | ||
:rtype: torch.Tensor | ||
""" | ||
# Process node and edge features | ||
filter_node = self.node_layer(x_j) | ||
filter_edge = self.edge_layer(edge_attr) | ||
|
||
# Compute the message to be passed | ||
message = self.message_layer(filter_node * filter_edge) | ||
|
||
return self.activation(message) | ||
|
||
def update(self, message, x): | ||
""" | ||
Update the node features with the received messages. | ||
|
||
:param torch.Tensor message: The message to be passed. | ||
:param x: The node features. | ||
:type x: torch.Tensor | LabelTensor | ||
:return: The updated node features. | ||
:rtype: torch.Tensor | ||
""" | ||
return x + message |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
"""Module for the E(n) Equivariant Graph Neural Network block.""" | ||
|
||
import torch | ||
from torch_geometric.nn import MessagePassing | ||
from torch_geometric.utils import degree | ||
|
||
|
||
class EnEquivariantGraphBlock(MessagePassing): | ||
""" | ||
TODO | ||
""" | ||
|
||
def __init__( | ||
self, | ||
channels_h, | ||
channels_m, | ||
channels_a, | ||
aggr: str = "add", | ||
hidden_channels: int = 64, | ||
**kwargs, | ||
): | ||
""" | ||
TODO | ||
""" | ||
super().__init__(aggr=aggr, **kwargs) | ||
|
||
self.phi_e = torch.nn.Sequential( | ||
torch.nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels), | ||
torch.nn.LayerNorm(hidden_channels), | ||
torch.nn.SiLU(), | ||
torch.nn.Linear(hidden_channels, channels_m), | ||
torch.nn.LayerNorm(channels_m), | ||
torch.nn.SiLU(), | ||
) | ||
self.phi_x = torch.nn.Sequential( | ||
torch.nn.Linear(channels_m, hidden_channels), | ||
torch.nn.LayerNorm(hidden_channels), | ||
torch.nn.SiLU(), | ||
torch.nn.Linear(hidden_channels, 1), | ||
) | ||
self.phi_h = torch.nn.Sequential( | ||
torch.nn.Linear(channels_h + channels_m, hidden_channels), | ||
torch.nn.LayerNorm(hidden_channels), | ||
torch.nn.SiLU(), | ||
torch.nn.Linear(hidden_channels, channels_h), | ||
) | ||
|
||
def forward(self, x, h, edge_attr, edge_index, c=None): | ||
""" | ||
TODO | ||
""" | ||
if c is None: | ||
c = degree(edge_index[0], x.shape[0]).unsqueeze(-1) | ||
return self.propagate( | ||
edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c | ||
) | ||
|
||
def message(self, x_i, x_j, h_i, h_j, edge_attr): | ||
""" | ||
TODO | ||
""" | ||
mh_ij = self.phi_e( | ||
torch.cat( | ||
[ | ||
h_i, | ||
h_j, | ||
torch.norm(x_i - x_j, dim=-1, keepdim=True) ** 2, | ||
edge_attr, | ||
], | ||
dim=-1, | ||
) | ||
) | ||
mx_ij = (x_i - x_j) * self.phi_x(mh_ij) | ||
return torch.cat((mx_ij, mh_ij), dim=-1) | ||
|
||
def update(self, aggr_out, x, h, edge_attr, c): | ||
""" | ||
TODO | ||
""" | ||
m_x, m_h = aggr_out[:, : self.m_len], aggr_out[:, self.m_len :] | ||
h_l1 = self.phi_h(torch.cat([h, m_h], dim=-1)) | ||
x_l1 = x + (m_x / c) | ||
return x_l1, h_l1 | ||
|
||
@property | ||
def edge_function(self): | ||
""" | ||
TODO | ||
""" | ||
return self._edge_function | ||
|
||
@property | ||
def attribute_function(self): | ||
""" | ||
TODO | ||
""" | ||
return self._attribute_function |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok for me!