|
| 1 | +"""Module for the DeepEnsemble physics solver.""" |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from .ensemble_solver_interface import DeepEnsembleSolverInterface |
| 6 | +from ..physics_informed_solver import PINNInterface |
| 7 | +from ...problem import InverseProblem |
| 8 | + |
| 9 | + |
| 10 | +class DeepEnsemblePINN(PINNInterface, DeepEnsembleSolverInterface): |
| 11 | + r""" |
| 12 | + Deep Ensemble Physics Informed Solver class. This class implements a |
| 13 | + Deep Ensemble for Physics Informed Neural Networks using user |
| 14 | + specified ``model``s to solve a specific ``problem``. |
| 15 | +
|
| 16 | + An ensemble model is constructed by combining multiple models that solve |
| 17 | + the same type of problem. Mathematically, this creates an implicit |
| 18 | + distribution :math:`p(\mathbf{u} \mid \mathbf{s})` over the possible |
| 19 | + outputs :math:`\mathbf{u}`, given the original input :math:`\mathbf{s}`. |
| 20 | + The models :math:`\mathcal{M}_{i\in (1,\dots,r)}` in |
| 21 | + the ensemble work collaboratively to capture different |
| 22 | + aspects of the data or task, with each model contributing a distinct |
| 23 | + prediction :math:`\mathbf{y}_{i}=\mathcal{M}_i(\mathbf{u} \mid \mathbf{s})`. |
| 24 | + By aggregating these predictions, the ensemble |
| 25 | + model can achieve greater robustness and accuracy compared to individual |
| 26 | + models, leveraging the diversity of the models to reduce overfitting and |
| 27 | + improve generalization. Furthemore, statistical metrics can |
| 28 | + be computed, e.g. the ensemble mean and variance: |
| 29 | +
|
| 30 | + .. math:: |
| 31 | + \mathbf{\mu} = \frac{1}{N}\sum_{i=1}^r \mathbf{y}_{i} |
| 32 | +
|
| 33 | + .. math:: |
| 34 | + \mathbf{\sigma^2} = \frac{1}{N}\sum_{i=1}^r |
| 35 | + (\mathbf{y}_{i} - \mathbf{\mu})^2 |
| 36 | +
|
| 37 | + During training the PINN loss is minimized by each ensemble model: |
| 38 | +
|
| 39 | + .. math:: |
| 40 | + \mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^4 |
| 41 | + \mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) + |
| 42 | + \frac{1}{N}\sum_{i=1}^N |
| 43 | + \mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i)), |
| 44 | +
|
| 45 | + for the differential system: |
| 46 | + |
| 47 | + .. math:: |
| 48 | +
|
| 49 | + \begin{cases} |
| 50 | + \mathcal{A}[\mathbf{u}](\mathbf{x})=0\quad,\mathbf{x}\in\Omega\\ |
| 51 | + \mathcal{B}[\mathbf{u}](\mathbf{x})=0\quad, |
| 52 | + \mathbf{x}\in\partial\Omega |
| 53 | + \end{cases} |
| 54 | +
|
| 55 | + :math:`\mathcal{L}` indicates a specific loss function, typically the MSE: |
| 56 | +
|
| 57 | + .. math:: |
| 58 | + \mathcal{L}(v) = \| v \|^2_2. |
| 59 | +
|
| 60 | + .. seealso:: |
| 61 | +
|
| 62 | + **Original reference**: Zou, Z., Wang, Z., & Karniadakis, G. E. (2025). |
| 63 | + *Learning and discovering multiple solutions using physics-informed |
| 64 | + neural networks with random initialization and deep ensemble*. |
| 65 | + DOI: `arXiv:2503.06320 <https://arxiv.org/abs/2503.06320>`_. |
| 66 | +
|
| 67 | + .. warning:: |
| 68 | + This solver does not work with inverse problem. Hence in the ``problem`` |
| 69 | + definition must not inherit from |
| 70 | + :class:`~pina.problem.inverse_problem.InverseProblem`. |
| 71 | + """ |
| 72 | + |
| 73 | + def __init__( |
| 74 | + self, |
| 75 | + problem, |
| 76 | + models, |
| 77 | + loss=None, |
| 78 | + optimizers=None, |
| 79 | + schedulers=None, |
| 80 | + weighting=None, |
| 81 | + ensemble_dim=0, |
| 82 | + ): |
| 83 | + """ |
| 84 | + Initialization of the :class:`DeepEnsemblePINN` class. |
| 85 | +
|
| 86 | + :param AbstractProblem problem: The problem to be solved. |
| 87 | + :param torch.nn.Module models: The neural network models to be used. |
| 88 | + :param torch.nn.Module loss: The loss function to be minimized. |
| 89 | + If ``None``, the :class:`torch.nn.MSELoss` loss is used. |
| 90 | + Default is ``None``. |
| 91 | + :param Optimizer optimizer: The optimizer to be used. |
| 92 | + If ``None``, the :class:`torch.optim.Adam` optimizer is used. |
| 93 | + Default is ``None``. |
| 94 | + :param Scheduler scheduler: Learning rate scheduler. |
| 95 | + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` |
| 96 | + scheduler is used. Default is ``None``. |
| 97 | + :param WeightingInterface weighting: The weighting schema to be used. |
| 98 | + If ``None``, no weighting schema is used. Default is ``None``. |
| 99 | + :param int ensemble_dim: The dimension along which the ensemble |
| 100 | + outputs are stacked. Default is 0. |
| 101 | + :raises NotImplementedError: If an inverse problem is passed. |
| 102 | + """ |
| 103 | + if isinstance(problem, InverseProblem): |
| 104 | + raise NotImplementedError( |
| 105 | + "DeepEnsemblePINN can not be used to solve inverse problems." |
| 106 | + ) |
| 107 | + super().__init__( |
| 108 | + problem=problem, |
| 109 | + models=models, |
| 110 | + loss=loss, |
| 111 | + optimizers=optimizers, |
| 112 | + schedulers=schedulers, |
| 113 | + weighting=weighting, |
| 114 | + ensemble_dim=ensemble_dim, |
| 115 | + ) |
| 116 | + |
| 117 | + def loss_data(self, input, target): |
| 118 | + """ |
| 119 | + Compute the data loss for the ensemble PINN solver by evaluating |
| 120 | + the loss between the network's output and the true solution for each |
| 121 | + model. This method should not be overridden, if not intentionally. |
| 122 | +
|
| 123 | + :param input: The input to the neural network. |
| 124 | + :type input: LabelTensor | torch.Tensor | Graph | Data |
| 125 | + :param target: The target to compare with the network's output. |
| 126 | + :type target: LabelTensor | torch.Tensor | Graph | Data |
| 127 | + :return: The supervised loss, averaged over the number of observations. |
| 128 | + :rtype: torch.Tensor |
| 129 | + """ |
| 130 | + predictions = self.forward(input) |
| 131 | + loss = sum( |
| 132 | + self._loss_fn(predictions[idx], target) |
| 133 | + for idx in range(self.num_ensemble) |
| 134 | + ) |
| 135 | + return loss / self.num_ensemble |
| 136 | + |
| 137 | + def loss_phys(self, samples, equation): |
| 138 | + """ |
| 139 | + Computes the physics loss for the ensemble PINN solver by evaluating |
| 140 | + the loss between the network's output and the true solution for each |
| 141 | + model. This method should not be overridden, if not intentionally. |
| 142 | +
|
| 143 | + :param LabelTensor samples: The samples to evaluate the physics loss. |
| 144 | + :param EquationInterface equation: The governing equation. |
| 145 | + :return: The computed physics loss. |
| 146 | + :rtype: LabelTensor |
| 147 | + """ |
| 148 | + return self._residual_loss(samples, equation) |
| 149 | + |
| 150 | + def _residual_loss(self, samples, equation): |
| 151 | + """ |
| 152 | + Computes the physics loss for the physics-informed solver based on the |
| 153 | + provided samples and equation. This method should never be overridden |
| 154 | + by the user, if not intentionally, |
| 155 | + since it is used internally to compute validation loss. It overrides the |
| 156 | + :obj:`~pina.solver.physics_informed_solver.PINNInterface._residual_loss` |
| 157 | + method. |
| 158 | +
|
| 159 | + :param LabelTensor samples: The samples to evaluate the loss. |
| 160 | + :param EquationInterface equation: The governing equation. |
| 161 | + :return: The residual loss. |
| 162 | + :rtype: torch.Tensor |
| 163 | + """ |
| 164 | + loss = 0 |
| 165 | + predictions = self.forward(samples) |
| 166 | + for idx in range(self.num_ensemble): |
| 167 | + residuals = equation.residual(samples, predictions[idx]) |
| 168 | + target = torch.zeros_like(residuals, requires_grad=True) |
| 169 | + loss = loss + self._loss_fn(residuals, target) |
| 170 | + return loss / self.num_ensemble |
0 commit comments