Skip to content

Remove Zygote #2505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@

0.37 removes the old Gibbs constructors deprecated in 0.36.

### Remove Zygote support

Zygote is no longer officially supported as an automatic differentiation backend, and `AutoZygote` is no longer exported. You can continue to use Zygote by importing `AutoZygote` from ADTypes and it may well continue to work, but it is no longer tested and no effort will be expended to fix it if something breaks.

[Mooncake](https://github.com/compintell/Mooncake.jl/) is the recommended replacement for Zygote.

### DynamicPPL 0.35

Turing.jl v0.37 uses DynamicPPL v0.35, which brings with it several breaking changes:
Expand Down
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ See the [AD guide](https://turinglang.org/docs/tutorials/docs-10-using-turing-au
|:----------------- |:------------------------------------ |:---------------------- |
| `AutoForwardDiff` | [`ADTypes.AutoForwardDiff`](@extref) | ForwardDiff.jl backend |
| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend |
| `AutoZygote` | [`ADTypes.AutoZygote`](@extref) | Zygote.jl backend |
| `AutoMooncake` | [`ADTypes.AutoMooncake`](@extref) | Mooncake.jl backend |

### Debugging
Expand Down
1 change: 0 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ export @model, # modelling
externalsampler,
AutoForwardDiff, # ADTypes
AutoReverseDiff,
AutoZygote,
AutoMooncake,
setprogress!, # debugging
Flat,
Expand Down
3 changes: 1 addition & 2 deletions src/essential/Essential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using Bijectors: PDMatDistribution
using AdvancedVI
using StatsFuns: logsumexp, softmax
@reexport using DynamicPPL
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoMooncake
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake

using AdvancedPS: AdvancedPS

Expand All @@ -20,7 +20,6 @@ include("container.jl")
export @model,
@varname,
AutoForwardDiff,
AutoZygote,
AutoReverseDiff,
AutoMooncake,
@logprob_str,
Expand Down
2 changes: 0 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "5"
Expand Down Expand Up @@ -75,5 +74,4 @@ StableRNGs = "1"
StatsBase = "0.33, 0.34"
StatsFuns = "0.9.5, 1"
TimerOutputs = "0.5"
Zygote = "0.5.4, 0.6"
julia = "1.10"
19 changes: 1 addition & 18 deletions test/test_utils/ad_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ using Mooncake: Mooncake
using Test: Test
using Turing: Turing
using Turing: DynamicPPL
using Zygote: Zygote

export ADTypeCheckContext, adbackends

Expand All @@ -31,9 +30,6 @@ const eltypes_by_adtype = Dict(
ReverseDiff.TrackedVector,
),
Turing.AutoMooncake => (Mooncake.CoDual,),
# Zygote.Dual is actually the same as ForwardDiff.Dual, so can't distinguish between the
# two by element type. However, we have other checks for Zygote, see check_adtype.
Turing.AutoZygote => (Zygote.Dual,),
)

"""
Expand Down Expand Up @@ -90,7 +86,6 @@ For instance, evaluating a model with
would throw an error if within the model a type associated with e.g. ReverseDiff was
encountered.

As a current short-coming, this context can not distinguish between ForwardDiff and Zygote.
"""
struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <:
DynamicPPL.AbstractContext
Expand Down Expand Up @@ -134,21 +129,9 @@ end

Check that the element types in `vi` are compatible with the ADType of `context`.

When Zygote is being used, we also more explicitly check that `adtype(context)` is
`AutoZygote`. This is because Zygote uses the same element type as ForwardDiff, so we can't
discriminate between the two based on element type alone. This function will still fail to
catch cases where Zygote is supposed to be used, but ForwardDiff is used instead.

Throw an `IncompatibleADTypeError` if an incompatible element type is encountered, or
`WrongADBackendError` if Zygote is used unexpectedly.
Throw an `IncompatibleADTypeError` if an incompatible element type is encountered.
"""
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo)
Zygote.hook(vi) do _
if !(adtype(context) <: Turing.AutoZygote)
throw(WrongADBackendError(Turing.AutoZygote, adtype(context)))
end
end

valids = valid_eltypes(context)
for val in vi[:]
valtype = typeof(val)
Expand Down
9 changes: 0 additions & 9 deletions test/test_utils/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ using ReverseDiff: ReverseDiff
using Test: @test, @testset, @test_throws
using Turing: Turing
using Turing: DynamicPPL
using Zygote: Zygote

# Check that the ADTypeCheckContext works as expected.
@testset "ADTypeCheckContext" begin
Expand All @@ -16,20 +15,12 @@ using Zygote: Zygote
adtypes = (
Turing.AutoForwardDiff(),
Turing.AutoReverseDiff(),
Turing.AutoZygote(),
# TODO: Mooncake
# Turing.AutoMooncake(config=nothing),
)
for actual_adtype in adtypes
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
for expected_adtype in adtypes
if (
actual_adtype == Turing.AutoForwardDiff() &&
expected_adtype == Turing.AutoZygote()
)
# TODO(mhauru) We are currently unable to check this case.
continue
end
contextualised_tm = DynamicPPL.contextualize(
tm, ADTypeCheckContext(expected_adtype, tm.context)
)
Expand Down
Loading