Skip to content

Commit ca57ef0

Browse files
committed
Reimplement refinement
1 parent 90701f6 commit ca57ef0

File tree

7 files changed

+192
-200
lines changed

7 files changed

+192
-200
lines changed

pina/callback/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22

33
__all__ = [
44
"SwitchOptimizer",
5-
"R3Refinement",
65
"MetricTracker",
76
"PINAProgressBar",
87
"LinearWeightUpdate",
98
]
109

1110
from .optimizer_callback import SwitchOptimizer
12-
from .adaptive_refinement_callback import R3Refinement
1311
from .processing_callback import MetricTracker, PINAProgressBar
1412
from .linear_weight_update_callback import LinearWeightUpdate

pina/callback/adaptive_refinement_callback.py

Lines changed: 0 additions & 184 deletions
This file was deleted.

pina/callback/refinement/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
__all__ = ["R3Refinement"]
2+
3+
from .r3_refinement import R3Refinement
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Module for the R3Refinement callback."""
2+
3+
import torch
4+
from .refinement_interface import RefinementInterface
5+
from ...label_tensor import LabelTensor
6+
from ...utils import check_consistency
7+
8+
9+
class R3Refinement(RefinementInterface):
10+
"""
11+
PINA Implementation of an R3 Refinement Callback.
12+
"""
13+
14+
def __init__(self, sample_every):
15+
"""
16+
This callback implements the R3 (Retain-Resample-Release) routine for
17+
sampling new points based on adaptive search.
18+
The algorithm incrementally accumulates collocation points in regions
19+
of high PDE residuals, and releases those with low residuals.
20+
Points are sampled uniformly in all regions where sampling is needed.
21+
22+
.. seealso::
23+
24+
Original Reference: Daw, Arka, et al. *Mitigating Propagation
25+
Failures in Physics-informed Neural Networks
26+
using Retain-Resample-Release (R3) Sampling. (2023)*.
27+
DOI: `10.48550/arXiv.2207.02338
28+
<https://doi.org/10.48550/arXiv.2207.02338>`_
29+
30+
:param int sample_every: Frequency for sampling.
31+
:raises ValueError: If `sample_every` is not an integer.
32+
33+
Example:
34+
>>> r3_callback = R3Refinement(sample_every=5)
35+
"""
36+
37+
super().__init__(sample_every=sample_every)
38+
self.const_pts = None
39+
40+
def sample(self, condition_name, condition):
41+
avg_res, res = self.per_point_residual([condition_name])
42+
pts = self.dataset.conditions_dict[condition_name]["input"]
43+
domain = condition.domain
44+
labels = pts.labels
45+
pts = pts.cpu().detach().as_subclass(torch.Tensor)
46+
residuals = res[condition_name]
47+
mask = (residuals > avg_res).flatten()
48+
if any(mask): # append residuals greater than average
49+
pts = (pts[mask]).as_subclass(LabelTensor)
50+
pts.labels = labels
51+
numb_pts = self.const_pts[condition_name] - len(pts)
52+
else:
53+
numb_pts = self.const_pts[condition_name]
54+
pts = None
55+
self.problem.discretise_domain(numb_pts, "random", domains=[domain])
56+
sampled_points = self.problem.discretised_domains[domain]
57+
tmp = (
58+
sampled_points
59+
if pts is None
60+
else LabelTensor.cat([pts, sampled_points])
61+
)
62+
return tmp
63+
64+
def on_train_start(self, trainer, _):
65+
"""
66+
Callback function called at the start of training.
67+
68+
This method extracts the locations for sampling from the problem
69+
conditions and calculates the total population.
70+
71+
:param trainer: The trainer object managing the training process.
72+
:type trainer: pytorch_lightning.Trainer
73+
:param _: Placeholder argument (not used).
74+
"""
75+
super().on_train_start(trainer, _)
76+
self.const_pts = {}
77+
for condition in self.conditions:
78+
pts = self.dataset.conditions_dict[condition]["input"]
79+
self.const_pts[condition] = len(pts)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
RefinementInterface class for handling the refinement of points in a neural
3+
network training process.
4+
"""
5+
6+
import torch
7+
from abc import ABCMeta
8+
from lightning.pytorch import Callback
9+
from torch_geometric.data.feature_store import abstractmethod
10+
from torch_geometric.nn.conv import point_transformer_conv
11+
from ...condition.domain_equation_condition import DomainEquationCondition
12+
13+
14+
class RefinementInterface(Callback):
15+
"""
16+
Interface class of Refinement
17+
"""
18+
19+
def __init__(self, sample_every):
20+
"""
21+
Initializes the RefinementInterface.
22+
23+
:param int sample_every: The number of epochs between each refinement.
24+
"""
25+
self.sample_every = sample_every
26+
self.conditions = None
27+
self.dataset = None
28+
self.solver = None
29+
30+
def on_train_start(self, trainer, _):
31+
"""
32+
Called when the training begins. It initializes the conditions and
33+
dataset.
34+
35+
:param lightning.pytorch.Trainer trainer: The trainer object.
36+
:param _: Unused argument.
37+
"""
38+
self.problem = trainer.solver.problem
39+
self.solver = trainer.solver
40+
self.conditions = {}
41+
for name, cond in self.problem.conditions.items():
42+
if isinstance(cond, DomainEquationCondition):
43+
self.conditions[name] = cond
44+
self.dataset = trainer.datamodule.train_dataset
45+
46+
@property
47+
def points(self):
48+
"""
49+
Returns the points of the dataset.
50+
"""
51+
return self.dataset.conditions_dict
52+
53+
def on_train_epoch_end(self, trainer, _):
54+
"""
55+
Performs the refinement at the end of each training epoch (if needed).
56+
57+
:param lightning.pytorch.Trainer trainer: The trainer object.
58+
:param _: Unused argument.
59+
"""
60+
if trainer.current_epoch % self.sample_every == 0:
61+
self.update()
62+
63+
def update(self):
64+
"""
65+
Performs the refinement of the points.
66+
"""
67+
new_points = {}
68+
for name, condition in self.conditions.items():
69+
new_points[name] = {"input": self.sample(name, condition)}
70+
self.dataset.update_data(new_points)
71+
72+
def per_point_residual(self, conditions_name=None):
73+
"""
74+
Computes the residuals for a PINN object.
75+
76+
:return: the total loss, and pointwise loss.
77+
:rtype: tuple
78+
"""
79+
# compute residual
80+
res_loss = {}
81+
tot_loss = []
82+
points = self.points
83+
if conditions_name is None:
84+
conditions_name = list(self.conditions.keys())
85+
for name in conditions_name:
86+
cond = self.conditions[name]
87+
cond_points = points[name]["input"]
88+
target = self._compute_residual(cond_points, cond.equation)
89+
res_loss[name] = torch.abs(target).as_subclass(torch.Tensor)
90+
tot_loss.append(torch.abs(target))
91+
return torch.vstack(tot_loss).tensor.mean(), res_loss
92+
93+
def _compute_residual(self, pts, equation):
94+
pts.requires_grad_(True)
95+
pts.retain_grad()
96+
return equation.residual(pts, self.solver.forward(pts))
97+
98+
@abstractmethod
99+
def sample(self, condition):
100+
"""
101+
Samples new points based on the condition.
102+
"""
103+
pass

0 commit comments

Comments
 (0)