|
3 | 3 | network training process.
|
4 | 4 | """
|
5 | 5 |
|
6 |
| -import torch |
7 |
| -from abc import ABCMeta |
| 6 | +from abc import ABCMeta, abstractmethod |
8 | 7 | from lightning.pytorch import Callback
|
9 |
| -from torch_geometric.data.feature_store import abstractmethod |
10 |
| -from torch_geometric.nn.conv import point_transformer_conv |
11 |
| -from ...condition.domain_equation_condition import DomainEquationCondition |
| 8 | +from ...utils import check_consistency |
| 9 | +from ...solver.physics_informed_solver import PINNInterface |
12 | 10 |
|
13 | 11 |
|
14 | 12 | class RefinementInterface(Callback, metaclass=ABCMeta):
|
15 | 13 | """
|
16 |
| - Interface class of Refinement |
| 14 | + Interface class of Refinement approaches. |
17 | 15 | """
|
18 | 16 |
|
19 |
| - def __init__(self, sample_every): |
| 17 | + def __init__(self, sample_every, condition_to_update=None): |
20 | 18 | """
|
21 | 19 | Initializes the RefinementInterface.
|
22 | 20 |
|
23 | 21 | :param int sample_every: The number of epochs between each refinement.
|
24 | 22 | """
|
| 23 | + # check consistency of the input |
| 24 | + check_consistency(sample_every, int) |
| 25 | + if condition_to_update is not None: |
| 26 | + if not isinstance(condition_to_update, (list, tuple)): |
| 27 | + raise ValueError( |
| 28 | + "'condition_to_update' must be iter of strings." |
| 29 | + ) |
| 30 | + check_consistency(condition_to_update, str) |
| 31 | + # store |
25 | 32 | self.sample_every = sample_every
|
26 |
| - self.conditions = None |
27 |
| - self.dataset = None |
28 |
| - self.solver = None |
| 33 | + self._condition_to_update = condition_to_update |
| 34 | + self._dataset = None |
| 35 | + self._initial_population_size = None |
29 | 36 |
|
30 |
| - def on_train_start(self, trainer, _): |
| 37 | + def on_train_start(self, trainer, solver): |
31 | 38 | """
|
32 | 39 | Called when the training begins. It initializes the conditions and
|
33 | 40 | dataset.
|
34 | 41 |
|
35 | 42 | :param lightning.pytorch.Trainer trainer: The trainer object.
|
36 | 43 | :param _: Unused argument.
|
37 | 44 | """
|
38 |
| - self.problem = trainer.solver.problem |
39 |
| - self.solver = trainer.solver |
40 |
| - self.conditions = {} |
41 |
| - for name, cond in self.problem.conditions.items(): |
42 |
| - if isinstance(cond, DomainEquationCondition): |
43 |
| - self.conditions[name] = cond |
44 |
| - self.dataset = trainer.datamodule.train_dataset |
| 45 | + # check we have valid conditions names |
| 46 | + if self._condition_to_update is None: |
| 47 | + self._condition_to_update = list(solver.problem.conditions.keys()) |
45 | 48 |
|
46 |
| - @property |
47 |
| - def points(self): |
48 |
| - """ |
49 |
| - Returns the points of the dataset. |
50 |
| - """ |
51 |
| - return self.dataset.conditions_dict |
| 49 | + for cond in self._condition_to_update: |
| 50 | + if cond not in solver.problem.conditions: |
| 51 | + raise RuntimeError( |
| 52 | + f"Condition '{cond}' not found in " |
| 53 | + f"{list(solver.problem.conditions.keys())}." |
| 54 | + ) |
| 55 | + if not hasattr(solver.problem.conditions[cond], "domain"): |
| 56 | + raise RuntimeError( |
| 57 | + f"Condition '{cond}' does not contain a domain to " |
| 58 | + "sample from." |
| 59 | + ) |
| 60 | + # check solver |
| 61 | + if not isinstance(solver, PINNInterface): |
| 62 | + raise RuntimeError( |
| 63 | + f"Refinment strategies are currently implemented only " |
| 64 | + "for physics informed based solvers. Please use a Solver " |
| 65 | + "inheriting from 'PINNInterface'." |
| 66 | + ) |
| 67 | + # store dataset |
| 68 | + self._dataset = trainer.datamodule.train_dataset |
| 69 | + # compute initial population size |
| 70 | + self._initial_population_size = self._compute_population_size( |
| 71 | + self._condition_to_update |
| 72 | + ) |
| 73 | + return super().on_train_epoch_start(trainer, solver) |
52 | 74 |
|
53 |
| - def on_train_epoch_end(self, trainer, _): |
| 75 | + def on_train_epoch_end(self, trainer, solver): |
54 | 76 | """
|
55 | 77 | Performs the refinement at the end of each training epoch (if needed).
|
56 | 78 |
|
57 | 79 | :param lightning.pytorch.Trainer trainer: The trainer object.
|
58 | 80 | :param _: Unused argument.
|
59 | 81 | """
|
60 | 82 | if trainer.current_epoch % self.sample_every == 0:
|
61 |
| - self.update() |
| 83 | + self._update_points(solver) |
| 84 | + return super().on_train_epoch_end(trainer, solver) |
62 | 85 |
|
63 |
| - def update(self): |
| 86 | + @abstractmethod |
| 87 | + def sample(self, current_points, condition_name, solver): |
64 | 88 | """
|
65 |
| - Performs the refinement of the points. |
| 89 | + Samples new points based on the condition. |
66 | 90 | """
|
67 |
| - new_points = {} |
68 |
| - for name, condition in self.conditions.items(): |
69 |
| - new_points[name] = {"input": self.sample(name, condition)} |
70 |
| - self.dataset.update_data(new_points) |
| 91 | + pass |
71 | 92 |
|
72 |
| - def per_point_residual(self, conditions_name=None): |
| 93 | + @property |
| 94 | + def dataset(self): |
73 | 95 | """
|
74 |
| - Computes the residuals for a PINN object. |
75 |
| -
|
76 |
| - :return: the total loss, and pointwise loss. |
77 |
| - :rtype: tuple |
| 96 | + Returns the dataset for training. |
78 | 97 | """
|
79 |
| - # compute residual |
80 |
| - res_loss = {} |
81 |
| - tot_loss = [] |
82 |
| - points = self.points |
83 |
| - if conditions_name is None: |
84 |
| - conditions_name = list(self.conditions.keys()) |
85 |
| - for name in conditions_name: |
86 |
| - cond = self.conditions[name] |
87 |
| - cond_points = points[name]["input"] |
88 |
| - target = self._compute_residual(cond_points, cond.equation) |
89 |
| - res_loss[name] = torch.abs(target).as_subclass(torch.Tensor) |
90 |
| - tot_loss.append(torch.abs(target)) |
91 |
| - return torch.vstack(tot_loss).tensor.mean(), res_loss |
| 98 | + return self._dataset |
92 | 99 |
|
93 |
| - def _compute_residual(self, pts, equation): |
94 |
| - pts.requires_grad_(True) |
95 |
| - pts.retain_grad() |
96 |
| - return equation.residual(pts, self.solver.forward(pts)) |
| 100 | + @property |
| 101 | + def initial_population_size(self): |
| 102 | + """ |
| 103 | + Returns the dataset for training. |
| 104 | + """ |
| 105 | + return self._initial_population_size |
97 | 106 |
|
98 |
| - @abstractmethod |
99 |
| - def sample(self, condition): |
| 107 | + def _update_points(self, solver): |
100 | 108 | """
|
101 |
| - Samples new points based on the condition. |
| 109 | + Performs the refinement of the points. |
102 | 110 | """
|
103 |
| - pass |
| 111 | + new_points = {} |
| 112 | + for name in self._condition_to_update: |
| 113 | + current_points = self.dataset.conditions_dict[name]["input"] |
| 114 | + new_points[name] = { |
| 115 | + "input": self.sample(current_points, name, solver) |
| 116 | + } |
| 117 | + self.dataset.update_data(new_points) |
| 118 | + |
| 119 | + def _compute_population_size(self, conditions): |
| 120 | + return { |
| 121 | + cond: len(self.dataset.conditions_dict[cond]["input"]) |
| 122 | + for cond in conditions |
| 123 | + } |
0 commit comments