2
2
3
3
from abc import ABCMeta , abstractmethod
4
4
import torch
5
- from torch .nn .modules .loss import _Loss
6
5
7
- from ..solver import SolverInterface
8
- from ...utils import check_consistency
9
- from ...loss .loss_interface import LossInterface
10
- from ...problem import InverseProblem
6
+ from ..supervised_solver import SupervisedSolverInterface
11
7
from ...condition import (
12
8
InputTargetCondition ,
13
9
InputEquationCondition ,
14
10
DomainEquationCondition ,
15
11
)
16
12
17
13
18
- class PINNInterface (SolverInterface , metaclass = ABCMeta ):
14
+ class PINNInterface (SupervisedSolverInterface , metaclass = ABCMeta ):
19
15
"""
20
16
Base class for Physics-Informed Neural Network (PINN) solvers, implementing
21
17
the :class:`~pina.solver.solver.SolverInterface` class.
@@ -32,7 +28,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
32
28
DomainEquationCondition ,
33
29
)
34
30
35
- def __init__ (self , problem , loss = None , ** kwargs ):
31
+ def __init__ (self , ** kwargs ):
36
32
"""
37
33
Initialization of the :class:`PINNInterface` class.
38
34
@@ -41,28 +37,13 @@ def __init__(self, problem, loss=None, **kwargs):
41
37
If ``None``, the :class:`torch.nn.MSELoss` loss is used.
42
38
Default is `None`.
43
39
:param kwargs: Additional keyword arguments to be passed to the
44
- :class:`~pina.solver.solver.SolverInterface` class.
40
+ :class:`~pina.solver.supervised_solver.SupervisedSolverInterface`
41
+ class.
45
42
"""
43
+ kwargs ["use_lt" ] = True
44
+ super ().__init__ (** kwargs )
46
45
47
- if loss is None :
48
- loss = torch .nn .MSELoss ()
49
-
50
- super ().__init__ (problem = problem , use_lt = True , ** kwargs )
51
-
52
- # check consistency
53
- check_consistency (loss , (LossInterface , _Loss ), subclass = False )
54
-
55
- # assign variables
56
- self ._loss_fn = loss
57
-
58
- # inverse problem handling
59
- if isinstance (self .problem , InverseProblem ):
60
- self ._params = self .problem .unknown_parameters
61
- self ._clamp_params = self ._clamp_inverse_problem_params
62
- else :
63
- self ._params = None
64
- self ._clamp_params = lambda : None
65
-
46
+ # current condition name
66
47
self .__metric = None
67
48
68
49
def optimization_cycle (self , batch , loss_residuals = None ):
@@ -103,8 +84,6 @@ def optimization_cycle(self, batch, loss_residuals=None):
103
84
)
104
85
# append loss
105
86
condition_loss [condition_name ] = loss
106
- # clamp unknown parameters in InverseProblem (if needed)
107
- self ._clamp_params ()
108
87
return condition_loss
109
88
110
89
@torch .set_grad_enabled (True )
@@ -135,7 +114,6 @@ def test_step(self, batch):
135
114
"""
136
115
return super ().test_step (batch , loss_residuals = self ._residual_loss )
137
116
138
- @abstractmethod
139
117
def loss_data (self , input , target ):
140
118
"""
141
119
Compute the data loss for the PINN solver by evaluating the loss
@@ -147,7 +125,12 @@ def loss_data(self, input, target):
147
125
network's output.
148
126
:return: The supervised loss, averaged over the number of observations.
149
127
:rtype: LabelTensor
128
+ :raises NotImplementedError: If the method is not implemented.
150
129
"""
130
+ raise NotImplementedError (
131
+ "PINN is being used in a supervised learning context, but the "
132
+ "'loss_data' method has not been implemented. "
133
+ )
151
134
152
135
@abstractmethod
153
136
def loss_phys (self , samples , equation ):
@@ -196,26 +179,6 @@ def _residual_loss(self, samples, equation):
196
179
residuals = self .compute_residual (samples , equation )
197
180
return self ._loss_fn (residuals , torch .zeros_like (residuals ))
198
181
199
- def _clamp_inverse_problem_params (self ):
200
- """
201
- Clamps the parameters of the inverse problem solver to specified ranges.
202
- """
203
- for v in self ._params :
204
- self ._params [v ].data .clamp_ (
205
- self .problem .unknown_parameter_domain .range_ [v ][0 ],
206
- self .problem .unknown_parameter_domain .range_ [v ][1 ],
207
- )
208
-
209
- @property
210
- def loss (self ):
211
- """
212
- The loss used for training.
213
-
214
- :return: The loss function used for training.
215
- :rtype: torch.nn.Module
216
- """
217
- return self ._loss_fn
218
-
219
182
@property
220
183
def current_condition_name (self ):
221
184
"""
0 commit comments