Skip to content

Commit 03a8f31

Browse files
committed
Fix adaptive refinement
1 parent 6b355b4 commit 03a8f31

File tree

3 files changed

+190
-163
lines changed

3 files changed

+190
-163
lines changed

pina/callback/adaptive_refinement_callback.py

Lines changed: 148 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -34,148 +34,151 @@ def __init__(self, sample_every):
3434
Example:
3535
>>> r3_callback = R3Refinement(sample_every=5)
3636
"""
37-
raise NotImplementedError(
38-
"R3Refinement callback is being refactored in the pina "
39-
f"{importlib.metadata.metadata('pina-mathlab')['Version']} "
40-
"version. Please use version 0.1 if R3Refinement is required."
41-
)
42-
43-
# super().__init__()
44-
45-
# # sample every
46-
# check_consistency(sample_every, int)
47-
# self._sample_every = sample_every
48-
# self._const_pts = None
49-
50-
# def _compute_residual(self, trainer):
51-
# """
52-
# Computes the residuals for a PINN object.
53-
54-
# :return: the total loss, and pointwise loss.
55-
# :rtype: tuple
56-
# """
57-
58-
# # extract the solver and device from trainer
59-
# solver = trainer.solver
60-
# device = trainer._accelerator_connector._accelerator_flag
61-
# precision = trainer.precision
62-
# if precision == "64-true":
63-
# precision = torch.float64
64-
# elif precision == "32-true":
65-
# precision = torch.float32
66-
# else:
67-
# raise RuntimeError(
68-
# "Currently R3Refinement is only implemented "
69-
# "for precision '32-true' and '64-true', set "
70-
# "Trainer precision to match one of the "
71-
# "available precisions."
72-
# )
73-
74-
# # compute residual
75-
# res_loss = {}
76-
# tot_loss = []
77-
# for location in self._sampling_locations:
78-
# condition = solver.problem.conditions[location]
79-
# pts = solver.problem.input_pts[location]
80-
# # send points to correct device
81-
# pts = pts.to(device=device, dtype=precision)
82-
# pts = pts.requires_grad_(True)
83-
# pts.retain_grad()
84-
# # PINN loss: equation evaluated only for sampling locations
85-
# target = condition.equation.residual(pts, solver.forward(pts))
86-
# res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
87-
# tot_loss.append(torch.abs(target))
88-
89-
# print(tot_loss)
90-
91-
# return torch.vstack(tot_loss), res_loss
92-
93-
# def _r3_routine(self, trainer):
94-
# """
95-
# R3 refinement main routine.
96-
97-
# :param Trainer trainer: PINA Trainer.
98-
# """
99-
# # compute residual (all device possible)
100-
# tot_loss, res_loss = self._compute_residual(trainer)
101-
# tot_loss = tot_loss.as_subclass(torch.Tensor)
102-
103-
# # !!!!!! From now everything is performed on CPU !!!!!!
104-
105-
# # average loss
106-
# avg = (tot_loss.mean()).to("cpu")
107-
# old_pts = {} # points to be retained
108-
# for location in self._sampling_locations:
109-
# pts = trainer._model.problem.input_pts[location]
110-
# labels = pts.labels
111-
# pts = pts.cpu().detach().as_subclass(torch.Tensor)
112-
# residuals = res_loss[location].cpu()
113-
# mask = (residuals > avg).flatten()
114-
# if any(mask): # append residuals greater than average
115-
# pts = (pts[mask]).as_subclass(LabelTensor)
116-
# pts.labels = labels
117-
# old_pts[location] = pts
118-
# numb_pts = self._const_pts[location] - len(old_pts[location])
119-
# # sample new points
120-
# trainer._model.problem.discretise_domain(
121-
# numb_pts, "random", locations=[location]
122-
# )
123-
124-
# else: # if no res greater than average, samples all uniformly
125-
# numb_pts = self._const_pts[location]
126-
# # sample new points
127-
# trainer._model.problem.discretise_domain(
128-
# numb_pts, "random", locations=[location]
129-
# )
130-
# # adding previous population points
131-
# trainer._model.problem.add_points(old_pts)
132-
133-
# # update dataloader
134-
# trainer._create_or_update_loader()
135-
136-
# def on_train_start(self, trainer, _):
137-
# """
138-
# Callback function called at the start of training.
139-
140-
# This method extracts the locations for sampling from the problem
141-
# conditions and calculates the total population.
142-
143-
# :param trainer: The trainer object managing the training process.
144-
# :type trainer: pytorch_lightning.Trainer
145-
# :param _: Placeholder argument (not used).
146-
147-
# :return: None
148-
# :rtype: None
149-
# """
150-
# # extract locations for sampling
151-
# problem = trainer.solver.problem
152-
# locations = []
153-
# for condition_name in problem.conditions:
154-
# condition = problem.conditions[condition_name]
155-
# if hasattr(condition, "location"):
156-
# locations.append(condition_name)
157-
# self._sampling_locations = locations
158-
159-
# # extract total population
160-
# const_pts = {} # for each location, store the pts to keep constant
161-
# for location in self._sampling_locations:
162-
# pts = trainer._model.problem.input_pts[location]
163-
# const_pts[location] = len(pts)
164-
# self._const_pts = const_pts
165-
166-
# def on_train_epoch_end(self, trainer, __):
167-
# """
168-
# Callback function called at the end of each training epoch.
169-
170-
# This method triggers the R3 routine for refinement if the current
171-
# epoch is a multiple of `_sample_every`.
172-
173-
# :param trainer: The trainer object managing the training process.
174-
# :type trainer: pytorch_lightning.Trainer
175-
# :param __: Placeholder argument (not used).
176-
177-
# :return: None
178-
# :rtype: None
179-
# """
180-
# if trainer.current_epoch % self._sample_every == 0:
181-
# self._r3_routine(trainer)
37+
38+
super().__init__()
39+
40+
# sample every
41+
check_consistency(sample_every, int)
42+
self._sample_every = sample_every
43+
self._const_pts = None
44+
self._domains = None
45+
46+
def _compute_residual(self, trainer):
47+
"""
48+
Computes the residuals for a PINN object.
49+
50+
:return: the total loss, and pointwise loss.
51+
:rtype: tuple
52+
"""
53+
54+
# extract the solver and device from trainer
55+
solver = trainer.solver
56+
device = trainer._accelerator_connector._accelerator_flag
57+
precision = trainer.precision
58+
if precision == "64-true":
59+
precision = torch.float64
60+
elif precision == "32-true":
61+
precision = torch.float32
62+
else:
63+
raise RuntimeError(
64+
"Currently R3Refinement is only implemented "
65+
"for precision '32-true' and '64-true', set "
66+
"Trainer precision to match one of the "
67+
"available precisions."
68+
)
69+
70+
# compute residual
71+
res_loss = {}
72+
tot_loss = []
73+
for condition in self._conditions:
74+
pts = trainer.datamodule.train_dataset.conditions_dict[condition][
75+
"input"
76+
]
77+
equation = solver.problem.conditions[condition].equation
78+
# send points to correct device
79+
pts = pts.to(device=device, dtype=precision)
80+
pts = pts.requires_grad_(True)
81+
pts.retain_grad()
82+
# PINN loss: equation evaluated only for sampling locations
83+
target = equation.residual(pts, solver.forward(pts))
84+
res_loss[condition] = torch.abs(target).as_subclass(torch.Tensor)
85+
tot_loss.append(torch.abs(target))
86+
return torch.vstack(tot_loss), res_loss
87+
88+
def _r3_routine(self, trainer):
89+
"""
90+
R3 refinement main routine.
91+
92+
:param Trainer trainer: PINA Trainer.
93+
"""
94+
# compute residual (all device possible)
95+
tot_loss, res_loss = self._compute_residual(trainer)
96+
tot_loss = tot_loss.as_subclass(torch.Tensor)
97+
98+
# !!!!!! From now everything is performed on CPU !!!!!!
99+
100+
# average loss
101+
avg = (tot_loss.mean()).to("cpu")
102+
new_pts = {}
103+
104+
dataset = trainer.datamodule.train_dataset
105+
problem = trainer.solver.problem
106+
for condition in self._conditions:
107+
pts = dataset.conditions_dict[condition]["input"]
108+
domain = problem.conditions[condition].domain
109+
if not isinstance(domain, str):
110+
domain = condition
111+
labels = pts.labels
112+
pts = pts.cpu().detach().as_subclass(torch.Tensor)
113+
residuals = res_loss[condition].cpu()
114+
mask = (residuals > avg).flatten()
115+
if any(mask): # append residuals greater than average
116+
pts = (pts[mask]).as_subclass(LabelTensor)
117+
pts.labels = labels
118+
numb_pts = self._const_pts[condition] - len(pts)
119+
else: # if no res greater than average, samples all uniformly
120+
numb_pts = self._const_pts[condition]
121+
pts = None
122+
problem.discretise_domain(numb_pts, "random", domains=[domain])
123+
sampled_points = problem.discretised_domains[domain]
124+
tmp = (
125+
sampled_points
126+
if pts is None
127+
else LabelTensor.cat([pts, sampled_points])
128+
)
129+
new_pts[condition] = {"input": tmp}
130+
dataset.update_data(new_pts)
131+
132+
def on_train_start(self, trainer, _):
133+
"""
134+
Callback function called at the start of training.
135+
136+
This method extracts the locations for sampling from the problem
137+
conditions and calculates the total population.
138+
139+
:param trainer: The trainer object managing the training process.
140+
:type trainer: pytorch_lightning.Trainer
141+
:param _: Placeholder argument (not used).
142+
143+
:return: None
144+
:rtype: None
145+
"""
146+
problem = trainer.solver.problem
147+
if hasattr(problem, "domains"):
148+
domains = problem.domains
149+
self._domains = domains
150+
else:
151+
self._domains = {}
152+
for name, data in problem.conditions.items():
153+
if hasattr(data, "domain"):
154+
self._domains[name] = data.domain
155+
self._conditions = []
156+
for name, data in problem.conditions.items():
157+
if hasattr(data, "domain"):
158+
self._conditions.append(name)
159+
160+
# extract total population
161+
const_pts = {} # for each location, store the pts to keep constant
162+
for condition in self._conditions:
163+
pts = trainer.datamodule.train_dataset.conditions_dict[condition][
164+
"input"
165+
]
166+
const_pts[condition] = len(pts)
167+
self._const_pts = const_pts
168+
169+
def on_train_epoch_end(self, trainer, __):
170+
"""
171+
Callback function called at the end of each training epoch.
172+
173+
This method triggers the R3 routine for refinement if the current
174+
epoch is a multiple of `_sample_every`.
175+
176+
:param trainer: The trainer object managing the training process.
177+
:type trainer: pytorch_lightning.Trainer
178+
:param __: Placeholder argument (not used).
179+
180+
:return: None
181+
:rtype: None
182+
"""
183+
if trainer.current_epoch % self._sample_every == 0:
184+
self._r3_routine(trainer)

pina/data/dataset.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,18 @@ def input(self):
239239
"""
240240
return {k: v["input"] for k, v in self.conditions_dict.items()}
241241

242+
def update_data(self, conditions_dict):
243+
"""
244+
Update the dataset with new data.
245+
This method is used to update the dataset with new data. It replaces
246+
the current data with the new data provided in the conditions_dict
247+
parameter.
248+
:param dict conditions_dict: Dictionary containing the new data.
249+
:type conditions_dict: dict
250+
:return: None
251+
"""
252+
self.conditions_dict = conditions_dict
253+
242254

243255
class PinaGraphDataset(PinaDataset):
244256
"""

tests/test_callback/test_adaptive_refinement_callback.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,39 @@
1919
solver = PINN(problem=poisson_problem, model=model)
2020

2121

22-
# def test_r3constructor():
23-
# R3Refinement(sample_every=10)
22+
def test_r3constructor():
23+
R3Refinement(sample_every=10)
2424

2525

2626
# def test_r3refinment_routine():
2727
# # make the trainer
28-
# trainer = Trainer(solver=solver,
29-
# callback=[R3Refinement(sample_every=1)],
30-
# accelerator='cpu',
31-
# max_epochs=5)
28+
# trainer = Trainer(
29+
# solver=solver,
30+
# callbacks=[R3Refinement(sample_every=1)],
31+
# accelerator="cpu",
32+
# max_epochs=5,
33+
# )
3234
# trainer.train()
3335

34-
# def test_r3refinment_routine():
35-
# model = FeedForward(len(poisson_problem.input_variables),
36-
# len(poisson_problem.output_variables))
37-
# solver = PINN(problem=poisson_problem, model=model)
38-
# trainer = Trainer(solver=solver,
39-
# callback=[R3Refinement(sample_every=1)],
40-
# accelerator='cpu',
41-
# max_epochs=5)
42-
# before_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
43-
# trainer.train()
44-
# after_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()}
45-
# assert before_n_points == after_n_points
36+
37+
def test_r3refinment_routine():
38+
model = FeedForward(
39+
len(poisson_problem.input_variables),
40+
len(poisson_problem.output_variables),
41+
)
42+
solver = PINN(problem=poisson_problem, model=model)
43+
trainer = Trainer(
44+
solver=solver,
45+
callbacks=[R3Refinement(sample_every=1)],
46+
accelerator="cpu",
47+
max_epochs=5,
48+
)
49+
before_n_points = {
50+
loc: len(pts) for loc, pts in trainer.solver.problem.input_pts.items()
51+
}
52+
trainer.train()
53+
after_n_points = {
54+
loc: len(pts)
55+
for loc, pts in trainer.data_module.train_dataset.input.items()
56+
}
57+
assert before_n_points == after_n_points

0 commit comments

Comments
 (0)