Skip to content

Commit d3867c6

Browse files
samuelbxSamuel Boïté
andauthored
[FIX] unstable TorchBackend.sqrtm() around repeated eigenvalues (#774)
* stable matrix sqrt using closed-form diff * torch matrix sqrt gradcheck * pre commit * edit releases md --------- Co-authored-by: Samuel Boïté <samuel.boite@polytechnique.edu>
1 parent be211ac commit d3867c6

File tree

3 files changed

+40
-6
lines changed

3 files changed

+40
-6
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#### Closed issues
66
- Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770)
77
- Add test for build from source (PR #772, Issue #764)
8+
- Stable `ot.TorchBackend.sqrtm` around repeated eigvals (PR #774, Issue #773)
89

910
## 0.9.6.post1
1011

ot/backend.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1938,6 +1938,7 @@ def __init__(self):
19381938
self.rng_cuda_ = torch.Generator("cpu")
19391939

19401940
from torch.autograd import Function
1941+
from torch.autograd.function import once_differentiable
19411942

19421943
# define a function that takes inputs val and grads
19431944
# ad returns a val tensor with proper gradients
@@ -1952,7 +1953,31 @@ def backward(ctx, grad_output):
19521953
# the gradients are grad
19531954
return (None, None) + tuple(g * grad_output for g in ctx.grads)
19541955

1956+
# define a differentiable SPD matrix sqrt
1957+
# with closed-form VJP
1958+
class MatrixSqrtFunction(Function):
1959+
@staticmethod
1960+
def forward(ctx, a):
1961+
a_sym = 0.5 * (a + a.transpose(-2, -1))
1962+
L, V = torch.linalg.eigh(a_sym)
1963+
s = L.clamp_min(0).sqrt()
1964+
y = (V * s.unsqueeze(-2)) @ V.transpose(-2, -1)
1965+
ctx.save_for_backward(s, V)
1966+
return y
1967+
1968+
@staticmethod
1969+
@once_differentiable
1970+
def backward(ctx, g):
1971+
s, V = ctx.saved_tensors
1972+
g_sym = 0.5 * (g + g.transpose(-2, -1))
1973+
ghat = V.transpose(-2, -1) @ g_sym @ V
1974+
d = s.unsqueeze(-1) + s.unsqueeze(-2)
1975+
xhat = ghat / d
1976+
xhat = xhat.masked_fill(d == 0, 0)
1977+
return V @ xhat @ V.transpose(-2, -1)
1978+
19551979
self.ValFunction = ValFunction
1980+
self.MatrixSqrtFunction = MatrixSqrtFunction
19561981

19571982
def _to_numpy(self, a):
19581983
if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
@@ -2395,12 +2420,7 @@ def pinv(self, a, hermitian=False):
23952420
return torch.linalg.pinv(a, hermitian=hermitian)
23962421

23972422
def sqrtm(self, a):
2398-
L, V = torch.linalg.eigh(a)
2399-
L = torch.sqrt(L)
2400-
# Q[...] = V[...] @ diag(L[...])
2401-
Q = torch.einsum("...jk,...k->...jk", V, L)
2402-
# R[...] = Q[...] @ V[...].T
2403-
return torch.einsum("...jk,...kl->...jl", Q, torch.transpose(V, -1, -2))
2423+
return self.MatrixSqrtFunction.apply(a)
24042424

24052425
def eigh(self, a):
24062426
return torch.linalg.eigh(a)

test/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,19 @@ def fun(a, b, d):
822822
assert nx.allclose(dl_db, b)
823823

824824

825+
def test_sqrtm_backward_torch():
826+
if not torch:
827+
pytest.skip("Torch not available")
828+
nx = ot.backend.TorchBackend()
829+
torch.manual_seed(42)
830+
d = 5
831+
A = torch.randn(d, d, dtype=torch.float64, device="cpu")
832+
A = A @ A.T
833+
A.requires_grad_(True)
834+
func = lambda x: nx.sqrtm(x).sum()
835+
assert torch.autograd.gradcheck(func, (A,), atol=1e-4, rtol=1e-4)
836+
837+
825838
def test_get_backend_none():
826839
a, b = np.zeros((2, 3)), None
827840
nx = get_backend(a, b)

0 commit comments

Comments
 (0)