From 1c19d56238f1cd181327bc9ba8b23fc785c6003c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 18 Jun 2025 18:05:32 +0200 Subject: [PATCH 1/2] feat: support forward-mode Mooncake [experimental] --- .../docs/src/explanation/backends.md | 4 +- .../DifferentiationInterfaceMooncakeExt.jl | 20 ++-- .../forward_onearg.jl | 92 ++++++++++++++ .../forward_twoarg.jl | 112 ++++++++++++++++++ .../src/DifferentiationInterface.jl | 2 + .../test/Back/Mooncake/test.jl | 6 +- 6 files changed, 227 insertions(+), 9 deletions(-) create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index 013d9e7c8..b04f35f6b 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -12,7 +12,7 @@ We support the following dense backend choices from [ADTypes.jl](https://github. - [`AutoFiniteDifferences`](@extref ADTypes.AutoFiniteDifferences) - [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff) - [`AutoGTPSA`](@extref ADTypes.AutoGTPSA) -- [`AutoMooncake`](@extref ADTypes.AutoMooncake) +- [`AutoMooncake`](@extref ADTypes.AutoMooncake) and [`AutoMooncakeForward`](@extref ADTypes.AutoMooncake) - [`AutoPolyesterForwardDiff`](@extref ADTypes.AutoPolyesterForwardDiff) - [`AutoReverseDiff`](@extref ADTypes.AutoReverseDiff) - [`AutoSymbolics`](@extref ADTypes.AutoSymbolics) @@ -48,6 +48,7 @@ In practice, many AD backends have custom implementations for high-level operato | `AutoForwardDiff` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | `AutoGTPSA` | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | | `AutoMooncake` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | + | `AutoMooncakeForward` | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | `AutoPolyesterForwardDiff` | 🔀 | ❌ | 🔀 | ✅ | ✅ | 🔀 | 🔀 | 🔀 | | `AutoReverseDiff` | ❌ | 🔀 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | `AutoSymbolics` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | @@ -68,6 +69,7 @@ Moreover, each context type is supported by a specific subset of backends: | `AutoForwardDiff` | ✅ | ✅ | | `AutoGTPSA` | ✅ | ❌ | | `AutoMooncake` | ✅ | ✅ | +| `AutoMooncakeForward` | ✅ | ✅ | | `AutoPolyesterForwardDiff` | ✅ | ✅ | | `AutoReverseDiff` | ✅ | ❌ | | `AutoSymbolics` | ✅ | ✅ | diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 6253ea229..0d40a2590 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -1,29 +1,35 @@ module DifferentiationInterfaceMooncakeExt -using ADTypes: ADTypes, AutoMooncake +using ADTypes: ADTypes, AutoMooncake, AutoMooncakeForward import DifferentiationInterface as DI using Mooncake: CoDual, Config, + Dual, + prepare_derivative_cache, prepare_gradient_cache, prepare_pullback_cache, + primal, + tangent, tangent_type, + value_and_derivative!!, value_and_gradient!!, value_and_pullback!!, + zero_dual, zero_tangent, _copy_output, _copy_to_output!! -DI.check_available(::AutoMooncake) = true +const AnyAutoMooncake{C} = Union{AutoMooncake{C},AutoMooncakeForward{C}} -get_config(::AutoMooncake{Nothing}) = Config() -get_config(backend::AutoMooncake{<:Config}) = backend.config +DI.check_available(::AnyAutoMooncake) = true -# tangents need to be copied before returning, otherwise they are still aliased in the cache -mycopy(x::Union{Number,AbstractArray{<:Number}}) = copy(x) -mycopy(x) = deepcopy(x) +get_config(::AnyAutoMooncake{Nothing}) = Config() +get_config(backend::AnyAutoMooncake{<:Config}) = backend.config include("onearg.jl") include("twoarg.jl") +include("forward_onearg.jl") +include("forward_twoarg.jl") end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl new file mode 100644 index 000000000..c9228d972 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -0,0 +1,92 @@ +## Pushforward + +struct MooncakeOneArgPushforwardPrep{SIG,Tcache,DX} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} + cache::Tcache + dx_righttype::DX +end + +function DI.prepare_pushforward_nokwarg( + strict::Val, + f::F, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + _sig = DI.signature(f, backend, x, tx, contexts...; strict) + config = get_config(backend) + # TODO: silence_debug_messages + cache = prepare_derivative_cache(f, x, map(DI.unwrap, contexts)...; config.debug_mode) + dx_righttype = zero_tangent(x) + prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype) + return prep +end + +function DI.value_and_pushforward( + f::F, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x::X, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C,X} + DI.check_prep(f, prep, backend, x, tx, contexts...) + ys_and_ty = map(tx) do dx + dx_righttype = + dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) + y_dual = value_and_derivative!!( + prep.cache, + zero_dual(f), + Dual(x, dx_righttype), + map(zero_dual ∘ DI.unwrap, contexts)..., + ) + y = primal(y_dual) + dy = _copy_output(tangent(y_dual)) + return y, dy + end + y = first(ys_and_ty[1]) + ty = last.(ys_and_ty) + return y, ty +end + +function DI.pushforward( + f::F, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) + return DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)[2] +end + +function DI.value_and_pushforward!( + f::F, + ty::NTuple, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) + y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...) + foreach(copyto!, ty, new_ty) + return y, ty +end + +function DI.pushforward!( + f::F, + ty::NTuple, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + DI.check_prep(f, prep, backend, x, tx, contexts...) + DI.value_and_pushforward!(f, ty, prep, backend, x, tx, contexts...) + return ty +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl new file mode 100644 index 000000000..f90524643 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -0,0 +1,112 @@ +## Pushforward + +struct MooncakeTwoArgPushforwardPrep{SIG,Tcache,DX,DY} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} + cache::Tcache + dx_righttype::DX + dy_righttype::DY +end + +function DI.prepare_pushforward_nokwarg( + strict::Val, + f!::F, + y, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) + config = get_config(backend) + # TODO: silence_debug_messages + cache = prepare_derivative_cache( + f!, y, x, map(DI.unwrap, contexts)...; config.debug_mode + ) + dx_righttype = zero_tangent(x) + dy_righttype = zero_tangent(y) + prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype) + return prep +end + +function DI.value_and_pushforward( + f!::F, + y, + prep::MooncakeTwoArgPushforwardPrep, + backend::AutoMooncakeForward, + x::X, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C,X} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + ty = map(tx) do dx + dx_righttype = + dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) + y_dual = zero_dual(y) + value_and_derivative!!( + prep.cache, + zero_dual(f!), + y_dual, + Dual(x, dx_righttype), + map(zero_dual ∘ DI.unwrap, contexts)..., + ) + dy = _copy_output(tangent(y_dual)) + return dy + end + return y, ty +end + +function DI.pushforward( + f!::F, + y, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + return DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)[2] +end + +function DI.value_and_pushforward!( + f!::F, + y::Y, + ty::NTuple, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x::X, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C,X,Y} + DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + foreach(tx, ty) do dx, dy + dx_righttype = + dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) + dy_righttype = + dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) + value_and_derivative!!( + prep.cache, + zero_dual(f!), + Dual(y, dy_righttype), + Dual(x, dx_righttype), + map(zero_dual ∘ DI.unwrap, contexts)..., + ) + dy === dy_righttype || copyto!(dy, dy_righttype) + end + return y, ty +end + +function DI.pushforward!( + f!::F, + y, + ty::NTuple, + prep::MooncakeOneArgPushforwardPrep, + backend::AutoMooncakeForward, + x, + tx::NTuple, + contexts::Vararg{DI.Context,C}; +) where {F,C} + DI.check_prep(f!, y, ty, prep, backend, x, tx, contexts...) + DI.pushforward!(f!, y, ty, prep, backend, x, tx, contexts...) + return ty +end diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 32e699572..85f25fa28 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -28,6 +28,7 @@ using ADTypes: AutoForwardDiff, AutoGTPSA, AutoMooncake, + AutoMooncakeForward, AutoPolyesterForwardDiff, AutoReverseDiff, AutoSymbolics, @@ -115,6 +116,7 @@ export AutoFiniteDifferences export AutoForwardDiff export AutoGTPSA export AutoMooncake +export AutoMooncakeForward export AutoPolyesterForwardDiff export AutoReverseDiff export AutoSymbolics diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 8c9ab839a..7f48c3899 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -10,7 +10,11 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" -backends = [AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config())] +backends = [ + AutoMooncake(; config=nothing), + AutoMooncake(; config=Mooncake.Config()), + AutoMooncakeForward(; config=nothing); +] for backend in backends @test check_available(backend) From 2f9b365ef1007dad935274c96cacd965ba2ea97f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 1 Jul 2025 10:33:59 +0200 Subject: [PATCH 2/2] Fix comma --- DifferentiationInterface/test/Back/Mooncake/test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 7f48c3899..b695179f1 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -13,7 +13,7 @@ LOGGING = get(ENV, "CI", "false") == "false" backends = [ AutoMooncake(; config=nothing), AutoMooncake(; config=Mooncake.Config()), - AutoMooncakeForward(; config=nothing); + AutoMooncakeForward(; config=nothing), ] for backend in backends