Skip to content

Remove Zygote #2505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ See the [AD guide](https://turinglang.org/docs/tutorials/docs-10-using-turing-au
|:----------------- |:------------------------------------ |:---------------------- |
| `AutoForwardDiff` | [`ADTypes.AutoForwardDiff`](@extref) | ForwardDiff.jl backend |
| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend |
| `AutoZygote` | [`ADTypes.AutoZygote`](@extref) | Zygote.jl backend |
| `AutoMooncake` | [`ADTypes.AutoMooncake`](@extref) | Mooncake.jl backend |

### Debugging
Expand Down
1 change: 0 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ export @model, # modelling
externalsampler,
AutoForwardDiff, # ADTypes
AutoReverseDiff,
AutoZygote,
AutoMooncake,
setprogress!, # debugging
Flat,
Expand Down
3 changes: 1 addition & 2 deletions src/essential/Essential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using Bijectors: PDMatDistribution
using AdvancedVI
using StatsFuns: logsumexp, softmax
@reexport using DynamicPPL
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoMooncake
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake

using AdvancedPS: AdvancedPS

Expand All @@ -20,7 +20,6 @@ include("container.jl")
export @model,
@varname,
AutoForwardDiff,
AutoZygote,
AutoReverseDiff,
AutoMooncake,
@logprob_str,
Expand Down
2 changes: 0 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "5"
Expand Down Expand Up @@ -75,5 +74,4 @@ StableRNGs = "1"
StatsBase = "0.33, 0.34"
StatsFuns = "0.9.5, 1"
TimerOutputs = "0.5"
Zygote = "0.5.4, 0.6"
julia = "1.10"
20 changes: 1 addition & 19 deletions test/essential/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ using ReverseDiff
using Test: @test, @testset
using Turing
using Turing: SampleFromPrior
using Zygote

function test_model_ad(model, f, syms::Vector{Symbol})
# Set up VI.
Expand Down Expand Up @@ -87,20 +86,6 @@ end
vi, ad_test_f, SampleFromPrior(), DynamicPPL.DefaultContext()
)
x = map(x -> Float64(x), vi[SampleFromPrior()])

zygoteℓ = LogDensityProblemsAD.ADgradient(Turing.AutoZygote(), ℓ)
if isdefined(Base, :get_extension)
@test zygoteℓ isa
Base.get_extension(
LogDensityProblemsAD, :LogDensityProblemsADZygoteExt
).ZygoteGradientLogDensity
else
@test zygoteℓ isa
LogDensityProblemsAD.LogDensityProblemsADZygoteExt.ZygoteGradientLogDensity
end
@test zygoteℓ.ℓ === ℓ
∇E2 = LogDensityProblems.logdensity_and_gradient(zygoteℓ, x)[2]
@test sort(∇E2) ≈ grad_FWAD atol = 1e-9
end

@testset "general AD tests" begin
Expand Down Expand Up @@ -135,11 +120,10 @@ end

test_model_ad(wishart_ad(), logp3, [:v])
end
@testset "Simplex Zygote and ReverseDiff (with and without caching) AD" begin
@testset "Simplex ReverseDiff (with and without caching) AD" begin
@model function dir()
return theta ~ Dirichlet(1 ./ fill(4, 4))
end
sample(dir(), HMC(0.01, 1; adtype=AutoZygote()), 1000)
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000)
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=true)), 1000)
end
Expand All @@ -149,14 +133,12 @@ end
end

sample(wishart(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000)
sample(wishart(), HMC(0.01, 1; adtype=AutoZygote()), 1000)

@model function invwishart()
return theta ~ InverseWishart(4, Matrix{Float64}(I, 4, 4))
end

sample(invwishart(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000)
sample(invwishart(), HMC(0.01, 1; adtype=AutoZygote()), 1000)
end
@testset "Hessian test" begin
@model function tst(x, ::Type{TV}=Vector{Float64}) where {TV}
Expand Down
18 changes: 1 addition & 17 deletions test/test_utils/ad_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ using Mooncake: Mooncake
using Test: Test
using Turing: Turing
using Turing: DynamicPPL
using Zygote: Zygote

export ADTypeCheckContext, adbackends

Expand All @@ -31,9 +30,6 @@ const eltypes_by_adtype = Dict(
ReverseDiff.TrackedVector,
),
Turing.AutoMooncake => (Mooncake.CoDual,),
# Zygote.Dual is actually the same as ForwardDiff.Dual, so can't distinguish between the
# two by element type. However, we have other checks for Zygote, see check_adtype.
Turing.AutoZygote => (Zygote.Dual,),
)

"""
Expand Down Expand Up @@ -90,7 +86,6 @@ For instance, evaluating a model with
would throw an error if within the model a type associated with e.g. ReverseDiff was
encountered.

As a current short-coming, this context can not distinguish between ForwardDiff and Zygote.
"""
struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <:
DynamicPPL.AbstractContext
Expand Down Expand Up @@ -134,20 +129,9 @@ end

Check that the element types in `vi` are compatible with the ADType of `context`.

When Zygote is being used, we also more explicitly check that `adtype(context)` is
`AutoZygote`. This is because Zygote uses the same element type as ForwardDiff, so we can't
discriminate between the two based on element type alone. This function will still fail to
catch cases where Zygote is supposed to be used, but ForwardDiff is used instead.

Throw an `IncompatibleADTypeError` if an incompatible element type is encountered, or
`WrongADBackendError` if Zygote is used unexpectedly.
Throw an `IncompatibleADTypeError` if an incompatible element type is encountered.
"""
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo)
Zygote.hook(vi) do _
if !(adtype(context) <: Turing.AutoZygote)
throw(WrongADBackendError(Turing.AutoZygote, adtype(context)))
end
end

valids = valid_eltypes(context)
for val in vi[:]
Expand Down
9 changes: 0 additions & 9 deletions test/test_utils/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ using ReverseDiff: ReverseDiff
using Test: @test, @testset, @test_throws
using Turing: Turing
using Turing: DynamicPPL
using Zygote: Zygote

# Check that the ADTypeCheckContext works as expected.
@testset "ADTypeCheckContext" begin
Expand All @@ -16,20 +15,12 @@ using Zygote: Zygote
adtypes = (
Turing.AutoForwardDiff(),
Turing.AutoReverseDiff(),
Turing.AutoZygote(),
# TODO: Mooncake
# Turing.AutoMooncake(config=nothing),
)
for actual_adtype in adtypes
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
for expected_adtype in adtypes
if (
actual_adtype == Turing.AutoForwardDiff() &&
expected_adtype == Turing.AutoZygote()
)
# TODO(mhauru) We are currently unable to check this case.
continue
end
contextualised_tm = DynamicPPL.contextualize(
tm, ADTypeCheckContext(expected_adtype, tm.context)
)
Expand Down
Loading