From e2b19860bdb1b4a82c1c6fd968f71d117f63934a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 25 Sep 2024 07:58:30 +0200 Subject: [PATCH 1/2] Remove unneeded evaluation for ReverseDiff --- .../onearg.jl | 78 +++++-------------- .../twoarg.jl | 4 +- 2 files changed, 21 insertions(+), 61 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index 44b405ff7..e57857709 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -47,9 +47,7 @@ struct ReverseDiffGradientPrep{T} <: GradientPrep tape::T end -function DI.prepare_gradient( - f, ::AutoReverseDiff{Compile}, x::AbstractArray -) where {Compile} +function DI.prepare_gradient(f, ::AutoReverseDiff{Compile}, x) where {Compile} tape = GradientTape(f, x) if Compile tape = compile(tape) @@ -58,36 +56,27 @@ function DI.prepare_gradient( end function DI.value_and_gradient!( - f, - grad::AbstractArray, - prep::ReverseDiffGradientPrep, - ::AutoReverseDiff, - x::AbstractArray, + f, grad::AbstractArray, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x ) - y = f(x) # TODO: remove once ReverseDiff#251 is fixed - result = MutableDiffResult(y, (grad,)) + result = MutableDiffResult(zero(eltype(grad)), (grad,)) result = gradient!(result, prep.tape, x) return DiffResults.value(result), DiffResults.derivative(result) end function DI.value_and_gradient( - f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff, x::AbstractArray + f, prep::ReverseDiffGradientPrep, backend::AutoReverseDiff, x ) grad = similar(x) return DI.value_and_gradient!(f, grad, prep, backend, x) end function DI.gradient!( - _f, - grad::AbstractArray, - prep::ReverseDiffGradientPrep, - ::AutoReverseDiff, - x::AbstractArray, + _f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x::AbstractArray ) return gradient!(grad, prep.tape, x) end -function DI.gradient(_f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x::AbstractArray) +function DI.gradient(_f, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x) return gradient!(prep.tape, x) end @@ -97,9 +86,7 @@ struct ReverseDiffOneArgJacobianPrep{T} <: JacobianPrep tape::T end -function DI.prepare_jacobian( - f, ::AutoReverseDiff{Compile}, x::AbstractArray -) where {Compile} +function DI.prepare_jacobian(f, ::AutoReverseDiff{Compile}, x) where {Compile} tape = JacobianTape(f, x) if Compile tape = compile(tape) @@ -108,11 +95,7 @@ function DI.prepare_jacobian( end function DI.value_and_jacobian!( - f, - jac::AbstractMatrix, - prep::ReverseDiffOneArgJacobianPrep, - ::AutoReverseDiff, - x::AbstractArray, + f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x ) y = f(x) result = MutableDiffResult(y, (jac,)) @@ -120,25 +103,15 @@ function DI.value_and_jacobian!( return DiffResults.value(result), DiffResults.derivative(result) end -function DI.value_and_jacobian( - f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x::AbstractArray -) +function DI.value_and_jacobian(f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x) return f(x), jacobian!(prep.tape, x) end -function DI.jacobian!( - _f, - jac::AbstractMatrix, - prep::ReverseDiffOneArgJacobianPrep, - ::AutoReverseDiff, - x::AbstractArray, -) +function DI.jacobian!(_f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x) return jacobian!(jac, prep.tape, x) end -function DI.jacobian( - f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x::AbstractArray -) +function DI.jacobian(f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x) return jacobian!(prep.tape, x) end @@ -148,7 +121,7 @@ struct ReverseDiffHessianPrep{T} <: HessianPrep tape::T end -function DI.prepare_hessian(f, ::AutoReverseDiff{Compile}, x::AbstractArray) where {Compile} +function DI.prepare_hessian(f, ::AutoReverseDiff{Compile}, x) where {Compile} tape = HessianTape(f, x) if Compile tape = compile(tape) @@ -156,30 +129,18 @@ function DI.prepare_hessian(f, ::AutoReverseDiff{Compile}, x::AbstractArray) whe return ReverseDiffHessianPrep(tape) end -function DI.hessian!( - _f, - hess::AbstractMatrix, - prep::ReverseDiffHessianPrep, - ::AutoReverseDiff, - x::AbstractArray, -) +function DI.hessian!(_f, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x) return hessian!(hess, prep.tape, x) end -function DI.hessian(_f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x::AbstractArray) +function DI.hessian(_f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x) return hessian!(prep.tape, x) end function DI.value_gradient_and_hessian!( - f, - grad, - hess::AbstractMatrix, - prep::ReverseDiffHessianPrep, - ::AutoReverseDiff, - x::AbstractArray, + f, grad, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x ) - y = f(x) # TODO: remove once ReverseDiff#251 is fixed - result = MutableDiffResult(y, (grad, hess)) + result = MutableDiffResult(one(eltype(grad)), (grad, hess)) result = hessian!(result, prep.tape, x) return ( DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) @@ -187,10 +148,11 @@ function DI.value_gradient_and_hessian!( end function DI.value_gradient_and_hessian( - f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x::AbstractArray + f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x ) - y = f(x) # TODO: remove once ReverseDiff#251 is fixed - result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x)))) + result = MutableDiffResult( + one(eltype(x)), (similar(x), similar(x, length(x), length(x))) + ) result = hessian!(result, prep.tape, x) return ( DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl index 334a9e9c1..6153a6c97 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl @@ -81,9 +81,7 @@ struct ReverseDiffTwoArgJacobianPrep{T} <: JacobianPrep tape::T end -function DI.prepare_jacobian( - f!, y::AbstractArray, ::AutoReverseDiff{Compile}, x::AbstractArray -) where {Compile} +function DI.prepare_jacobian(f!, y, ::AutoReverseDiff{Compile}, x) where {Compile} tape = JacobianTape(f!, y, x) if Compile tape = compile(tape) From d6728828ea623433d00c29662183591a994f62b9 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 25 Sep 2024 08:49:46 +0200 Subject: [PATCH 2/2] Undo fix --- .../ext/DifferentiationInterfaceReverseDiffExt/onearg.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index e57857709..d4b340d43 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -58,7 +58,8 @@ end function DI.value_and_gradient!( f, grad::AbstractArray, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x ) - result = MutableDiffResult(zero(eltype(grad)), (grad,)) + y = f(x) # TODO: remove once ReverseDiff#251 is fixed + result = MutableDiffResult(y, (grad,)) result = gradient!(result, prep.tape, x) return DiffResults.value(result), DiffResults.derivative(result) end @@ -140,7 +141,8 @@ end function DI.value_gradient_and_hessian!( f, grad, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x ) - result = MutableDiffResult(one(eltype(grad)), (grad, hess)) + y = f(x) # TODO: remove once ReverseDiff#251 is fixed + result = MutableDiffResult(y, (grad, hess)) result = hessian!(result, prep.tape, x) return ( DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)