@@ -1938,6 +1938,7 @@ def __init__(self):
19381938 self .rng_cuda_ = torch .Generator ("cpu" )
19391939
19401940 from torch .autograd import Function
1941+ from torch .autograd .function import once_differentiable
19411942
19421943 # define a function that takes inputs val and grads
19431944 # ad returns a val tensor with proper gradients
@@ -1952,7 +1953,31 @@ def backward(ctx, grad_output):
19521953 # the gradients are grad
19531954 return (None , None ) + tuple (g * grad_output for g in ctx .grads )
19541955
1956+ # define a differentiable SPD matrix sqrt
1957+ # with closed-form VJP
1958+ class MatrixSqrtFunction (Function ):
1959+ @staticmethod
1960+ def forward (ctx , a ):
1961+ a_sym = 0.5 * (a + a .transpose (- 2 , - 1 ))
1962+ L , V = torch .linalg .eigh (a_sym )
1963+ s = L .clamp_min (0 ).sqrt ()
1964+ y = (V * s .unsqueeze (- 2 )) @ V .transpose (- 2 , - 1 )
1965+ ctx .save_for_backward (s , V )
1966+ return y
1967+
1968+ @staticmethod
1969+ @once_differentiable
1970+ def backward (ctx , g ):
1971+ s , V = ctx .saved_tensors
1972+ g_sym = 0.5 * (g + g .transpose (- 2 , - 1 ))
1973+ ghat = V .transpose (- 2 , - 1 ) @ g_sym @ V
1974+ d = s .unsqueeze (- 1 ) + s .unsqueeze (- 2 )
1975+ xhat = ghat / d
1976+ xhat = xhat .masked_fill (d == 0 , 0 )
1977+ return V @ xhat @ V .transpose (- 2 , - 1 )
1978+
19551979 self .ValFunction = ValFunction
1980+ self .MatrixSqrtFunction = MatrixSqrtFunction
19561981
19571982 def _to_numpy (self , a ):
19581983 if isinstance (a , float ) or isinstance (a , int ) or isinstance (a , np .ndarray ):
@@ -2395,12 +2420,7 @@ def pinv(self, a, hermitian=False):
23952420 return torch .linalg .pinv (a , hermitian = hermitian )
23962421
23972422 def sqrtm (self , a ):
2398- L , V = torch .linalg .eigh (a )
2399- L = torch .sqrt (L )
2400- # Q[...] = V[...] @ diag(L[...])
2401- Q = torch .einsum ("...jk,...k->...jk" , V , L )
2402- # R[...] = Q[...] @ V[...].T
2403- return torch .einsum ("...jk,...kl->...jl" , Q , torch .transpose (V , - 1 , - 2 ))
2423+ return self .MatrixSqrtFunction .apply (a )
24042424
24052425 def eigh (self , a ):
24062426 return torch .linalg .eigh (a )
0 commit comments