Skip to content

Commit 2cb0ead

Browse files
GiovanniCanalidario-coscia
authored andcommitted
support built-in equations in system
1 parent de47d69 commit 2cb0ead

File tree

2 files changed

+105
-27
lines changed

2 files changed

+105
-27
lines changed

pina/equation/system_equation.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,51 @@
88

99
class SystemEquation(EquationInterface):
1010
"""
11-
Implementation of the System of Equations. Every ``equation`` passed to a
12-
:class:`~pina.condition.condition.Condition` object must be either a
13-
:class:`~pina.equation.equation.Equation` or a
14-
:class:`~pina.equation.system_equation.SystemEquation` instance.
11+
Implementation of the System of Equations, to be passed to a
12+
:class:`~pina.condition.condition.Condition` object.
13+
14+
Unlike the :class:`~pina.equation.equation.Equation` class, which represents
15+
a single equation, the :class:`SystemEquation` class allows multiple
16+
equations to be grouped together into a system. This is particularly useful
17+
when dealing with multi-component outputs or coupled physical models, where
18+
the residual must be computed collectively across several constraints.
19+
20+
Each equation in the system must be either:
21+
- An instance of :class:`~pina.equation.equation.Equation`;
22+
- A callable function.
23+
24+
The residuals from each equation are computed independently and then
25+
aggregated using an optional reduction strategy (e.g., ``mean``, ``sum``).
26+
The resulting residual is returned as a single :class:`~pina.LabelTensor`.
27+
28+
:Example:
29+
30+
>>> from pina.equation import SystemEquation, FixedValue, FixedGradient
31+
>>> from pina import LabelTensor
32+
>>> import torch
33+
>>> pts = LabelTensor(torch.rand(10, 2), labels=["x", "y"])
34+
>>> pts.requires_grad = True
35+
>>> output_ = torch.pow(pts, 2)
36+
>>> output_.labels = ["u", "v"]
37+
>>> system_equation = SystemEquation(
38+
... [
39+
... FixedValue(value=1.0, components=["u"]),
40+
... FixedGradient(value=0.0, components=["v"],d=["y"]),
41+
... ],
42+
... reduction="mean",
43+
... )
44+
>>> residual = system_equation.residual(pts, output_)
45+
1546
"""
1647

1748
def __init__(self, list_equation, reduction=None):
1849
"""
1950
Initialization of the :class:`SystemEquation` class.
2051
21-
:param Callable equation: A ``torch`` callable function used to compute
22-
the residual of a mathematical equation.
52+
:param list_equation: A list containing either callable functions or
53+
instances of :class:`~pina.equation.equation.Equation`, used to
54+
compute the residuals of mathematical equations.
55+
:type list_equation: list[Callable] | list[Equation]
2356
:param str reduction: The reduction method to aggregate the residuals of
2457
each equation. Available options are: ``None``, ``mean``, ``sum``,
2558
``callable``.
@@ -32,9 +65,10 @@ def __init__(self, list_equation, reduction=None):
3265
check_consistency([list_equation], list)
3366

3467
# equations definition
35-
self.equations = []
36-
for _, equation in enumerate(list_equation):
37-
self.equations.append(Equation(equation))
68+
self.equations = [
69+
equation if isinstance(equation, Equation) else Equation(equation)
70+
for equation in list_equation
71+
]
3872

3973
# possible reduction
4074
if reduction == "mean":
@@ -45,7 +79,7 @@ def __init__(self, list_equation, reduction=None):
4579
self.reduction = reduction
4680
else:
4781
raise NotImplementedError(
48-
"Only mean and sum reductions implemented."
82+
"Only mean and sum reductions are currenly supported."
4983
)
5084

5185
def residual(self, input_, output_, params_=None):

tests/test_equations/test_system_equation.py

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pina.equation import SystemEquation
1+
from pina.equation import SystemEquation, FixedValue, FixedGradient
22
from pina.operator import grad, laplacian
33
from pina import LabelTensor
44
import torch
@@ -24,34 +24,78 @@ def foo():
2424
pass
2525

2626

27-
def test_constructor():
28-
SystemEquation([eq1, eq2])
29-
SystemEquation([eq1, eq2], reduction="sum")
27+
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
28+
def test_constructor(reduction):
29+
30+
# Constructor with callable functions
31+
SystemEquation([eq1, eq2], reduction=reduction)
32+
33+
# Constructor with Equation instances
34+
SystemEquation(
35+
[
36+
FixedValue(value=0.0, components=["u1"]),
37+
FixedGradient(value=0.0, components=["u2"]),
38+
],
39+
reduction=reduction,
40+
)
41+
42+
# Constructor with mixed types
43+
SystemEquation(
44+
[
45+
FixedValue(value=0.0, components=["u1"]),
46+
eq1,
47+
],
48+
reduction=reduction,
49+
)
50+
51+
# Non-standard reduction not implemented
3052
with pytest.raises(NotImplementedError):
3153
SystemEquation([eq1, eq2], reduction="foo")
54+
55+
# Invalid input type
3256
with pytest.raises(ValueError):
3357
SystemEquation(foo)
3458

3559

36-
def test_residual():
60+
@pytest.mark.parametrize("reduction", [None, "mean", "sum"])
61+
def test_residual(reduction):
3762

63+
# Generate random points and output
3864
pts = LabelTensor(torch.rand(10, 2), labels=["x", "y"])
3965
pts.requires_grad = True
4066
u = torch.pow(pts, 2)
4167
u.labels = ["u1", "u2"]
4268

43-
eq_1 = SystemEquation([eq1, eq2], reduction="mean")
44-
res = eq_1.residual(pts, u)
45-
assert res.shape == torch.Size([10])
69+
# System with callable functions
70+
system_eq = SystemEquation([eq1, eq2], reduction=reduction)
71+
res = system_eq.residual(pts, u)
72+
73+
# Checks on the shape of the residual
74+
shape = torch.Size([10, 3]) if reduction is None else torch.Size([10])
75+
assert res.shape == shape
4676

47-
eq_1 = SystemEquation([eq1, eq2], reduction="sum")
48-
res = eq_1.residual(pts, u)
49-
assert res.shape == torch.Size([10])
77+
# System with Equation instances
78+
system_eq = SystemEquation(
79+
[
80+
FixedValue(value=0.0, components=["u1"]),
81+
FixedGradient(value=0.0, components=["u2"]),
82+
],
83+
reduction=reduction,
84+
)
5085

51-
eq_1 = SystemEquation([eq1, eq2], reduction=None)
52-
res = eq_1.residual(pts, u)
53-
assert res.shape == torch.Size([10, 3])
86+
# Checks on the shape of the residual
87+
shape = torch.Size([10, 3]) if reduction is None else torch.Size([10])
88+
assert res.shape == shape
89+
90+
# System with mixed types
91+
system_eq = SystemEquation(
92+
[
93+
FixedValue(value=0.0, components=["u1"]),
94+
eq1,
95+
],
96+
reduction=reduction,
97+
)
5498

55-
eq_1 = SystemEquation([eq1, eq2])
56-
res = eq_1.residual(pts, u)
57-
assert res.shape == torch.Size([10, 3])
99+
# Checks on the shape of the residual
100+
shape = torch.Size([10, 3]) if reduction is None else torch.Size([10])
101+
assert res.shape == shape

0 commit comments

Comments
 (0)