1- """Trainer module ."""
1+ """Module for the Trainer ."""
22
33import sys
44import torch
1010
1111class Trainer (lightning .pytorch .Trainer ):
1212 """
13- PINA custom Trainer class which allows to customize standard Lightning
14- Trainer class for PINNs training.
13+ PINA custom Trainer class to extend the standard Lightning functionality.
14+
15+ This class enables specific features or behaviors required by the PINA
16+ framework. It modifies the standard :class:`lightning.pytorch.Trainer` class
17+ to better support the training process in PINA.
1518 """
1619
1720 def __init__ (
@@ -29,43 +32,35 @@ def __init__(
2932 ** kwargs ,
3033 ):
3134 """
32- Initialize the Trainer class for by calling Lightning costructor and
33- adding many other functionalities.
35+ Initialization of the :class:`Trainer` class.
3436
35- :param solver: A pina:class:`SolverInterface` solver for the
36- differential problem.
37- :type solver: SolverInterface
38- :param batch_size: How many samples per batch to load.
39- If ``batch_size=None`` all
40- samples are loaded and data are not batched, defaults to None.
41- :type batch_size: int | None
42- :param train_size: Percentage of elements in the train dataset.
43- :type train_size: float
44- :param test_size: Percentage of elements in the test dataset.
45- :type test_size: float
46- :param val_size: Percentage of elements in the val dataset.
47- :type val_size: float
48- :param compile: if True model is compiled before training,
49- default False. For Windows users compilation is always disabled.
50- :type compile: bool
51- :param automatic_batching: if True automatic PyTorch batching is
52- performed. Please avoid using automatic batching when batch_size is
53- large, default False.
54- :type automatic_batching: bool
55- :param num_workers: Number of worker threads for data loading.
56- Default 0 (serial loading).
57- :type num_workers: int
58- :param pin_memory: Whether to use pinned memory for faster data
59- transfer to GPU. Default False.
60- :type pin_memory: bool
61- :param shuffle: Whether to shuffle the data for training. Default True.
62- :type pin_memory: bool
37+ :param SolverInterface solver: A :class:`~pina.solver.SolverInterface`
38+ solver used to solve a :class:`~pina.problem.AbstractProblem`.
39+ :param int batch_size: The number of samples per batch to load.
40+ If ``None``, all samples are loaded and data is not batched.
41+ Default is ``None``.
42+ :param float train_size: The percentage of elements to include in the
43+ training dataset. Default is ``1.0``.
44+ :param float test_size: The percentage of elements to include in the
45+ test dataset. Default is ``0.0``.
46+ :param float val_size: The percentage of elements to include in the
47+ validation dataset. Default is ``0.0``.
48+ :param bool compile: If ``True``, the model is compiled before training.
49+ Default is ``False``. For Windows users, it is always disabled.
50+ :param bool automatic_batching: If ``True``, automatic PyTorch batching
51+ is performed. Avoid using automatic batching when ``batch_size`` is
52+ large. Default is ``False``.
53+ :param int num_workers: The number of worker threads for data loading.
54+ Default is ``0`` (serial loading).
55+ :param bool pin_memory: Whether to use pinned memory for faster data
56+ transfer to GPU. Default is ``False``.
57+ :param bool shuffle: Whether to shuffle the data during training.
58+ Default is ``True``.
6359
6460 :Keyword Arguments:
65- The additional keyword arguments specify the training setup
66- and can be choosen from the `pytorch-lightning
67- Trainer API <https://lightning.ai/docs/pytorch/stable/common/
68- trainer.html#trainer-class-api>`_
61+ Additional keyword arguments that specify the training setup.
62+ These can be selected from the pytorch-lightning Trainer API
63+ <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>_.
6964 """
7065 # check consistency for init types
7166 self ._check_input_consistency (
@@ -139,6 +134,10 @@ def __init__(
139134 }
140135
141136 def _move_to_device (self ):
137+ """
138+ Moves the ``unknown_parameters`` of an instance of
139+ :class:`~pina.problem.AbstractProblem` to the :class:`Trainer` device.
140+ """
142141 device = self ._accelerator_connector ._parallel_devices [0 ]
143142 # move parameters to device
144143 pb = self .solver .problem
@@ -160,9 +159,25 @@ def _create_datamodule(
160159 shuffle ,
161160 ):
162161 """
163- This method is used here because is resampling is needed
164- during training, there is no need to define to touch the
165- trainer dataloader, just call the method.
162+ This method is designed to handle the creation of a data module when
163+ resampling is needed during training. Instead of manually defining and
164+ modifying the trainer's dataloaders, this method is called to
165+ automatically configure the data module.
166+
167+ :param float train_size: The percentage of elements to include in the
168+ training dataset.
169+ :param float test_size: The percentage of elements to include in the
170+ test dataset.
171+ :param float val_size: The percentage of elements to include in the
172+ validation dataset.
173+ :param int batch_size: The number of samples per batch to load.
174+ :param bool automatic_batching: Whether to perform automatic batching
175+ with PyTorch.
176+ :param bool pin_memory: Whether to use pinned memory for faster data
177+ transfer to GPU.
178+ :param int num_workers: The number of worker threads for data loading.
179+ :param bool shuffle: Whether to shuffle the data during training.
180+ :raises RuntimeError: If not all conditions are sampled.
166181 """
167182 if not self .solver .problem .are_all_domains_discretised :
168183 error_message = "\n " .join (
@@ -193,33 +208,52 @@ def _create_datamodule(
193208
194209 def train (self , ** kwargs ):
195210 """
196- Train the solver method .
211+ Manage the training process of the solver .
197212 """
198213 return super ().fit (self .solver , datamodule = self .data_module , ** kwargs )
199214
200215 def test (self , ** kwargs ):
201216 """
202- Test the solver method .
217+ Manage the test process of the solver .
203218 """
204219 return super ().test (self .solver , datamodule = self .data_module , ** kwargs )
205220
206221 @property
207222 def solver (self ):
208223 """
209- Returning trainer solver.
224+ Get the solver.
225+
226+ :return: The solver.
227+ :rtype: SolverInterface
210228 """
211229 return self ._solver
212230
213231 @solver .setter
214232 def solver (self , solver ):
233+ """
234+ Set the solver.
235+
236+ :param SolverInterface solver: The solver to set.
237+ """
215238 self ._solver = solver
216239
217240 @staticmethod
218241 def _check_input_consistency (
219242 solver , train_size , test_size , val_size , automatic_batching , compile
220243 ):
221244 """
222- Check the consistency of the input parameters."
245+ Verifies the consistency of the parameters for the solver configuration.
246+
247+ :param SolverInterface solver: The solver.
248+ :param float train_size: The percentage of elements to include in the
249+ training dataset.
250+ :param float test_size: The percentage of elements to include in the
251+ test dataset.
252+ :param float val_size: The percentage of elements to include in the
253+ validation dataset.
254+ :param bool automatic_batching: Whether to perform automatic batching
255+ with PyTorch.
256+ :param bool compile: If ``True``, the model is compiled before training.
223257 """
224258
225259 check_consistency (solver , SolverInterface )
@@ -236,8 +270,14 @@ def _check_consistency_and_set_defaults(
236270 pin_memory , num_workers , shuffle , batch_size
237271 ):
238272 """
239- Check the consistency of the input parameters and set the default
240- values.
273+ Checks the consistency of input parameters and sets default values
274+ for missing or invalid parameters.
275+
276+ :param bool pin_memory: Whether to use pinned memory for faster data
277+ transfer to GPU.
278+ :param int num_workers: The number of worker threads for data loading.
279+ :param bool shuffle: Whether to shuffle the data during training.
280+ :param int batch_size: The number of samples per batch to load.
241281 """
242282 if pin_memory is not None :
243283 check_consistency (pin_memory , bool )
0 commit comments