Skip to content

Commit b78b6cd

Browse files
committed
radial field
1 parent 426da3e commit b78b6cd

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""Module for the Radial Field Network block."""
2+
3+
import torch
4+
from torch_geometric.nn import MessagePassing
5+
from ....utils import check_consistency
6+
7+
8+
class RadialFieldNetworkBlock(MessagePassing):
9+
"""
10+
Implementation of the Radial Field Network block.
11+
12+
This block is used to perform message-passing between nodes and edges in a
13+
graph neural network, following the scheme proposed by Köhler et al. (2020).
14+
It serves as an inner block in a larger graph neural network architecture.
15+
16+
The message between two nodes connected by an edge is computed by applying a
17+
linear transformation to the sender node features and the edge features,
18+
followed by a non-linear activation function. Messages are then aggregated
19+
using an aggregation scheme (e.g., sum, mean, min, max, or product).
20+
21+
The update step is performed by a simple addition of the incoming messages
22+
to the node features.
23+
24+
.. seealso::
25+
26+
**Original reference** Köhler, J., Klein, L., & Noé, F. (2020, November).
27+
Equivariant flows: exact likelihood generative learning for symmetric densities.
28+
In International conference on machine learning (pp. 5361-5370). PMLR.
29+
"""
30+
31+
32+
33+
def __init__(
34+
self,
35+
node_feature_dim,
36+
hidden_dim,
37+
edge_feature_dim,
38+
activation=torch.nn.ReLU,
39+
aggr="add",
40+
node_dim=-2,
41+
flow="source_to_target",
42+
):
43+
"""
44+
Initialization of the :class:`RadialFieldNetworkBlock` class.
45+
46+
:param int node_feature_dim: The dimension of the node features.
47+
:param int edge_feature_dim: The dimension of the edge features.
48+
:param torch.nn.Module activation: The activation function.
49+
Default is :class:`torch.nn.Tanh`.
50+
:param str aggr: The aggregation scheme to use for message passing.
51+
Available options are "add", "mean", "min", "max", "mul".
52+
See :class:`torch_geometric.nn.MessagePassing` for more details.
53+
Default is "add".
54+
:param int node_dim: The axis along which to propagate. Default is -2.
55+
:param str flow: The direction of message passing. Available options
56+
are "source_to_target" and "target_to_source".
57+
The "source_to_target" flow means that messages are sent from
58+
the source node to the target node, while the "target_to_source"
59+
flow means that messages are sent from the target node to the
60+
source node. See :class:`torch_geometric.nn.MessagePassing` for more
61+
details. Default is "source_to_target".
62+
:raises ValueError: If `node_feature_dim` is not a positive integer.
63+
:raises ValueError: If `edge_feature_dim` is not a positive integer.
64+
"""
65+
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
66+
67+
# Check consistency
68+
check_consistency(node_feature_dim, int)
69+
check_consistency(edge_feature_dim, int)
70+
71+
# Check values
72+
if node_feature_dim <= 0:
73+
raise ValueError(
74+
"`node_feature_dim` must be a positive integer,"
75+
f" got {node_feature_dim}."
76+
)
77+
78+
if edge_feature_dim <= 0:
79+
raise ValueError(
80+
"`edge_feature_dim` must be a positive integer,"
81+
f" got {edge_feature_dim}."
82+
)
83+
84+
85+
# Initialize parameters
86+
self.node_feature_dim = node_feature_dim
87+
self.edge_feature_dim = edge_feature_dim
88+
self.hidden_dim = hidden_dim
89+
self.activation = activation
90+
self.layer = lambda i,o: torch.nn.Linear(
91+
in_features=i,
92+
out_features=o,
93+
bias=True,
94+
)
95+
# Layer for processing node features
96+
self.radial_field = torch.nn.Sequential([self.layer(1,self.hidden_dim),
97+
torch.nn.ReLU,
98+
self.layer(self.hidden_dim,1)]
99+
)
100+
101+
102+
def forward(self, x, edge_index):
103+
"""
104+
Forward pass of the block, triggering the message-passing routine.
105+
106+
:param x: The node features.
107+
:type x: torch.Tensor | LabelTensor
108+
:param torch.Tensor edge_index: The edge indices.
109+
:return: The updated node features.
110+
:rtype: torch.Tensor
111+
"""
112+
return self.propagate(edge_index=edge_index, x=x)
113+
114+
def message(self, x_j, x_i):
115+
"""
116+
Compute the message to be passed between nodes and edges.
117+
118+
:param x_j: Concatenation of the node position and the
119+
node features of the sender nodes.
120+
:type x_j: torch.Tensor | LabelTensor
121+
:param edge_attr: The edge attributes.
122+
:type edge_attr: torch.Tensor | LabelTensor
123+
:return: The message to be passed.
124+
:rtype: torch.Tensor
125+
"""
126+
r = torch.norm(x_i-x_j)*(x_i-x_j)
127+
128+
129+
return self.activation(self.radial_field(r))
130+
131+
132+
def update(self, message, x):
133+
"""
134+
Update the node features with the received messages.
135+
136+
:param torch.Tensor message: The message to be passed.
137+
:param x: The node features.
138+
:type x: torch.Tensor | LabelTensor
139+
:return: The updated node features.
140+
:rtype: torch.Tensor
141+
"""
142+
return x + message

0 commit comments

Comments
 (0)