From 97389465f954ca1e56a3157480a35c9cf255b499 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 19 Jun 2025 15:28:04 +0100 Subject: [PATCH 1/4] Fix LKJ numerical stability with PDMats --- Project.toml | 6 ++++-- src/bijectors/corr.jl | 4 +--- src/utils.jl | 12 ++++-------- test/bijectors/corr.jl | 15 +++++++++++++++ 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 6a8a8ce6..e42ab3d5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.15.7" +version = "0.15.8" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -14,6 +14,7 @@ IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" @@ -37,8 +38,8 @@ BijectorsEnzymeCoreExt = "EnzymeCore" BijectorsForwardDiffExt = "ForwardDiff" BijectorsLazyArraysExt = "LazyArrays" BijectorsMooncakeExt = "Mooncake" -BijectorsReverseDiffExt = "ReverseDiff" BijectorsReverseDiffChainRulesExt = ["ChainRules", "ReverseDiff"] +BijectorsReverseDiffExt = "ReverseDiff" BijectorsTrackerExt = "Tracker" BijectorsZygoteExt = "Zygote" @@ -59,6 +60,7 @@ LazyArrays = "2" LogExpFunctions = "0.3.3" MappedArrays = "0.2.2, 0.3, 0.4" Mooncake = "0.4.95" +PDMats = "0.11.35" Reexport = "0.2, 1" ReverseDiff = "1" Roots = "1.3.15, 2" diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 085324f0..37905cb7 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -136,9 +136,7 @@ function logabsdetjac(b::VecCorrBijector, x) end function with_logabsdet_jacobian(::Inverse{VecCorrBijector}, y) - U_logJ = _inv_link_chol_lkj(y) - # workaround for `Tracker.TrackedTuple` not supporting iteration - U, logJ = U_logJ[1], U_logJ[2] + U, logJ = _inv_link_chol_lkj(y) K = size(U, 1) for j in 2:(K - 1) logJ += (K - j) * log(U[j, j]) diff --git a/src/utils.jl b/src/utils.jl index 9fd6c65c..0aa09ffc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,5 @@ +using PDMats: PDMat + # `permutedims` seems to work better with AD (cf. KernelFunctions.jl) aT_b(a::AbstractVector{<:Real}, b::AbstractMatrix{<:Real}) = permutedims(a) * b # `permutedims` can't be used here since scalar output is desired @@ -11,14 +13,8 @@ _vec(x::Real) = x lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) -function pd_from_lower(X) - L = lower_triangular(X) - return L * L' -end -function pd_from_upper(X) - U = upper_triangular(X) - return U' * U -end +pd_from_lower(X) = PDMat(Cholesky(LowerTriangular(X))) +pd_from_upper(X) = PDMat(Cholesky(UpperTriangular(X))) # HACK: Allows us to define custom chain rules while we wait for upstream fixes. transpose_eager(X::AbstractMatrix) = permutedims(X) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 8a423bc3..19a0570b 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -1,5 +1,6 @@ using Bijectors, DistributionsAD, LinearAlgebra, Test using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector +using Random: Xoshiro @testset "CorrBijector & VecCorrBijector" begin for d in [1, 2, 5] @@ -43,6 +44,20 @@ using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector @test size(dist_unconstrained) == size(x) @test dist_unconstrained isa MatrixDistribution end + + @testset "Pathological samples for invlink" begin + # see https://github.com/TuringLang/Bijectors.jl/issues/387 + d = LKJ(3, 3.0) + for i in 1:100 + rng = Xoshiro(i) + y = randn(rng, 3) * 15 + f_inv = inverse(bijector(d)) + x = f_inv(y) + @test logpdf(d, x) isa Float64 # used to crash. + x, _ = with_logabsdet_jacobian(f_inv, y) + @test logpdf(d, x) isa Float64 + end + end end @testset "VecCholeskyBijector" begin From a6d08f4085aa78afdb2177a2afb4e5ac5b0f84c4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 19 Jun 2025 15:34:30 +0100 Subject: [PATCH 2/4] Add `cholesky_{upper,lower}` methods --- src/utils.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 0aa09ffc..bc615954 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -31,6 +31,7 @@ rather than `LowerTriangular`. that returns a `Matrix` rather than `LowerTriangular`. """ cholesky_lower(X::AbstractMatrix) = lower_triangular(parent(cholesky(Hermitian(X, :L)).L)) +cholesky_lower(X::PDMat) = X.chol.L cholesky_lower(X::Cholesky) = X.L """ @@ -44,6 +45,7 @@ rather than `UpperTriangular`. that returns a `Matrix` rather than `UpperTriangular`. """ cholesky_upper(X::AbstractMatrix) = upper_triangular(parent(cholesky(Hermitian(X)).U)) +cholesky_upper(X::PDMat) = X.chol.U cholesky_upper(X::Cholesky) = X.U """ From 1fcca65b5a5e61f102790ab0b0ca7221e9f15485 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 19 Jun 2025 15:51:12 +0100 Subject: [PATCH 3/4] use a new method instead to avoid messing with other bits of code --- src/bijectors/corr.jl | 4 ++-- src/utils.jl | 18 +++++++++++++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 37905cb7..3fcba323 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -77,7 +77,7 @@ function with_logabsdet_jacobian(ib::Inverse{CorrBijector}, y) for j in 2:(K - 1) logJ += (K - j) * log(U[j, j]) end - return pd_from_upper(U), logJ + return pdmat_from_upper(U), logJ end logabsdetjac(::Inverse{CorrBijector}, Y) = _logabsdetjac_inv_corr(Y) @@ -141,7 +141,7 @@ function with_logabsdet_jacobian(::Inverse{VecCorrBijector}, y) for j in 2:(K - 1) logJ += (K - j) * log(U[j, j]) end - return pd_from_upper(U), logJ + return pdmat_from_upper(U), logJ end function logabsdetjac(::Inverse{VecCorrBijector}, y) diff --git a/src/utils.jl b/src/utils.jl index bc615954..bb18a66d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,12 +9,24 @@ aT_b(a::AbstractVector{<:Real}, b::AbstractVector{<:Real}) = dot(a, b) _vec(x::AbstractArray{<:Real}) = vec(x) _vec(x::Real) = x +# using PDMats.PDMat improves numerical stability of downstream operations, +# most notably taking the determinant +pdmat_from_lower(X::AbstractMatrix) = PDMat(Cholesky(LowerTriangular(X))) +pdmat_from_upper(X::AbstractMatrix) = PDMat(Cholesky(UpperTriangular(X))) + # # Because `ReverseDiff` does not play well with structural matrices. lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) - -pd_from_lower(X) = PDMat(Cholesky(LowerTriangular(X))) -pd_from_upper(X) = PDMat(Cholesky(UpperTriangular(X))) +# TODO: Replace remaining uses of `pd_from_{lower,upper}` with +# `pdmat_from_{lower,upper}`. +function pd_from_lower(X) + L = lower_triangular(X) + return L * L' +end +function pd_from_upper(X) + U = upper_triangular(X) + return U' * U +end # HACK: Allows us to define custom chain rules while we wait for upstream fixes. transpose_eager(X::AbstractMatrix) = permutedims(X) From 699a61c7fe8248fa6b1fee3910af3856776656e7 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 19 Jun 2025 15:52:22 +0100 Subject: [PATCH 4/4] Use public `LinearAlgebra.cholesky` interface Co-authored-by: David Widmann --- src/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index bb18a66d..702f5be6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -43,7 +43,7 @@ rather than `LowerTriangular`. that returns a `Matrix` rather than `LowerTriangular`. """ cholesky_lower(X::AbstractMatrix) = lower_triangular(parent(cholesky(Hermitian(X, :L)).L)) -cholesky_lower(X::PDMat) = X.chol.L +cholesky_lower(X::PDMat) = cholesky_lower(cholesky(X)) cholesky_lower(X::Cholesky) = X.L """ @@ -57,7 +57,7 @@ rather than `UpperTriangular`. that returns a `Matrix` rather than `UpperTriangular`. """ cholesky_upper(X::AbstractMatrix) = upper_triangular(parent(cholesky(Hermitian(X)).U)) -cholesky_upper(X::PDMat) = X.chol.U +cholesky_upper(X::PDMat) = cholesky_upper(cholesky(X)) cholesky_upper(X::Cholesky) = X.U """