Skip to content

Commit 9901fc4

Browse files
Fixed GetJac! for FiniteDiff.jl
1 parent 4316ee9 commit 9901fc4

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

src/DifferentiationOperators.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ _GetDeriv(ADmode::Val{:Zygote}; kwargs...) = throw("GetDeriv() not available for
1111
_GetGrad(ADmode::Val{:Zygote}; order::Int=-1, kwargs...) = (Func::Function,p;Kwargs...) -> Zygote.gradient(Func, p; kwargs...)[1]
1212
_GetJac(ADmode::Val{:Zygote}; order::Int=-1, kwargs...) = (Func::Function,p;Kwargs...) -> Zygote.jacobian(Func, p; kwargs...)[1]
1313
_GetHess(ADmode::Val{:Zygote}; order::Int=-1, kwargs...) = (Func::Function,p;Kwargs...) -> Zygote.hessian(Func, p; kwargs...)
14+
_GetDoubleJac(ADmode::Val{:Zygote}; kwargs...) = throw("GetDoubleJac() not available for Zygote.jl") # Zygote does not support mutating arrays
1415

1516
_GetDeriv(ADmode::Val{:FiniteDiff}; kwargs...) = FiniteDiff.finite_difference_derivative
1617
_GetGrad(ADmode::Val{:FiniteDiff}; kwargs...) = FiniteDiff.finite_difference_gradient
1718
_GetJac(ADmode::Val{:FiniteDiff}; kwargs...) = FiniteDiff.finite_difference_jacobian
1819
_GetHess(ADmode::Val{:FiniteDiff}; kwargs...) = FiniteDiff.finite_difference_hessian
20+
_GetDoubleJac(ADmode::Val{:FiniteDiff}; kwargs...) = throw("GetDoubleJac() not available for FiniteDiff.jl")
1921

2022
_GetDeriv(ADmode::Val{:FiniteDifferences}; kwargs...) = throw("GetDeriv() not available for FiniteDifferences.jl")
2123
_GetGrad(ADmode::Val{:FiniteDifferences}; order::Int=3, kwargs...) = (Func::Function,p;Kwargs...) -> FiniteDifferences.grad(central_fdm(order,1), Func, p; kwargs...)[1]
@@ -33,7 +35,18 @@ _GetMatrixJac!(ADmode::Val{:ReverseDiff}; kwargs...) = _GetJac!(ADmode; kwargs..
3335

3436
#_GetDeriv!(ADmode::Val{:FiniteDiff}; kwargs...) = FiniteDiff.finite_difference_derivative!
3537
_GetGrad!(ADmode::Val{:FiniteDiff}; kwargs...) = FiniteDiff.finite_difference_gradient!
36-
_GetJac!(ADmode::Val{:FiniteDiff}; kwargs...) = FiniteDiff.finite_difference_jacobian!
38+
function _GetJac!(ADmode::Val{:FiniteDiff}; kwargs...)
39+
function FiniteDiff__finite_difference_jacobian!(Y::AbstractArray{<:Number}, F::Function, X, args...; kwargs...)
40+
# in-place FiniteDiff operators assume that function itself is also in-place
41+
if MaximalNumberOfArguments(F) > 1
42+
FiniteDiff.finite_difference_jacobian!(Y, F, X, args...; kwargs...)
43+
else
44+
# Use fake method
45+
(Y[:] .= vec(_GetJac(ADmode; kwargs...)(F, X, args...)))
46+
# FiniteDiff.finite_difference_jacobian!(Y, (Res,x)->copyto!(Res,F(x)), args...; kwargs...)
47+
end
48+
end
49+
end
3750
_GetHess!(ADmode::Val{:FiniteDiff}; kwargs...) = FiniteDiff.finite_difference_hessian!
3851
_GetMatrixJac!(ADmode::Val{:FiniteDiff}; kwargs...) = _GetJac!(ADmode; kwargs...)
3952

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ using SafeTestsets
2424
for ADmode [:Zygote, :ReverseDiff, :FiniteDifferences]
2525
MyTest(ADmode)
2626
end
27+
MyTest(:FiniteDiff; atol=0.2)
2728

2829

2930
function TestDoubleJac(ADmode::Symbol; atol::Real=2e-5, kwargs...)
@@ -65,4 +66,5 @@ end
6566
for ADmode [:Zygote, :ReverseDiff, :FiniteDifferences]
6667
MyInplaceTest(ADmode)
6768
end
69+
MyInplaceTest(:FiniteDiff; atol=0.2)
6870
end

0 commit comments

Comments
 (0)