From 0aedf084c20bceaa4b65e183b3b70ed2e4ec7f76 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 4 Apr 2025 03:06:25 +0100 Subject: [PATCH 1/5] Implement AD testing (with DITest) --- Project.toml | 3 + ...namicPPLDifferentiationInterfaceTestExt.jl | 68 +++++++++++++++++++ src/test_utils.jl | 4 ++ test/Project.toml | 1 + test/ad.jl | 14 ++-- 5 files changed, 85 insertions(+), 5 deletions(-) create mode 100644 ext/DynamicPPLDifferentiationInterfaceTestExt.jl diff --git a/Project.toml b/Project.toml index e49d11908..fa1175b9f 100644 --- a/Project.toml +++ b/Project.toml @@ -26,6 +26,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" @@ -35,6 +36,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] +DynamicPPLDifferentiationInterfaceTestExt = ["DifferentiationInterface", "DifferentiationInterfaceTest"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] @@ -52,6 +54,7 @@ ChainRulesCore = "1" Compat = "4" ConstructionBase = "1.5.4" DifferentiationInterface = "0.6.41" +DifferentiationInterfaceTest = "0.9.6" Distributions = "0.25" DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" diff --git a/ext/DynamicPPLDifferentiationInterfaceTestExt.jl b/ext/DynamicPPLDifferentiationInterfaceTestExt.jl new file mode 100644 index 000000000..d80696d53 --- /dev/null +++ b/ext/DynamicPPLDifferentiationInterfaceTestExt.jl @@ -0,0 +1,68 @@ +module DynamicPPLDifferentiationInterfaceTestExt + +using DynamicPPL: + DynamicPPL, + ADTypes, + LogDensityProblems, + Model, + AbstractVarInfo, + VarInfo, + LogDensityFunction +import DifferentiationInterface as DI +import DifferentiationInterfaceTest as DIT + +""" + REFERENCE_ADTYPE + +Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since +it's the default AD backend used in Turing.jl. +""" +const REFERENCE_ADTYPE = ADTypes.AutoForwardDiff() + +""" + DynamicPPL.TestUtils.AD.make_scenario( + model::Model, + adtype::ADTypes.AbstractADType, + varinfo::AbstractVarInfo=VarInfo(model), + params::Vector{<:Real}=varinfo[:], + reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, + expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, + ) + +Construct a DifferentiationInterfaceTest.Scenario for the given `model` and `adtype`. + +More docs to follow. +""" +function DynamicPPL.TestUtils.AD.make_scenario( + model::Model, + adtype::ADTypes.AbstractADType; + varinfo::AbstractVarInfo=VarInfo(model), + params::Vector{<:Real}=varinfo[:], + reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, + expected_grad::Union{Nothing,Vector{<:Real}}=nothing, +) + params = map(identity, params) + context = DynamicPPL.DefaultContext() + adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo, context) + if DynamicPPL.use_closure(adtype) + f = x -> DynamicPPL.logdensity_at(x, model, varinfo, context) + di_contexts = () + else + f = DynamicPPL.logdensity_at + di_contexts = (DI.Constant(model), DI.Constant(varinfo), DI.Constant(context)) + end + + # Calculate ground truth to compare against + grad_true = if expected_grad === nothing + ldf_reference = LogDensityFunction(model; adtype=reference_adtype) + LogDensityProblems.logdensity_and_gradient(ldf_reference, params)[2] + else + expected_grad + end + + return DIT.Scenario{:gradient,:out}( + f, params; contexts=di_contexts, res1=grad_true, name="$(model.f)" + ) +end + +end diff --git a/src/test_utils.jl b/src/test_utils.jl index c7d12c927..c559da44c 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -19,4 +19,8 @@ include("test_utils/contexts.jl") include("test_utils/varinfo.jl") include("test_utils/sampler.jl") +module AD + function make_scenario end +end + end diff --git a/test/Project.toml b/test/Project.toml index 9fa3fd872..fa382b08a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" diff --git a/test/ad.jl b/test/ad.jl index a4f3dbfa7..c8df6fd92 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,4 +1,5 @@ using DynamicPPL: LogDensityFunction +import DifferentiationInterfaceTest as DIT @testset "Automatic differentiation" begin # Used as the ground truth that others are compared against. @@ -27,7 +28,7 @@ using DynamicPPL: LogDensityFunction x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype) - ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) + ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)[2] @testset "$adtype" for adtype in test_adtypes @info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype" @@ -56,10 +57,13 @@ using DynamicPPL: LogDensityFunction ref_ldf, adtype ) else - ldf = DynamicPPL.LogDensityFunction(ref_ldf, adtype) - logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x) - @test grad ≈ ref_grad - @test logp ≈ ref_logp + scen = DynamicPPL.TestUtils.AD.make_scenario( + m, adtype; varinfo=varinfo, expected_grad=ref_grad + ) + tadtype = DynamicPPL.tweak_adtype( + adtype, m, varinfo, DefaultContext() + ) + DIT.test_differentiation(tadtype, [scen]; scenario_intact=false) end end end From 4539a80e46b32a6c84fff8b1766c7bd8a16cad8b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 4 Apr 2025 03:13:42 +0100 Subject: [PATCH 2/5] Fix 1.10 extensions --- Project.toml | 2 +- ext/DynamicPPLDifferentiationInterfaceTestExt.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index fa1175b9f..8cf8a9a51 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] -DynamicPPLDifferentiationInterfaceTestExt = ["DifferentiationInterface", "DifferentiationInterfaceTest"] +DynamicPPLDifferentiationInterfaceTestExt = ["DifferentiationInterfaceTest"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] diff --git a/ext/DynamicPPLDifferentiationInterfaceTestExt.jl b/ext/DynamicPPLDifferentiationInterfaceTestExt.jl index d80696d53..38f7fd129 100644 --- a/ext/DynamicPPLDifferentiationInterfaceTestExt.jl +++ b/ext/DynamicPPLDifferentiationInterfaceTestExt.jl @@ -5,10 +5,10 @@ using DynamicPPL: ADTypes, LogDensityProblems, Model, + DI, # DifferentiationInterface AbstractVarInfo, VarInfo, LogDensityFunction -import DifferentiationInterface as DI import DifferentiationInterfaceTest as DIT """ From 70e1aa9fc4f3cee2010914c04cf487fe30289534 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 4 Apr 2025 11:11:05 +0100 Subject: [PATCH 3/5] Make interface more consistent --- ...namicPPLDifferentiationInterfaceTestExt.jl | 22 +++++++++++++++++-- src/test_utils.jl | 2 +- test/ad.jl | 6 +---- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/ext/DynamicPPLDifferentiationInterfaceTestExt.jl b/ext/DynamicPPLDifferentiationInterfaceTestExt.jl index 38f7fd129..194e22393 100644 --- a/ext/DynamicPPLDifferentiationInterfaceTestExt.jl +++ b/ext/DynamicPPLDifferentiationInterfaceTestExt.jl @@ -20,7 +20,7 @@ it's the default AD backend used in Turing.jl. const REFERENCE_ADTYPE = ADTypes.AutoForwardDiff() """ - DynamicPPL.TestUtils.AD.make_scenario( + make_scenario( model::Model, adtype::ADTypes.AbstractADType, varinfo::AbstractVarInfo=VarInfo(model), @@ -33,7 +33,7 @@ Construct a DifferentiationInterfaceTest.Scenario for the given `model` and `adt More docs to follow. """ -function DynamicPPL.TestUtils.AD.make_scenario( +function make_scenario( model::Model, adtype::ADTypes.AbstractADType; varinfo::AbstractVarInfo=VarInfo(model), @@ -65,4 +65,22 @@ function DynamicPPL.TestUtils.AD.make_scenario( ) end +function DynamicPPL.TestUtils.AD.run_ad( + model::Model, + adtype::ADTypes.AbstractADType; + varinfo::AbstractVarInfo=VarInfo(model), + params::Vector{<:Real}=varinfo[:], + reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, + expected_grad::Union{Nothing,Vector{<:Real}}=nothing, + kwargs..., +) + scen = make_scenario(model, adtype; varinfo=varinfo, expected_grad=expected_grad) + tweaked_adtype = DynamicPPL.tweak_adtype( + adtype, model, varinfo, DynamicPPL.DefaultContext() + ) + return DIT.test_differentiation( + tweaked_adtype, [scen]; scenario_intact=false, kwargs... + ) +end + end diff --git a/src/test_utils.jl b/src/test_utils.jl index c559da44c..db480eadc 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -20,7 +20,7 @@ include("test_utils/varinfo.jl") include("test_utils/sampler.jl") module AD - function make_scenario end + function run_ad end end end diff --git a/test/ad.jl b/test/ad.jl index c8df6fd92..b5a5a1384 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -57,13 +57,9 @@ import DifferentiationInterfaceTest as DIT ref_ldf, adtype ) else - scen = DynamicPPL.TestUtils.AD.make_scenario( + DynamicPPL.TestUtils.AD.run_ad( m, adtype; varinfo=varinfo, expected_grad=ref_grad ) - tadtype = DynamicPPL.tweak_adtype( - adtype, m, varinfo, DefaultContext() - ) - DIT.test_differentiation(tadtype, [scen]; scenario_intact=false) end end end From e1a34e1b1a783bf5ef4827f0fe578129f674b6ba Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 8 Apr 2025 14:56:03 +0100 Subject: [PATCH 4/5] Add varinfo to LogDensityFunction --- ext/DynamicPPLDifferentiationInterfaceTestExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/DynamicPPLDifferentiationInterfaceTestExt.jl b/ext/DynamicPPLDifferentiationInterfaceTestExt.jl index 194e22393..8602ef1c3 100644 --- a/ext/DynamicPPLDifferentiationInterfaceTestExt.jl +++ b/ext/DynamicPPLDifferentiationInterfaceTestExt.jl @@ -54,7 +54,7 @@ function make_scenario( # Calculate ground truth to compare against grad_true = if expected_grad === nothing - ldf_reference = LogDensityFunction(model; adtype=reference_adtype) + ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype) LogDensityProblems.logdensity_and_gradient(ldf_reference, params)[2] else expected_grad From 236f279303f7495e977d81eda27d4d7c9a689ccb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 14 Apr 2025 23:23:28 +0100 Subject: [PATCH 5/5] Update ext/DynamicPPLDifferentiationInterfaceTestExt.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- ext/DynamicPPLDifferentiationInterfaceTestExt.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/ext/DynamicPPLDifferentiationInterfaceTestExt.jl b/ext/DynamicPPLDifferentiationInterfaceTestExt.jl index 8602ef1c3..431ea3fb8 100644 --- a/ext/DynamicPPLDifferentiationInterfaceTestExt.jl +++ b/ext/DynamicPPLDifferentiationInterfaceTestExt.jl @@ -44,6 +44,7 @@ function make_scenario( params = map(identity, params) context = DynamicPPL.DefaultContext() adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo, context) + # Below is a performance optimisation, see: https://github.com/TuringLang/DynamicPPL.jl/pull/806#issuecomment-2658049143 if DynamicPPL.use_closure(adtype) f = x -> DynamicPPL.logdensity_at(x, model, varinfo, context) di_contexts = ()