Skip to content

Commit f38ccbd

Browse files
committed
radial_field fix + schnet block
1 parent 805aa03 commit f38ccbd

File tree

2 files changed

+166
-20
lines changed

2 files changed

+166
-20
lines changed

pina/model/block/message_passing/radial_field_network_block.py

+12-20
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Module for the Radial Field Network block."""
22

33
import torch
4+
from ....model import FeedForward
45
from torch_geometric.nn import MessagePassing
56
from ....utils import check_consistency
67

@@ -34,7 +35,8 @@ def __init__(
3435
self,
3536
node_feature_dim,
3637
hidden_dim,
37-
edge_feature_dim,
38+
radial_hidden_dim=16,
39+
n_radial_layers=2,
3840
activation=torch.nn.ReLU,
3941
aggr="add",
4042
node_dim=-2,
@@ -66,7 +68,6 @@ def __init__(
6668

6769
# Check consistency
6870
check_consistency(node_feature_dim, int)
69-
check_consistency(edge_feature_dim, int)
7071

7172
# Check values
7273
if node_feature_dim <= 0:
@@ -75,27 +76,18 @@ def __init__(
7576
f" got {node_feature_dim}."
7677
)
7778

78-
if edge_feature_dim <= 0:
79-
raise ValueError(
80-
"`edge_feature_dim` must be a positive integer,"
81-
f" got {edge_feature_dim}."
82-
)
83-
84-
8579
# Initialize parameters
8680
self.node_feature_dim = node_feature_dim
87-
self.edge_feature_dim = edge_feature_dim
8881
self.hidden_dim = hidden_dim
8982
self.activation = activation
90-
self.layer = lambda i,o: torch.nn.Linear(
91-
in_features=i,
92-
out_features=o,
93-
bias=True,
94-
)
83+
9584
# Layer for processing node features
96-
self.radial_field = torch.nn.Sequential([self.layer(1,self.hidden_dim),
97-
torch.nn.ReLU,
98-
self.layer(self.hidden_dim,1)]
85+
self.radial_field = FeedForward(
86+
input_dimensions=1,
87+
output_dimensions=1,
88+
inner_size=radial_hidden_dim,
89+
n_layers=n_radial_layers,
90+
func=self.activation,
9991
)
10092

10193

@@ -124,10 +116,10 @@ def message(self, x_j, x_i):
124116
:return: The message to be passed.
125117
:rtype: torch.Tensor
126118
"""
127-
r = torch.norm(x_i-x_j)*(x_i-x_j)
119+
r = torch.norm(x_i-x_j)
128120

129121

130-
return self.activation(self.radial_field(r))
122+
return self.radial_field(r)*(x_i-x_j)
131123

132124

133125
def update(self, message, x):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
"""Module for the Schnet block."""
2+
3+
import torch
4+
from ....model import FeedForward
5+
from torch_geometric.nn import MessagePassing
6+
from ....utils import check_consistency
7+
8+
9+
class SchnetBlock(MessagePassing):
10+
"""
11+
Implementation of the Schnet block.
12+
13+
This block is used to perform message-passing between nodes and edges in a
14+
graph neural network, following the scheme proposed by Schütt et al. (2017).
15+
It serves as an inner block in a larger graph neural network architecture.
16+
17+
The message between two nodes connected by an edge is computed by applying a
18+
linear transformation to the sender node features and the edge features,
19+
followed by a non-linear activation function. Messages are then aggregated
20+
using an aggregation scheme (e.g., sum, mean, min, max, or product).
21+
22+
The update step is performed by a simple addition of the incoming messages
23+
to the node features.
24+
25+
.. seealso::
26+
27+
**Original reference** Schütt, K., Kindermans, P. J., Sauceda Felix, H. E., Chmiela, S., Tkatchenko, A., & Müller, K. R. (2017).
28+
Schnet: A continuous-filter convolutional neural network for modeling quantum interactions.
29+
Advances in neural information processing systems, 30.
30+
"""
31+
32+
33+
34+
def __init__(
35+
self,
36+
node_feature_dim,
37+
node_pos_dim,
38+
hidden_dim,
39+
radial_hidden_dim=16,
40+
n_message_layers=2,
41+
n_update_layers=2,
42+
n_radial_layers=2,
43+
activation=torch.nn.ReLU,
44+
aggr="add",
45+
node_dim=-2,
46+
flow="source_to_target",
47+
):
48+
"""
49+
Initialization of the :class:`RadialFieldNetworkBlock` class.
50+
51+
:param int node_feature_dim: The dimension of the node features.
52+
:param int edge_feature_dim: The dimension of the edge features.
53+
:param torch.nn.Module activation: The activation function.
54+
Default is :class:`torch.nn.Tanh`.
55+
:param str aggr: The aggregation scheme to use for message passing.
56+
Available options are "add", "mean", "min", "max", "mul".
57+
See :class:`torch_geometric.nn.MessagePassing` for more details.
58+
Default is "add".
59+
:param int node_dim: The axis along which to propagate. Default is -2.
60+
:param str flow: The direction of message passing. Available options
61+
are "source_to_target" and "target_to_source".
62+
The "source_to_target" flow means that messages are sent from
63+
the source node to the target node, while the "target_to_source"
64+
flow means that messages are sent from the target node to the
65+
source node. See :class:`torch_geometric.nn.MessagePassing` for more
66+
details. Default is "source_to_target".
67+
:raises ValueError: If `node_feature_dim` is not a positive integer.
68+
:raises ValueError: If `edge_feature_dim` is not a positive integer.
69+
"""
70+
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
71+
72+
# Check consistency
73+
check_consistency(node_feature_dim, int)
74+
75+
# Check values
76+
if node_feature_dim <= 0:
77+
raise ValueError(
78+
"`node_feature_dim` must be a positive integer,"
79+
f" got {node_feature_dim}."
80+
)
81+
82+
83+
# Initialize parameters
84+
self.node_feature_dim = node_feature_dim
85+
self.node_pos_dim = node_pos_dim
86+
self.hidden_dim = hidden_dim
87+
self.activation = activation
88+
89+
# Layer for processing node features
90+
self.radial_field = FeedForward(
91+
input_dimensions=1,
92+
output_dimensions=1,
93+
inner_size=radial_hidden_dim,
94+
n_layers=n_radial_layers,
95+
func=self.activation,
96+
)
97+
98+
self.update_net = FeedForward(
99+
input_dimensions=self.node_pos_dim + self.hidden_dim,
100+
output_dimensions=self.hidden_dim,
101+
inner_size=self.hidden_dim,
102+
n_layers=n_update_layers,
103+
func=self.activation,
104+
)
105+
106+
self.message_net = FeedForward(
107+
input_dimensions=self.node_feature_dim,
108+
output_dimensions=self.node_pos_dim + self.hidden_dim,
109+
inner_size=self.hidden_dim,
110+
n_layers=n_message_layers,
111+
func=self.activation,
112+
)
113+
114+
115+
def forward(self, x, pos, edge_index):
116+
"""
117+
Forward pass of the block, triggering the message-passing routine.
118+
119+
:param x: The node features.
120+
:type x: torch.Tensor | LabelTensor
121+
:param torch.Tensor edge_index: The edge indices. In the original formulation,
122+
the messages are aggregated from all nodes, not only from the neighbours.
123+
:return: The updated node features.
124+
:rtype: torch.Tensor
125+
"""
126+
return self.propagate(edge_index=edge_index, x=x, pos=pos)
127+
128+
def message(self, x_i, pos_i ,pos_j):
129+
"""
130+
Compute the message to be passed between nodes and edges.
131+
132+
:param x_j: Concatenation of the node position and the
133+
node features of the sender nodes.
134+
:type x_j: torch.Tensor | LabelTensor
135+
:param edge_attr: The edge attributes.
136+
:type edge_attr: torch.Tensor | LabelTensor
137+
:return: The message to be passed.
138+
:rtype: torch.Tensor
139+
"""
140+
141+
return self.radial_field(torch.norm(pos_i-pos_j))*self.message_net(x_i)
142+
143+
144+
def update(self, message, pos):
145+
"""
146+
Update the node features with the received messages.
147+
148+
:param torch.Tensor message: The message to be passed.
149+
:param x: The node features.
150+
:type x: torch.Tensor | LabelTensor
151+
:return: The concatenation of the update position features and the updated node features.
152+
:rtype: torch.Tensor
153+
"""
154+
return self.update_net(torch.cat((pos, message), dim=-1))

0 commit comments

Comments
 (0)