Skip to content

Commit 426da3e

Browse files
add interaction network block
1 parent 4b284b8 commit 426da3e

File tree

2 files changed

+190
-8
lines changed

2 files changed

+190
-8
lines changed

pina/model/block/message_passing/deep_tensor_network_block.py

+31-7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from torch_geometric.nn import MessagePassing
5+
from ....utils import check_consistency
56

67

78
class DeepTensorNetworkBlock(MessagePassing):
@@ -25,7 +26,7 @@ class DeepTensorNetworkBlock(MessagePassing):
2526
**Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al.
2627
*Quantum-Chemical Insights from Deep Tensor Neural Networks*.
2728
Nature Communications 8, 13890 (2017).
28-
DOI: `<https://doi.org/10.1038/ncomms13890>_`
29+
DOI: `<https://doi.org/10.1038/ncomms13890>_`.
2930
"""
3031

3132
def __init__(
@@ -38,7 +39,7 @@ def __init__(
3839
flow="source_to_target",
3940
):
4041
"""
41-
Initialization of the :class:`AVNOBDeepTensorNetworkBlocklock` class.
42+
Initialization of the :class:`DeepTensorNetworkBlocklock` class.
4243
4344
:param int node_feature_dim: The dimension of the node features.
4445
:param int edge_feature_dim: The dimension of the edge features.
@@ -49,12 +50,36 @@ def __init__(
4950
See :class:`torch_geometric.nn.MessagePassing` for more details.
5051
Default is "add".
5152
:param int node_dim: The axis along which to propagate. Default is -2.
52-
:param str flow: The direction of message passing.
53-
See :class:`torch_geometric.nn.MessagePassing` for more details.
54-
Default is "source_to_target".
53+
:param str flow: The direction of message passing. Available options
54+
are "source_to_target" and "target_to_source".
55+
The "source_to_target" flow means that messages are sent from
56+
the source node to the target node, while the "target_to_source"
57+
flow means that messages are sent from the target node to the
58+
source node. See :class:`torch_geometric.nn.MessagePassing` for more
59+
details. Default is "source_to_target".
60+
:raises ValueError: If `node_feature_dim` is not a positive integer.
61+
:raises ValueError: If `edge_feature_dim` is not a positive integer.
5562
"""
5663
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
5764

65+
# Check consistency
66+
check_consistency(node_feature_dim, int)
67+
check_consistency(edge_feature_dim, int)
68+
69+
# Check values
70+
if node_feature_dim <= 0:
71+
raise ValueError(
72+
"`node_feature_dim` must be a positive integer,"
73+
f" got {node_feature_dim}."
74+
)
75+
76+
if edge_feature_dim <= 0:
77+
raise ValueError(
78+
"`edge_feature_dim` must be a positive integer,"
79+
f" got {edge_feature_dim}."
80+
)
81+
82+
# Initialize parameters
5883
self.node_feature_dim = node_feature_dim
5984
self.edge_feature_dim = edge_feature_dim
6085
self.activation = activation
@@ -82,8 +107,7 @@ def __init__(
82107

83108
def forward(self, x, edge_index, edge_attr):
84109
"""
85-
Forward pass of the block. It performs a message-passing operation
86-
between nodes and edges.
110+
Forward pass of the block, triggering the message-passing routine.
87111
88112
:param x: The node features.
89113
:type x: torch.Tensor | LabelTensor

pina/model/block/message_passing/interaction_network_block.py

+159-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,167 @@
22

33
import torch
44
from torch_geometric.nn import MessagePassing
5+
from ....model import FeedForward
6+
from ....utils import check_consistency
57

68

79
class InteractionNetworkBlock(MessagePassing):
810
"""
9-
TODO
11+
Implementation of the Interaction Network block.
12+
13+
This block is used to perform message-passing between nodes and edges in a
14+
graph neural network, following the scheme proposed by Battaglia et al.
15+
(2016).
16+
It serves as an inner block in a larger graph neural network architecture.
17+
18+
The message between two nodes connected by an edge is computed by applying a
19+
multi-layer perceptron (MLP) to the concatenation of the sender and
20+
recipient node features. Messages are then aggregated using an aggregation
21+
scheme (e.g., sum, mean, min, max, or product).
22+
23+
The update step is performed by applying another MLP to the concatenation of
24+
the incoming messages and the node features.
25+
26+
.. seealso::
27+
28+
**Original reference**: Battaglia, P. W., et al. (2016).
29+
*Interaction Networks for Learning about Objects, Relations and
30+
Physics*.
31+
In Advances in Neural Information Processing Systems (NeurIPS 2016).
32+
DOI: `<https://doi.org/10.48550/arXiv.1612.00222>_`.
1033
"""
34+
35+
def __init__(
36+
self,
37+
node_feature_dim,
38+
hidden_dim,
39+
n_message_layers=2,
40+
n_update_layers=2,
41+
activation=torch.nn.SiLU,
42+
aggr="add",
43+
node_dim=-2,
44+
flow="source_to_target",
45+
):
46+
"""
47+
Initialization of the :class:`InteractionNetworkBlock` class.
48+
49+
:param int node_feature_dim: The dimension of the node features.
50+
:param int hidden_dim: The dimension of the hidden features.
51+
:param int n_message_layers: The number of layers in the message
52+
network. Default is 2.
53+
:param int n_update_layers: The number of layers in the update network.
54+
Default is 2.
55+
:param torch.nn.Module activation: The activation function.
56+
Default is :class:`torch.nn.SiLU`.
57+
:param str aggr: The aggregation scheme to use for message passing.
58+
Available options are "add", "mean", "min", "max", "mul".
59+
See :class:`torch_geometric.nn.MessagePassing` for more details.
60+
Default is "add".
61+
:param int node_dim: The axis along which to propagate. Default is -2.
62+
:param str flow: The direction of message passing. Available options
63+
are "source_to_target" and "target_to_source".
64+
The "source_to_target" flow means that messages are sent from
65+
the source node to the target node, while the "target_to_source"
66+
flow means that messages are sent from the target node to the
67+
source node. See :class:`torch_geometric.nn.MessagePassing` for more
68+
details. Default is "source_to_target".
69+
:raises ValueError: If `node_feature_dim` is not a positive integer.
70+
:raises ValueError: If `hidden_dim` is not a positive integer.
71+
:raises ValueError: If `n_message_layers` is not a positive integer.
72+
:raises ValueError: If `n_update_layers` is not a positive integer.
73+
"""
74+
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
75+
76+
# Check consistency
77+
check_consistency(node_feature_dim, int)
78+
check_consistency(hidden_dim, int)
79+
check_consistency(n_message_layers, int)
80+
check_consistency(n_update_layers, int)
81+
82+
# Check values
83+
if node_feature_dim <= 0:
84+
raise ValueError(
85+
"`node_feature_dim` must be a positive integer,"
86+
f" got {node_feature_dim}."
87+
)
88+
89+
if hidden_dim <= 0:
90+
raise ValueError(
91+
"`hidden_dim` must be a positive integer," f" got {hidden_dim}."
92+
)
93+
94+
if n_message_layers <= 0:
95+
raise ValueError(
96+
"`n_message_layers` must be a positive integer,"
97+
f" got {n_message_layers}."
98+
)
99+
100+
if n_update_layers <= 0:
101+
raise ValueError(
102+
"`n_update_layers` must be a positive integer,"
103+
f" got {n_update_layers}."
104+
)
105+
106+
# Initialize parameters
107+
self.node_feature_dim = node_feature_dim
108+
self.hidden_dim = hidden_dim
109+
self.activation = activation
110+
111+
# Message network
112+
self.message_net = FeedForward(
113+
input_dimensions=2 * self.node_feature_dim,
114+
output_dimensions=self.hidden_dim,
115+
inner_size=self.hidden_dim,
116+
n_layers=n_message_layers,
117+
func=self.activation,
118+
)
119+
120+
# Update network
121+
self.update_net = FeedForward(
122+
input_dimensions=self.node_feature_dim + self.hidden_dim,
123+
output_dimensions=self.hidden_dim,
124+
inner_size=self.node_feature_dim,
125+
n_layers=n_update_layers,
126+
func=self.activation,
127+
)
128+
129+
def forward(self, x, edge_index, edge_attr):
130+
"""
131+
Forward pass of the block, triggering the message-passing routine.
132+
133+
:param x: The node features.
134+
:type x: torch.Tensor | LabelTensor
135+
:param torch.Tensor edge_index: The edge indeces.
136+
:param edge_attr: The edge attributes.
137+
:type edge_attr: torch.Tensor | LabelTensor
138+
:return: The updated node features.
139+
:rtype: torch.Tensor
140+
"""
141+
142+
# TODO: edge_attr is not used in the message function
143+
return self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr)
144+
145+
def message(self, x_i, x_j):
146+
"""
147+
Compute the message to be passed between nodes and edges.
148+
149+
:param x_i: The node features of the recipient nodes.
150+
:type x_i: torch.Tensor | LabelTensor
151+
:param x_j: The node features of the sender nodes.
152+
:type x_j: torch.Tensor | LabelTensor
153+
:return: The message to be passed.
154+
:rtype: torch.Tensor
155+
"""
156+
return self.message_net(torch.cat((x_i, x_j), dim=-1))
157+
158+
def update(self, message, x):
159+
"""
160+
Update the node features with the received messages.
161+
162+
:param torch.Tensor message: The message to be passed.
163+
:param x: The node features.
164+
:type x: torch.Tensor | LabelTensor
165+
:return: The updated node features.
166+
:rtype: torch.Tensor
167+
"""
168+
return self.update_net(torch.cat((x, message), dim=-1))

0 commit comments

Comments
 (0)