@@ -34,148 +34,151 @@ def __init__(self, sample_every):
3434 Example:
3535 >>> r3_callback = R3Refinement(sample_every=5)
3636 """
37- raise NotImplementedError (
38- "R3Refinement callback is being refactored in the pina "
39- f"{ importlib .metadata .metadata ('pina-mathlab' )['Version' ]} "
40- "version. Please use version 0.1 if R3Refinement is required."
41- )
42-
43- # super().__init__()
44-
45- # # sample every
46- # check_consistency(sample_every, int)
47- # self._sample_every = sample_every
48- # self._const_pts = None
49-
50- # def _compute_residual(self, trainer):
51- # """
52- # Computes the residuals for a PINN object.
53-
54- # :return: the total loss, and pointwise loss.
55- # :rtype: tuple
56- # """
57-
58- # # extract the solver and device from trainer
59- # solver = trainer.solver
60- # device = trainer._accelerator_connector._accelerator_flag
61- # precision = trainer.precision
62- # if precision == "64-true":
63- # precision = torch.float64
64- # elif precision == "32-true":
65- # precision = torch.float32
66- # else:
67- # raise RuntimeError(
68- # "Currently R3Refinement is only implemented "
69- # "for precision '32-true' and '64-true', set "
70- # "Trainer precision to match one of the "
71- # "available precisions."
72- # )
73-
74- # # compute residual
75- # res_loss = {}
76- # tot_loss = []
77- # for location in self._sampling_locations:
78- # condition = solver.problem.conditions[location]
79- # pts = solver.problem.input_pts[location]
80- # # send points to correct device
81- # pts = pts.to(device=device, dtype=precision)
82- # pts = pts.requires_grad_(True)
83- # pts.retain_grad()
84- # # PINN loss: equation evaluated only for sampling locations
85- # target = condition.equation.residual(pts, solver.forward(pts))
86- # res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
87- # tot_loss.append(torch.abs(target))
88-
89- # print(tot_loss)
90-
91- # return torch.vstack(tot_loss), res_loss
92-
93- # def _r3_routine(self, trainer):
94- # """
95- # R3 refinement main routine.
96-
97- # :param Trainer trainer: PINA Trainer.
98- # """
99- # # compute residual (all device possible)
100- # tot_loss, res_loss = self._compute_residual(trainer)
101- # tot_loss = tot_loss.as_subclass(torch.Tensor)
102-
103- # # !!!!!! From now everything is performed on CPU !!!!!!
104-
105- # # average loss
106- # avg = (tot_loss.mean()).to("cpu")
107- # old_pts = {} # points to be retained
108- # for location in self._sampling_locations:
109- # pts = trainer._model.problem.input_pts[location]
110- # labels = pts.labels
111- # pts = pts.cpu().detach().as_subclass(torch.Tensor)
112- # residuals = res_loss[location].cpu()
113- # mask = (residuals > avg).flatten()
114- # if any(mask): # append residuals greater than average
115- # pts = (pts[mask]).as_subclass(LabelTensor)
116- # pts.labels = labels
117- # old_pts[location] = pts
118- # numb_pts = self._const_pts[location] - len(old_pts[location])
119- # # sample new points
120- # trainer._model.problem.discretise_domain(
121- # numb_pts, "random", locations=[location]
122- # )
123-
124- # else: # if no res greater than average, samples all uniformly
125- # numb_pts = self._const_pts[location]
126- # # sample new points
127- # trainer._model.problem.discretise_domain(
128- # numb_pts, "random", locations=[location]
129- # )
130- # # adding previous population points
131- # trainer._model.problem.add_points(old_pts)
132-
133- # # update dataloader
134- # trainer._create_or_update_loader()
135-
136- # def on_train_start(self, trainer, _):
137- # """
138- # Callback function called at the start of training.
139-
140- # This method extracts the locations for sampling from the problem
141- # conditions and calculates the total population.
142-
143- # :param trainer: The trainer object managing the training process.
144- # :type trainer: pytorch_lightning.Trainer
145- # :param _: Placeholder argument (not used).
146-
147- # :return: None
148- # :rtype: None
149- # """
150- # # extract locations for sampling
151- # problem = trainer.solver.problem
152- # locations = []
153- # for condition_name in problem.conditions:
154- # condition = problem.conditions[condition_name]
155- # if hasattr(condition, "location"):
156- # locations.append(condition_name)
157- # self._sampling_locations = locations
158-
159- # # extract total population
160- # const_pts = {} # for each location, store the pts to keep constant
161- # for location in self._sampling_locations:
162- # pts = trainer._model.problem.input_pts[location]
163- # const_pts[location] = len(pts)
164- # self._const_pts = const_pts
165-
166- # def on_train_epoch_end(self, trainer, __):
167- # """
168- # Callback function called at the end of each training epoch.
169-
170- # This method triggers the R3 routine for refinement if the current
171- # epoch is a multiple of `_sample_every`.
172-
173- # :param trainer: The trainer object managing the training process.
174- # :type trainer: pytorch_lightning.Trainer
175- # :param __: Placeholder argument (not used).
176-
177- # :return: None
178- # :rtype: None
179- # """
180- # if trainer.current_epoch % self._sample_every == 0:
181- # self._r3_routine(trainer)
37+
38+ super ().__init__ ()
39+
40+ # sample every
41+ check_consistency (sample_every , int )
42+ self ._sample_every = sample_every
43+ self ._const_pts = None
44+ self ._domains = None
45+
46+ def _compute_residual (self , trainer ):
47+ """
48+ Computes the residuals for a PINN object.
49+
50+ :return: the total loss, and pointwise loss.
51+ :rtype: tuple
52+ """
53+
54+ # extract the solver and device from trainer
55+ solver = trainer .solver
56+ device = trainer ._accelerator_connector ._accelerator_flag
57+ precision = trainer .precision
58+ if precision == "64-true" :
59+ precision = torch .float64
60+ elif precision == "32-true" :
61+ precision = torch .float32
62+ else :
63+ raise RuntimeError (
64+ "Currently R3Refinement is only implemented "
65+ "for precision '32-true' and '64-true', set "
66+ "Trainer precision to match one of the "
67+ "available precisions."
68+ )
69+
70+ # compute residual
71+ res_loss = {}
72+ tot_loss = []
73+ for condition in self ._conditions :
74+ pts = trainer .datamodule .train_dataset .conditions_dict [condition ][
75+ "input"
76+ ]
77+ equation = solver .problem .conditions [condition ].equation
78+ # send points to correct device
79+ pts = pts .to (device = device , dtype = precision )
80+ pts = pts .requires_grad_ (True )
81+ pts .retain_grad ()
82+ # PINN loss: equation evaluated only for sampling locations
83+ target = equation .residual (pts , solver .forward (pts ))
84+ res_loss [condition ] = torch .abs (target ).as_subclass (torch .Tensor )
85+ tot_loss .append (torch .abs (target ))
86+ return torch .vstack (tot_loss ), res_loss
87+
88+ def _r3_routine (self , trainer ):
89+ """
90+ R3 refinement main routine.
91+
92+ :param Trainer trainer: PINA Trainer.
93+ """
94+ # compute residual (all device possible)
95+ tot_loss , res_loss = self ._compute_residual (trainer )
96+ tot_loss = tot_loss .as_subclass (torch .Tensor )
97+
98+ # !!!!!! From now everything is performed on CPU !!!!!!
99+
100+ # average loss
101+ avg = (tot_loss .mean ()).to ("cpu" )
102+ new_pts = {}
103+
104+ dataset = trainer .datamodule .train_dataset
105+ problem = trainer .solver .problem
106+ for condition in self ._conditions :
107+ pts = dataset .conditions_dict [condition ]["input" ]
108+ domain = problem .conditions [condition ].domain
109+ if not isinstance (domain , str ):
110+ domain = condition
111+ labels = pts .labels
112+ pts = pts .cpu ().detach ().as_subclass (torch .Tensor )
113+ residuals = res_loss [condition ].cpu ()
114+ mask = (residuals > avg ).flatten ()
115+ if any (mask ): # append residuals greater than average
116+ pts = (pts [mask ]).as_subclass (LabelTensor )
117+ pts .labels = labels
118+ numb_pts = self ._const_pts [condition ] - len (pts )
119+ else : # if no res greater than average, samples all uniformly
120+ numb_pts = self ._const_pts [condition ]
121+ pts = None
122+ problem .discretise_domain (numb_pts , "random" , domains = [domain ])
123+ sampled_points = problem .discretised_domains [domain ]
124+ tmp = (
125+ sampled_points
126+ if pts is None
127+ else LabelTensor .cat ([pts , sampled_points ])
128+ )
129+ new_pts [condition ] = {"input" : tmp }
130+ dataset .update_data (new_pts )
131+
132+ def on_train_start (self , trainer , _ ):
133+ """
134+ Callback function called at the start of training.
135+
136+ This method extracts the locations for sampling from the problem
137+ conditions and calculates the total population.
138+
139+ :param trainer: The trainer object managing the training process.
140+ :type trainer: pytorch_lightning.Trainer
141+ :param _: Placeholder argument (not used).
142+
143+ :return: None
144+ :rtype: None
145+ """
146+ problem = trainer .solver .problem
147+ if hasattr (problem , "domains" ):
148+ domains = problem .domains
149+ self ._domains = domains
150+ else :
151+ self ._domains = {}
152+ for name , data in problem .conditions .items ():
153+ if hasattr (data , "domain" ):
154+ self ._domains [name ] = data .domain
155+ self ._conditions = []
156+ for name , data in problem .conditions .items ():
157+ if hasattr (data , "domain" ):
158+ self ._conditions .append (name )
159+
160+ # extract total population
161+ const_pts = {} # for each location, store the pts to keep constant
162+ for condition in self ._conditions :
163+ pts = trainer .datamodule .train_dataset .conditions_dict [condition ][
164+ "input"
165+ ]
166+ const_pts [condition ] = len (pts )
167+ self ._const_pts = const_pts
168+
169+ def on_train_epoch_end (self , trainer , __ ):
170+ """
171+ Callback function called at the end of each training epoch.
172+
173+ This method triggers the R3 routine for refinement if the current
174+ epoch is a multiple of `_sample_every`.
175+
176+ :param trainer: The trainer object managing the training process.
177+ :type trainer: pytorch_lightning.Trainer
178+ :param __: Placeholder argument (not used).
179+
180+ :return: None
181+ :rtype: None
182+ """
183+ if trainer .current_epoch % self ._sample_every == 0 :
184+ self ._r3_routine (trainer )
0 commit comments