Skip to content

Commit a84b105

Browse files
authored
Update PINNInterface Inheritance (#542)
1 parent c8ed625 commit a84b105

File tree

2 files changed

+46
-51
lines changed

2 files changed

+46
-51
lines changed

pina/solver/physics_informed_solver/pinn_interface.py

+13-50
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,16 @@
22

33
from abc import ABCMeta, abstractmethod
44
import torch
5-
from torch.nn.modules.loss import _Loss
65

7-
from ..solver import SolverInterface
8-
from ...utils import check_consistency
9-
from ...loss.loss_interface import LossInterface
10-
from ...problem import InverseProblem
6+
from ..supervised_solver import SupervisedSolverInterface
117
from ...condition import (
128
InputTargetCondition,
139
InputEquationCondition,
1410
DomainEquationCondition,
1511
)
1612

1713

18-
class PINNInterface(SolverInterface, metaclass=ABCMeta):
14+
class PINNInterface(SupervisedSolverInterface, metaclass=ABCMeta):
1915
"""
2016
Base class for Physics-Informed Neural Network (PINN) solvers, implementing
2117
the :class:`~pina.solver.solver.SolverInterface` class.
@@ -32,7 +28,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
3228
DomainEquationCondition,
3329
)
3430

35-
def __init__(self, problem, loss=None, **kwargs):
31+
def __init__(self, **kwargs):
3632
"""
3733
Initialization of the :class:`PINNInterface` class.
3834
@@ -41,28 +37,13 @@ def __init__(self, problem, loss=None, **kwargs):
4137
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
4238
Default is `None`.
4339
:param kwargs: Additional keyword arguments to be passed to the
44-
:class:`~pina.solver.solver.SolverInterface` class.
40+
:class:`~pina.solver.supervised_solver.SupervisedSolverInterface`
41+
class.
4542
"""
43+
kwargs["use_lt"] = True
44+
super().__init__(**kwargs)
4645

47-
if loss is None:
48-
loss = torch.nn.MSELoss()
49-
50-
super().__init__(problem=problem, use_lt=True, **kwargs)
51-
52-
# check consistency
53-
check_consistency(loss, (LossInterface, _Loss), subclass=False)
54-
55-
# assign variables
56-
self._loss_fn = loss
57-
58-
# inverse problem handling
59-
if isinstance(self.problem, InverseProblem):
60-
self._params = self.problem.unknown_parameters
61-
self._clamp_params = self._clamp_inverse_problem_params
62-
else:
63-
self._params = None
64-
self._clamp_params = lambda: None
65-
46+
# current condition name
6647
self.__metric = None
6748

6849
def optimization_cycle(self, batch, loss_residuals=None):
@@ -103,8 +84,6 @@ def optimization_cycle(self, batch, loss_residuals=None):
10384
)
10485
# append loss
10586
condition_loss[condition_name] = loss
106-
# clamp unknown parameters in InverseProblem (if needed)
107-
self._clamp_params()
10887
return condition_loss
10988

11089
@torch.set_grad_enabled(True)
@@ -135,7 +114,6 @@ def test_step(self, batch):
135114
"""
136115
return super().test_step(batch, loss_residuals=self._residual_loss)
137116

138-
@abstractmethod
139117
def loss_data(self, input, target):
140118
"""
141119
Compute the data loss for the PINN solver by evaluating the loss
@@ -147,7 +125,12 @@ def loss_data(self, input, target):
147125
network's output.
148126
:return: The supervised loss, averaged over the number of observations.
149127
:rtype: LabelTensor
128+
:raises NotImplementedError: If the method is not implemented.
150129
"""
130+
raise NotImplementedError(
131+
"PINN is being used in a supervised learning context, but the "
132+
"'loss_data' method has not been implemented. "
133+
)
151134

152135
@abstractmethod
153136
def loss_phys(self, samples, equation):
@@ -196,26 +179,6 @@ def _residual_loss(self, samples, equation):
196179
residuals = self.compute_residual(samples, equation)
197180
return self._loss_fn(residuals, torch.zeros_like(residuals))
198181

199-
def _clamp_inverse_problem_params(self):
200-
"""
201-
Clamps the parameters of the inverse problem solver to specified ranges.
202-
"""
203-
for v in self._params:
204-
self._params[v].data.clamp_(
205-
self.problem.unknown_parameter_domain.range_[v][0],
206-
self.problem.unknown_parameter_domain.range_[v][1],
207-
)
208-
209-
@property
210-
def loss(self):
211-
"""
212-
The loss used for training.
213-
214-
:return: The loss function used for training.
215-
:rtype: torch.nn.Module
216-
"""
217-
return self._loss_fn
218-
219182
@property
220183
def current_condition_name(self):
221184
"""

pina/solver/solver.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66

77
from torch._dynamo import OptimizedModule
8-
from ..problem import AbstractProblem
8+
from ..problem import AbstractProblem, InverseProblem
99
from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler
1010
from ..loss import WeightingInterface
1111
from ..loss.scalar_weighting import _NoWeighting
@@ -64,6 +64,14 @@ def __init__(self, problem, weighting, use_lt):
6464
self._pina_optimizers = None
6565
self._pina_schedulers = None
6666

67+
# inverse problem handling
68+
if isinstance(self.problem, InverseProblem):
69+
self._params = self.problem.unknown_parameters
70+
self._clamp_params = self._clamp_inverse_problem_params
71+
else:
72+
self._params = None
73+
self._clamp_params = lambda: None
74+
6775
@abstractmethod
6876
def forward(self, *args, **kwargs):
6977
"""
@@ -231,14 +239,29 @@ def _optimization_cycle(self, batch, **kwargs):
231239
containing the condition name and the associated scalar loss.
232240
:rtype: dict
233241
"""
242+
# compute losses
234243
losses = self.optimization_cycle(batch)
244+
# clamp unknown parameters in InverseProblem (if needed)
245+
self._clamp_params()
246+
# store log
235247
for name, value in losses.items():
236248
self.store_log(
237249
f"{name}_loss", value.item(), self.get_batch_size(batch)
238250
)
251+
# aggregate
239252
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
240253
return loss
241254

255+
def _clamp_inverse_problem_params(self):
256+
"""
257+
Clamps the parameters of the inverse problem solver to specified ranges.
258+
"""
259+
for v in self._params:
260+
self._params[v].data.clamp_(
261+
self.problem.unknown_parameter_domain.range_[v][0],
262+
self.problem.unknown_parameter_domain.range_[v][1],
263+
)
264+
242265
@staticmethod
243266
def _compile_modules(model):
244267
"""
@@ -405,6 +428,15 @@ def configure_optimizers(self):
405428
:rtype: tuple[list[Optimizer], list[Scheduler]]
406429
"""
407430
self.optimizer.hook(self.model.parameters())
431+
if isinstance(self.problem, InverseProblem):
432+
self.optimizer.instance.add_param_group(
433+
{
434+
"params": [
435+
self._params[var]
436+
for var in self.problem.unknown_variables
437+
]
438+
}
439+
)
408440
self.scheduler.hook(self.optimizer)
409441
return ([self.optimizer.instance], [self.scheduler.instance])
410442

0 commit comments

Comments
 (0)