Skip to content

Message Passing Module #516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: dev
Choose a base branch
from
Draft

Message Passing Module #516

wants to merge 6 commits into from

Conversation

dario-coscia
Copy link
Collaborator

@dario-coscia dario-coscia commented Mar 21, 2025

Description

This PR fixes #515.

Here a tentative RoadMap

Checklist

  • Code follows the project’s Code Style Guidelines
  • Tests have been added or updated
  • Documentation has been updated if necessary
  • Pull request is linked to an open issue

@dario-coscia dario-coscia added enhancement New feature or request pr-to-fix Label for PR that needs modification labels Mar 21, 2025
@dario-coscia dario-coscia linked an issue Mar 21, 2025 that may be closed by this pull request
Copy link
Contributor

github-actions bot commented Mar 21, 2025

badge

Code Coverage Summary

Filename                                                     Stmts    Miss  Cover    Missing
---------------------------------------------------------  -------  ------  -------  ---------------------------------------------------------------------------------------------------------------------
__init__.py                                                      7       0  100.00%
collector.py                                                    39       1  97.44%   46
graph.py                                                       114      11  90.35%   99-100, 112, 124, 126, 142, 144, 166, 169, 182, 271
label_tensor.py                                                251      32  87.25%   81, 121, 144-148, 165, 177, 182, 188-193, 273, 280, 332, 334, 348, 444-447, 490, 537, 629, 649-651, 664-673, 688, 710
operator.py                                                     68       5  92.65%   250-268, 457
operators.py                                                     6       6  0.00%    3-12
plotter.py                                                       1       1  0.00%    3
trainer.py                                                      75       6  92.00%   195-204, 293, 314, 318, 322
utils.py                                                        56       8  85.71%   113, 150, 153, 156, 192-195
adaptive_function/__init__.py                                    3       0  100.00%
adaptive_function/adaptive_function.py                          55       0  100.00%
adaptive_function/adaptive_function_interface.py                51       6  88.24%   98, 141, 148-151
adaptive_functions/__init__.py                                   6       6  0.00%    3-12
callback/__init__.py                                             5       0  100.00%
callback/adaptive_refinement_callback.py                         8       1  87.50%   37
callback/linear_weight_update_callback.py                       28       1  96.43%   63
callback/optimizer_callback.py                                  22       1  95.45%   34
callback/processing_callback.py                                 49       5  89.80%   42-43, 73, 168, 171
callbacks/__init__.py                                            6       6  0.00%    3-12
condition/__init__.py                                            7       0  100.00%
condition/condition.py                                          35       8  77.14%   23, 127-128, 131-132, 135-136, 151
condition/condition_interface.py                                37       4  89.19%   31, 76, 100, 122
condition/data_condition.py                                     26       1  96.15%   56
condition/domain_equation_condition.py                          19       0  100.00%
condition/input_equation_condition.py                           44       1  97.73%   129
condition/input_target_condition.py                             44       1  97.73%   125
data/__init__.py                                                 3       0  100.00%
data/data_module.py                                            204      22  89.22%   42-53, 133, 173, 194, 233, 314-318, 324-328, 400, 467, 547, 638, 640
data/dataset.py                                                 80       6  92.50%   42, 123-126, 291
domain/__init__.py                                              10       0  100.00%
domain/cartesian.py                                            112      10  91.07%   37, 47, 75-76, 92, 97, 103, 246, 256, 264
domain/difference_domain.py                                     25       2  92.00%   54, 87
domain/domain_interface.py                                      20       5  75.00%   37-41
domain/ellipsoid.py                                            104      24  76.92%   52, 56, 127, 250-257, 269-282, 286-287, 290, 295
domain/exclusion_domain.py                                      28       1  96.43%   86
domain/intersection_domain.py                                   28       1  96.43%   85
domain/operation_interface.py                                   26       1  96.15%   88
domain/simplex.py                                               72      14  80.56%   62, 207-225, 246-247, 251, 256
domain/union_domain.py                                          25       2  92.00%   43, 114
equation/__init__.py                                             4       0  100.00%
equation/equation.py                                            11       0  100.00%
equation/equation_factory.py                                    24      10  58.33%   37, 62-75, 97-110, 132-145
equation/equation_interface.py                                   4       0  100.00%
equation/system_equation.py                                     22       0  100.00%
geometry/__init__.py                                             7       7  0.00%    3-15
loss/__init__.py                                                 7       0  100.00%
loss/loss_interface.py                                          17       2  88.24%   45, 51
loss/lp_loss.py                                                 15       0  100.00%
loss/ntk_weighting.py                                           26       0  100.00%
loss/power_loss.py                                              15       0  100.00%
loss/scalar_weighting.py                                        16       0  100.00%
loss/weighting_interface.py                                      6       0  100.00%
model/__init__.py                                               10       0  100.00%
model/average_neural_operator.py                                31       2  93.55%   73, 82
model/deeponet.py                                               93      13  86.02%   187-190, 209, 240, 283, 293, 303, 313, 323, 333, 488, 498
model/feed_forward.py                                           89      11  87.64%   58, 195, 200, 278-292
model/fourier_neural_operator.py                                78      10  87.18%   96-100, 110, 155-159, 218, 220, 242, 342
model/graph_neural_operator.py                                  40       2  95.00%   58, 60
model/kernel_neural_operator.py                                 34       6  82.35%   83-84, 103-104, 123-124
model/low_rank_neural_operator.py                               27       2  92.59%   89, 98
model/multi_feed_forward.py                                     12       5  58.33%   25-31
model/spline.py                                                 89      37  58.43%   30, 41-66, 69, 128-132, 135, 159-177, 180
model/block/__init__.py                                         12       0  100.00%
model/block/average_neural_operator_block.py                    12       0  100.00%
model/block/convolution.py                                      64      13  79.69%   77, 81, 85, 91, 97, 111, 114, 151, 161, 171, 181, 191, 201
model/block/convolution_2d.py                                  146      27  81.51%   155, 162, 282, 314, 379-433, 456
model/block/embedding.py                                        48       7  85.42%   93, 143-146, 155, 168
model/block/fourier_block.py                                    31       0  100.00%
model/block/gno_block.py                                        22       4  81.82%   73-77, 87
model/block/integral.py                                         18       4  77.78%   22-25, 71
model/block/low_rank_block.py                                   24       0  100.00%
model/block/orthogonal.py                                       37       0  100.00%
model/block/pod_block.py                                        65       9  86.15%   54-57, 69, 99, 134-139, 170, 195
model/block/rbf_block.py                                       179      25  86.03%   18, 42, 53, 64, 75, 86, 97, 223, 280, 282, 298, 301, 329, 335, 363, 367, 511-524
model/block/residual.py                                         46       0  100.00%
model/block/spectral.py                                         83       4  95.18%   132, 140, 262, 270
model/block/stride.py                                           28       7  75.00%   55, 58, 61, 67, 72-74
model/block/utils_convolution.py                                22       3  86.36%   58-60
model/block/message_passing/__init__.py                          3       3  0.00%    3-9
model/block/message_passing/deep_tensor_network_block.py        27      27  0.00%    3-152
model/block/message_passing/egnn_block.py                       28      28  0.00%    3-97
model/block/message_passing/interaction_network_block.py        30      30  0.00%    3-168
model/block/message_passing/radial_field_network_block.py       21      21  0.00%    3-135
model/block/message_passing/schnet_block.py                     23      23  0.00%    3-154
model/layers/__init__.py                                         6       6  0.00%    3-12
optim/__init__.py                                                5       0  100.00%
optim/optimizer_interface.py                                     7       0  100.00%
optim/scheduler_interface.py                                     7       0  100.00%
optim/torch_optimizer.py                                        14       0  100.00%
optim/torch_scheduler.py                                        19       2  89.47%   5-6
problem/__init__.py                                              6       0  100.00%
problem/abstract_problem.py                                    104      14  86.54%   52, 61, 101-106, 135, 147, 165, 239, 243, 272
problem/inverse_problem.py                                      22       0  100.00%
problem/parametric_problem.py                                    8       1  87.50%   29
problem/spatial_problem.py                                       8       0  100.00%
problem/time_dependent_problem.py                                8       0  100.00%
problem/zoo/__init__.py                                          8       0  100.00%
problem/zoo/advection.py                                        33       7  78.79%   36-38, 52, 108-110
problem/zoo/allen_cahn.py                                       20       6  70.00%   20-22, 34-36
problem/zoo/diffusion_reaction.py                               29       5  82.76%   94-104
problem/zoo/helmholtz.py                                        30       6  80.00%   36-42, 103-107
problem/zoo/inverse_poisson_2d_square.py                        31       0  100.00%
problem/zoo/poisson_2d_square.py                                19       3  84.21%   65-70
problem/zoo/supervised_problem.py                               11       0  100.00%
solver/__init__.py                                               6       0  100.00%
solver/garom.py                                                107       2  98.13%   129-130
solver/solver.py                                               188      10  94.68%   192, 215, 287, 290-291, 350, 432, 515, 556, 562
solver/ensemble_solver/__init__.py                               4       0  100.00%
solver/ensemble_solver/ensemble_pinn.py                         23       1  95.65%   104
solver/ensemble_solver/ensemble_solver_interface.py             27       0  100.00%
solver/ensemble_solver/ensemble_supervised.py                    9       0  100.00%
solver/physics_informed_solver/__init__.py                       8       0  100.00%
solver/physics_informed_solver/causal_pinn.py                   47       3  93.62%   157, 166-167
solver/physics_informed_solver/competitive_pinn.py              58       0  100.00%
solver/physics_informed_solver/gradient_pinn.py                 17       0  100.00%
solver/physics_informed_solver/pinn.py                          18       0  100.00%
solver/physics_informed_solver/pinn_interface.py                47       1  97.87%   130
solver/physics_informed_solver/rba_pinn.py                      35       3  91.43%   155-158
solver/physics_informed_solver/self_adaptive_pinn.py            90       3  96.67%   315-318
solver/supervised_solver/__init__.py                             4       0  100.00%
solver/supervised_solver/reduced_order_model.py                 24       1  95.83%   137
solver/supervised_solver/supervised.py                           7       0  100.00%
solver/supervised_solver/supervised_solver_interface.py         25       0  100.00%
solvers/__init__.py                                              6       6  0.00%    3-12
solvers/pinns/__init__.py                                        6       6  0.00%    3-12
TOTAL                                                         4596     627  86.36%

Results for commit: 1c6bef4

Minimum allowed coverage is 80.123%

♻️ This comment has been updated with latest results

@dario-coscia
Copy link
Collaborator Author

Hi @AleDinve @GiovanniCanali ! How is it going with this?

@GiovanniCanali
Copy link
Collaborator

Hi @AleDinve @GiovanniCanali ! How is it going with this?

Hi @dario-coscia, I need to fix some minor issues with InteractionNetwork, and then I will fix EGNN. Also, tests will be implemented. @AleDinve agreed to take care of the remaining classes.

@AleDinve
Copy link
Collaborator

Hi @AleDinve @GiovanniCanali ! How is it going with this?

Hi @dario-coscia, I need to fix some minor issues with InteractionNetwork, and then I will fix EGNN. Also, tests will be implemented. @AleDinve agreed to take care of the remaining classes.

Yes, I confirm, I will have a tentative implementation of the classes assigned to me by the weekend.

Copy link
Collaborator Author

@dario-coscia dario-coscia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good @GiovanniCanali and @AleDinve ! I made few comments on the implementation of the various blocks.

I think we should think about inserting inside utils.py a simple function that checks integer types and values. For example (very minimalistic):

def check_values(value, positive=True, strict=True):
   if positive and strict:
       assert value >= 0
   .....

this would reduce a lot of lines of code inside the blocks.

:return: The message to be passed.
:rtype: torch.Tensor
"""
return self.message_net(torch.cat((x_i, x_j), dim=-1))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here the edge_attr is needed

.. seealso::

**Original reference** Köhler, J., Klein, L., & Noé, F. (2020, November).
Equivariant flows: exact likelihood generative learning for symmetric densities.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

80 lines max


:param x: The node features.
:type x: torch.Tensor | LabelTensor
:param torch.Tensor edge_index: The edge indices. In the original formulation,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

80 lines max, also LabelTensor is ok

:return: The message to be passed.
:rtype: torch.Tensor
"""
r = torch.norm(x_i-x_j)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.norm(x_i - x_j) computes the global norm (a scalar), which is probably not what we want in message passing.

We want per-edge norms, use:

r = torch.norm(x_i - x_j, dim=1, keepdim=True)

the messages are aggregated from all nodes, not only from the neighbours.
:return: The updated node features.
:rtype: torch.Tensor
"""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to avoid self loops, as in the original paper:

from torch_geometric.utils import remove_self_loops

edge_index, _ = remove_self_loops(edge_index)


.. seealso::

**Original reference** Schütt, K., Kindermans, P. J., Sauceda Felix, H. E., Chmiela, S., Tkatchenko, A., & Müller, K. R. (2017).
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

80 max line

flow="source_to_target",
):
"""
Initialization of the :class:`RadialFieldNetworkBlock` class.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SchNet?

:return: The updated node features.
:rtype: torch.Tensor
"""
return self.propagate(edge_index=edge_index, x=x, pos=pos)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above for self loops

"""
Compute the message to be passed between nodes and edges.

:param x_j: Concatenation of the node position and the
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing doc for pos

details. Default is "source_to_target".
:raises ValueError: If `node_feature_dim` is not a positive integer.
:raises ValueError: If `edge_feature_dim` is not a positive integer.
"""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok for me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request pr-to-fix Label for PR that needs modification
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Message Passing Module
3 participants