Skip to content

Commit f246da4

Browse files
fix utils and trainer doc
1 parent 3db6105 commit f246da4

2 files changed

Lines changed: 161 additions & 91 deletions

File tree

pina/trainer.py

Lines changed: 86 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Trainer module."""
1+
"""Module for the Trainer."""
22

33
import sys
44
import torch
@@ -10,8 +10,11 @@
1010

1111
class Trainer(lightning.pytorch.Trainer):
1212
"""
13-
PINA custom Trainer class which allows to customize standard Lightning
14-
Trainer class for PINNs training.
13+
PINA custom Trainer class to extend the standard Lightning functionality.
14+
15+
This class enables specific features or behaviors required by the PINA
16+
framework. It modifies the standard :class:`lightning.pytorch.Trainer` class
17+
to better support the training process in PINA.
1518
"""
1619

1720
def __init__(
@@ -29,43 +32,35 @@ def __init__(
2932
**kwargs,
3033
):
3134
"""
32-
Initialize the Trainer class for by calling Lightning costructor and
33-
adding many other functionalities.
35+
Initialization of the :class:`Trainer` class.
3436
35-
:param solver: A pina:class:`SolverInterface` solver for the
36-
differential problem.
37-
:type solver: SolverInterface
38-
:param batch_size: How many samples per batch to load.
39-
If ``batch_size=None`` all
40-
samples are loaded and data are not batched, defaults to None.
41-
:type batch_size: int | None
42-
:param train_size: Percentage of elements in the train dataset.
43-
:type train_size: float
44-
:param test_size: Percentage of elements in the test dataset.
45-
:type test_size: float
46-
:param val_size: Percentage of elements in the val dataset.
47-
:type val_size: float
48-
:param compile: if True model is compiled before training,
49-
default False. For Windows users compilation is always disabled.
50-
:type compile: bool
51-
:param automatic_batching: if True automatic PyTorch batching is
52-
performed. Please avoid using automatic batching when batch_size is
53-
large, default False.
54-
:type automatic_batching: bool
55-
:param num_workers: Number of worker threads for data loading.
56-
Default 0 (serial loading).
57-
:type num_workers: int
58-
:param pin_memory: Whether to use pinned memory for faster data
59-
transfer to GPU. Default False.
60-
:type pin_memory: bool
61-
:param shuffle: Whether to shuffle the data for training. Default True.
62-
:type pin_memory: bool
37+
:param SolverInterface solver: A :class:`~pina.solver.SolverInterface`
38+
solver used to solve a :class:`~pina.problem.AbstractProblem`.
39+
:param int batch_size: The number of samples per batch to load.
40+
If ``None``, all samples are loaded and data is not batched.
41+
Default is ``None``.
42+
:param float train_size: The percentage of elements to include in the
43+
training dataset. Default is ``1.0``.
44+
:param float test_size: The percentage of elements to include in the
45+
test dataset. Default is ``0.0``.
46+
:param float val_size: The percentage of elements to include in the
47+
validation dataset. Default is ``0.0``.
48+
:param bool compile: If ``True``, the model is compiled before training.
49+
Default is ``False``. For Windows users, it is always disabled.
50+
:param bool automatic_batching: If ``True``, automatic PyTorch batching
51+
is performed. Avoid using automatic batching when ``batch_size`` is
52+
large. Default is ``False``.
53+
:param int num_workers: The number of worker threads for data loading.
54+
Default is ``0`` (serial loading).
55+
:param bool pin_memory: Whether to use pinned memory for faster data
56+
transfer to GPU. Default is ``False``.
57+
:param bool shuffle: Whether to shuffle the data during training.
58+
Default is ``True``.
6359
6460
:Keyword Arguments:
65-
The additional keyword arguments specify the training setup
66-
and can be choosen from the `pytorch-lightning
67-
Trainer API <https://lightning.ai/docs/pytorch/stable/common/
68-
trainer.html#trainer-class-api>`_
61+
Additional keyword arguments that specify the training setup.
62+
These can be selected from the pytorch-lightning Trainer API
63+
<https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>_.
6964
"""
7065
# check consistency for init types
7166
self._check_input_consistency(
@@ -139,6 +134,10 @@ def __init__(
139134
}
140135

141136
def _move_to_device(self):
137+
"""
138+
Moves the ``unknown_parameters`` of an instance of
139+
:class:`~pina.problem.AbstractProblem` to the :class:`Trainer` device.
140+
"""
142141
device = self._accelerator_connector._parallel_devices[0]
143142
# move parameters to device
144143
pb = self.solver.problem
@@ -160,9 +159,25 @@ def _create_datamodule(
160159
shuffle,
161160
):
162161
"""
163-
This method is used here because is resampling is needed
164-
during training, there is no need to define to touch the
165-
trainer dataloader, just call the method.
162+
This method is designed to handle the creation of a data module when
163+
resampling is needed during training. Instead of manually defining and
164+
modifying the trainer's dataloaders, this method is called to
165+
automatically configure the data module.
166+
167+
:param float train_size: The percentage of elements to include in the
168+
training dataset.
169+
:param float test_size: The percentage of elements to include in the
170+
test dataset.
171+
:param float val_size: The percentage of elements to include in the
172+
validation dataset.
173+
:param int batch_size: The number of samples per batch to load.
174+
:param bool automatic_batching: Whether to perform automatic batching
175+
with PyTorch.
176+
:param bool pin_memory: Whether to use pinned memory for faster data
177+
transfer to GPU.
178+
:param int num_workers: The number of worker threads for data loading.
179+
:param bool shuffle: Whether to shuffle the data during training.
180+
:raises RuntimeError: If not all conditions are sampled.
166181
"""
167182
if not self.solver.problem.are_all_domains_discretised:
168183
error_message = "\n".join(
@@ -193,33 +208,52 @@ def _create_datamodule(
193208

194209
def train(self, **kwargs):
195210
"""
196-
Train the solver method.
211+
Manage the training process of the solver.
197212
"""
198213
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
199214

200215
def test(self, **kwargs):
201216
"""
202-
Test the solver method.
217+
Manage the test process of the solver.
203218
"""
204219
return super().test(self.solver, datamodule=self.data_module, **kwargs)
205220

206221
@property
207222
def solver(self):
208223
"""
209-
Returning trainer solver.
224+
Get the solver.
225+
226+
:return: The solver.
227+
:rtype: SolverInterface
210228
"""
211229
return self._solver
212230

213231
@solver.setter
214232
def solver(self, solver):
233+
"""
234+
Set the solver.
235+
236+
:param SolverInterface solver: The solver to set.
237+
"""
215238
self._solver = solver
216239

217240
@staticmethod
218241
def _check_input_consistency(
219242
solver, train_size, test_size, val_size, automatic_batching, compile
220243
):
221244
"""
222-
Check the consistency of the input parameters."
245+
Verifies the consistency of the parameters for the solver configuration.
246+
247+
:param SolverInterface solver: The solver.
248+
:param float train_size: The percentage of elements to include in the
249+
training dataset.
250+
:param float test_size: The percentage of elements to include in the
251+
test dataset.
252+
:param float val_size: The percentage of elements to include in the
253+
validation dataset.
254+
:param bool automatic_batching: Whether to perform automatic batching
255+
with PyTorch.
256+
:param bool compile: If ``True``, the model is compiled before training.
223257
"""
224258

225259
check_consistency(solver, SolverInterface)
@@ -236,8 +270,14 @@ def _check_consistency_and_set_defaults(
236270
pin_memory, num_workers, shuffle, batch_size
237271
):
238272
"""
239-
Check the consistency of the input parameters and set the default
240-
values.
273+
Checks the consistency of input parameters and sets default values
274+
for missing or invalid parameters.
275+
276+
:param bool pin_memory: Whether to use pinned memory for faster data
277+
transfer to GPU.
278+
:param int num_workers: The number of worker threads for data loading.
279+
:param bool shuffle: Whether to shuffle the data during training.
280+
:param int batch_size: The number of samples per batch to load.
241281
"""
242282
if pin_memory is not None:
243283
check_consistency(pin_memory, bool)

pina/utils.py

Lines changed: 75 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Utils module."""
1+
"""Module for utility functions."""
22

33
import types
44
from functools import reduce
@@ -12,35 +12,36 @@ def custom_warning_format(
1212
message, category, filename, lineno, file=None, line=None
1313
):
1414
"""
15-
Depewarning custom format.
15+
Custom warning formatting function.
1616
1717
:param str message: The warning message.
18-
:param class category: The warning category.
19-
:param str filename: The filename where the warning was raised.
20-
:param int lineno: The line number where the warning was raised.
21-
:param str file: The file object where the warning was raised.
22-
:param inr line: The line where the warning was raised.
18+
:param Warning category: The warning category.
19+
:param str filename: The filename where the warning is raised.
20+
:param int lineno: The line number where the warning is raised.
21+
:param str file: The file object where the warning is raised.
22+
Default is None.
23+
:param int line: The line where the warning is raised.
2324
:return: The formatted warning message.
2425
:rtype: str
2526
"""
2627
return f"{filename}: {category.__name__}: {message}\n"
2728

2829

2930
def check_consistency(object_, object_instance, subclass=False):
30-
"""Helper function to check object inheritance consistency.
31-
Given a specific ``'object'`` we check if the object is
32-
instance of a specific ``'object_instance'``, or in case
33-
``'subclass=True'`` we check if the object is subclass
34-
if the ``'object_instance'``.
35-
36-
:param (iterable or class object) object: The object to check the
37-
inheritance
38-
:param Object object_instance: The parent class from where the object
39-
is expected to inherit
40-
:param str object_name: The name of the object
41-
:param bool subclass: Check if is a subclass and not instance
42-
:raises ValueError: If the object does not inherit from the
43-
specified class
31+
"""
32+
Check if an object maintains inheritance consistency.
33+
34+
This function checks whether a given object is an instance of a specified
35+
class or, if ``subclass=True``, whether it is a subclass of the specified
36+
class.
37+
38+
:param object: The object to check.
39+
:type object: Iterable | Object
40+
:param Object object_instance: The expected parent class.
41+
:param bool subclass: If True, checks whether ``object_`` is a subclass
42+
of ``object_instance`` instead of an instance. Default is ``False``.
43+
:raises ValueError: If ``object_`` does not inherit from ``object_instance``
44+
as expected.
4445
"""
4546
if not isinstance(object_, (list, set, tuple)):
4647
object_ = [object_]
@@ -59,18 +60,28 @@ def check_consistency(object_, object_instance, subclass=False):
5960

6061
def labelize_forward(forward, input_variables, output_variables):
6162
"""
62-
Wrapper decorator to allow users to enable or disable the use of
63-
LabelTensors during the forward pass.
64-
65-
:param forward: The torch.nn.Module forward function.
66-
:type forward: Callable
67-
:param input_variables: The problem input variables.
68-
:type input_variables: list[str] | tuple[str]
69-
:param output_variables: The problem output variables.
70-
:type output_variables: list[str] | tuple[str]
63+
Decorator to enable or disable the use of :class:`~pina.LabelTensor`
64+
during the forward pass.
65+
66+
:param Callable forward: The forward function of a :class:`torch.nn.Module`.
67+
:param list[str] input_variables: The names of the input variables of a
68+
:class:`~pina.problem.AbstractProblem`.
69+
:param list[str] output_variables: The names of the output variables of a
70+
:class:`~pina.problem.AbstractProblem`.
71+
:return: The decorated forward function.
72+
:rtype: Callable
7173
"""
7274

7375
def wrapper(x):
76+
"""
77+
Decorated forward function.
78+
79+
:param LabelTensor x: The labelized input of the forward pass of an
80+
instance of :class:`torch.nn.Module`.
81+
:return: The labelized output of the forward pass of an instance of
82+
:class:`torch.nn.Module`.
83+
:rtype: LabelTensor
84+
"""
7485
x = x.extract(input_variables)
7586
output = forward(x)
7687
# keep it like this, directly using LabelTensor(...) raises errors
@@ -82,15 +93,32 @@ def wrapper(x):
8293
return wrapper
8394

8495

85-
def merge_tensors(tensors): # name to be changed
86-
"""TODO"""
96+
def merge_tensors(tensors):
97+
"""
98+
Merge a list of :class:`~pina.LabelTensor` instances into a single
99+
:class:`~pina.LabelTensor` tensor, by applying iteratively the cartesian
100+
product.
101+
102+
:param list[LabelTensor] tensors: The list of tensors to merge.
103+
:raises ValueError: If the list of tensors is empty.
104+
:return: The merged tensor.
105+
:rtype: LabelTensor
106+
"""
87107
if tensors:
88108
return reduce(merge_two_tensors, tensors[1:], tensors[0])
89109
raise ValueError("Expected at least one tensor")
90110

91111

92112
def merge_two_tensors(tensor1, tensor2):
93-
"""TODO"""
113+
"""
114+
Merge two :class:`~pina.LabelTensor` instances into a single
115+
:class:`~pina.LabelTensor` tensor, by applying the cartesian product.
116+
117+
:param LabelTensor tensor1: The first tensor to merge.
118+
:param LabelTensor tensor2: The second tensor to merge.
119+
:return: The merged tensor.
120+
:rtype: LabelTensor
121+
"""
94122
n1 = tensor1.shape[0]
95123
n2 = tensor2.shape[0]
96124

@@ -102,12 +130,14 @@ def merge_two_tensors(tensor1, tensor2):
102130

103131

104132
def torch_lhs(n, dim):
105-
"""Latin Hypercube Sampling torch routine.
106-
Sampling in range $[0, 1)^d$.
133+
"""
134+
The Latin Hypercube Sampling torch routine, sampling in :math:`[0, 1)`$.
107135
108-
:param int n: number of samples
109-
:param int dim: dimensions of latin hypercube
110-
:return: samples
136+
:param int n: The number of points to sample.
137+
:param int dim: The number of dimensions of the sampling space.
138+
:raises TypeError: If `n` or `dim` are not integers.
139+
:raises ValueError: If `dim` is less than 1.
140+
:return: The sampled points.
111141
:rtype: torch.tensor
112142
"""
113143

@@ -137,22 +167,22 @@ def torch_lhs(n, dim):
137167

138168
def is_function(f):
139169
"""
140-
Checks whether the given object `f` is a function or lambda.
170+
Check if the given object is a function or a lambda.
141171
142-
:param object f: The object to be checked.
143-
:return: `True` if `f` is a function, `False` otherwise.
172+
:param Object f: The object to be checked.
173+
:return: ``True`` if ``f`` is a function, ``False`` otherwise.
144174
:rtype: bool
145175
"""
146176
return isinstance(f, (types.FunctionType, types.LambdaType))
147177

148178

149179
def chebyshev_roots(n):
150180
"""
151-
Return the roots of *n* Chebyshev polynomials (between [-1, 1]).
181+
Compute the roots of the Chebyshev polynomial of degree ``n``.
152182
153-
:param int n: number of roots
154-
:return: roots
155-
:rtype: torch.tensor
183+
:param int n: The number of roots to return.
184+
:return: The roots of the Chebyshev polynomials.
185+
:rtype: torch.Tensor
156186
"""
157187
pi = torch.acos(torch.zeros(1)).item() * 2
158188
k = torch.arange(n)

0 commit comments

Comments
 (0)