From 0163b35bc919296f5f7775df66d3bcafed35758f Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 12 Mar 2025 12:26:04 +0000 Subject: [PATCH 1/3] Remove `Zygote`; fix https://github.com/TuringLang/Turing.jl/issues/2504 --- docs/src/api.md | 1 - src/Turing.jl | 1 - src/essential/Essential.jl | 3 +-- test/Project.toml | 2 -- test/essential/ad.jl | 20 +------------------- test/test_utils/ad_utils.jl | 18 +----------------- test/test_utils/test_utils.jl | 9 --------- 7 files changed, 3 insertions(+), 51 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index afedf59bb7..3066a7fad9 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -88,7 +88,6 @@ See the [AD guide](https://turinglang.org/docs/tutorials/docs-10-using-turing-au |:----------------- |:------------------------------------ |:---------------------- | | `AutoForwardDiff` | [`ADTypes.AutoForwardDiff`](@extref) | ForwardDiff.jl backend | | `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend | -| `AutoZygote` | [`ADTypes.AutoZygote`](@extref) | Zygote.jl backend | | `AutoMooncake` | [`ADTypes.AutoMooncake`](@extref) | Mooncake.jl backend | ### Debugging diff --git a/src/Turing.jl b/src/Turing.jl index 6318e2bd52..c6ca82d340 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -106,7 +106,6 @@ export @model, # modelling externalsampler, AutoForwardDiff, # ADTypes AutoReverseDiff, - AutoZygote, AutoMooncake, setprogress!, # debugging Flat, diff --git a/src/essential/Essential.jl b/src/essential/Essential.jl index c04c7e862b..cfa064c651 100644 --- a/src/essential/Essential.jl +++ b/src/essential/Essential.jl @@ -11,7 +11,7 @@ using Bijectors: PDMatDistribution using AdvancedVI using StatsFuns: logsumexp, softmax @reexport using DynamicPPL -using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoMooncake +using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake using AdvancedPS: AdvancedPS @@ -20,7 +20,6 @@ include("container.jl") export @model, @varname, AutoForwardDiff, - AutoZygote, AutoReverseDiff, AutoMooncake, @logprob_str, diff --git a/test/Project.toml b/test/Project.toml index e96d505e73..885ca1c7f7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -36,7 +36,6 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractMCMC = "5" @@ -75,5 +74,4 @@ StableRNGs = "1" StatsBase = "0.33, 0.34" StatsFuns = "0.9.5, 1" TimerOutputs = "0.5" -Zygote = "0.5.4, 0.6" julia = "1.10" diff --git a/test/essential/ad.jl b/test/essential/ad.jl index 0c497a2cbd..674976f729 100644 --- a/test/essential/ad.jl +++ b/test/essential/ad.jl @@ -11,7 +11,6 @@ using ReverseDiff using Test: @test, @testset using Turing using Turing: SampleFromPrior -using Zygote function test_model_ad(model, f, syms::Vector{Symbol}) # Set up VI. @@ -87,20 +86,6 @@ end vi, ad_test_f, SampleFromPrior(), DynamicPPL.DefaultContext() ) x = map(x -> Float64(x), vi[SampleFromPrior()]) - - zygoteℓ = LogDensityProblemsAD.ADgradient(Turing.AutoZygote(), ℓ) - if isdefined(Base, :get_extension) - @test zygoteℓ isa - Base.get_extension( - LogDensityProblemsAD, :LogDensityProblemsADZygoteExt - ).ZygoteGradientLogDensity - else - @test zygoteℓ isa - LogDensityProblemsAD.LogDensityProblemsADZygoteExt.ZygoteGradientLogDensity - end - @test zygoteℓ.ℓ === ℓ - ∇E2 = LogDensityProblems.logdensity_and_gradient(zygoteℓ, x)[2] - @test sort(∇E2) ≈ grad_FWAD atol = 1e-9 end @testset "general AD tests" begin @@ -135,11 +120,10 @@ end test_model_ad(wishart_ad(), logp3, [:v]) end - @testset "Simplex Zygote and ReverseDiff (with and without caching) AD" begin + @testset "Simplex ReverseDiff (with and without caching) AD" begin @model function dir() return theta ~ Dirichlet(1 ./ fill(4, 4)) end - sample(dir(), HMC(0.01, 1; adtype=AutoZygote()), 1000) sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000) sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=true)), 1000) end @@ -149,14 +133,12 @@ end end sample(wishart(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000) - sample(wishart(), HMC(0.01, 1; adtype=AutoZygote()), 1000) @model function invwishart() return theta ~ InverseWishart(4, Matrix{Float64}(I, 4, 4)) end sample(invwishart(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000) - sample(invwishart(), HMC(0.01, 1; adtype=AutoZygote()), 1000) end @testset "Hessian test" begin @model function tst(x, ::Type{TV}=Vector{Float64}) where {TV} diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index 2c01dc524e..7de93547ae 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -8,7 +8,6 @@ using Mooncake: Mooncake using Test: Test using Turing: Turing using Turing: DynamicPPL -using Zygote: Zygote export ADTypeCheckContext, adbackends @@ -31,9 +30,6 @@ const eltypes_by_adtype = Dict( ReverseDiff.TrackedVector, ), Turing.AutoMooncake => (Mooncake.CoDual,), - # Zygote.Dual is actually the same as ForwardDiff.Dual, so can't distinguish between the - # two by element type. However, we have other checks for Zygote, see check_adtype. - Turing.AutoZygote => (Zygote.Dual,), ) """ @@ -90,7 +86,6 @@ For instance, evaluating a model with would throw an error if within the model a type associated with e.g. ReverseDiff was encountered. -As a current short-coming, this context can not distinguish between ForwardDiff and Zygote. """ struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext @@ -134,20 +129,9 @@ end Check that the element types in `vi` are compatible with the ADType of `context`. -When Zygote is being used, we also more explicitly check that `adtype(context)` is -`AutoZygote`. This is because Zygote uses the same element type as ForwardDiff, so we can't -discriminate between the two based on element type alone. This function will still fail to -catch cases where Zygote is supposed to be used, but ForwardDiff is used instead. - -Throw an `IncompatibleADTypeError` if an incompatible element type is encountered, or -`WrongADBackendError` if Zygote is used unexpectedly. +Throw an `IncompatibleADTypeError` if an incompatible element type is encountered. """ function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo) - Zygote.hook(vi) do _ - if !(adtype(context) <: Turing.AutoZygote) - throw(WrongADBackendError(Turing.AutoZygote, adtype(context))) - end - end valids = valid_eltypes(context) for val in vi[:] diff --git a/test/test_utils/test_utils.jl b/test/test_utils/test_utils.jl index bf9f2b9b8d..243a80b881 100644 --- a/test/test_utils/test_utils.jl +++ b/test/test_utils/test_utils.jl @@ -7,7 +7,6 @@ using ReverseDiff: ReverseDiff using Test: @test, @testset, @test_throws using Turing: Turing using Turing: DynamicPPL -using Zygote: Zygote # Check that the ADTypeCheckContext works as expected. @testset "ADTypeCheckContext" begin @@ -16,20 +15,12 @@ using Zygote: Zygote adtypes = ( Turing.AutoForwardDiff(), Turing.AutoReverseDiff(), - Turing.AutoZygote(), # TODO: Mooncake # Turing.AutoMooncake(config=nothing), ) for actual_adtype in adtypes sampler = Turing.HMC(0.1, 5; adtype=actual_adtype) for expected_adtype in adtypes - if ( - actual_adtype == Turing.AutoForwardDiff() && - expected_adtype == Turing.AutoZygote() - ) - # TODO(mhauru) We are currently unable to check this case. - continue - end contextualised_tm = DynamicPPL.contextualize( tm, ADTypeCheckContext(expected_adtype, tm.context) ) From 0139b026f7e087174c5f6a5d49bdcf15b1fc1284 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 12 Mar 2025 12:33:13 +0000 Subject: [PATCH 2/3] Update test/test_utils/ad_utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/test_utils/ad_utils.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index 7de93547ae..2a94355c0b 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -132,7 +132,6 @@ Check that the element types in `vi` are compatible with the ADType of `context` Throw an `IncompatibleADTypeError` if an incompatible element type is encountered. """ function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo) - valids = valid_eltypes(context) for val in vi[:] valtype = typeof(val) From a23985c6c94795d9aa1fb6a066e38593f10f8568 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 14 Mar 2025 12:54:13 +0000 Subject: [PATCH 3/3] Add HISTORY.md entry about removing support for Zygote --- HISTORY.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index 6251974075..62ca1d350c 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -6,6 +6,12 @@ 0.37 removes the old Gibbs constructors deprecated in 0.36. +### Remove Zygote support + +Zygote is no longer officially supported as an automatic differentiation backend, and `AutoZygote` is no longer exported. You can continue to use Zygote by importing `AutoZygote` from ADTypes and it may well continue to work, but it is no longer tested and no effort will be expended to fix it if something breaks. + +[Mooncake](https://github.com/compintell/Mooncake.jl/) is the recommended replacement for Zygote. + ### DynamicPPL 0.35 Turing.jl v0.37 uses DynamicPPL v0.35, which brings with it several breaking changes: