Skip to content

Commit cee3e4e

Browse files
authored
Better scenarios with weird arrays (#523)
* Improve matrix skipping * Fix array of strings using fromJSON * Put all tests on same nesting level * Skip pre DIT * Group tests together * Single quote * Deactivate docs * ToJSON * Go back to individual folders * Skip pre * Fix * Test internals first * Reactivate docs * Better scenarios with weird arrays * Don't test Zygote on staticarrays
1 parent e716262 commit cee3e4e

File tree

12 files changed

+217
-210
lines changed

12 files changed

+217
-210
lines changed

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ using DifferentiationInterface:
2121
outer,
2222
unwrap,
2323
with_contexts
24+
import ForwardDiff.DiffResults as DR
2425
using ForwardDiff.DiffResults:
2526
DiffResults, DiffResult, GradientResult, HessianResult, MutableDiffResult
2627
using ForwardDiff:

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,11 @@ function DI.value_and_gradient!(
165165
f::F, grad, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
166166
) where {F,C}
167167
fc = with_contexts(f, contexts...)
168-
result = MutableDiffResult(zero(eltype(x)), (grad,))
168+
result = DiffResult(zero(eltype(x)), (grad,))
169169
result = gradient!(result, fc, x)
170-
return DiffResults.value(result), DiffResults.gradient(result)
170+
y = DR.value(result)
171+
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
172+
return y, grad
171173
end
172174

173175
function DI.value_and_gradient(
@@ -176,7 +178,7 @@ function DI.value_and_gradient(
176178
fc = with_contexts(f, contexts...)
177179
result = GradientResult(x)
178180
result = gradient!(result, fc, x)
179-
return DiffResults.value(result), DiffResults.gradient(result)
181+
return DR.value(result), DR.gradient(result)
180182
end
181183

182184
function DI.gradient!(
@@ -213,9 +215,11 @@ function DI.value_and_gradient!(
213215
contexts::Vararg{Context,C},
214216
) where {F,C}
215217
fc = with_contexts(f, contexts...)
216-
result = MutableDiffResult(zero(eltype(x)), (grad,))
218+
result = DiffResult(zero(eltype(x)), (grad,))
217219
result = gradient!(result, fc, x, prep.config)
218-
return DiffResults.value(result), DiffResults.gradient(result)
220+
y = DR.value(result)
221+
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
222+
return y, grad
219223
end
220224

221225
function DI.value_and_gradient(
@@ -224,7 +228,7 @@ function DI.value_and_gradient(
224228
fc = with_contexts(f, contexts...)
225229
result = GradientResult(x)
226230
result = gradient!(result, fc, x, prep.config)
227-
return DiffResults.value(result), DiffResults.gradient(result)
231+
return DR.value(result), DR.gradient(result)
228232
end
229233

230234
function DI.gradient!(
@@ -255,9 +259,11 @@ function DI.value_and_jacobian!(
255259
) where {F,C}
256260
fc = with_contexts(f, contexts...)
257261
y = fc(x)
258-
result = MutableDiffResult(y, (jac,))
262+
result = DiffResult(y, (jac,))
259263
result = jacobian!(result, fc, x)
260-
return DiffResults.value(result), DiffResults.jacobian(result)
264+
y = DR.value(result)
265+
jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result))
266+
return y, jac
261267
end
262268

263269
function DI.value_and_jacobian(
@@ -302,9 +308,11 @@ function DI.value_and_jacobian!(
302308
) where {F,C}
303309
fc = with_contexts(f, contexts...)
304310
y = fc(x)
305-
result = MutableDiffResult(y, (jac,))
311+
result = DiffResult(y, (jac,))
306312
result = jacobian!(result, fc, x, prep.config)
307-
return DiffResults.value(result), DiffResults.jacobian(result)
313+
y = DR.value(result)
314+
jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result))
315+
return y, jac
308316
end
309317

310318
function DI.value_and_jacobian(
@@ -457,11 +465,12 @@ function DI.value_gradient_and_hessian!(
457465
f::F, grad, hess, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
458466
) where {F,C}
459467
fc = with_contexts(f, contexts...)
460-
result = MutableDiffResult(one(eltype(x)), (grad, hess))
468+
result = DiffResult(one(eltype(x)), (grad, hess))
461469
result = hessian!(result, fc, x)
462-
return (
463-
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
464-
)
470+
y = DR.value(result)
471+
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
472+
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
473+
return (y, grad, hess)
465474
end
466475

467476
function DI.value_gradient_and_hessian(
@@ -470,9 +479,7 @@ function DI.value_gradient_and_hessian(
470479
fc = with_contexts(f, contexts...)
471480
result = HessianResult(x)
472481
result = hessian!(result, fc, x)
473-
return (
474-
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
475-
)
482+
return (DR.value(result), DR.gradient(result), DR.hessian(result))
476483
end
477484

478485
### Prepared
@@ -527,11 +534,12 @@ function DI.value_gradient_and_hessian!(
527534
contexts::Vararg{Context,C},
528535
) where {F,C}
529536
fc = with_contexts(f, contexts...)
530-
result = MutableDiffResult(one(eltype(x)), (grad, hess))
537+
result = DiffResult(one(eltype(x)), (grad, hess))
531538
result = hessian!(result, fc, x, prep.manual_result_config)
532-
return (
533-
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
534-
)
539+
y = DR.value(result)
540+
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
541+
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
542+
return (y, grad, hess)
535543
end
536544

537545
function DI.value_gradient_and_hessian(
@@ -540,7 +548,5 @@ function DI.value_gradient_and_hessian(
540548
fc = with_contexts(f, contexts...)
541549
result = HessianResult(x)
542550
result = hessian!(result, fc, x, prep.auto_result_config)
543-
return (
544-
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
545-
)
551+
return (DR.value(result), DR.gradient(result), DR.hessian(result))
546552
end

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ using DifferentiationInterface:
1515
NoPullbackPrep,
1616
unwrap,
1717
with_contexts
18-
using ReverseDiff.DiffResults: DiffResults, DiffResult, GradientResult, MutableDiffResult
18+
import ReverseDiff.DiffResults as DR
19+
using ReverseDiff.DiffResults:
20+
DiffResults, DiffResult, GradientResult, HessianResult, MutableDiffResult
1921
using LinearAlgebra: dot, mul!
2022
using ReverseDiff:
2123
CompiledGradient,

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,12 @@ end
8484
function DI.value_and_gradient!(
8585
f, grad, prep::ReverseDiffGradientPrep, ::AutoReverseDiff, x
8686
)
87-
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
88-
result = MutableDiffResult(y, (grad,))
87+
y = f(x) # TODO: ReverseDiff#251
88+
result = DiffResult(y, (grad,))
8989
result = gradient!(result, prep.tape, x)
90-
return DiffResults.value(result), DiffResults.derivative(result)
90+
y = DR.value(result)
91+
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
92+
return y, grad
9193
end
9294

9395
function DI.value_and_gradient(
@@ -115,10 +117,12 @@ function DI.value_and_gradient!(
115117
f, grad, ::NoGradientPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
116118
) where {C}
117119
fc = with_contexts(f, contexts...)
118-
y = fc(x) # TODO: remove once ReverseDiff#251 is fixed
119-
result = MutableDiffResult(y, (grad,))
120+
y = fc(x) # TODO: ReverseDiff#251
121+
result = DiffResult(y, (grad,))
120122
result = gradient!(result, fc, x)
121-
return DiffResults.value(result), DiffResults.derivative(result)
123+
y = DR.value(result)
124+
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
125+
return y, grad
122126
end
123127

124128
function DI.value_and_gradient(
@@ -162,9 +166,11 @@ function DI.value_and_jacobian!(
162166
f, jac, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x
163167
)
164168
y = f(x)
165-
result = MutableDiffResult(y, (jac,))
169+
result = DiffResult(y, (jac,))
166170
result = jacobian!(result, prep.tape, x)
167-
return DiffResults.value(result), DiffResults.derivative(result)
171+
y = DR.value(result)
172+
jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result))
173+
return y, jac
168174
end
169175

170176
function DI.value_and_jacobian(f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff, x)
@@ -190,9 +196,11 @@ function DI.value_and_jacobian!(
190196
) where {C}
191197
fc = with_contexts(f, contexts...)
192198
y = fc(x)
193-
result = MutableDiffResult(y, (jac,))
199+
result = DiffResult(y, (jac,))
194200
result = jacobian!(result, fc, x)
195-
return DiffResults.value(result), DiffResults.derivative(result)
201+
y = DR.value(result)
202+
jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result))
203+
return y, jac
196204
end
197205

198206
function DI.value_and_jacobian(
@@ -220,46 +228,49 @@ end
220228

221229
### Without contexts
222230

223-
struct ReverseDiffHessianPrep{T} <: HessianPrep
224-
tape::T
231+
struct ReverseDiffHessianGradientPrep{GT,HT} <: HessianPrep
232+
gradient_tape::GT
233+
hessian_tape::HT
225234
end
226235

227236
function DI.prepare_hessian(f, ::AutoReverseDiff{Compile}, x) where {Compile}
228-
tape = HessianTape(f, x)
237+
gradient_tape = GradientTape(f, x)
238+
hessian_tape = HessianTape(f, x)
229239
if Compile
230-
tape = compile(tape)
240+
gradient_tape = compile(gradient_tape)
241+
hessian_tape = compile(hessian_tape)
231242
end
232-
return ReverseDiffHessianPrep(tape)
243+
return ReverseDiffHessianGradientPrep(gradient_tape, hessian_tape)
233244
end
234245

235-
function DI.hessian!(_f, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x)
236-
return hessian!(hess, prep.tape, x)
246+
function DI.hessian!(_f, hess, prep::ReverseDiffHessianGradientPrep, ::AutoReverseDiff, x)
247+
return hessian!(hess, prep.hessian_tape, x)
237248
end
238249

239-
function DI.hessian(_f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x)
240-
return hessian!(prep.tape, x)
250+
function DI.hessian(_f, prep::ReverseDiffHessianGradientPrep, ::AutoReverseDiff, x)
251+
return hessian!(prep.hessian_tape, x)
241252
end
242253

243254
function DI.value_gradient_and_hessian!(
244-
f, grad, hess, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x
255+
f, grad, hess, prep::ReverseDiffHessianGradientPrep, ::AutoReverseDiff, x
245256
)
246-
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
247-
result = MutableDiffResult(y, (grad, hess))
248-
result = hessian!(result, prep.tape, x)
249-
return (
250-
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
251-
)
257+
y = f(x) # TODO: ReverseDiff#251
258+
result = DiffResult(y, (grad, hess))
259+
result = hessian!(result, prep.hessian_tape, x)
260+
y = DR.value(result)
261+
# grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
262+
grad = gradient!(grad, prep.gradient_tape, x) # TODO: ReverseDiff#251
263+
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
264+
return y, grad, hess
252265
end
253266

254267
function DI.value_gradient_and_hessian(
255-
f, prep::ReverseDiffHessianPrep, ::AutoReverseDiff, x
268+
f, prep::ReverseDiffHessianGradientPrep, ::AutoReverseDiff, x
256269
)
257270
y = f(x) # TODO: remove once ReverseDiff#251 is fixed
258-
result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x))))
259-
result = hessian!(result, prep.tape, x)
260-
return (
261-
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
262-
)
271+
result = DiffResult(y, (similar(x), similar(x, length(x), length(x))))
272+
result = hessian!(result, prep.hessian_tape, x)
273+
return (DR.value(result), DR.gradient(result), DR.hessian(result))
263274
end
264275

265276
### With contexts
@@ -286,22 +297,22 @@ function DI.value_gradient_and_hessian!(
286297
f, grad, hess, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
287298
) where {C}
288299
fc = with_contexts(f, contexts...)
289-
y = fc(x) # TODO: remove once ReverseDiff#251 is fixed
290-
result = MutableDiffResult(y, (grad, hess))
300+
y = fc(x) # TODO: ReverseDiff#251
301+
result = DiffResult(y, (grad, hess))
291302
result = hessian!(result, fc, x)
292-
return (
293-
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
294-
)
303+
y = DR.value(result)
304+
# grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
305+
grad = gradient!(grad, fc, x) # TODO: ReverseDiff#251
306+
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
307+
return y, grad, hess
295308
end
296309

297310
function DI.value_gradient_and_hessian(
298311
f, ::NoHessianPrep, ::AutoReverseDiff, x, contexts::Vararg{Context,C}
299312
) where {C}
300313
fc = with_contexts(f, contexts...)
301-
y = fc(x) # TODO: remove once ReverseDiff#251 is fixed
302-
result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x))))
314+
y = fc(x) # TODO: ReverseDiff#251
315+
result = HessianResult(x)
303316
result = hessian!(result, fc, x)
304-
return (
305-
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
306-
)
317+
return (DR.value(result), DR.gradient(result), DR.hessian(result))
307318
end

DifferentiationInterface/test/Back/Zygote/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ test_differentiation(
4040
if VERSION >= v"1.10"
4141
test_differentiation(
4242
AutoZygote(),
43-
vcat(component_scenarios(), gpu_scenarios(), static_scenarios());
43+
vcat(component_scenarios(), gpu_scenarios());
4444
second_order=false,
4545
logging=LOGGING,
4646
)

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestComponentArraysExt/DifferentiationInterfaceTestComponentArraysExt.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ function comp_to_num_pullback(x, dy)
2525
end
2626

2727
function comp_to_num_scenarios_onearg(x::ComponentVector; dx::AbstractVector, dy::Number)
28-
nb_args = 1
2928
f = comp_to_num
30-
y = f(x)
3129
dy_from_dx = comp_to_num_pushforward(x, dx)
3230
dx_from_dy = comp_to_num_pullback(x, dy)
3331
grad = comp_to_num_gradient(x)

0 commit comments

Comments
 (0)