Skip to content

Commit 3c1fed9

Browse files
annaivagnesdario-coscia
authored andcommitted
add singular values in PODBlock
1 parent e3d4c2f commit 3c1fed9

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

pina/model/block/pod_block.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(self, rank, scale_coefficients=True):
3030
super().__init__()
3131
self.__scale_coefficients = scale_coefficients
3232
self._basis = None
33+
self._singular_values = None
3334
self._scaler = None
3435
self._rank = rank
3536

@@ -70,6 +71,19 @@ def basis(self):
7071

7172
return self._basis[: self.rank]
7273

74+
@property
75+
def singular_values(self):
76+
"""
77+
The singular values of the POD basis.
78+
79+
:return: The singular values.
80+
:rtype: torch.Tensor
81+
"""
82+
if self._singular_values is None:
83+
return None
84+
85+
return self._singular_values[: self.rank]
86+
7387
@property
7488
def scaler(self):
7589
"""
@@ -136,15 +150,19 @@ def _fit_pod(self, X, randomized):
136150
"This may slow down computations.",
137151
ResourceWarning,
138152
)
139-
self._basis = torch.svd(X.T)[0].T
153+
u, s, v = torch.svd(X.T)
140154
else:
141155
if randomized:
142156
warnings.warn(
143157
"Considering a randomized algorithm to compute the POD basis"
144158
)
145-
self._basis = torch.svd_lowrank(X.T, q=X.shape[0])[0].T
159+
u, s, v = torch.svd_lowrank(X.T, q=X.shape[0])
160+
146161
else:
147-
self._basis = torch.svd(X.T)[0].T
162+
u, s, v = torch.svd(X.T)
163+
self._basis = u.T
164+
self._singular_values = s
165+
148166

149167
def forward(self, X):
150168
"""

tests/test_blocks/test_pod.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def test_fit(rank, scale):
2323
assert pod._basis == None
2424
assert pod.basis == None
2525
assert pod._scaler == None
26+
assert pod._singular_values == None
27+
assert pod.singular_values == None
2628
assert pod.rank == rank
2729
assert pod.scale_coefficients == scale
2830

@@ -37,6 +39,8 @@ def test_fit(rank, scale, randomized):
3739
dof = toy_snapshots.shape[1]
3840
assert pod.basis.shape == (rank, dof)
3941
assert pod._basis.shape == (n_snap, dof)
42+
assert pod.singular_values.shape == (rank,)
43+
assert pod._singular_values.shape == (n_snap,)
4044
if scale is True:
4145
assert pod._scaler["mean"].shape == (n_snap,)
4246
assert pod._scaler["std"].shape == (n_snap,)

0 commit comments

Comments
 (0)