@@ -34,148 +34,151 @@ def __init__(self, sample_every):
34
34
Example:
35
35
>>> r3_callback = R3Refinement(sample_every=5)
36
36
"""
37
- raise NotImplementedError (
38
- "R3Refinement callback is being refactored in the pina "
39
- f"{ importlib .metadata .metadata ('pina-mathlab' )['Version' ]} "
40
- "version. Please use version 0.1 if R3Refinement is required."
41
- )
42
-
43
- # super().__init__()
44
-
45
- # # sample every
46
- # check_consistency(sample_every, int)
47
- # self._sample_every = sample_every
48
- # self._const_pts = None
49
-
50
- # def _compute_residual(self, trainer):
51
- # """
52
- # Computes the residuals for a PINN object.
53
-
54
- # :return: the total loss, and pointwise loss.
55
- # :rtype: tuple
56
- # """
57
-
58
- # # extract the solver and device from trainer
59
- # solver = trainer.solver
60
- # device = trainer._accelerator_connector._accelerator_flag
61
- # precision = trainer.precision
62
- # if precision == "64-true":
63
- # precision = torch.float64
64
- # elif precision == "32-true":
65
- # precision = torch.float32
66
- # else:
67
- # raise RuntimeError(
68
- # "Currently R3Refinement is only implemented "
69
- # "for precision '32-true' and '64-true', set "
70
- # "Trainer precision to match one of the "
71
- # "available precisions."
72
- # )
73
-
74
- # # compute residual
75
- # res_loss = {}
76
- # tot_loss = []
77
- # for location in self._sampling_locations:
78
- # condition = solver.problem.conditions[location]
79
- # pts = solver.problem.input_pts[location]
80
- # # send points to correct device
81
- # pts = pts.to(device=device, dtype=precision)
82
- # pts = pts.requires_grad_(True)
83
- # pts.retain_grad()
84
- # # PINN loss: equation evaluated only for sampling locations
85
- # target = condition.equation.residual(pts, solver.forward(pts))
86
- # res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
87
- # tot_loss.append(torch.abs(target))
88
-
89
- # print(tot_loss)
90
-
91
- # return torch.vstack(tot_loss), res_loss
92
-
93
- # def _r3_routine(self, trainer):
94
- # """
95
- # R3 refinement main routine.
96
-
97
- # :param Trainer trainer: PINA Trainer.
98
- # """
99
- # # compute residual (all device possible)
100
- # tot_loss, res_loss = self._compute_residual(trainer)
101
- # tot_loss = tot_loss.as_subclass(torch.Tensor)
102
-
103
- # # !!!!!! From now everything is performed on CPU !!!!!!
104
-
105
- # # average loss
106
- # avg = (tot_loss.mean()).to("cpu")
107
- # old_pts = {} # points to be retained
108
- # for location in self._sampling_locations:
109
- # pts = trainer._model.problem.input_pts[location]
110
- # labels = pts.labels
111
- # pts = pts.cpu().detach().as_subclass(torch.Tensor)
112
- # residuals = res_loss[location].cpu()
113
- # mask = (residuals > avg).flatten()
114
- # if any(mask): # append residuals greater than average
115
- # pts = (pts[mask]).as_subclass(LabelTensor)
116
- # pts.labels = labels
117
- # old_pts[location] = pts
118
- # numb_pts = self._const_pts[location] - len(old_pts[location])
119
- # # sample new points
120
- # trainer._model.problem.discretise_domain(
121
- # numb_pts, "random", locations=[location]
122
- # )
123
-
124
- # else: # if no res greater than average, samples all uniformly
125
- # numb_pts = self._const_pts[location]
126
- # # sample new points
127
- # trainer._model.problem.discretise_domain(
128
- # numb_pts, "random", locations=[location]
129
- # )
130
- # # adding previous population points
131
- # trainer._model.problem.add_points(old_pts)
132
-
133
- # # update dataloader
134
- # trainer._create_or_update_loader()
135
-
136
- # def on_train_start(self, trainer, _):
137
- # """
138
- # Callback function called at the start of training.
139
-
140
- # This method extracts the locations for sampling from the problem
141
- # conditions and calculates the total population.
142
-
143
- # :param trainer: The trainer object managing the training process.
144
- # :type trainer: pytorch_lightning.Trainer
145
- # :param _: Placeholder argument (not used).
146
-
147
- # :return: None
148
- # :rtype: None
149
- # """
150
- # # extract locations for sampling
151
- # problem = trainer.solver.problem
152
- # locations = []
153
- # for condition_name in problem.conditions:
154
- # condition = problem.conditions[condition_name]
155
- # if hasattr(condition, "location"):
156
- # locations.append(condition_name)
157
- # self._sampling_locations = locations
158
-
159
- # # extract total population
160
- # const_pts = {} # for each location, store the pts to keep constant
161
- # for location in self._sampling_locations:
162
- # pts = trainer._model.problem.input_pts[location]
163
- # const_pts[location] = len(pts)
164
- # self._const_pts = const_pts
165
-
166
- # def on_train_epoch_end(self, trainer, __):
167
- # """
168
- # Callback function called at the end of each training epoch.
169
-
170
- # This method triggers the R3 routine for refinement if the current
171
- # epoch is a multiple of `_sample_every`.
172
-
173
- # :param trainer: The trainer object managing the training process.
174
- # :type trainer: pytorch_lightning.Trainer
175
- # :param __: Placeholder argument (not used).
176
-
177
- # :return: None
178
- # :rtype: None
179
- # """
180
- # if trainer.current_epoch % self._sample_every == 0:
181
- # self._r3_routine(trainer)
37
+
38
+ super ().__init__ ()
39
+
40
+ # sample every
41
+ check_consistency (sample_every , int )
42
+ self ._sample_every = sample_every
43
+ self ._const_pts = None
44
+ self ._domains = None
45
+
46
+ def _compute_residual (self , trainer ):
47
+ """
48
+ Computes the residuals for a PINN object.
49
+
50
+ :return: the total loss, and pointwise loss.
51
+ :rtype: tuple
52
+ """
53
+
54
+ # extract the solver and device from trainer
55
+ solver = trainer .solver
56
+ device = trainer ._accelerator_connector ._accelerator_flag
57
+ precision = trainer .precision
58
+ if precision == "64-true" :
59
+ precision = torch .float64
60
+ elif precision == "32-true" :
61
+ precision = torch .float32
62
+ else :
63
+ raise RuntimeError (
64
+ "Currently R3Refinement is only implemented "
65
+ "for precision '32-true' and '64-true', set "
66
+ "Trainer precision to match one of the "
67
+ "available precisions."
68
+ )
69
+
70
+ # compute residual
71
+ res_loss = {}
72
+ tot_loss = []
73
+ for condition in self ._conditions :
74
+ pts = trainer .datamodule .train_dataset .conditions_dict [condition ][
75
+ "input"
76
+ ]
77
+ equation = solver .problem .conditions [condition ].equation
78
+ # send points to correct device
79
+ pts = pts .to (device = device , dtype = precision )
80
+ pts = pts .requires_grad_ (True )
81
+ pts .retain_grad ()
82
+ # PINN loss: equation evaluated only for sampling locations
83
+ target = equation .residual (pts , solver .forward (pts ))
84
+ res_loss [condition ] = torch .abs (target ).as_subclass (torch .Tensor )
85
+ tot_loss .append (torch .abs (target ))
86
+ return torch .vstack (tot_loss ), res_loss
87
+
88
+ def _r3_routine (self , trainer ):
89
+ """
90
+ R3 refinement main routine.
91
+
92
+ :param Trainer trainer: PINA Trainer.
93
+ """
94
+ # compute residual (all device possible)
95
+ tot_loss , res_loss = self ._compute_residual (trainer )
96
+ tot_loss = tot_loss .as_subclass (torch .Tensor )
97
+
98
+ # !!!!!! From now everything is performed on CPU !!!!!!
99
+
100
+ # average loss
101
+ avg = (tot_loss .mean ()).to ("cpu" )
102
+ new_pts = {}
103
+
104
+ dataset = trainer .datamodule .train_dataset
105
+ problem = trainer .solver .problem
106
+ for condition in self ._conditions :
107
+ pts = dataset .conditions_dict [condition ]["input" ]
108
+ domain = problem .conditions [condition ].domain
109
+ if not isinstance (domain , str ):
110
+ domain = condition
111
+ labels = pts .labels
112
+ pts = pts .cpu ().detach ().as_subclass (torch .Tensor )
113
+ residuals = res_loss [condition ].cpu ()
114
+ mask = (residuals > avg ).flatten ()
115
+ if any (mask ): # append residuals greater than average
116
+ pts = (pts [mask ]).as_subclass (LabelTensor )
117
+ pts .labels = labels
118
+ numb_pts = self ._const_pts [condition ] - len (pts )
119
+ else : # if no res greater than average, samples all uniformly
120
+ numb_pts = self ._const_pts [condition ]
121
+ pts = None
122
+ problem .discretise_domain (numb_pts , "random" , domains = [domain ])
123
+ sampled_points = problem .discretised_domains [domain ]
124
+ tmp = (
125
+ sampled_points
126
+ if pts is None
127
+ else LabelTensor .cat ([pts , sampled_points ])
128
+ )
129
+ new_pts [condition ] = {"input" : tmp }
130
+ dataset .update_data (new_pts )
131
+
132
+ def on_train_start (self , trainer , _ ):
133
+ """
134
+ Callback function called at the start of training.
135
+
136
+ This method extracts the locations for sampling from the problem
137
+ conditions and calculates the total population.
138
+
139
+ :param trainer: The trainer object managing the training process.
140
+ :type trainer: pytorch_lightning.Trainer
141
+ :param _: Placeholder argument (not used).
142
+
143
+ :return: None
144
+ :rtype: None
145
+ """
146
+ problem = trainer .solver .problem
147
+ if hasattr (problem , "domains" ):
148
+ domains = problem .domains
149
+ self ._domains = domains
150
+ else :
151
+ self ._domains = {}
152
+ for name , data in problem .conditions .items ():
153
+ if hasattr (data , "domain" ):
154
+ self ._domains [name ] = data .domain
155
+ self ._conditions = []
156
+ for name , data in problem .conditions .items ():
157
+ if hasattr (data , "domain" ):
158
+ self ._conditions .append (name )
159
+
160
+ # extract total population
161
+ const_pts = {} # for each location, store the pts to keep constant
162
+ for condition in self ._conditions :
163
+ pts = trainer .datamodule .train_dataset .conditions_dict [condition ][
164
+ "input"
165
+ ]
166
+ const_pts [condition ] = len (pts )
167
+ self ._const_pts = const_pts
168
+
169
+ def on_train_epoch_end (self , trainer , __ ):
170
+ """
171
+ Callback function called at the end of each training epoch.
172
+
173
+ This method triggers the R3 routine for refinement if the current
174
+ epoch is a multiple of `_sample_every`.
175
+
176
+ :param trainer: The trainer object managing the training process.
177
+ :type trainer: pytorch_lightning.Trainer
178
+ :param __: Placeholder argument (not used).
179
+
180
+ :return: None
181
+ :rtype: None
182
+ """
183
+ if trainer .current_epoch % self ._sample_every == 0 :
184
+ self ._r3_routine (trainer )
0 commit comments