Skip to content

Commit 1fcca65

Browse files
committed
use a new method instead to avoid messing with other bits of code
1 parent a6d08f4 commit 1fcca65

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

src/bijectors/corr.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ function with_logabsdet_jacobian(ib::Inverse{CorrBijector}, y)
7777
for j in 2:(K - 1)
7878
logJ += (K - j) * log(U[j, j])
7979
end
80-
return pd_from_upper(U), logJ
80+
return pdmat_from_upper(U), logJ
8181
end
8282

8383
logabsdetjac(::Inverse{CorrBijector}, Y) = _logabsdetjac_inv_corr(Y)
@@ -141,7 +141,7 @@ function with_logabsdet_jacobian(::Inverse{VecCorrBijector}, y)
141141
for j in 2:(K - 1)
142142
logJ += (K - j) * log(U[j, j])
143143
end
144-
return pd_from_upper(U), logJ
144+
return pdmat_from_upper(U), logJ
145145
end
146146

147147
function logabsdetjac(::Inverse{VecCorrBijector}, y)

src/utils.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,24 @@ aT_b(a::AbstractVector{<:Real}, b::AbstractVector{<:Real}) = dot(a, b)
99
_vec(x::AbstractArray{<:Real}) = vec(x)
1010
_vec(x::Real) = x
1111

12+
# using PDMats.PDMat improves numerical stability of downstream operations,
13+
# most notably taking the determinant
14+
pdmat_from_lower(X::AbstractMatrix) = PDMat(Cholesky(LowerTriangular(X)))
15+
pdmat_from_upper(X::AbstractMatrix) = PDMat(Cholesky(UpperTriangular(X)))
16+
1217
# # Because `ReverseDiff` does not play well with structural matrices.
1318
lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A))
1419
upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A))
15-
16-
pd_from_lower(X) = PDMat(Cholesky(LowerTriangular(X)))
17-
pd_from_upper(X) = PDMat(Cholesky(UpperTriangular(X)))
20+
# TODO: Replace remaining uses of `pd_from_{lower,upper}` with
21+
# `pdmat_from_{lower,upper}`.
22+
function pd_from_lower(X)
23+
L = lower_triangular(X)
24+
return L * L'
25+
end
26+
function pd_from_upper(X)
27+
U = upper_triangular(X)
28+
return U' * U
29+
end
1830

1931
# HACK: Allows us to define custom chain rules while we wait for upstream fixes.
2032
transpose_eager(X::AbstractMatrix) = permutedims(X)

0 commit comments

Comments
 (0)