From 90701f606961bbcf1ad92418752a1154dd627719 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Thu, 8 May 2025 12:10:51 +0200 Subject: [PATCH 01/12] Fix adaptive refinement --- pina/callback/adaptive_refinement_callback.py | 293 +++++++++--------- pina/data/dataset.py | 12 + .../test_adaptive_refinement_callback.py | 48 +-- 3 files changed, 190 insertions(+), 163 deletions(-) diff --git a/pina/callback/adaptive_refinement_callback.py b/pina/callback/adaptive_refinement_callback.py index 84ac0cfcc..bdcd9f773 100644 --- a/pina/callback/adaptive_refinement_callback.py +++ b/pina/callback/adaptive_refinement_callback.py @@ -34,148 +34,151 @@ def __init__(self, sample_every): Example: >>> r3_callback = R3Refinement(sample_every=5) """ - raise NotImplementedError( - "R3Refinement callback is being refactored in the pina " - f"{importlib.metadata.metadata('pina-mathlab')['Version']} " - "version. Please use version 0.1 if R3Refinement is required." - ) - - # super().__init__() - - # # sample every - # check_consistency(sample_every, int) - # self._sample_every = sample_every - # self._const_pts = None - - # def _compute_residual(self, trainer): - # """ - # Computes the residuals for a PINN object. - - # :return: the total loss, and pointwise loss. - # :rtype: tuple - # """ - - # # extract the solver and device from trainer - # solver = trainer.solver - # device = trainer._accelerator_connector._accelerator_flag - # precision = trainer.precision - # if precision == "64-true": - # precision = torch.float64 - # elif precision == "32-true": - # precision = torch.float32 - # else: - # raise RuntimeError( - # "Currently R3Refinement is only implemented " - # "for precision '32-true' and '64-true', set " - # "Trainer precision to match one of the " - # "available precisions." - # ) - - # # compute residual - # res_loss = {} - # tot_loss = [] - # for location in self._sampling_locations: - # condition = solver.problem.conditions[location] - # pts = solver.problem.input_pts[location] - # # send points to correct device - # pts = pts.to(device=device, dtype=precision) - # pts = pts.requires_grad_(True) - # pts.retain_grad() - # # PINN loss: equation evaluated only for sampling locations - # target = condition.equation.residual(pts, solver.forward(pts)) - # res_loss[location] = torch.abs(target).as_subclass(torch.Tensor) - # tot_loss.append(torch.abs(target)) - - # print(tot_loss) - - # return torch.vstack(tot_loss), res_loss - - # def _r3_routine(self, trainer): - # """ - # R3 refinement main routine. - - # :param Trainer trainer: PINA Trainer. - # """ - # # compute residual (all device possible) - # tot_loss, res_loss = self._compute_residual(trainer) - # tot_loss = tot_loss.as_subclass(torch.Tensor) - - # # !!!!!! From now everything is performed on CPU !!!!!! - - # # average loss - # avg = (tot_loss.mean()).to("cpu") - # old_pts = {} # points to be retained - # for location in self._sampling_locations: - # pts = trainer._model.problem.input_pts[location] - # labels = pts.labels - # pts = pts.cpu().detach().as_subclass(torch.Tensor) - # residuals = res_loss[location].cpu() - # mask = (residuals > avg).flatten() - # if any(mask): # append residuals greater than average - # pts = (pts[mask]).as_subclass(LabelTensor) - # pts.labels = labels - # old_pts[location] = pts - # numb_pts = self._const_pts[location] - len(old_pts[location]) - # # sample new points - # trainer._model.problem.discretise_domain( - # numb_pts, "random", locations=[location] - # ) - - # else: # if no res greater than average, samples all uniformly - # numb_pts = self._const_pts[location] - # # sample new points - # trainer._model.problem.discretise_domain( - # numb_pts, "random", locations=[location] - # ) - # # adding previous population points - # trainer._model.problem.add_points(old_pts) - - # # update dataloader - # trainer._create_or_update_loader() - - # def on_train_start(self, trainer, _): - # """ - # Callback function called at the start of training. - - # This method extracts the locations for sampling from the problem - # conditions and calculates the total population. - - # :param trainer: The trainer object managing the training process. - # :type trainer: pytorch_lightning.Trainer - # :param _: Placeholder argument (not used). - - # :return: None - # :rtype: None - # """ - # # extract locations for sampling - # problem = trainer.solver.problem - # locations = [] - # for condition_name in problem.conditions: - # condition = problem.conditions[condition_name] - # if hasattr(condition, "location"): - # locations.append(condition_name) - # self._sampling_locations = locations - - # # extract total population - # const_pts = {} # for each location, store the pts to keep constant - # for location in self._sampling_locations: - # pts = trainer._model.problem.input_pts[location] - # const_pts[location] = len(pts) - # self._const_pts = const_pts - - # def on_train_epoch_end(self, trainer, __): - # """ - # Callback function called at the end of each training epoch. - - # This method triggers the R3 routine for refinement if the current - # epoch is a multiple of `_sample_every`. - - # :param trainer: The trainer object managing the training process. - # :type trainer: pytorch_lightning.Trainer - # :param __: Placeholder argument (not used). - - # :return: None - # :rtype: None - # """ - # if trainer.current_epoch % self._sample_every == 0: - # self._r3_routine(trainer) + + super().__init__() + + # sample every + check_consistency(sample_every, int) + self._sample_every = sample_every + self._const_pts = None + self._domains = None + + def _compute_residual(self, trainer): + """ + Computes the residuals for a PINN object. + + :return: the total loss, and pointwise loss. + :rtype: tuple + """ + + # extract the solver and device from trainer + solver = trainer.solver + device = trainer._accelerator_connector._accelerator_flag + precision = trainer.precision + if precision == "64-true": + precision = torch.float64 + elif precision == "32-true": + precision = torch.float32 + else: + raise RuntimeError( + "Currently R3Refinement is only implemented " + "for precision '32-true' and '64-true', set " + "Trainer precision to match one of the " + "available precisions." + ) + + # compute residual + res_loss = {} + tot_loss = [] + for condition in self._conditions: + pts = trainer.datamodule.train_dataset.conditions_dict[condition][ + "input" + ] + equation = solver.problem.conditions[condition].equation + # send points to correct device + pts = pts.to(device=device, dtype=precision) + pts = pts.requires_grad_(True) + pts.retain_grad() + # PINN loss: equation evaluated only for sampling locations + target = equation.residual(pts, solver.forward(pts)) + res_loss[condition] = torch.abs(target).as_subclass(torch.Tensor) + tot_loss.append(torch.abs(target)) + return torch.vstack(tot_loss), res_loss + + def _r3_routine(self, trainer): + """ + R3 refinement main routine. + + :param Trainer trainer: PINA Trainer. + """ + # compute residual (all device possible) + tot_loss, res_loss = self._compute_residual(trainer) + tot_loss = tot_loss.as_subclass(torch.Tensor) + + # !!!!!! From now everything is performed on CPU !!!!!! + + # average loss + avg = (tot_loss.mean()).to("cpu") + new_pts = {} + + dataset = trainer.datamodule.train_dataset + problem = trainer.solver.problem + for condition in self._conditions: + pts = dataset.conditions_dict[condition]["input"] + domain = problem.conditions[condition].domain + if not isinstance(domain, str): + domain = condition + labels = pts.labels + pts = pts.cpu().detach().as_subclass(torch.Tensor) + residuals = res_loss[condition].cpu() + mask = (residuals > avg).flatten() + if any(mask): # append residuals greater than average + pts = (pts[mask]).as_subclass(LabelTensor) + pts.labels = labels + numb_pts = self._const_pts[condition] - len(pts) + else: # if no res greater than average, samples all uniformly + numb_pts = self._const_pts[condition] + pts = None + problem.discretise_domain(numb_pts, "random", domains=[domain]) + sampled_points = problem.discretised_domains[domain] + tmp = ( + sampled_points + if pts is None + else LabelTensor.cat([pts, sampled_points]) + ) + new_pts[condition] = {"input": tmp} + dataset.update_data(new_pts) + + def on_train_start(self, trainer, _): + """ + Callback function called at the start of training. + + This method extracts the locations for sampling from the problem + conditions and calculates the total population. + + :param trainer: The trainer object managing the training process. + :type trainer: pytorch_lightning.Trainer + :param _: Placeholder argument (not used). + + :return: None + :rtype: None + """ + problem = trainer.solver.problem + if hasattr(problem, "domains"): + domains = problem.domains + self._domains = domains + else: + self._domains = {} + for name, data in problem.conditions.items(): + if hasattr(data, "domain"): + self._domains[name] = data.domain + self._conditions = [] + for name, data in problem.conditions.items(): + if hasattr(data, "domain"): + self._conditions.append(name) + + # extract total population + const_pts = {} # for each location, store the pts to keep constant + for condition in self._conditions: + pts = trainer.datamodule.train_dataset.conditions_dict[condition][ + "input" + ] + const_pts[condition] = len(pts) + self._const_pts = const_pts + + def on_train_epoch_end(self, trainer, __): + """ + Callback function called at the end of each training epoch. + + This method triggers the R3 routine for refinement if the current + epoch is a multiple of `_sample_every`. + + :param trainer: The trainer object managing the training process. + :type trainer: pytorch_lightning.Trainer + :param __: Placeholder argument (not used). + + :return: None + :rtype: None + """ + if trainer.current_epoch % self._sample_every == 0: + self._r3_routine(trainer) diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 8d58be4c3..52e31addf 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -239,6 +239,18 @@ def input(self): """ return {k: v["input"] for k, v in self.conditions_dict.items()} + def update_data(self, conditions_dict): + """ + Update the dataset with new data. + This method is used to update the dataset with new data. It replaces + the current data with the new data provided in the conditions_dict + parameter. + :param dict conditions_dict: Dictionary containing the new data. + :type conditions_dict: dict + :return: None + """ + self.conditions_dict = conditions_dict + class PinaGraphDataset(PinaDataset): """ diff --git a/tests/test_callback/test_adaptive_refinement_callback.py b/tests/test_callback/test_adaptive_refinement_callback.py index dcabef13a..c5c6bf844 100644 --- a/tests/test_callback/test_adaptive_refinement_callback.py +++ b/tests/test_callback/test_adaptive_refinement_callback.py @@ -19,27 +19,39 @@ solver = PINN(problem=poisson_problem, model=model) -# def test_r3constructor(): -# R3Refinement(sample_every=10) +def test_r3constructor(): + R3Refinement(sample_every=10) # def test_r3refinment_routine(): # # make the trainer -# trainer = Trainer(solver=solver, -# callback=[R3Refinement(sample_every=1)], -# accelerator='cpu', -# max_epochs=5) +# trainer = Trainer( +# solver=solver, +# callbacks=[R3Refinement(sample_every=1)], +# accelerator="cpu", +# max_epochs=5, +# ) # trainer.train() -# def test_r3refinment_routine(): -# model = FeedForward(len(poisson_problem.input_variables), -# len(poisson_problem.output_variables)) -# solver = PINN(problem=poisson_problem, model=model) -# trainer = Trainer(solver=solver, -# callback=[R3Refinement(sample_every=1)], -# accelerator='cpu', -# max_epochs=5) -# before_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()} -# trainer.train() -# after_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()} -# assert before_n_points == after_n_points + +def test_r3refinment_routine(): + model = FeedForward( + len(poisson_problem.input_variables), + len(poisson_problem.output_variables), + ) + solver = PINN(problem=poisson_problem, model=model) + trainer = Trainer( + solver=solver, + callbacks=[R3Refinement(sample_every=1)], + accelerator="cpu", + max_epochs=5, + ) + before_n_points = { + loc: len(pts) for loc, pts in trainer.solver.problem.input_pts.items() + } + trainer.train() + after_n_points = { + loc: len(pts) + for loc, pts in trainer.data_module.train_dataset.input.items() + } + assert before_n_points == after_n_points From eb18138dc10db45dc45dae38778c08dabb369825 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 13 May 2025 14:17:53 +0200 Subject: [PATCH 02/12] Reimplement refinement --- pina/callback/__init__.py | 2 - pina/callback/adaptive_refinement_callback.py | 184 ------------------ pina/callback/refinement/__init__.py | 7 + pina/callback/refinement/r3_refinement.py | 79 ++++++++ .../refinement/refinement_interface.py | 103 ++++++++++ pina/data/dataset.py | 8 +- .../test_adaptive_refinement_callback.py | 13 +- 7 files changed, 196 insertions(+), 200 deletions(-) delete mode 100644 pina/callback/adaptive_refinement_callback.py create mode 100644 pina/callback/refinement/__init__.py create mode 100644 pina/callback/refinement/r3_refinement.py create mode 100644 pina/callback/refinement/refinement_interface.py diff --git a/pina/callback/__init__.py b/pina/callback/__init__.py index 421071a2c..f55b0b725 100644 --- a/pina/callback/__init__.py +++ b/pina/callback/__init__.py @@ -2,13 +2,11 @@ __all__ = [ "SwitchOptimizer", - "R3Refinement", "MetricTracker", "PINAProgressBar", "LinearWeightUpdate", ] from .optimizer_callback import SwitchOptimizer -from .adaptive_refinement_callback import R3Refinement from .processing_callback import MetricTracker, PINAProgressBar from .linear_weight_update_callback import LinearWeightUpdate diff --git a/pina/callback/adaptive_refinement_callback.py b/pina/callback/adaptive_refinement_callback.py deleted file mode 100644 index bdcd9f773..000000000 --- a/pina/callback/adaptive_refinement_callback.py +++ /dev/null @@ -1,184 +0,0 @@ -"""Module for the R3Refinement callback.""" - -import importlib.metadata -import torch -from lightning.pytorch.callbacks import Callback -from ..label_tensor import LabelTensor -from ..utils import check_consistency - - -class R3Refinement(Callback): - """ - PINA Implementation of an R3 Refinement Callback. - """ - - def __init__(self, sample_every): - """ - This callback implements the R3 (Retain-Resample-Release) routine for - sampling new points based on adaptive search. - The algorithm incrementally accumulates collocation points in regions - of high PDE residuals, and releases those with low residuals. - Points are sampled uniformly in all regions where sampling is needed. - - .. seealso:: - - Original Reference: Daw, Arka, et al. *Mitigating Propagation - Failures in Physics-informed Neural Networks - using Retain-Resample-Release (R3) Sampling. (2023)*. - DOI: `10.48550/arXiv.2207.02338 - `_ - - :param int sample_every: Frequency for sampling. - :raises ValueError: If `sample_every` is not an integer. - - Example: - >>> r3_callback = R3Refinement(sample_every=5) - """ - - super().__init__() - - # sample every - check_consistency(sample_every, int) - self._sample_every = sample_every - self._const_pts = None - self._domains = None - - def _compute_residual(self, trainer): - """ - Computes the residuals for a PINN object. - - :return: the total loss, and pointwise loss. - :rtype: tuple - """ - - # extract the solver and device from trainer - solver = trainer.solver - device = trainer._accelerator_connector._accelerator_flag - precision = trainer.precision - if precision == "64-true": - precision = torch.float64 - elif precision == "32-true": - precision = torch.float32 - else: - raise RuntimeError( - "Currently R3Refinement is only implemented " - "for precision '32-true' and '64-true', set " - "Trainer precision to match one of the " - "available precisions." - ) - - # compute residual - res_loss = {} - tot_loss = [] - for condition in self._conditions: - pts = trainer.datamodule.train_dataset.conditions_dict[condition][ - "input" - ] - equation = solver.problem.conditions[condition].equation - # send points to correct device - pts = pts.to(device=device, dtype=precision) - pts = pts.requires_grad_(True) - pts.retain_grad() - # PINN loss: equation evaluated only for sampling locations - target = equation.residual(pts, solver.forward(pts)) - res_loss[condition] = torch.abs(target).as_subclass(torch.Tensor) - tot_loss.append(torch.abs(target)) - return torch.vstack(tot_loss), res_loss - - def _r3_routine(self, trainer): - """ - R3 refinement main routine. - - :param Trainer trainer: PINA Trainer. - """ - # compute residual (all device possible) - tot_loss, res_loss = self._compute_residual(trainer) - tot_loss = tot_loss.as_subclass(torch.Tensor) - - # !!!!!! From now everything is performed on CPU !!!!!! - - # average loss - avg = (tot_loss.mean()).to("cpu") - new_pts = {} - - dataset = trainer.datamodule.train_dataset - problem = trainer.solver.problem - for condition in self._conditions: - pts = dataset.conditions_dict[condition]["input"] - domain = problem.conditions[condition].domain - if not isinstance(domain, str): - domain = condition - labels = pts.labels - pts = pts.cpu().detach().as_subclass(torch.Tensor) - residuals = res_loss[condition].cpu() - mask = (residuals > avg).flatten() - if any(mask): # append residuals greater than average - pts = (pts[mask]).as_subclass(LabelTensor) - pts.labels = labels - numb_pts = self._const_pts[condition] - len(pts) - else: # if no res greater than average, samples all uniformly - numb_pts = self._const_pts[condition] - pts = None - problem.discretise_domain(numb_pts, "random", domains=[domain]) - sampled_points = problem.discretised_domains[domain] - tmp = ( - sampled_points - if pts is None - else LabelTensor.cat([pts, sampled_points]) - ) - new_pts[condition] = {"input": tmp} - dataset.update_data(new_pts) - - def on_train_start(self, trainer, _): - """ - Callback function called at the start of training. - - This method extracts the locations for sampling from the problem - conditions and calculates the total population. - - :param trainer: The trainer object managing the training process. - :type trainer: pytorch_lightning.Trainer - :param _: Placeholder argument (not used). - - :return: None - :rtype: None - """ - problem = trainer.solver.problem - if hasattr(problem, "domains"): - domains = problem.domains - self._domains = domains - else: - self._domains = {} - for name, data in problem.conditions.items(): - if hasattr(data, "domain"): - self._domains[name] = data.domain - self._conditions = [] - for name, data in problem.conditions.items(): - if hasattr(data, "domain"): - self._conditions.append(name) - - # extract total population - const_pts = {} # for each location, store the pts to keep constant - for condition in self._conditions: - pts = trainer.datamodule.train_dataset.conditions_dict[condition][ - "input" - ] - const_pts[condition] = len(pts) - self._const_pts = const_pts - - def on_train_epoch_end(self, trainer, __): - """ - Callback function called at the end of each training epoch. - - This method triggers the R3 routine for refinement if the current - epoch is a multiple of `_sample_every`. - - :param trainer: The trainer object managing the training process. - :type trainer: pytorch_lightning.Trainer - :param __: Placeholder argument (not used). - - :return: None - :rtype: None - """ - if trainer.current_epoch % self._sample_every == 0: - self._r3_routine(trainer) diff --git a/pina/callback/refinement/__init__.py b/pina/callback/refinement/__init__.py new file mode 100644 index 000000000..c2d37c349 --- /dev/null +++ b/pina/callback/refinement/__init__.py @@ -0,0 +1,7 @@ +__all__ = [ + "RefinementInterface", + "R3Refinement", +] + +from .refinement_interface import RefinementInterface +from .r3_refinement import R3Refinement diff --git a/pina/callback/refinement/r3_refinement.py b/pina/callback/refinement/r3_refinement.py new file mode 100644 index 000000000..6dcc1f30e --- /dev/null +++ b/pina/callback/refinement/r3_refinement.py @@ -0,0 +1,79 @@ +"""Module for the R3Refinement callback.""" + +import torch +from .refinement_interface import RefinementInterface +from ...label_tensor import LabelTensor +from ...utils import check_consistency + + +class R3Refinement(RefinementInterface): + """ + PINA Implementation of an R3 Refinement Callback. + """ + + def __init__(self, sample_every): + """ + This callback implements the R3 (Retain-Resample-Release) routine for + sampling new points based on adaptive search. + The algorithm incrementally accumulates collocation points in regions + of high PDE residuals, and releases those with low residuals. + Points are sampled uniformly in all regions where sampling is needed. + + .. seealso:: + + Original Reference: Daw, Arka, et al. *Mitigating Propagation + Failures in Physics-informed Neural Networks + using Retain-Resample-Release (R3) Sampling. (2023)*. + DOI: `10.48550/arXiv.2207.02338 + `_ + + :param int sample_every: Frequency for sampling. + :raises ValueError: If `sample_every` is not an integer. + + Example: + >>> r3_callback = R3Refinement(sample_every=5) + """ + + super().__init__(sample_every=sample_every) + self.const_pts = None + + def sample(self, condition_name, condition): + avg_res, res = self.per_point_residual([condition_name]) + pts = self.dataset.conditions_dict[condition_name]["input"] + domain = condition.domain + labels = pts.labels + pts = pts.cpu().detach().as_subclass(torch.Tensor) + residuals = res[condition_name] + mask = (residuals > avg_res).flatten() + if any(mask): # append residuals greater than average + pts = (pts[mask]).as_subclass(LabelTensor) + pts.labels = labels + numb_pts = self.const_pts[condition_name] - len(pts) + else: + numb_pts = self.const_pts[condition_name] + pts = None + self.problem.discretise_domain(numb_pts, "random", domains=[domain]) + sampled_points = self.problem.discretised_domains[domain] + tmp = ( + sampled_points + if pts is None + else LabelTensor.cat([pts, sampled_points]) + ) + return tmp + + def on_train_start(self, trainer, _): + """ + Callback function called at the start of training. + + This method extracts the locations for sampling from the problem + conditions and calculates the total population. + + :param trainer: The trainer object managing the training process. + :type trainer: pytorch_lightning.Trainer + :param _: Placeholder argument (not used). + """ + super().on_train_start(trainer, _) + self.const_pts = {} + for condition in self.conditions: + pts = self.dataset.conditions_dict[condition]["input"] + self.const_pts[condition] = len(pts) diff --git a/pina/callback/refinement/refinement_interface.py b/pina/callback/refinement/refinement_interface.py new file mode 100644 index 000000000..6ef1b0897 --- /dev/null +++ b/pina/callback/refinement/refinement_interface.py @@ -0,0 +1,103 @@ +""" +RefinementInterface class for handling the refinement of points in a neural +network training process. +""" + +import torch +from abc import ABCMeta +from lightning.pytorch import Callback +from torch_geometric.data.feature_store import abstractmethod +from torch_geometric.nn.conv import point_transformer_conv +from ...condition.domain_equation_condition import DomainEquationCondition + + +class RefinementInterface(Callback, metaclass=ABCMeta): + """ + Interface class of Refinement + """ + + def __init__(self, sample_every): + """ + Initializes the RefinementInterface. + + :param int sample_every: The number of epochs between each refinement. + """ + self.sample_every = sample_every + self.conditions = None + self.dataset = None + self.solver = None + + def on_train_start(self, trainer, _): + """ + Called when the training begins. It initializes the conditions and + dataset. + + :param lightning.pytorch.Trainer trainer: The trainer object. + :param _: Unused argument. + """ + self.problem = trainer.solver.problem + self.solver = trainer.solver + self.conditions = {} + for name, cond in self.problem.conditions.items(): + if isinstance(cond, DomainEquationCondition): + self.conditions[name] = cond + self.dataset = trainer.datamodule.train_dataset + + @property + def points(self): + """ + Returns the points of the dataset. + """ + return self.dataset.conditions_dict + + def on_train_epoch_end(self, trainer, _): + """ + Performs the refinement at the end of each training epoch (if needed). + + :param lightning.pytorch.Trainer trainer: The trainer object. + :param _: Unused argument. + """ + if trainer.current_epoch % self.sample_every == 0: + self.update() + + def update(self): + """ + Performs the refinement of the points. + """ + new_points = {} + for name, condition in self.conditions.items(): + new_points[name] = {"input": self.sample(name, condition)} + self.dataset.update_data(new_points) + + def per_point_residual(self, conditions_name=None): + """ + Computes the residuals for a PINN object. + + :return: the total loss, and pointwise loss. + :rtype: tuple + """ + # compute residual + res_loss = {} + tot_loss = [] + points = self.points + if conditions_name is None: + conditions_name = list(self.conditions.keys()) + for name in conditions_name: + cond = self.conditions[name] + cond_points = points[name]["input"] + target = self._compute_residual(cond_points, cond.equation) + res_loss[name] = torch.abs(target).as_subclass(torch.Tensor) + tot_loss.append(torch.abs(target)) + return torch.vstack(tot_loss).tensor.mean(), res_loss + + def _compute_residual(self, pts, equation): + pts.requires_grad_(True) + pts.retain_grad() + return equation.residual(pts, self.solver.forward(pts)) + + @abstractmethod + def sample(self, condition): + """ + Samples new points based on the condition. + """ + pass diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 52e31addf..5d45ac672 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -239,7 +239,7 @@ def input(self): """ return {k: v["input"] for k, v in self.conditions_dict.items()} - def update_data(self, conditions_dict): + def update_data(self, new_conditions_dict): """ Update the dataset with new data. This method is used to update the dataset with new data. It replaces @@ -249,7 +249,11 @@ def update_data(self, conditions_dict): :type conditions_dict: dict :return: None """ - self.conditions_dict = conditions_dict + for condition, data in new_conditions_dict.items(): + if condition in self.conditions_dict: + self.conditions_dict[condition].update(data) + else: + self.conditions_dict[condition] = data class PinaGraphDataset(PinaDataset): diff --git a/tests/test_callback/test_adaptive_refinement_callback.py b/tests/test_callback/test_adaptive_refinement_callback.py index c5c6bf844..65b178834 100644 --- a/tests/test_callback/test_adaptive_refinement_callback.py +++ b/tests/test_callback/test_adaptive_refinement_callback.py @@ -2,7 +2,7 @@ from pina.trainer import Trainer from pina.model import FeedForward from pina.problem.zoo import Poisson2DSquareProblem as Poisson -from pina.callback import R3Refinement +from pina.callback.refinement import R3Refinement # make the problem @@ -23,17 +23,6 @@ def test_r3constructor(): R3Refinement(sample_every=10) -# def test_r3refinment_routine(): -# # make the trainer -# trainer = Trainer( -# solver=solver, -# callbacks=[R3Refinement(sample_every=1)], -# accelerator="cpu", -# max_epochs=5, -# ) -# trainer.train() - - def test_r3refinment_routine(): model = FeedForward( len(poisson_problem.input_variables), From 518264d1055fe02b2fa0d4ec09c7a62f434c80ed Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 13 May 2025 15:03:07 +0200 Subject: [PATCH 03/12] Fix doc --- docs/source/_rst/_code.rst | 3 ++- .../r3_refinement.rst} | 2 +- .../_rst/callback/refinement/refinement_interface.rst | 7 +++++++ 3 files changed, 10 insertions(+), 2 deletions(-) rename docs/source/_rst/callback/{adaptive_refinment_callback.rst => refinement/r3_refinement.rst} (63%) create mode 100644 docs/source/_rst/callback/refinement/refinement_interface.rst diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index ba059ddbc..bfe4621b2 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -238,7 +238,8 @@ Callbacks Processing callback Optimizer callback - Refinment callback + R3 Refinment callback + Refinment Interface callback Weighting callback Losses and Weightings diff --git a/docs/source/_rst/callback/adaptive_refinment_callback.rst b/docs/source/_rst/callback/refinement/r3_refinement.rst similarity index 63% rename from docs/source/_rst/callback/adaptive_refinment_callback.rst rename to docs/source/_rst/callback/refinement/r3_refinement.rst index 8afad6571..eb3bfebf2 100644 --- a/docs/source/_rst/callback/adaptive_refinment_callback.rst +++ b/docs/source/_rst/callback/refinement/r3_refinement.rst @@ -1,7 +1,7 @@ Refinments callbacks ======================= -.. currentmodule:: pina.callback.adaptive_refinement_callback +.. currentmodule:: pina.callback.refinement .. autoclass:: R3Refinement :members: :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/callback/refinement/refinement_interface.rst b/docs/source/_rst/callback/refinement/refinement_interface.rst new file mode 100644 index 000000000..5e02f2dc3 --- /dev/null +++ b/docs/source/_rst/callback/refinement/refinement_interface.rst @@ -0,0 +1,7 @@ +Refinement Interface +======================= + +.. currentmodule:: pina.callback.refinement +.. autoclass:: RefinementInterface + :members: + :show-inheritance: \ No newline at end of file From 1a056422e3c5d7b5aac13e56318a582ed1917b34 Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Tue, 13 May 2025 15:37:33 +0200 Subject: [PATCH 04/12] Fix doc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pina/data/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 5d45ac672..895ee096d 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -243,10 +243,10 @@ def update_data(self, new_conditions_dict): """ Update the dataset with new data. This method is used to update the dataset with new data. It replaces - the current data with the new data provided in the conditions_dict + the current data with the new data provided in the new_conditions_dict parameter. - :param dict conditions_dict: Dictionary containing the new data. - :type conditions_dict: dict + :param dict new_conditions_dict: Dictionary containing the new data. + :type new_conditions_dict: dict :return: None """ for condition, data in new_conditions_dict.items(): From abb8718984af4c6ca22a8e3256c403e8635fb2f6 Mon Sep 17 00:00:00 2001 From: Filippo Olivo Date: Tue, 13 May 2025 15:38:12 +0200 Subject: [PATCH 05/12] Fix test Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tests/test_callback/test_adaptive_refinement_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_callback/test_adaptive_refinement_callback.py b/tests/test_callback/test_adaptive_refinement_callback.py index 65b178834..8693529d7 100644 --- a/tests/test_callback/test_adaptive_refinement_callback.py +++ b/tests/test_callback/test_adaptive_refinement_callback.py @@ -23,7 +23,7 @@ def test_r3constructor(): R3Refinement(sample_every=10) -def test_r3refinment_routine(): +def test_r3refinement_routine(): model = FeedForward( len(poisson_problem.input_variables), len(poisson_problem.output_variables), From 1e9e92076f1017cd279da90b59e89b82a6d22865 Mon Sep 17 00:00:00 2001 From: Monthly Tag bot Date: Wed, 14 May 2025 13:57:56 +0200 Subject: [PATCH 06/12] clean code --- pina/callback/refinement/r3_refinement.py | 73 +++++----- .../refinement/refinement_interface.py | 136 ++++++++++-------- .../test_adaptive_refinement_callback.py | 30 ++-- 3 files changed, 127 insertions(+), 112 deletions(-) diff --git a/pina/callback/refinement/r3_refinement.py b/pina/callback/refinement/r3_refinement.py index 6dcc1f30e..61bc2a584 100644 --- a/pina/callback/refinement/r3_refinement.py +++ b/pina/callback/refinement/r3_refinement.py @@ -1,9 +1,12 @@ """Module for the R3Refinement callback.""" import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss from .refinement_interface import RefinementInterface from ...label_tensor import LabelTensor from ...utils import check_consistency +from ...loss import LossInterface class R3Refinement(RefinementInterface): @@ -11,7 +14,9 @@ class R3Refinement(RefinementInterface): PINA Implementation of an R3 Refinement Callback. """ - def __init__(self, sample_every): + def __init__( + self, sample_every, residual_loss=nn.L1Loss, condition_to_update=None + ): """ This callback implements the R3 (Retain-Resample-Release) routine for sampling new points based on adaptive search. @@ -33,47 +38,33 @@ def __init__(self, sample_every): Example: >>> r3_callback = R3Refinement(sample_every=5) """ + super().__init__(sample_every, condition_to_update) + # check consistency loss + check_consistency(residual_loss, (LossInterface, _Loss), subclass=True) + self.loss_fn = residual_loss(reduction="none") - super().__init__(sample_every=sample_every) - self.const_pts = None - - def sample(self, condition_name, condition): - avg_res, res = self.per_point_residual([condition_name]) - pts = self.dataset.conditions_dict[condition_name]["input"] - domain = condition.domain - labels = pts.labels - pts = pts.cpu().detach().as_subclass(torch.Tensor) - residuals = res[condition_name] - mask = (residuals > avg_res).flatten() - if any(mask): # append residuals greater than average - pts = (pts[mask]).as_subclass(LabelTensor) - pts.labels = labels - numb_pts = self.const_pts[condition_name] - len(pts) - else: - numb_pts = self.const_pts[condition_name] - pts = None - self.problem.discretise_domain(numb_pts, "random", domains=[domain]) - sampled_points = self.problem.discretised_domains[domain] - tmp = ( - sampled_points - if pts is None - else LabelTensor.cat([pts, sampled_points]) + def sample(self, current_points, condition_name, solver): + # Compute residuals for the given condition (average over fields) + condition = solver.problem.conditions[condition_name] + target = solver.compute_residual( + current_points.requires_grad_(True), condition.equation + ) + residuals = self.loss_fn(target, torch.zeros_like(target)).mean( + dim=tuple(range(1, target.ndim)) ) - return tmp - - def on_train_start(self, trainer, _): - """ - Callback function called at the start of training. - This method extracts the locations for sampling from the problem - conditions and calculates the total population. + # Prepare new points + labels = current_points.labels + domain_name = solver.problem.conditions[condition_name].domain + domain = solver.problem.domains[domain_name] + num_old_points = self.initial_population_size[condition_name] + mask = (residuals > residuals.mean()).flatten() - :param trainer: The trainer object managing the training process. - :type trainer: pytorch_lightning.Trainer - :param _: Placeholder argument (not used). - """ - super().on_train_start(trainer, _) - self.const_pts = {} - for condition in self.conditions: - pts = self.dataset.conditions_dict[condition]["input"] - self.const_pts[condition] = len(pts) + if mask.any(): # Use high-residual points + pts = current_points[mask] + pts.labels = labels + retain_pts = len(pts) + samples = domain.sample(num_old_points - retain_pts, "random") + return LabelTensor.cat([pts, samples]) + else: + return domain.sample(num_old_points, "random") diff --git a/pina/callback/refinement/refinement_interface.py b/pina/callback/refinement/refinement_interface.py index 6ef1b0897..aa4c9ae55 100644 --- a/pina/callback/refinement/refinement_interface.py +++ b/pina/callback/refinement/refinement_interface.py @@ -3,31 +3,38 @@ network training process. """ -import torch -from abc import ABCMeta +from abc import ABCMeta, abstractmethod from lightning.pytorch import Callback -from torch_geometric.data.feature_store import abstractmethod -from torch_geometric.nn.conv import point_transformer_conv -from ...condition.domain_equation_condition import DomainEquationCondition +from ...utils import check_consistency +from ...solver.physics_informed_solver import PINNInterface class RefinementInterface(Callback, metaclass=ABCMeta): """ - Interface class of Refinement + Interface class of Refinement approaches. """ - def __init__(self, sample_every): + def __init__(self, sample_every, condition_to_update=None): """ Initializes the RefinementInterface. :param int sample_every: The number of epochs between each refinement. """ + # check consistency of the input + check_consistency(sample_every, int) + if condition_to_update is not None: + if not isinstance(condition_to_update, (list, tuple)): + raise ValueError( + "'condition_to_update' must be iter of strings." + ) + check_consistency(condition_to_update, str) + # store self.sample_every = sample_every - self.conditions = None - self.dataset = None - self.solver = None + self._condition_to_update = condition_to_update + self._dataset = None + self._initial_population_size = None - def on_train_start(self, trainer, _): + def on_train_start(self, trainer, solver): """ Called when the training begins. It initializes the conditions and dataset. @@ -35,22 +42,37 @@ def on_train_start(self, trainer, _): :param lightning.pytorch.Trainer trainer: The trainer object. :param _: Unused argument. """ - self.problem = trainer.solver.problem - self.solver = trainer.solver - self.conditions = {} - for name, cond in self.problem.conditions.items(): - if isinstance(cond, DomainEquationCondition): - self.conditions[name] = cond - self.dataset = trainer.datamodule.train_dataset + # check we have valid conditions names + if self._condition_to_update is None: + self._condition_to_update = list(solver.problem.conditions.keys()) - @property - def points(self): - """ - Returns the points of the dataset. - """ - return self.dataset.conditions_dict + for cond in self._condition_to_update: + if cond not in solver.problem.conditions: + raise RuntimeError( + f"Condition '{cond}' not found in " + f"{list(solver.problem.conditions.keys())}." + ) + if not hasattr(solver.problem.conditions[cond], "domain"): + raise RuntimeError( + f"Condition '{cond}' does not contain a domain to " + "sample from." + ) + # check solver + if not isinstance(solver, PINNInterface): + raise RuntimeError( + f"Refinment strategies are currently implemented only " + "for physics informed based solvers. Please use a Solver " + "inheriting from 'PINNInterface'." + ) + # store dataset + self._dataset = trainer.datamodule.train_dataset + # compute initial population size + self._initial_population_size = self._compute_population_size( + self._condition_to_update + ) + return super().on_train_epoch_start(trainer, solver) - def on_train_epoch_end(self, trainer, _): + def on_train_epoch_end(self, trainer, solver): """ Performs the refinement at the end of each training epoch (if needed). @@ -58,46 +80,44 @@ def on_train_epoch_end(self, trainer, _): :param _: Unused argument. """ if trainer.current_epoch % self.sample_every == 0: - self.update() + self._update_points(solver) + return super().on_train_epoch_end(trainer, solver) - def update(self): + @abstractmethod + def sample(self, current_points, condition_name, solver): """ - Performs the refinement of the points. + Samples new points based on the condition. """ - new_points = {} - for name, condition in self.conditions.items(): - new_points[name] = {"input": self.sample(name, condition)} - self.dataset.update_data(new_points) + pass - def per_point_residual(self, conditions_name=None): + @property + def dataset(self): """ - Computes the residuals for a PINN object. - - :return: the total loss, and pointwise loss. - :rtype: tuple + Returns the dataset for training. """ - # compute residual - res_loss = {} - tot_loss = [] - points = self.points - if conditions_name is None: - conditions_name = list(self.conditions.keys()) - for name in conditions_name: - cond = self.conditions[name] - cond_points = points[name]["input"] - target = self._compute_residual(cond_points, cond.equation) - res_loss[name] = torch.abs(target).as_subclass(torch.Tensor) - tot_loss.append(torch.abs(target)) - return torch.vstack(tot_loss).tensor.mean(), res_loss + return self._dataset - def _compute_residual(self, pts, equation): - pts.requires_grad_(True) - pts.retain_grad() - return equation.residual(pts, self.solver.forward(pts)) + @property + def initial_population_size(self): + """ + Returns the dataset for training. + """ + return self._initial_population_size - @abstractmethod - def sample(self, condition): + def _update_points(self, solver): """ - Samples new points based on the condition. + Performs the refinement of the points. """ - pass + new_points = {} + for name in self._condition_to_update: + current_points = self.dataset.conditions_dict[name]["input"] + new_points[name] = { + "input": self.sample(current_points, name, solver) + } + self.dataset.update_data(new_points) + + def _compute_population_size(self, conditions): + return { + cond: len(self.dataset.conditions_dict[cond]["input"]) + for cond in conditions + } diff --git a/tests/test_callback/test_adaptive_refinement_callback.py b/tests/test_callback/test_adaptive_refinement_callback.py index 8693529d7..9d1e63735 100644 --- a/tests/test_callback/test_adaptive_refinement_callback.py +++ b/tests/test_callback/test_adaptive_refinement_callback.py @@ -1,3 +1,7 @@ +import pytest + +from torch.nn import MSELoss + from pina.solver import PINN from pina.trainer import Trainer from pina.model import FeedForward @@ -7,28 +11,27 @@ # make the problem poisson_problem = Poisson() -boundaries = ["g1", "g2", "g3", "g4"] -n = 10 -poisson_problem.discretise_domain(n, "grid", domains=boundaries) -poisson_problem.discretise_domain(n, "grid", domains="D") +poisson_problem.discretise_domain(10, "grid", domains=["g1", "g2", "g3", "g4"]) +poisson_problem.discretise_domain(10, "grid", domains="D") model = FeedForward( len(poisson_problem.input_variables), len(poisson_problem.output_variables) ) - -# make the solver solver = PINN(problem=poisson_problem, model=model) -def test_r3constructor(): +def test_constructor(): + # good constructor R3Refinement(sample_every=10) + R3Refinement(sample_every=10, residual_loss=MSELoss) + R3Refinement(sample_every=10, condition_to_update=["D"]) + # wrong constructor + with pytest.raises(ValueError): + R3Refinement(sample_every="str") + with pytest.raises(ValueError): + R3Refinement(sample_every=10, condition_to_update=3) -def test_r3refinement_routine(): - model = FeedForward( - len(poisson_problem.input_variables), - len(poisson_problem.output_variables), - ) - solver = PINN(problem=poisson_problem, model=model) +def test_sample(): trainer = Trainer( solver=solver, callbacks=[R3Refinement(sample_every=1)], @@ -43,4 +46,5 @@ def test_r3refinement_routine(): loc: len(pts) for loc, pts in trainer.data_module.train_dataset.input.items() } + assert before_n_points == trainer.callbacks[0].initial_population_size assert before_n_points == after_n_points From 360943e405c3c7fac6030e63096b5c3939112e34 Mon Sep 17 00:00:00 2001 From: Monthly Tag bot Date: Wed, 14 May 2025 14:43:45 +0200 Subject: [PATCH 07/12] fix tests --- .../refinement/refinement_interface.py | 6 +++++- .../test_adaptive_refinement_callback.py | 18 +++++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/pina/callback/refinement/refinement_interface.py b/pina/callback/refinement/refinement_interface.py index aa4c9ae55..da068894e 100644 --- a/pina/callback/refinement/refinement_interface.py +++ b/pina/callback/refinement/refinement_interface.py @@ -7,6 +7,7 @@ from lightning.pytorch import Callback from ...utils import check_consistency from ...solver.physics_informed_solver import PINNInterface +from ...condition import DomainEquationCondition class RefinementInterface(Callback, metaclass=ABCMeta): @@ -44,7 +45,10 @@ def on_train_start(self, trainer, solver): """ # check we have valid conditions names if self._condition_to_update is None: - self._condition_to_update = list(solver.problem.conditions.keys()) + self._condition_to_update = [ + name for name, cond in solver.problem.conditions.items() + if hasattr(cond, "domain") + ] for cond in self._condition_to_update: if cond not in solver.problem.conditions: diff --git a/tests/test_callback/test_adaptive_refinement_callback.py b/tests/test_callback/test_adaptive_refinement_callback.py index 9d1e63735..7866c7f7b 100644 --- a/tests/test_callback/test_adaptive_refinement_callback.py +++ b/tests/test_callback/test_adaptive_refinement_callback.py @@ -31,20 +31,28 @@ def test_constructor(): R3Refinement(sample_every=10, condition_to_update=3) -def test_sample(): +@pytest.mark.parametrize( + "condition_to_update", [["D", "g1"], ["D", "g1", "g2", "g3", "g4"]] +) +def test_sample(condition_to_update): trainer = Trainer( solver=solver, - callbacks=[R3Refinement(sample_every=1)], + callbacks=[ + R3Refinement( + sample_every=1, condition_to_update=condition_to_update + ) + ], accelerator="cpu", max_epochs=5, ) before_n_points = { - loc: len(pts) for loc, pts in trainer.solver.problem.input_pts.items() + loc: len(trainer.solver.problem.input_pts[loc]) + for loc in condition_to_update } trainer.train() after_n_points = { - loc: len(pts) - for loc, pts in trainer.data_module.train_dataset.input.items() + loc: len(trainer.data_module.train_dataset.input[loc]) + for loc in condition_to_update } assert before_n_points == trainer.callbacks[0].initial_population_size assert before_n_points == after_n_points From 36d49247c5f464e45705a4a457424cb9663787a2 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Wed, 14 May 2025 15:43:02 +0200 Subject: [PATCH 08/12] Fix codacy --- pina/callback/refinement/__init__.py | 4 ++++ pina/callback/refinement/r3_refinement.py | 5 ++--- pina/callback/refinement/refinement_interface.py | 7 +++---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pina/callback/refinement/__init__.py b/pina/callback/refinement/__init__.py index c2d37c349..396fcabaa 100644 --- a/pina/callback/refinement/__init__.py +++ b/pina/callback/refinement/__init__.py @@ -1,3 +1,7 @@ +""" +Module for Pina Refinement callbacks. +""" + __all__ = [ "RefinementInterface", "R3Refinement", diff --git a/pina/callback/refinement/r3_refinement.py b/pina/callback/refinement/r3_refinement.py index 61bc2a584..c204f1aa0 100644 --- a/pina/callback/refinement/r3_refinement.py +++ b/pina/callback/refinement/r3_refinement.py @@ -1,7 +1,7 @@ """Module for the R3Refinement callback.""" import torch -import torch.nn as nn +from torch import nn from torch.nn.modules.loss import _Loss from .refinement_interface import RefinementInterface from ...label_tensor import LabelTensor @@ -66,5 +66,4 @@ def sample(self, current_points, condition_name, solver): retain_pts = len(pts) samples = domain.sample(num_old_points - retain_pts, "random") return LabelTensor.cat([pts, samples]) - else: - return domain.sample(num_old_points, "random") + return domain.sample(num_old_points, "random") diff --git a/pina/callback/refinement/refinement_interface.py b/pina/callback/refinement/refinement_interface.py index da068894e..a68ae4ec8 100644 --- a/pina/callback/refinement/refinement_interface.py +++ b/pina/callback/refinement/refinement_interface.py @@ -7,7 +7,6 @@ from lightning.pytorch import Callback from ...utils import check_consistency from ...solver.physics_informed_solver import PINNInterface -from ...condition import DomainEquationCondition class RefinementInterface(Callback, metaclass=ABCMeta): @@ -46,7 +45,8 @@ def on_train_start(self, trainer, solver): # check we have valid conditions names if self._condition_to_update is None: self._condition_to_update = [ - name for name, cond in solver.problem.conditions.items() + name + for name, cond in solver.problem.conditions.items() if hasattr(cond, "domain") ] @@ -64,7 +64,7 @@ def on_train_start(self, trainer, solver): # check solver if not isinstance(solver, PINNInterface): raise RuntimeError( - f"Refinment strategies are currently implemented only " + "Refinment strategies are currently implemented only " "for physics informed based solvers. Please use a Solver " "inheriting from 'PINNInterface'." ) @@ -92,7 +92,6 @@ def sample(self, current_points, condition_name, solver): """ Samples new points based on the condition. """ - pass @property def dataset(self): From 0666cf41fe8f5f059fac73b906774f9f7b2f5f0d Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Wed, 14 May 2025 16:48:09 +0200 Subject: [PATCH 09/12] Fix docstring --- pina/callback/refinement/r3_refinement.py | 21 ++++++++++- .../refinement/refinement_interface.py | 35 ++++++++++++++++--- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/pina/callback/refinement/r3_refinement.py b/pina/callback/refinement/r3_refinement.py index c204f1aa0..7434ead24 100644 --- a/pina/callback/refinement/r3_refinement.py +++ b/pina/callback/refinement/r3_refinement.py @@ -33,7 +33,17 @@ def __init__( `_ :param int sample_every: Frequency for sampling. - :raises ValueError: If `sample_every` is not an integer. + :param loss: Loss function + :type loss: LossInterface | ~torch.nn.modules.loss._Loss + :param condition_to_update: The conditions to update during the + refinement process. If None, all conditions with a conditions will + be updated. Default is None. + :type condition_to_update: list(str) | tuple(str) | str + :raises ValueError: If the condition_to_update is not a string or + iterable of strings. + :raises TypeError: If the residual_loss is not a subclass of + torch.nn.Module. + Example: >>> r3_callback = R3Refinement(sample_every=5) @@ -44,6 +54,15 @@ def __init__( self.loss_fn = residual_loss(reduction="none") def sample(self, current_points, condition_name, solver): + """ + Sample new points based on the R3 refinement strategy. + + :param current_points: Current points in the domain. + :param condition_name: Name of the condition to update. + :param solver: The solver object. + :return: New points sampled based on the R3 strategy. + :rtype: LabelTensor + """ # Compute residuals for the given condition (average over fields) condition = solver.problem.conditions[condition_name] target = solver.compute_residual( diff --git a/pina/callback/refinement/refinement_interface.py b/pina/callback/refinement/refinement_interface.py index a68ae4ec8..d4e79c940 100644 --- a/pina/callback/refinement/refinement_interface.py +++ b/pina/callback/refinement/refinement_interface.py @@ -19,10 +19,17 @@ def __init__(self, sample_every, condition_to_update=None): Initializes the RefinementInterface. :param int sample_every: The number of epochs between each refinement. + :param condition_to_update: The conditions to update during the + refinement process. If None, all conditions with a domain will be + updated. Default is None. + :type condition_to_update: list(str) | tuple(str) | str + """ # check consistency of the input check_consistency(sample_every, int) if condition_to_update is not None: + if isinstance(condition_to_update, str): + condition_to_update = [condition_to_update] if not isinstance(condition_to_update, (list, tuple)): raise ValueError( "'condition_to_update' must be iter of strings." @@ -39,8 +46,13 @@ def on_train_start(self, trainer, solver): Called when the training begins. It initializes the conditions and dataset. - :param lightning.pytorch.Trainer trainer: The trainer object. - :param _: Unused argument. + :param ~lightning.pytorch.trainer.trainer.Trainer trainer: The trainer + object. + :param ~pina.solver.solver.SolverInterface solver: The solver + object associated with the trainer. + :raises RuntimeError: If the solver is not a PINNInterface. + :raises RuntimeError: If the conditions do not have a domain to sample + from. """ # check we have valid conditions names if self._condition_to_update is None: @@ -80,8 +92,8 @@ def on_train_epoch_end(self, trainer, solver): """ Performs the refinement at the end of each training epoch (if needed). - :param lightning.pytorch.Trainer trainer: The trainer object. - :param _: Unused argument. + :param ~lightning.pytorch.trainer.trainer.Trainer: The trainer object. + :param PINNInterface solver: The solver object. """ if trainer.current_epoch % self.sample_every == 0: self._update_points(solver) @@ -91,6 +103,12 @@ def on_train_epoch_end(self, trainer, solver): def sample(self, current_points, condition_name, solver): """ Samples new points based on the condition. + + :param current_points: Current points in the domain. + :param condition_name: Name of the condition to update. + :param solver: The solver object. + :return: New points sampled based on the R3 strategy. + :rtype: LabelTensor """ @property @@ -110,6 +128,8 @@ def initial_population_size(self): def _update_points(self, solver): """ Performs the refinement of the points. + + :param PINNInterface solver: The solver object. """ new_points = {} for name in self._condition_to_update: @@ -120,6 +140,13 @@ def _update_points(self, solver): self.dataset.update_data(new_points) def _compute_population_size(self, conditions): + """ + Computes the number of points in the dataset for each condition. + + :param conditions: List of conditions to compute the number of points. + :return: Dictionary with the population size for each condition. + :rtype: dict + """ return { cond: len(self.dataset.conditions_dict[cond]["input"]) for cond in conditions From 9e69796b5d93611c3033e5fc6464b416118b9b2f Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Wed, 14 May 2025 18:54:49 +0200 Subject: [PATCH 10/12] add white space --- pina/data/dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 895ee096d..9c0214022 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -245,6 +245,7 @@ def update_data(self, new_conditions_dict): This method is used to update the dataset with new data. It replaces the current data with the new data provided in the new_conditions_dict parameter. + :param dict new_conditions_dict: Dictionary containing the new data. :type new_conditions_dict: dict :return: None From eb0b17ca00e4fe28308bca0e033da7560fd2fddf Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Wed, 14 May 2025 18:56:47 +0200 Subject: [PATCH 11/12] update doc --- pina/callback/refinement/refinement_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pina/callback/refinement/refinement_interface.py b/pina/callback/refinement/refinement_interface.py index d4e79c940..997a26b70 100644 --- a/pina/callback/refinement/refinement_interface.py +++ b/pina/callback/refinement/refinement_interface.py @@ -121,7 +121,7 @@ def dataset(self): @property def initial_population_size(self): """ - Returns the dataset for training. + Returns the dataset for training size. """ return self._initial_population_size From 2a316a89d7b0b20dbe0300b4a12ec01c1d384fa5 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Thu, 15 May 2025 14:28:30 +0200 Subject: [PATCH 12/12] Fixes --- pina/callback/__init__.py | 2 ++ pina/callback/refinement/r3_refinement.py | 2 +- pina/callback/refinement/refinement_interface.py | 6 ++++-- pina/data/dataset.py | 1 - 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pina/callback/__init__.py b/pina/callback/__init__.py index f55b0b725..dc1164e47 100644 --- a/pina/callback/__init__.py +++ b/pina/callback/__init__.py @@ -5,8 +5,10 @@ "MetricTracker", "PINAProgressBar", "LinearWeightUpdate", + "R3Refinement", ] from .optimizer_callback import SwitchOptimizer from .processing_callback import MetricTracker, PINAProgressBar from .linear_weight_update_callback import LinearWeightUpdate +from .refinement import R3Refinement diff --git a/pina/callback/refinement/r3_refinement.py b/pina/callback/refinement/r3_refinement.py index 7434ead24..c90b2953e 100644 --- a/pina/callback/refinement/r3_refinement.py +++ b/pina/callback/refinement/r3_refinement.py @@ -59,7 +59,7 @@ def sample(self, current_points, condition_name, solver): :param current_points: Current points in the domain. :param condition_name: Name of the condition to update. - :param solver: The solver object. + :param PINNInterface solver: The solver object. :return: New points sampled based on the R3 strategy. :rtype: LabelTensor """ diff --git a/pina/callback/refinement/refinement_interface.py b/pina/callback/refinement/refinement_interface.py index 997a26b70..adc6e4e7c 100644 --- a/pina/callback/refinement/refinement_interface.py +++ b/pina/callback/refinement/refinement_interface.py @@ -95,7 +95,9 @@ def on_train_epoch_end(self, trainer, solver): :param ~lightning.pytorch.trainer.trainer.Trainer: The trainer object. :param PINNInterface solver: The solver object. """ - if trainer.current_epoch % self.sample_every == 0: + if (trainer.current_epoch % self.sample_every == 0) and ( + trainer.current_epoch != 0 + ): self._update_points(solver) return super().on_train_epoch_end(trainer, solver) @@ -106,7 +108,7 @@ def sample(self, current_points, condition_name, solver): :param current_points: Current points in the domain. :param condition_name: Name of the condition to update. - :param solver: The solver object. + :param PINNInterface solver: The solver object. :return: New points sampled based on the R3 strategy. :rtype: LabelTensor """ diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 9c0214022..386c3c53c 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -247,7 +247,6 @@ def update_data(self, new_conditions_dict): parameter. :param dict new_conditions_dict: Dictionary containing the new data. - :type new_conditions_dict: dict :return: None """ for condition, data in new_conditions_dict.items():