-
Notifications
You must be signed in to change notification settings - Fork 76
Message Passing Module #516
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
Code Coverage Summary
Results for commit: 1c6bef4 Minimum allowed coverage is ♻️ This comment has been updated with latest results |
b17fd16
to
23b3560
Compare
Hi @AleDinve @GiovanniCanali ! How is it going with this? |
Hi @dario-coscia, I need to fix some minor issues with InteractionNetwork, and then I will fix EGNN. Also, tests will be implemented. @AleDinve agreed to take care of the remaining classes. |
Yes, I confirm, I will have a tentative implementation of the classes assigned to me by the weekend. |
5727bcc
to
426da3e
Compare
f38ccbd
to
1c6bef4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good @GiovanniCanali and @AleDinve ! I made few comments on the implementation of the various blocks.
I think we should think about inserting inside utils.py
a simple function that checks integer types and values. For example (very minimalistic):
def check_values(value, positive=True, strict=True):
if positive and strict:
assert value >= 0
.....
this would reduce a lot of lines of code inside the blocks.
:return: The message to be passed. | ||
:rtype: torch.Tensor | ||
""" | ||
return self.message_net(torch.cat((x_i, x_j), dim=-1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here the edge_attr is needed
.. seealso:: | ||
|
||
**Original reference** Köhler, J., Klein, L., & Noé, F. (2020, November). | ||
Equivariant flows: exact likelihood generative learning for symmetric densities. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
80 lines max
|
||
:param x: The node features. | ||
:type x: torch.Tensor | LabelTensor | ||
:param torch.Tensor edge_index: The edge indices. In the original formulation, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
80 lines max, also LabelTensor is ok
:return: The message to be passed. | ||
:rtype: torch.Tensor | ||
""" | ||
r = torch.norm(x_i-x_j) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.norm(x_i - x_j)
computes the global norm (a scalar), which is probably not what we want in message passing.
We want per-edge norms, use:
r = torch.norm(x_i - x_j, dim=1, keepdim=True)
the messages are aggregated from all nodes, not only from the neighbours. | ||
:return: The updated node features. | ||
:rtype: torch.Tensor | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to avoid self loops, as in the original paper:
from torch_geometric.utils import remove_self_loops
edge_index, _ = remove_self_loops(edge_index)
|
||
.. seealso:: | ||
|
||
**Original reference** Schütt, K., Kindermans, P. J., Sauceda Felix, H. E., Chmiela, S., Tkatchenko, A., & Müller, K. R. (2017). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
80 max line
flow="source_to_target", | ||
): | ||
""" | ||
Initialization of the :class:`RadialFieldNetworkBlock` class. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SchNet?
:return: The updated node features. | ||
:rtype: torch.Tensor | ||
""" | ||
return self.propagate(edge_index=edge_index, x=x, pos=pos) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above for self loops
""" | ||
Compute the message to be passed between nodes and edges. | ||
|
||
:param x_j: Concatenation of the node position and the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing doc for pos
details. Default is "source_to_target". | ||
:raises ValueError: If `node_feature_dim` is not a positive integer. | ||
:raises ValueError: If `edge_feature_dim` is not a positive integer. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok for me!
Description
This PR fixes #515.
Here a tentative RoadMap
InteractionNetworkBlock
: implements the interaction network (See sec.2 of https://arxiv.org/pdf/1704.01212)DeepTensorNetworkBlock
: implements the Deep Tensor Neural Networks block (See sec.2 of https://arxiv.org/pdf/1704.01212)EGNNBlock
: implements the E(n) Equivariant Graph Neural Network block(See sec.3 of https://arxiv.org/pdf/2102.09844)RadialFieldBlock
: implements the radial field network block (See Tab1. of https://arxiv.org/pdf/2102.09844)SchnetBlock
: implements the schnet block (See Tab1. of https://arxiv.org/pdf/2102.09844)Checklist