diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index ba059ddbc..bfe4621b2 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -238,7 +238,8 @@ Callbacks Processing callback Optimizer callback - Refinment callback + R3 Refinment callback + Refinment Interface callback Weighting callback Losses and Weightings diff --git a/docs/source/_rst/callback/adaptive_refinment_callback.rst b/docs/source/_rst/callback/refinement/r3_refinement.rst similarity index 63% rename from docs/source/_rst/callback/adaptive_refinment_callback.rst rename to docs/source/_rst/callback/refinement/r3_refinement.rst index 8afad6571..eb3bfebf2 100644 --- a/docs/source/_rst/callback/adaptive_refinment_callback.rst +++ b/docs/source/_rst/callback/refinement/r3_refinement.rst @@ -1,7 +1,7 @@ Refinments callbacks ======================= -.. currentmodule:: pina.callback.adaptive_refinement_callback +.. currentmodule:: pina.callback.refinement .. autoclass:: R3Refinement :members: :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/callback/refinement/refinement_interface.rst b/docs/source/_rst/callback/refinement/refinement_interface.rst new file mode 100644 index 000000000..5e02f2dc3 --- /dev/null +++ b/docs/source/_rst/callback/refinement/refinement_interface.rst @@ -0,0 +1,7 @@ +Refinement Interface +======================= + +.. currentmodule:: pina.callback.refinement +.. autoclass:: RefinementInterface + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/callback/__init__.py b/pina/callback/__init__.py index 421071a2c..dc1164e47 100644 --- a/pina/callback/__init__.py +++ b/pina/callback/__init__.py @@ -2,13 +2,13 @@ __all__ = [ "SwitchOptimizer", - "R3Refinement", "MetricTracker", "PINAProgressBar", "LinearWeightUpdate", + "R3Refinement", ] from .optimizer_callback import SwitchOptimizer -from .adaptive_refinement_callback import R3Refinement from .processing_callback import MetricTracker, PINAProgressBar from .linear_weight_update_callback import LinearWeightUpdate +from .refinement import R3Refinement diff --git a/pina/callback/adaptive_refinement_callback.py b/pina/callback/adaptive_refinement_callback.py deleted file mode 100644 index 84ac0cfcc..000000000 --- a/pina/callback/adaptive_refinement_callback.py +++ /dev/null @@ -1,181 +0,0 @@ -"""Module for the R3Refinement callback.""" - -import importlib.metadata -import torch -from lightning.pytorch.callbacks import Callback -from ..label_tensor import LabelTensor -from ..utils import check_consistency - - -class R3Refinement(Callback): - """ - PINA Implementation of an R3 Refinement Callback. - """ - - def __init__(self, sample_every): - """ - This callback implements the R3 (Retain-Resample-Release) routine for - sampling new points based on adaptive search. - The algorithm incrementally accumulates collocation points in regions - of high PDE residuals, and releases those with low residuals. - Points are sampled uniformly in all regions where sampling is needed. - - .. seealso:: - - Original Reference: Daw, Arka, et al. *Mitigating Propagation - Failures in Physics-informed Neural Networks - using Retain-Resample-Release (R3) Sampling. (2023)*. - DOI: `10.48550/arXiv.2207.02338 - `_ - - :param int sample_every: Frequency for sampling. - :raises ValueError: If `sample_every` is not an integer. - - Example: - >>> r3_callback = R3Refinement(sample_every=5) - """ - raise NotImplementedError( - "R3Refinement callback is being refactored in the pina " - f"{importlib.metadata.metadata('pina-mathlab')['Version']} " - "version. Please use version 0.1 if R3Refinement is required." - ) - - # super().__init__() - - # # sample every - # check_consistency(sample_every, int) - # self._sample_every = sample_every - # self._const_pts = None - - # def _compute_residual(self, trainer): - # """ - # Computes the residuals for a PINN object. - - # :return: the total loss, and pointwise loss. - # :rtype: tuple - # """ - - # # extract the solver and device from trainer - # solver = trainer.solver - # device = trainer._accelerator_connector._accelerator_flag - # precision = trainer.precision - # if precision == "64-true": - # precision = torch.float64 - # elif precision == "32-true": - # precision = torch.float32 - # else: - # raise RuntimeError( - # "Currently R3Refinement is only implemented " - # "for precision '32-true' and '64-true', set " - # "Trainer precision to match one of the " - # "available precisions." - # ) - - # # compute residual - # res_loss = {} - # tot_loss = [] - # for location in self._sampling_locations: - # condition = solver.problem.conditions[location] - # pts = solver.problem.input_pts[location] - # # send points to correct device - # pts = pts.to(device=device, dtype=precision) - # pts = pts.requires_grad_(True) - # pts.retain_grad() - # # PINN loss: equation evaluated only for sampling locations - # target = condition.equation.residual(pts, solver.forward(pts)) - # res_loss[location] = torch.abs(target).as_subclass(torch.Tensor) - # tot_loss.append(torch.abs(target)) - - # print(tot_loss) - - # return torch.vstack(tot_loss), res_loss - - # def _r3_routine(self, trainer): - # """ - # R3 refinement main routine. - - # :param Trainer trainer: PINA Trainer. - # """ - # # compute residual (all device possible) - # tot_loss, res_loss = self._compute_residual(trainer) - # tot_loss = tot_loss.as_subclass(torch.Tensor) - - # # !!!!!! From now everything is performed on CPU !!!!!! - - # # average loss - # avg = (tot_loss.mean()).to("cpu") - # old_pts = {} # points to be retained - # for location in self._sampling_locations: - # pts = trainer._model.problem.input_pts[location] - # labels = pts.labels - # pts = pts.cpu().detach().as_subclass(torch.Tensor) - # residuals = res_loss[location].cpu() - # mask = (residuals > avg).flatten() - # if any(mask): # append residuals greater than average - # pts = (pts[mask]).as_subclass(LabelTensor) - # pts.labels = labels - # old_pts[location] = pts - # numb_pts = self._const_pts[location] - len(old_pts[location]) - # # sample new points - # trainer._model.problem.discretise_domain( - # numb_pts, "random", locations=[location] - # ) - - # else: # if no res greater than average, samples all uniformly - # numb_pts = self._const_pts[location] - # # sample new points - # trainer._model.problem.discretise_domain( - # numb_pts, "random", locations=[location] - # ) - # # adding previous population points - # trainer._model.problem.add_points(old_pts) - - # # update dataloader - # trainer._create_or_update_loader() - - # def on_train_start(self, trainer, _): - # """ - # Callback function called at the start of training. - - # This method extracts the locations for sampling from the problem - # conditions and calculates the total population. - - # :param trainer: The trainer object managing the training process. - # :type trainer: pytorch_lightning.Trainer - # :param _: Placeholder argument (not used). - - # :return: None - # :rtype: None - # """ - # # extract locations for sampling - # problem = trainer.solver.problem - # locations = [] - # for condition_name in problem.conditions: - # condition = problem.conditions[condition_name] - # if hasattr(condition, "location"): - # locations.append(condition_name) - # self._sampling_locations = locations - - # # extract total population - # const_pts = {} # for each location, store the pts to keep constant - # for location in self._sampling_locations: - # pts = trainer._model.problem.input_pts[location] - # const_pts[location] = len(pts) - # self._const_pts = const_pts - - # def on_train_epoch_end(self, trainer, __): - # """ - # Callback function called at the end of each training epoch. - - # This method triggers the R3 routine for refinement if the current - # epoch is a multiple of `_sample_every`. - - # :param trainer: The trainer object managing the training process. - # :type trainer: pytorch_lightning.Trainer - # :param __: Placeholder argument (not used). - - # :return: None - # :rtype: None - # """ - # if trainer.current_epoch % self._sample_every == 0: - # self._r3_routine(trainer) diff --git a/pina/callback/refinement/__init__.py b/pina/callback/refinement/__init__.py new file mode 100644 index 000000000..396fcabaa --- /dev/null +++ b/pina/callback/refinement/__init__.py @@ -0,0 +1,11 @@ +""" +Module for Pina Refinement callbacks. +""" + +__all__ = [ + "RefinementInterface", + "R3Refinement", +] + +from .refinement_interface import RefinementInterface +from .r3_refinement import R3Refinement diff --git a/pina/callback/refinement/r3_refinement.py b/pina/callback/refinement/r3_refinement.py new file mode 100644 index 000000000..c90b2953e --- /dev/null +++ b/pina/callback/refinement/r3_refinement.py @@ -0,0 +1,88 @@ +"""Module for the R3Refinement callback.""" + +import torch +from torch import nn +from torch.nn.modules.loss import _Loss +from .refinement_interface import RefinementInterface +from ...label_tensor import LabelTensor +from ...utils import check_consistency +from ...loss import LossInterface + + +class R3Refinement(RefinementInterface): + """ + PINA Implementation of an R3 Refinement Callback. + """ + + def __init__( + self, sample_every, residual_loss=nn.L1Loss, condition_to_update=None + ): + """ + This callback implements the R3 (Retain-Resample-Release) routine for + sampling new points based on adaptive search. + The algorithm incrementally accumulates collocation points in regions + of high PDE residuals, and releases those with low residuals. + Points are sampled uniformly in all regions where sampling is needed. + + .. seealso:: + + Original Reference: Daw, Arka, et al. *Mitigating Propagation + Failures in Physics-informed Neural Networks + using Retain-Resample-Release (R3) Sampling. (2023)*. + DOI: `10.48550/arXiv.2207.02338 + `_ + + :param int sample_every: Frequency for sampling. + :param loss: Loss function + :type loss: LossInterface | ~torch.nn.modules.loss._Loss + :param condition_to_update: The conditions to update during the + refinement process. If None, all conditions with a conditions will + be updated. Default is None. + :type condition_to_update: list(str) | tuple(str) | str + :raises ValueError: If the condition_to_update is not a string or + iterable of strings. + :raises TypeError: If the residual_loss is not a subclass of + torch.nn.Module. + + + Example: + >>> r3_callback = R3Refinement(sample_every=5) + """ + super().__init__(sample_every, condition_to_update) + # check consistency loss + check_consistency(residual_loss, (LossInterface, _Loss), subclass=True) + self.loss_fn = residual_loss(reduction="none") + + def sample(self, current_points, condition_name, solver): + """ + Sample new points based on the R3 refinement strategy. + + :param current_points: Current points in the domain. + :param condition_name: Name of the condition to update. + :param PINNInterface solver: The solver object. + :return: New points sampled based on the R3 strategy. + :rtype: LabelTensor + """ + # Compute residuals for the given condition (average over fields) + condition = solver.problem.conditions[condition_name] + target = solver.compute_residual( + current_points.requires_grad_(True), condition.equation + ) + residuals = self.loss_fn(target, torch.zeros_like(target)).mean( + dim=tuple(range(1, target.ndim)) + ) + + # Prepare new points + labels = current_points.labels + domain_name = solver.problem.conditions[condition_name].domain + domain = solver.problem.domains[domain_name] + num_old_points = self.initial_population_size[condition_name] + mask = (residuals > residuals.mean()).flatten() + + if mask.any(): # Use high-residual points + pts = current_points[mask] + pts.labels = labels + retain_pts = len(pts) + samples = domain.sample(num_old_points - retain_pts, "random") + return LabelTensor.cat([pts, samples]) + return domain.sample(num_old_points, "random") diff --git a/pina/callback/refinement/refinement_interface.py b/pina/callback/refinement/refinement_interface.py new file mode 100644 index 000000000..adc6e4e7c --- /dev/null +++ b/pina/callback/refinement/refinement_interface.py @@ -0,0 +1,155 @@ +""" +RefinementInterface class for handling the refinement of points in a neural +network training process. +""" + +from abc import ABCMeta, abstractmethod +from lightning.pytorch import Callback +from ...utils import check_consistency +from ...solver.physics_informed_solver import PINNInterface + + +class RefinementInterface(Callback, metaclass=ABCMeta): + """ + Interface class of Refinement approaches. + """ + + def __init__(self, sample_every, condition_to_update=None): + """ + Initializes the RefinementInterface. + + :param int sample_every: The number of epochs between each refinement. + :param condition_to_update: The conditions to update during the + refinement process. If None, all conditions with a domain will be + updated. Default is None. + :type condition_to_update: list(str) | tuple(str) | str + + """ + # check consistency of the input + check_consistency(sample_every, int) + if condition_to_update is not None: + if isinstance(condition_to_update, str): + condition_to_update = [condition_to_update] + if not isinstance(condition_to_update, (list, tuple)): + raise ValueError( + "'condition_to_update' must be iter of strings." + ) + check_consistency(condition_to_update, str) + # store + self.sample_every = sample_every + self._condition_to_update = condition_to_update + self._dataset = None + self._initial_population_size = None + + def on_train_start(self, trainer, solver): + """ + Called when the training begins. It initializes the conditions and + dataset. + + :param ~lightning.pytorch.trainer.trainer.Trainer trainer: The trainer + object. + :param ~pina.solver.solver.SolverInterface solver: The solver + object associated with the trainer. + :raises RuntimeError: If the solver is not a PINNInterface. + :raises RuntimeError: If the conditions do not have a domain to sample + from. + """ + # check we have valid conditions names + if self._condition_to_update is None: + self._condition_to_update = [ + name + for name, cond in solver.problem.conditions.items() + if hasattr(cond, "domain") + ] + + for cond in self._condition_to_update: + if cond not in solver.problem.conditions: + raise RuntimeError( + f"Condition '{cond}' not found in " + f"{list(solver.problem.conditions.keys())}." + ) + if not hasattr(solver.problem.conditions[cond], "domain"): + raise RuntimeError( + f"Condition '{cond}' does not contain a domain to " + "sample from." + ) + # check solver + if not isinstance(solver, PINNInterface): + raise RuntimeError( + "Refinment strategies are currently implemented only " + "for physics informed based solvers. Please use a Solver " + "inheriting from 'PINNInterface'." + ) + # store dataset + self._dataset = trainer.datamodule.train_dataset + # compute initial population size + self._initial_population_size = self._compute_population_size( + self._condition_to_update + ) + return super().on_train_epoch_start(trainer, solver) + + def on_train_epoch_end(self, trainer, solver): + """ + Performs the refinement at the end of each training epoch (if needed). + + :param ~lightning.pytorch.trainer.trainer.Trainer: The trainer object. + :param PINNInterface solver: The solver object. + """ + if (trainer.current_epoch % self.sample_every == 0) and ( + trainer.current_epoch != 0 + ): + self._update_points(solver) + return super().on_train_epoch_end(trainer, solver) + + @abstractmethod + def sample(self, current_points, condition_name, solver): + """ + Samples new points based on the condition. + + :param current_points: Current points in the domain. + :param condition_name: Name of the condition to update. + :param PINNInterface solver: The solver object. + :return: New points sampled based on the R3 strategy. + :rtype: LabelTensor + """ + + @property + def dataset(self): + """ + Returns the dataset for training. + """ + return self._dataset + + @property + def initial_population_size(self): + """ + Returns the dataset for training size. + """ + return self._initial_population_size + + def _update_points(self, solver): + """ + Performs the refinement of the points. + + :param PINNInterface solver: The solver object. + """ + new_points = {} + for name in self._condition_to_update: + current_points = self.dataset.conditions_dict[name]["input"] + new_points[name] = { + "input": self.sample(current_points, name, solver) + } + self.dataset.update_data(new_points) + + def _compute_population_size(self, conditions): + """ + Computes the number of points in the dataset for each condition. + + :param conditions: List of conditions to compute the number of points. + :return: Dictionary with the population size for each condition. + :rtype: dict + """ + return { + cond: len(self.dataset.conditions_dict[cond]["input"]) + for cond in conditions + } diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 8d58be4c3..386c3c53c 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -239,6 +239,22 @@ def input(self): """ return {k: v["input"] for k, v in self.conditions_dict.items()} + def update_data(self, new_conditions_dict): + """ + Update the dataset with new data. + This method is used to update the dataset with new data. It replaces + the current data with the new data provided in the new_conditions_dict + parameter. + + :param dict new_conditions_dict: Dictionary containing the new data. + :return: None + """ + for condition, data in new_conditions_dict.items(): + if condition in self.conditions_dict: + self.conditions_dict[condition].update(data) + else: + self.conditions_dict[condition] = data + class PinaGraphDataset(PinaDataset): """ diff --git a/tests/test_callback/test_adaptive_refinement_callback.py b/tests/test_callback/test_adaptive_refinement_callback.py index dcabef13a..7866c7f7b 100644 --- a/tests/test_callback/test_adaptive_refinement_callback.py +++ b/tests/test_callback/test_adaptive_refinement_callback.py @@ -1,45 +1,58 @@ +import pytest + +from torch.nn import MSELoss + from pina.solver import PINN from pina.trainer import Trainer from pina.model import FeedForward from pina.problem.zoo import Poisson2DSquareProblem as Poisson -from pina.callback import R3Refinement +from pina.callback.refinement import R3Refinement # make the problem poisson_problem = Poisson() -boundaries = ["g1", "g2", "g3", "g4"] -n = 10 -poisson_problem.discretise_domain(n, "grid", domains=boundaries) -poisson_problem.discretise_domain(n, "grid", domains="D") +poisson_problem.discretise_domain(10, "grid", domains=["g1", "g2", "g3", "g4"]) +poisson_problem.discretise_domain(10, "grid", domains="D") model = FeedForward( len(poisson_problem.input_variables), len(poisson_problem.output_variables) ) - -# make the solver solver = PINN(problem=poisson_problem, model=model) -# def test_r3constructor(): -# R3Refinement(sample_every=10) - - -# def test_r3refinment_routine(): -# # make the trainer -# trainer = Trainer(solver=solver, -# callback=[R3Refinement(sample_every=1)], -# accelerator='cpu', -# max_epochs=5) -# trainer.train() - -# def test_r3refinment_routine(): -# model = FeedForward(len(poisson_problem.input_variables), -# len(poisson_problem.output_variables)) -# solver = PINN(problem=poisson_problem, model=model) -# trainer = Trainer(solver=solver, -# callback=[R3Refinement(sample_every=1)], -# accelerator='cpu', -# max_epochs=5) -# before_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()} -# trainer.train() -# after_n_points = {loc : len(pts) for loc, pts in trainer.solver.problem.input_pts.items()} -# assert before_n_points == after_n_points +def test_constructor(): + # good constructor + R3Refinement(sample_every=10) + R3Refinement(sample_every=10, residual_loss=MSELoss) + R3Refinement(sample_every=10, condition_to_update=["D"]) + # wrong constructor + with pytest.raises(ValueError): + R3Refinement(sample_every="str") + with pytest.raises(ValueError): + R3Refinement(sample_every=10, condition_to_update=3) + + +@pytest.mark.parametrize( + "condition_to_update", [["D", "g1"], ["D", "g1", "g2", "g3", "g4"]] +) +def test_sample(condition_to_update): + trainer = Trainer( + solver=solver, + callbacks=[ + R3Refinement( + sample_every=1, condition_to_update=condition_to_update + ) + ], + accelerator="cpu", + max_epochs=5, + ) + before_n_points = { + loc: len(trainer.solver.problem.input_pts[loc]) + for loc in condition_to_update + } + trainer.train() + after_n_points = { + loc: len(trainer.data_module.train_dataset.input[loc]) + for loc in condition_to_update + } + assert before_n_points == trainer.callbacks[0].initial_population_size + assert before_n_points == after_n_points