Skip to content

Commit 4b284b8

Browse files
add deep tensor network block
1 parent 9c1b738 commit 4b284b8

File tree

4 files changed

+213
-30
lines changed

4 files changed

+213
-30
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Module for the message passing blocks of the graph neural models."""
2+
3+
__all__ = [
4+
"InteractionNetworkBlock",
5+
"DeepTensorNetworkBlock",
6+
]
7+
8+
from .interaction_network_block import InteractionNetworkBlock
9+
from .deep_tensor_network_block import DeepTensorNetworkBlock
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Module for the Deep Tensor Network block."""
2+
3+
import torch
4+
from torch_geometric.nn import MessagePassing
5+
6+
7+
class DeepTensorNetworkBlock(MessagePassing):
8+
"""
9+
Implementation of the Deep Tensor Network block.
10+
11+
This block is used to perform message-passing between nodes and edges in a
12+
graph neural network, following the scheme proposed by Schutt et al. (2017).
13+
It serves as an inner block in a larger graph neural network architecture.
14+
15+
The message between two nodes connected by an edge is computed by applying a
16+
linear transformation to the sender node features and the edge features,
17+
followed by a non-linear activation function. Messages are then aggregated
18+
using an aggregation scheme (e.g., sum, mean, min, max, or product).
19+
20+
The update step is performed by a simple addition of the incoming messages
21+
to the node features.
22+
23+
.. seealso::
24+
25+
**Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al.
26+
*Quantum-Chemical Insights from Deep Tensor Neural Networks*.
27+
Nature Communications 8, 13890 (2017).
28+
DOI: `<https://doi.org/10.1038/ncomms13890>_`
29+
"""
30+
31+
def __init__(
32+
self,
33+
node_feature_dim,
34+
edge_feature_dim,
35+
activation=torch.nn.Tanh,
36+
aggr="add",
37+
node_dim=-2,
38+
flow="source_to_target",
39+
):
40+
"""
41+
Initialization of the :class:`AVNOBDeepTensorNetworkBlocklock` class.
42+
43+
:param int node_feature_dim: The dimension of the node features.
44+
:param int edge_feature_dim: The dimension of the edge features.
45+
:param torch.nn.Module activation: The activation function.
46+
Default is :class:`torch.nn.Tanh`.
47+
:param str aggr: The aggregation scheme to use for message passing.
48+
Available options are "add", "mean", "min", "max", "mul".
49+
See :class:`torch_geometric.nn.MessagePassing` for more details.
50+
Default is "add".
51+
:param int node_dim: The axis along which to propagate. Default is -2.
52+
:param str flow: The direction of message passing.
53+
See :class:`torch_geometric.nn.MessagePassing` for more details.
54+
Default is "source_to_target".
55+
"""
56+
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
57+
58+
self.node_feature_dim = node_feature_dim
59+
self.edge_feature_dim = edge_feature_dim
60+
self.activation = activation
61+
62+
# Layer for processing node features
63+
self.node_layer = torch.nn.Linear(
64+
in_features=self.node_feature_dim,
65+
out_features=self.node_feature_dim,
66+
bias=True,
67+
)
68+
69+
# Layer for processing edge features
70+
self.edge_layer = torch.nn.Linear(
71+
in_features=self.edge_feature_dim,
72+
out_features=self.node_feature_dim,
73+
bias=True,
74+
)
75+
76+
# Layer for computing the message
77+
self.message_layer = torch.nn.Linear(
78+
in_features=self.node_feature_dim,
79+
out_features=self.node_feature_dim,
80+
bias=False,
81+
)
82+
83+
def forward(self, x, edge_index, edge_attr):
84+
"""
85+
Forward pass of the block. It performs a message-passing operation
86+
between nodes and edges.
87+
88+
:param x: The node features.
89+
:type x: torch.Tensor | LabelTensor
90+
:param torch.Tensor edge_index: The edge indeces.
91+
:param edge_attr: The edge attributes.
92+
:type edge_attr: torch.Tensor | LabelTensor
93+
:return: The updated node features.
94+
:rtype: torch.Tensor
95+
"""
96+
return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr)
97+
98+
def message(self, x_j, edge_attr):
99+
"""
100+
Compute the message to be passed between nodes and edges.
101+
102+
:param x_j: The node features of the sender nodes.
103+
:type x_j: torch.Tensor | LabelTensor
104+
:param edge_attr: The edge attributes.
105+
:type edge_attr: torch.Tensor | LabelTensor
106+
:return: The message to be passed.
107+
:rtype: torch.Tensor
108+
"""
109+
# Process node and edge features
110+
filter_node = self.node_layer(x_j)
111+
filter_edge = self.edge_layer(edge_attr)
112+
113+
# Compute the message to be passed
114+
message = self.message_layer(filter_node * filter_edge)
115+
116+
return self.activation(message)
117+
118+
def update(self, message, x):
119+
"""
120+
Update the node features with the received messages.
121+
122+
:param torch.Tensor message: The message to be passed.
123+
:param x: The node features.
124+
:type x: torch.Tensor | LabelTensor
125+
:return: The updated node features.
126+
:rtype: torch.Tensor
127+
"""
128+
return x + message
Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,97 @@
1+
"""Module for the E(n) Equivariant Graph Neural Network block."""
2+
13
import torch
2-
import torch.nn as nn
34
from torch_geometric.nn import MessagePassing
45
from torch_geometric.utils import degree
5-
from ....utils import check_consistency
66

77

88
class EnEquivariantGraphBlock(MessagePassing):
9-
def __init__(self,
10-
channels_h,
11-
channels_m,
12-
channels_a,
13-
aggr: str = 'add',
14-
hidden_channels: int = 64,
15-
**kwargs):
9+
"""
10+
TODO
11+
"""
12+
13+
def __init__(
14+
self,
15+
channels_h,
16+
channels_m,
17+
channels_a,
18+
aggr: str = "add",
19+
hidden_channels: int = 64,
20+
**kwargs,
21+
):
22+
"""
23+
TODO
24+
"""
1625
super().__init__(aggr=aggr, **kwargs)
1726

18-
self.phi_e = nn.Sequential(
19-
nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels),
20-
nn.LayerNorm(hidden_channels),
21-
nn.SiLU(),
22-
nn.Linear(hidden_channels, channels_m),
23-
nn.LayerNorm(channels_m),
24-
nn.SiLU()
27+
self.phi_e = torch.nn.Sequential(
28+
torch.nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels),
29+
torch.nn.LayerNorm(hidden_channels),
30+
torch.nn.SiLU(),
31+
torch.nn.Linear(hidden_channels, channels_m),
32+
torch.nn.LayerNorm(channels_m),
33+
torch.nn.SiLU(),
2534
)
26-
self.phi_x = nn.Sequential(
27-
nn.Linear(channels_m, hidden_channels),
28-
nn.LayerNorm(hidden_channels),
29-
nn.SiLU(),
30-
nn.Linear(hidden_channels, 1),
35+
self.phi_x = torch.nn.Sequential(
36+
torch.nn.Linear(channels_m, hidden_channels),
37+
torch.nn.LayerNorm(hidden_channels),
38+
torch.nn.SiLU(),
39+
torch.nn.Linear(hidden_channels, 1),
40+
)
41+
self.phi_h = torch.nn.Sequential(
42+
torch.nn.Linear(channels_h + channels_m, hidden_channels),
43+
torch.nn.LayerNorm(hidden_channels),
44+
torch.nn.SiLU(),
45+
torch.nn.Linear(hidden_channels, channels_h),
3146
)
32-
self.phi_h = nn.Sequential(
33-
nn.Linear(channels_h + channels_m, hidden_channels),
34-
nn.LayerNorm(hidden_channels),
35-
nn.SiLU(),
36-
nn.Linear(hidden_channels, channels_h),
37-
)
3847

3948
def forward(self, x, h, edge_attr, edge_index, c=None):
49+
"""
50+
TODO
51+
"""
4052
if c is None:
4153
c = degree(edge_index[0], x.shape[0]).unsqueeze(-1)
42-
return self.propagate(edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c)
54+
return self.propagate(
55+
edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c
56+
)
4357

4458
def message(self, x_i, x_j, h_i, h_j, edge_attr):
45-
mh_ij = self.phi_e(torch.cat([h_i, h_j, torch.norm(x_i - x_j, dim=-1, keepdim=True)**2, edge_attr], dim=-1))
59+
"""
60+
TODO
61+
"""
62+
mh_ij = self.phi_e(
63+
torch.cat(
64+
[
65+
h_i,
66+
h_j,
67+
torch.norm(x_i - x_j, dim=-1, keepdim=True) ** 2,
68+
edge_attr,
69+
],
70+
dim=-1,
71+
)
72+
)
4673
mx_ij = (x_i - x_j) * self.phi_x(mh_ij)
4774
return torch.cat((mx_ij, mh_ij), dim=-1)
4875

4976
def update(self, aggr_out, x, h, edge_attr, c):
50-
m_x, m_h = aggr_out[:, :self.m_len], aggr_out[:, self.m_len:]
77+
"""
78+
TODO
79+
"""
80+
m_x, m_h = aggr_out[:, : self.m_len], aggr_out[:, self.m_len :]
5181
h_l1 = self.phi_h(torch.cat([h, m_h], dim=-1))
5282
x_l1 = x + (m_x / c)
5383
return x_l1, h_l1
5484

5585
@property
5686
def edge_function(self):
87+
"""
88+
TODO
89+
"""
5790
return self._edge_function
5891

5992
@property
6093
def attribute_function(self):
94+
"""
95+
TODO
96+
"""
6197
return self._attribute_function
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""Module for the Interaction Network block."""
2+
3+
import torch
4+
from torch_geometric.nn import MessagePassing
5+
6+
7+
class InteractionNetworkBlock(MessagePassing):
8+
"""
9+
TODO
10+
"""

0 commit comments

Comments
 (0)