22
33from abc import ABCMeta , abstractmethod
44import torch
5- from torch .nn .modules .loss import _Loss
65
7- from ..solver import SolverInterface
8- from ...utils import check_consistency
9- from ...loss .loss_interface import LossInterface
10- from ...problem import InverseProblem
6+ from ..supervised_solver import SupervisedSolverInterface
117from ...condition import (
128 InputTargetCondition ,
139 InputEquationCondition ,
1410 DomainEquationCondition ,
1511)
1612
1713
18- class PINNInterface (SolverInterface , metaclass = ABCMeta ):
14+ class PINNInterface (SupervisedSolverInterface , metaclass = ABCMeta ):
1915 """
2016 Base class for Physics-Informed Neural Network (PINN) solvers, implementing
2117 the :class:`~pina.solver.solver.SolverInterface` class.
@@ -32,7 +28,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
3228 DomainEquationCondition ,
3329 )
3430
35- def __init__ (self , problem , loss = None , ** kwargs ):
31+ def __init__ (self , ** kwargs ):
3632 """
3733 Initialization of the :class:`PINNInterface` class.
3834
@@ -41,28 +37,13 @@ def __init__(self, problem, loss=None, **kwargs):
4137 If ``None``, the :class:`torch.nn.MSELoss` loss is used.
4238 Default is `None`.
4339 :param kwargs: Additional keyword arguments to be passed to the
44- :class:`~pina.solver.solver.SolverInterface` class.
40+ :class:`~pina.solver.supervised_solver.SupervisedSolverInterface`
41+ class.
4542 """
43+ kwargs ["use_lt" ] = True
44+ super ().__init__ (** kwargs )
4645
47- if loss is None :
48- loss = torch .nn .MSELoss ()
49-
50- super ().__init__ (problem = problem , use_lt = True , ** kwargs )
51-
52- # check consistency
53- check_consistency (loss , (LossInterface , _Loss ), subclass = False )
54-
55- # assign variables
56- self ._loss_fn = loss
57-
58- # inverse problem handling
59- if isinstance (self .problem , InverseProblem ):
60- self ._params = self .problem .unknown_parameters
61- self ._clamp_params = self ._clamp_inverse_problem_params
62- else :
63- self ._params = None
64- self ._clamp_params = lambda : None
65-
46+ # current condition name
6647 self .__metric = None
6748
6849 def optimization_cycle (self , batch , loss_residuals = None ):
@@ -103,8 +84,6 @@ def optimization_cycle(self, batch, loss_residuals=None):
10384 )
10485 # append loss
10586 condition_loss [condition_name ] = loss
106- # clamp unknown parameters in InverseProblem (if needed)
107- self ._clamp_params ()
10887 return condition_loss
10988
11089 @torch .set_grad_enabled (True )
@@ -135,7 +114,6 @@ def test_step(self, batch):
135114 """
136115 return super ().test_step (batch , loss_residuals = self ._residual_loss )
137116
138- @abstractmethod
139117 def loss_data (self , input , target ):
140118 """
141119 Compute the data loss for the PINN solver by evaluating the loss
@@ -147,7 +125,12 @@ def loss_data(self, input, target):
147125 network's output.
148126 :return: The supervised loss, averaged over the number of observations.
149127 :rtype: LabelTensor
128+ :raises NotImplementedError: If the method is not implemented.
150129 """
130+ raise NotImplementedError (
131+ "PINN is being used in a supervised learning context, but the "
132+ "'loss_data' method has not been implemented. "
133+ )
151134
152135 @abstractmethod
153136 def loss_phys (self , samples , equation ):
@@ -196,26 +179,6 @@ def _residual_loss(self, samples, equation):
196179 residuals = self .compute_residual (samples , equation )
197180 return self ._loss_fn (residuals , torch .zeros_like (residuals ))
198181
199- def _clamp_inverse_problem_params (self ):
200- """
201- Clamps the parameters of the inverse problem solver to specified ranges.
202- """
203- for v in self ._params :
204- self ._params [v ].data .clamp_ (
205- self .problem .unknown_parameter_domain .range_ [v ][0 ],
206- self .problem .unknown_parameter_domain .range_ [v ][1 ],
207- )
208-
209- @property
210- def loss (self ):
211- """
212- The loss used for training.
213-
214- :return: The loss function used for training.
215- :rtype: torch.nn.Module
216- """
217- return self ._loss_fn
218-
219182 @property
220183 def current_condition_name (self ):
221184 """
0 commit comments