Skip to content

Commit ed1b64c

Browse files
authored
Refactoring solvers (#541)
* Refactoring solvers * Simplify logic compile * Improve and update doc * Create SupervisedSolverInterface * Specialize SupervisedSolver and ReducedOrderModelSolver * Create EnsembleSolverInterface + EnsembleSupervisedSolver * Create tests ensemble solvers * formatter * codacy * fix issues + speedup test
1 parent 4e16d0a commit ed1b64c

37 files changed

+1513
-509
lines changed

docs/source/_rst/_code.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,19 @@ Solvers
6868
SolverInterface <solver/solver_interface.rst>
6969
SingleSolverInterface <solver/single_solver_interface.rst>
7070
MultiSolverInterface <solver/multi_solver_interface.rst>
71+
SupervisedSolverInterface <solver/supervised_solver/supervised_solver_interface>
72+
DeepEnsembleSolverInterface <solver/ensemble_solver/ensemble_solver_interface>
7173
PINNInterface <solver/physics_informed_solver/pinn_interface.rst>
7274
PINN <solver/physics_informed_solver/pinn.rst>
7375
GradientPINN <solver/physics_informed_solver/gradient_pinn.rst>
7476
CausalPINN <solver/physics_informed_solver/causal_pinn.rst>
7577
CompetitivePINN <solver/physics_informed_solver/competitive_pinn.rst>
7678
SelfAdaptivePINN <solver/physics_informed_solver/self_adaptive_pinn.rst>
7779
RBAPINN <solver/physics_informed_solver/rba_pinn.rst>
78-
SupervisedSolver <solver/supervised.rst>
79-
ReducedOrderModelSolver <solver/reduced_order_model.rst>
80+
DeepEnsemblePINN <solver/ensemble_solver/ensemble_pinn>
81+
SupervisedSolver <solver/supervised_solver/supervised.rst>
82+
DeepEnsembleSupervisedSolver <solver/ensemble_solver/ensemble_supervised>
83+
ReducedOrderModelSolver <solver/supervised_solver/reduced_order_model.rst>
8084
GAROM <solver/garom.rst>
8185

8286

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
DeepEnsemblePINN
2+
==================
3+
.. currentmodule:: pina.solver.ensemble_solver.ensemble_pinn
4+
5+
.. autoclass:: DeepEnsemblePINN
6+
:show-inheritance:
7+
:members:
8+
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
DeepEnsembleSolverInterface
2+
=============================
3+
.. currentmodule:: pina.solver.ensemble_solver.ensemble_solver_interface
4+
5+
.. autoclass:: DeepEnsembleSolverInterface
6+
:show-inheritance:
7+
:members:
8+
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
DeepEnsembleSupervisedSolver
2+
=============================
3+
.. currentmodule:: pina.solver.ensemble_solver.ensemble_supervised
4+
5+
.. autoclass:: DeepEnsembleSupervisedSolver
6+
:show-inheritance:
7+
:members:
8+

docs/source/_rst/solver/reduced_order_model.rst renamed to docs/source/_rst/solver/supervised_solver/reduced_order_model.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
ReducedOrderModelSolver
22
==========================
3-
.. currentmodule:: pina.solver.reduced_order_model
3+
.. currentmodule:: pina.solver.supervised_solver.reduced_order_model
44

55
.. autoclass:: ReducedOrderModelSolver
66
:members:

docs/source/_rst/solver/supervised.rst renamed to docs/source/_rst/solver/supervised_solver/supervised.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
SupervisedSolver
22
===================
3-
.. currentmodule:: pina.solver.supervised
3+
.. currentmodule:: pina.solver.supervised_solver.supervised
44

55
.. autoclass:: SupervisedSolver
66
:members:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
SupervisedSolverInterface
2+
==========================
3+
.. currentmodule:: pina.solver.supervised_solver.supervised_solver_interface
4+
5+
.. autoclass:: SupervisedSolverInterface
6+
:show-inheritance:
7+
:members:
8+

pina/solver/__init__.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,33 @@
1111
"CompetitivePINN",
1212
"SelfAdaptivePINN",
1313
"RBAPINN",
14+
"SupervisedSolverInterface",
1415
"SupervisedSolver",
1516
"ReducedOrderModelSolver",
17+
"DeepEnsembleSolverInterface",
18+
"DeepEnsembleSupervisedSolver",
19+
"DeepEnsemblePINN",
1620
"GAROM",
1721
]
1822

1923
from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface
20-
from .physics_informed_solver import *
21-
from .supervised import SupervisedSolver
22-
from .reduced_order_model import ReducedOrderModelSolver
24+
from .physics_informed_solver import (
25+
PINNInterface,
26+
PINN,
27+
GradientPINN,
28+
CausalPINN,
29+
CompetitivePINN,
30+
SelfAdaptivePINN,
31+
RBAPINN,
32+
)
33+
from .supervised_solver import (
34+
SupervisedSolverInterface,
35+
SupervisedSolver,
36+
ReducedOrderModelSolver,
37+
)
38+
from .ensemble_solver import (
39+
DeepEnsembleSolverInterface,
40+
DeepEnsembleSupervisedSolver,
41+
DeepEnsemblePINN,
42+
)
2343
from .garom import GAROM
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Module for the Ensemble solver classes."""
2+
3+
__all__ = [
4+
"DeepEnsembleSolverInterface",
5+
"DeepEnsembleSupervisedSolver",
6+
"DeepEnsemblePINN",
7+
]
8+
9+
from .ensemble_solver_interface import DeepEnsembleSolverInterface
10+
from .ensemble_supervised import DeepEnsembleSupervisedSolver
11+
from .ensemble_pinn import DeepEnsemblePINN
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
"""Module for the DeepEnsemble physics solver."""
2+
3+
import torch
4+
5+
from .ensemble_solver_interface import DeepEnsembleSolverInterface
6+
from ..physics_informed_solver import PINNInterface
7+
from ...problem import InverseProblem
8+
9+
10+
class DeepEnsemblePINN(PINNInterface, DeepEnsembleSolverInterface):
11+
r"""
12+
Deep Ensemble Physics Informed Solver class. This class implements a
13+
Deep Ensemble for Physics Informed Neural Networks using user
14+
specified ``model``s to solve a specific ``problem``.
15+
16+
An ensemble model is constructed by combining multiple models that solve
17+
the same type of problem. Mathematically, this creates an implicit
18+
distribution :math:`p(\mathbf{u} \mid \mathbf{s})` over the possible
19+
outputs :math:`\mathbf{u}`, given the original input :math:`\mathbf{s}`.
20+
The models :math:`\mathcal{M}_{i\in (1,\dots,r)}` in
21+
the ensemble work collaboratively to capture different
22+
aspects of the data or task, with each model contributing a distinct
23+
prediction :math:`\mathbf{y}_{i}=\mathcal{M}_i(\mathbf{u} \mid \mathbf{s})`.
24+
By aggregating these predictions, the ensemble
25+
model can achieve greater robustness and accuracy compared to individual
26+
models, leveraging the diversity of the models to reduce overfitting and
27+
improve generalization. Furthemore, statistical metrics can
28+
be computed, e.g. the ensemble mean and variance:
29+
30+
.. math::
31+
\mathbf{\mu} = \frac{1}{N}\sum_{i=1}^r \mathbf{y}_{i}
32+
33+
.. math::
34+
\mathbf{\sigma^2} = \frac{1}{N}\sum_{i=1}^r
35+
(\mathbf{y}_{i} - \mathbf{\mu})^2
36+
37+
During training the PINN loss is minimized by each ensemble model:
38+
39+
.. math::
40+
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^4
41+
\mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) +
42+
\frac{1}{N}\sum_{i=1}^N
43+
\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i)),
44+
45+
for the differential system:
46+
47+
.. math::
48+
49+
\begin{cases}
50+
\mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\
51+
\mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad,
52+
\mathbf{x}\in\partial\Omega
53+
\end{cases}
54+
55+
:math:`\mathcal{L}` indicates a specific loss function, typically the MSE:
56+
57+
.. math::
58+
\mathcal{L}(v) = \| v \|^2_2.
59+
60+
.. seealso::
61+
62+
**Original reference**: Zou, Z., Wang, Z., & Karniadakis, G. E. (2025).
63+
*Learning and discovering multiple solutions using physics-informed
64+
neural networks with random initialization and deep ensemble*.
65+
DOI: `arXiv:2503.06320 <https://arxiv.org/abs/2503.06320>`_.
66+
67+
.. warning::
68+
This solver does not work with inverse problem. Hence in the ``problem``
69+
definition must not inherit from
70+
:class:`~pina.problem.inverse_problem.InverseProblem`.
71+
"""
72+
73+
def __init__(
74+
self,
75+
problem,
76+
models,
77+
loss=None,
78+
optimizers=None,
79+
schedulers=None,
80+
weighting=None,
81+
ensemble_dim=0,
82+
):
83+
"""
84+
Initialization of the :class:`DeepEnsemblePINN` class.
85+
86+
:param AbstractProblem problem: The problem to be solved.
87+
:param torch.nn.Module models: The neural network models to be used.
88+
:param torch.nn.Module loss: The loss function to be minimized.
89+
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
90+
Default is ``None``.
91+
:param Optimizer optimizer: The optimizer to be used.
92+
If ``None``, the :class:`torch.optim.Adam` optimizer is used.
93+
Default is ``None``.
94+
:param Scheduler scheduler: Learning rate scheduler.
95+
If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR`
96+
scheduler is used. Default is ``None``.
97+
:param WeightingInterface weighting: The weighting schema to be used.
98+
If ``None``, no weighting schema is used. Default is ``None``.
99+
:param int ensemble_dim: The dimension along which the ensemble
100+
outputs are stacked. Default is 0.
101+
:raises NotImplementedError: If an inverse problem is passed.
102+
"""
103+
if isinstance(problem, InverseProblem):
104+
raise NotImplementedError(
105+
"DeepEnsemblePINN can not be used to solve inverse problems."
106+
)
107+
super().__init__(
108+
problem=problem,
109+
models=models,
110+
loss=loss,
111+
optimizers=optimizers,
112+
schedulers=schedulers,
113+
weighting=weighting,
114+
ensemble_dim=ensemble_dim,
115+
)
116+
117+
def loss_data(self, input, target):
118+
"""
119+
Compute the data loss for the ensemble PINN solver by evaluating
120+
the loss between the network's output and the true solution for each
121+
model. This method should not be overridden, if not intentionally.
122+
123+
:param input: The input to the neural network.
124+
:type input: LabelTensor | torch.Tensor | Graph | Data
125+
:param target: The target to compare with the network's output.
126+
:type target: LabelTensor | torch.Tensor | Graph | Data
127+
:return: The supervised loss, averaged over the number of observations.
128+
:rtype: torch.Tensor
129+
"""
130+
predictions = self.forward(input)
131+
loss = sum(
132+
self._loss_fn(predictions[idx], target)
133+
for idx in range(self.num_ensemble)
134+
)
135+
return loss / self.num_ensemble
136+
137+
def loss_phys(self, samples, equation):
138+
"""
139+
Computes the physics loss for the ensemble PINN solver by evaluating
140+
the loss between the network's output and the true solution for each
141+
model. This method should not be overridden, if not intentionally.
142+
143+
:param LabelTensor samples: The samples to evaluate the physics loss.
144+
:param EquationInterface equation: The governing equation.
145+
:return: The computed physics loss.
146+
:rtype: LabelTensor
147+
"""
148+
return self._residual_loss(samples, equation)
149+
150+
def _residual_loss(self, samples, equation):
151+
"""
152+
Computes the physics loss for the physics-informed solver based on the
153+
provided samples and equation. This method should never be overridden
154+
by the user, if not intentionally,
155+
since it is used internally to compute validation loss. It overrides the
156+
:obj:`~pina.solver.physics_informed_solver.PINNInterface._residual_loss`
157+
method.
158+
159+
:param LabelTensor samples: The samples to evaluate the loss.
160+
:param EquationInterface equation: The governing equation.
161+
:return: The residual loss.
162+
:rtype: torch.Tensor
163+
"""
164+
loss = 0
165+
predictions = self.forward(samples)
166+
for idx in range(self.num_ensemble):
167+
residuals = equation.residual(samples, predictions[idx])
168+
target = torch.zeros_like(residuals, requires_grad=True)
169+
loss = loss + self._loss_fn(residuals, target)
170+
return loss / self.num_ensemble

0 commit comments

Comments
 (0)