Skip to content

Commit 23b3560

Browse files
dario-cosciaGiovanniCanali
authored andcommitted
remove interface start egnn
1 parent 64ed499 commit 23b3560

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch_geometric.nn import MessagePassing
4+
from torch_geometric.utils import degree
5+
from ....utils import check_consistency
6+
7+
8+
class EnEquivariantGraphBlock(MessagePassing):
9+
def __init__(self,
10+
channels_h,
11+
channels_m,
12+
channels_a,
13+
aggr: str = 'add',
14+
hidden_channels: int = 64,
15+
**kwargs):
16+
super().__init__(aggr=aggr, **kwargs)
17+
18+
self.phi_e = nn.Sequential(
19+
nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels),
20+
nn.LayerNorm(hidden_channels),
21+
nn.SiLU(),
22+
nn.Linear(hidden_channels, channels_m),
23+
nn.LayerNorm(channels_m),
24+
nn.SiLU()
25+
)
26+
self.phi_x = nn.Sequential(
27+
nn.Linear(channels_m, hidden_channels),
28+
nn.LayerNorm(hidden_channels),
29+
nn.SiLU(),
30+
nn.Linear(hidden_channels, 1),
31+
)
32+
self.phi_h = nn.Sequential(
33+
nn.Linear(channels_h + channels_m, hidden_channels),
34+
nn.LayerNorm(hidden_channels),
35+
nn.SiLU(),
36+
nn.Linear(hidden_channels, channels_h),
37+
)
38+
39+
def forward(self, x, h, edge_attr, edge_index, c=None):
40+
if c is None:
41+
c = degree(edge_index[0], x.shape[0]).unsqueeze(-1)
42+
return self.propagate(edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c)
43+
44+
def message(self, x_i, x_j, h_i, h_j, edge_attr):
45+
mh_ij = self.phi_e(torch.cat([h_i, h_j, torch.norm(x_i - x_j, dim=-1, keepdim=True)**2, edge_attr], dim=-1))
46+
mx_ij = (x_i - x_j) * self.phi_x(mh_ij)
47+
return torch.cat((mx_ij, mh_ij), dim=-1)
48+
49+
def update(self, aggr_out, x, h, edge_attr, c):
50+
m_x, m_h = aggr_out[:, :self.m_len], aggr_out[:, self.m_len:]
51+
h_l1 = self.phi_h(torch.cat([h, m_h], dim=-1))
52+
x_l1 = x + (m_x / c)
53+
return x_l1, h_l1
54+
55+
@property
56+
def edge_function(self):
57+
return self._edge_function
58+
59+
@property
60+
def attribute_function(self):
61+
return self._attribute_function

pina/model/block/message_passing/message_passing_interface.py

Whitespace-only changes.

0 commit comments

Comments
 (0)