Skip to content

Better scenarios with weird arrays #523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Oct 1, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ using DifferentiationInterface:
outer,
unwrap,
with_contexts
import ForwardDiff.DiffResults as DR
using ForwardDiff.DiffResults:
DiffResults, DiffResult, GradientResult, HessianResult, MutableDiffResult
using ForwardDiff:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,11 @@ function DI.value_and_gradient!(
f::F, grad, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
) where {F,C}
fc = with_contexts(f, contexts...)
result = MutableDiffResult(zero(eltype(x)), (grad,))
result = DiffResult(zero(eltype(x)), (grad,))
result = gradient!(result, fc, x)
return DiffResults.value(result), DiffResults.gradient(result)
y = DR.value(result)
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
return y, grad
end

function DI.value_and_gradient(
Expand All @@ -176,7 +178,7 @@ function DI.value_and_gradient(
fc = with_contexts(f, contexts...)
result = GradientResult(x)
result = gradient!(result, fc, x)
return DiffResults.value(result), DiffResults.gradient(result)
return DR.value(result), DR.gradient(result)
end

function DI.gradient!(
Expand Down Expand Up @@ -213,9 +215,11 @@ function DI.value_and_gradient!(
contexts::Vararg{Context,C},
) where {F,C}
fc = with_contexts(f, contexts...)
result = MutableDiffResult(zero(eltype(x)), (grad,))
result = DiffResult(zero(eltype(x)), (grad,))
result = gradient!(result, fc, x, prep.config)
return DiffResults.value(result), DiffResults.gradient(result)
y = DR.value(result)
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
return y, grad
end

function DI.value_and_gradient(
Expand All @@ -224,7 +228,7 @@ function DI.value_and_gradient(
fc = with_contexts(f, contexts...)
result = GradientResult(x)
result = gradient!(result, fc, x, prep.config)
return DiffResults.value(result), DiffResults.gradient(result)
return DR.value(result), DR.gradient(result)
end

function DI.gradient!(
Expand Down Expand Up @@ -255,9 +259,11 @@ function DI.value_and_jacobian!(
) where {F,C}
fc = with_contexts(f, contexts...)
y = fc(x)
result = MutableDiffResult(y, (jac,))
result = DiffResult(y, (jac,))
result = jacobian!(result, fc, x)
return DiffResults.value(result), DiffResults.jacobian(result)
y = DR.value(result)
jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result))
return y, jac
end

function DI.value_and_jacobian(
Expand Down Expand Up @@ -302,9 +308,11 @@ function DI.value_and_jacobian!(
) where {F,C}
fc = with_contexts(f, contexts...)
y = fc(x)
result = MutableDiffResult(y, (jac,))
result = DiffResult(y, (jac,))
result = jacobian!(result, fc, x, prep.config)
return DiffResults.value(result), DiffResults.jacobian(result)
y = DR.value(result)
jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result))
return y, jac
end

function DI.value_and_jacobian(
Expand Down Expand Up @@ -457,11 +465,12 @@ function DI.value_gradient_and_hessian!(
f::F, grad, hess, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
) where {F,C}
fc = with_contexts(f, contexts...)
result = MutableDiffResult(one(eltype(x)), (grad, hess))
result = DiffResult(one(eltype(x)), (grad, hess))
result = hessian!(result, fc, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
)
y = DR.value(result)
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
return (y, grad, hess)
end

function DI.value_gradient_and_hessian(
Expand All @@ -470,9 +479,7 @@ function DI.value_gradient_and_hessian(
fc = with_contexts(f, contexts...)
result = HessianResult(x)
result = hessian!(result, fc, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
)
return (DR.value(result), DR.gradient(result), DR.hessian(result))
end

### Prepared
Expand Down Expand Up @@ -527,11 +534,12 @@ function DI.value_gradient_and_hessian!(
contexts::Vararg{Context,C},
) where {F,C}
fc = with_contexts(f, contexts...)
result = MutableDiffResult(one(eltype(x)), (grad, hess))
result = DiffResult(one(eltype(x)), (grad, hess))
result = hessian!(result, fc, x, prep.manual_result_config)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
)
y = DR.value(result)
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
return (y, grad, hess)
end

function DI.value_gradient_and_hessian(
Expand All @@ -540,7 +548,5 @@ function DI.value_gradient_and_hessian(
fc = with_contexts(f, contexts...)
result = HessianResult(x)
result = hessian!(result, fc, x, prep.auto_result_config)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
)
return (DR.value(result), DR.gradient(result), DR.hessian(result))
end
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ using DifferentiationInterface:
NoPullbackPrep,
unwrap,
with_contexts
using ReverseDiff.DiffResults: DiffResults, DiffResult, GradientResult, MutableDiffResult
import ReverseDiff.DiffResults as DR
using ReverseDiff.DiffResults:
DiffResults, DiffResult, GradientResult, HessianResult, MutableDiffResult
using LinearAlgebra: dot, mul!
using ReverseDiff:
CompiledGradient,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,12 @@ end
function DI.value_and_gradient!(
f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x
)
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (grad,))
y = f(x) # TODO: ReverseDiff#251
result = DiffResult(y, (grad,))
result = gradient!(result, prep.tape, x)
return DiffResults.value(result), DiffResults.derivative(result)
y = DR.value(result)
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
return y, grad
end

function DI.value_and_gradient(
Expand Down Expand Up @@ -115,10 +117,12 @@ function DI.value_and_gradient!(
f, grad, ::NoGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
) where {C}
fc = with_contexts(f, contexts...)
y = fc(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (grad,))
y = fc(x) # TODO: ReverseDiff#251
result = DiffResult(y, (grad,))
result = gradient!(result, fc, x)
return DiffResults.value(result), DiffResults.derivative(result)
y = DR.value(result)
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
return y, grad
end

function DI.value_and_gradient(
Expand Down Expand Up @@ -162,9 +166,11 @@ function DI.value_and_jacobian!(
f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x
)
y = f(x)
result = MutableDiffResult(y, (jac,))
result = DiffResult(y, (jac,))
result = jacobian!(result, prep.tape, x)
return DiffResults.value(result), DiffResults.derivative(result)
y = DR.value(result)
jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result))
return y, jac
end

function DI.value_and_jacobian(f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x)
Expand All @@ -190,9 +196,11 @@ function DI.value_and_jacobian!(
) where {C}
fc = with_contexts(f, contexts...)
y = fc(x)
result = MutableDiffResult(y, (jac,))
result = DiffResult(y, (jac,))
result = jacobian!(result, fc, x)
return DiffResults.value(result), DiffResults.derivative(result)
y = DR.value(result)
jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result))
return y, jac
end

function DI.value_and_jacobian(
Expand Down Expand Up @@ -220,46 +228,49 @@ end

### Without contexts

struct ReverseDiffHessianPrep{T} <: HessianPrep
tape::T
struct ReverseDiffHessianGradientPrep{GT,HT} <: HessianPrep
gradient_tape::GT
hessian_tape::HT
end

function DI.prepare_hessian(f, ::AutoReverseDiff{Compile}, x) where {Compile}
tape = HessianTape(f, x)
gradient_tape = GradientTape(f, x)
hessian_tape = HessianTape(f, x)
if Compile
tape = compile(tape)
gradient_tape = compile(gradient_tape)
hessian_tape = compile(hessian_tape)
end
return ReverseDiffHessianPrep(tape)
return ReverseDiffHessianGradientPrep(gradient_tape, hessian_tape)
end

function DI.hessian!(_f, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x)
return hessian!(hess, prep.tape, x)
function DI.hessian!(_f, hess, prep::ReverseDiffHessianGradientPrep, ::AutoReverseDiff, x)
return hessian!(hess, prep.hessian_tape, x)
end

function DI.hessian(_f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x)
return hessian!(prep.tape, x)
function DI.hessian(_f, prep::ReverseDiffHessianGradientPrep, ::AutoReverseDiff, x)
return hessian!(prep.hessian_tape, x)
end

function DI.value_gradient_and_hessian!(
f, grad, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x
f, grad, hess, prep::ReverseDiffHessianGradientPrep, ::AutoReverseDiff, x
)
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (grad, hess))
result = hessian!(result, prep.tape, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
)
y = f(x) # TODO: ReverseDiff#251
result = DiffResult(y, (grad, hess))
result = hessian!(result, prep.hessian_tape, x)
y = DR.value(result)
# grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
grad = gradient!(grad, prep.gradient_tape, x) # TODO: ReverseDiff#251
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
return y, grad, hess
end

function DI.value_gradient_and_hessian(
f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x
f, prep::ReverseDiffHessianGradientPrep, ::AutoReverseDiff, x
)
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x))))
result = hessian!(result, prep.tape, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
)
result = DiffResult(y, (similar(x), similar(x, length(x), length(x))))
result = hessian!(result, prep.hessian_tape, x)
return (DR.value(result), DR.gradient(result), DR.hessian(result))
end

### With contexts
Expand All @@ -286,22 +297,22 @@ function DI.value_gradient_and_hessian!(
f, grad, hess, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
) where {C}
fc = with_contexts(f, contexts...)
y = fc(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (grad, hess))
y = fc(x) # TODO: ReverseDiff#251
result = DiffResult(y, (grad, hess))
result = hessian!(result, fc, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
)
y = DR.value(result)
# grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
grad = gradient!(grad, fc, x) # TODO: ReverseDiff#251
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
return y, grad, hess
end

function DI.value_gradient_and_hessian(
f, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
) where {C}
fc = with_contexts(f, contexts...)
y = fc(x) # TODO: remove once ReverseDiff#251 is fixed
result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x))))
y = fc(x) # TODO: ReverseDiff#251
result = HessianResult(x)
result = hessian!(result, fc, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
)
return (DR.value(result), DR.gradient(result), DR.hessian(result))
end
2 changes: 1 addition & 1 deletion DifferentiationInterface/test/Back/Zygote/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ test_differentiation(
if VERSION >= v"1.10"
test_differentiation(
AutoZygote(),
vcat(component_scenarios(), gpu_scenarios(), static_scenarios());
vcat(component_scenarios(), gpu_scenarios());
second_order=false,
logging=LOGGING,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ function comp_to_num_pullback(x, dy)
end

function comp_to_num_scenarios_onearg(x::ComponentVector; dx::AbstractVector, dy::Number)
nb_args = 1
f = comp_to_num
y = f(x)
dy_from_dx = comp_to_num_pushforward(x, dx)
dx_from_dy = comp_to_num_pullback(x, dy)
grad = comp_to_num_gradient(x)
Expand Down
Loading
Loading