Skip to content

Commit 1e9e920

Browse files
author
Monthly Tag bot
committed
clean code
1 parent abb8718 commit 1e9e920

File tree

3 files changed

+127
-112
lines changed

3 files changed

+127
-112
lines changed
Lines changed: 32 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
"""Module for the R3Refinement callback."""
22

33
import torch
4+
import torch.nn as nn
5+
from torch.nn.modules.loss import _Loss
46
from .refinement_interface import RefinementInterface
57
from ...label_tensor import LabelTensor
68
from ...utils import check_consistency
9+
from ...loss import LossInterface
710

811

912
class R3Refinement(RefinementInterface):
1013
"""
1114
PINA Implementation of an R3 Refinement Callback.
1215
"""
1316

14-
def __init__(self, sample_every):
17+
def __init__(
18+
self, sample_every, residual_loss=nn.L1Loss, condition_to_update=None
19+
):
1520
"""
1621
This callback implements the R3 (Retain-Resample-Release) routine for
1722
sampling new points based on adaptive search.
@@ -33,47 +38,33 @@ def __init__(self, sample_every):
3338
Example:
3439
>>> r3_callback = R3Refinement(sample_every=5)
3540
"""
41+
super().__init__(sample_every, condition_to_update)
42+
# check consistency loss
43+
check_consistency(residual_loss, (LossInterface, _Loss), subclass=True)
44+
self.loss_fn = residual_loss(reduction="none")
3645

37-
super().__init__(sample_every=sample_every)
38-
self.const_pts = None
39-
40-
def sample(self, condition_name, condition):
41-
avg_res, res = self.per_point_residual([condition_name])
42-
pts = self.dataset.conditions_dict[condition_name]["input"]
43-
domain = condition.domain
44-
labels = pts.labels
45-
pts = pts.cpu().detach().as_subclass(torch.Tensor)
46-
residuals = res[condition_name]
47-
mask = (residuals > avg_res).flatten()
48-
if any(mask): # append residuals greater than average
49-
pts = (pts[mask]).as_subclass(LabelTensor)
50-
pts.labels = labels
51-
numb_pts = self.const_pts[condition_name] - len(pts)
52-
else:
53-
numb_pts = self.const_pts[condition_name]
54-
pts = None
55-
self.problem.discretise_domain(numb_pts, "random", domains=[domain])
56-
sampled_points = self.problem.discretised_domains[domain]
57-
tmp = (
58-
sampled_points
59-
if pts is None
60-
else LabelTensor.cat([pts, sampled_points])
46+
def sample(self, current_points, condition_name, solver):
47+
# Compute residuals for the given condition (average over fields)
48+
condition = solver.problem.conditions[condition_name]
49+
target = solver.compute_residual(
50+
current_points.requires_grad_(True), condition.equation
51+
)
52+
residuals = self.loss_fn(target, torch.zeros_like(target)).mean(
53+
dim=tuple(range(1, target.ndim))
6154
)
62-
return tmp
63-
64-
def on_train_start(self, trainer, _):
65-
"""
66-
Callback function called at the start of training.
6755

68-
This method extracts the locations for sampling from the problem
69-
conditions and calculates the total population.
56+
# Prepare new points
57+
labels = current_points.labels
58+
domain_name = solver.problem.conditions[condition_name].domain
59+
domain = solver.problem.domains[domain_name]
60+
num_old_points = self.initial_population_size[condition_name]
61+
mask = (residuals > residuals.mean()).flatten()
7062

71-
:param trainer: The trainer object managing the training process.
72-
:type trainer: pytorch_lightning.Trainer
73-
:param _: Placeholder argument (not used).
74-
"""
75-
super().on_train_start(trainer, _)
76-
self.const_pts = {}
77-
for condition in self.conditions:
78-
pts = self.dataset.conditions_dict[condition]["input"]
79-
self.const_pts[condition] = len(pts)
63+
if mask.any(): # Use high-residual points
64+
pts = current_points[mask]
65+
pts.labels = labels
66+
retain_pts = len(pts)
67+
samples = domain.sample(num_old_points - retain_pts, "random")
68+
return LabelTensor.cat([pts, samples])
69+
else:
70+
return domain.sample(num_old_points, "random")

pina/callback/refinement/refinement_interface.py

Lines changed: 78 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,101 +3,121 @@
33
network training process.
44
"""
55

6-
import torch
7-
from abc import ABCMeta
6+
from abc import ABCMeta, abstractmethod
87
from lightning.pytorch import Callback
9-
from torch_geometric.data.feature_store import abstractmethod
10-
from torch_geometric.nn.conv import point_transformer_conv
11-
from ...condition.domain_equation_condition import DomainEquationCondition
8+
from ...utils import check_consistency
9+
from ...solver.physics_informed_solver import PINNInterface
1210

1311

1412
class RefinementInterface(Callback, metaclass=ABCMeta):
1513
"""
16-
Interface class of Refinement
14+
Interface class of Refinement approaches.
1715
"""
1816

19-
def __init__(self, sample_every):
17+
def __init__(self, sample_every, condition_to_update=None):
2018
"""
2119
Initializes the RefinementInterface.
2220
2321
:param int sample_every: The number of epochs between each refinement.
2422
"""
23+
# check consistency of the input
24+
check_consistency(sample_every, int)
25+
if condition_to_update is not None:
26+
if not isinstance(condition_to_update, (list, tuple)):
27+
raise ValueError(
28+
"'condition_to_update' must be iter of strings."
29+
)
30+
check_consistency(condition_to_update, str)
31+
# store
2532
self.sample_every = sample_every
26-
self.conditions = None
27-
self.dataset = None
28-
self.solver = None
33+
self._condition_to_update = condition_to_update
34+
self._dataset = None
35+
self._initial_population_size = None
2936

30-
def on_train_start(self, trainer, _):
37+
def on_train_start(self, trainer, solver):
3138
"""
3239
Called when the training begins. It initializes the conditions and
3340
dataset.
3441
3542
:param lightning.pytorch.Trainer trainer: The trainer object.
3643
:param _: Unused argument.
3744
"""
38-
self.problem = trainer.solver.problem
39-
self.solver = trainer.solver
40-
self.conditions = {}
41-
for name, cond in self.problem.conditions.items():
42-
if isinstance(cond, DomainEquationCondition):
43-
self.conditions[name] = cond
44-
self.dataset = trainer.datamodule.train_dataset
45+
# check we have valid conditions names
46+
if self._condition_to_update is None:
47+
self._condition_to_update = list(solver.problem.conditions.keys())
4548

46-
@property
47-
def points(self):
48-
"""
49-
Returns the points of the dataset.
50-
"""
51-
return self.dataset.conditions_dict
49+
for cond in self._condition_to_update:
50+
if cond not in solver.problem.conditions:
51+
raise RuntimeError(
52+
f"Condition '{cond}' not found in "
53+
f"{list(solver.problem.conditions.keys())}."
54+
)
55+
if not hasattr(solver.problem.conditions[cond], "domain"):
56+
raise RuntimeError(
57+
f"Condition '{cond}' does not contain a domain to "
58+
"sample from."
59+
)
60+
# check solver
61+
if not isinstance(solver, PINNInterface):
62+
raise RuntimeError(
63+
f"Refinment strategies are currently implemented only "
64+
"for physics informed based solvers. Please use a Solver "
65+
"inheriting from 'PINNInterface'."
66+
)
67+
# store dataset
68+
self._dataset = trainer.datamodule.train_dataset
69+
# compute initial population size
70+
self._initial_population_size = self._compute_population_size(
71+
self._condition_to_update
72+
)
73+
return super().on_train_epoch_start(trainer, solver)
5274

53-
def on_train_epoch_end(self, trainer, _):
75+
def on_train_epoch_end(self, trainer, solver):
5476
"""
5577
Performs the refinement at the end of each training epoch (if needed).
5678
5779
:param lightning.pytorch.Trainer trainer: The trainer object.
5880
:param _: Unused argument.
5981
"""
6082
if trainer.current_epoch % self.sample_every == 0:
61-
self.update()
83+
self._update_points(solver)
84+
return super().on_train_epoch_end(trainer, solver)
6285

63-
def update(self):
86+
@abstractmethod
87+
def sample(self, current_points, condition_name, solver):
6488
"""
65-
Performs the refinement of the points.
89+
Samples new points based on the condition.
6690
"""
67-
new_points = {}
68-
for name, condition in self.conditions.items():
69-
new_points[name] = {"input": self.sample(name, condition)}
70-
self.dataset.update_data(new_points)
91+
pass
7192

72-
def per_point_residual(self, conditions_name=None):
93+
@property
94+
def dataset(self):
7395
"""
74-
Computes the residuals for a PINN object.
75-
76-
:return: the total loss, and pointwise loss.
77-
:rtype: tuple
96+
Returns the dataset for training.
7897
"""
79-
# compute residual
80-
res_loss = {}
81-
tot_loss = []
82-
points = self.points
83-
if conditions_name is None:
84-
conditions_name = list(self.conditions.keys())
85-
for name in conditions_name:
86-
cond = self.conditions[name]
87-
cond_points = points[name]["input"]
88-
target = self._compute_residual(cond_points, cond.equation)
89-
res_loss[name] = torch.abs(target).as_subclass(torch.Tensor)
90-
tot_loss.append(torch.abs(target))
91-
return torch.vstack(tot_loss).tensor.mean(), res_loss
98+
return self._dataset
9299

93-
def _compute_residual(self, pts, equation):
94-
pts.requires_grad_(True)
95-
pts.retain_grad()
96-
return equation.residual(pts, self.solver.forward(pts))
100+
@property
101+
def initial_population_size(self):
102+
"""
103+
Returns the dataset for training.
104+
"""
105+
return self._initial_population_size
97106

98-
@abstractmethod
99-
def sample(self, condition):
107+
def _update_points(self, solver):
100108
"""
101-
Samples new points based on the condition.
109+
Performs the refinement of the points.
102110
"""
103-
pass
111+
new_points = {}
112+
for name in self._condition_to_update:
113+
current_points = self.dataset.conditions_dict[name]["input"]
114+
new_points[name] = {
115+
"input": self.sample(current_points, name, solver)
116+
}
117+
self.dataset.update_data(new_points)
118+
119+
def _compute_population_size(self, conditions):
120+
return {
121+
cond: len(self.dataset.conditions_dict[cond]["input"])
122+
for cond in conditions
123+
}
Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import pytest
2+
3+
from torch.nn import MSELoss
4+
15
from pina.solver import PINN
26
from pina.trainer import Trainer
37
from pina.model import FeedForward
@@ -7,28 +11,27 @@
711

812
# make the problem
913
poisson_problem = Poisson()
10-
boundaries = ["g1", "g2", "g3", "g4"]
11-
n = 10
12-
poisson_problem.discretise_domain(n, "grid", domains=boundaries)
13-
poisson_problem.discretise_domain(n, "grid", domains="D")
14+
poisson_problem.discretise_domain(10, "grid", domains=["g1", "g2", "g3", "g4"])
15+
poisson_problem.discretise_domain(10, "grid", domains="D")
1416
model = FeedForward(
1517
len(poisson_problem.input_variables), len(poisson_problem.output_variables)
1618
)
17-
18-
# make the solver
1919
solver = PINN(problem=poisson_problem, model=model)
2020

2121

22-
def test_r3constructor():
22+
def test_constructor():
23+
# good constructor
2324
R3Refinement(sample_every=10)
25+
R3Refinement(sample_every=10, residual_loss=MSELoss)
26+
R3Refinement(sample_every=10, condition_to_update=["D"])
27+
# wrong constructor
28+
with pytest.raises(ValueError):
29+
R3Refinement(sample_every="str")
30+
with pytest.raises(ValueError):
31+
R3Refinement(sample_every=10, condition_to_update=3)
2432

2533

26-
def test_r3refinement_routine():
27-
model = FeedForward(
28-
len(poisson_problem.input_variables),
29-
len(poisson_problem.output_variables),
30-
)
31-
solver = PINN(problem=poisson_problem, model=model)
34+
def test_sample():
3235
trainer = Trainer(
3336
solver=solver,
3437
callbacks=[R3Refinement(sample_every=1)],
@@ -43,4 +46,5 @@ def test_r3refinement_routine():
4346
loc: len(pts)
4447
for loc, pts in trainer.data_module.train_dataset.input.items()
4548
}
49+
assert before_n_points == trainer.callbacks[0].initial_population_size
4650
assert before_n_points == after_n_points

0 commit comments

Comments
 (0)