Skip to content

Commit bb75cf8

Browse files
authored
Implement value_gradient_and_hessian (#305)
* Implement value_gradient_and_hessian * Fix extras * Fix sparse * Typo * Implement in extensions * Fixes * Fix * Fix ref * Fix * Fixes * Fix * Typos * Typo * Typo
1 parent af0b6ce commit bb75cf8

File tree

23 files changed

+403
-149
lines changed

23 files changed

+403
-149
lines changed

DifferentiationInterface/docs/src/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ hvp!
9393
prepare_hessian
9494
hessian
9595
hessian!
96+
value_gradient_and_hessian
97+
value_gradient_and_hessian!
9698
```
9799

98100
## Utilities

DifferentiationInterface/docs/src/operators.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,16 @@ These operators are computed using the input `x` and a "seed" `v`, which lives e
4545

4646
Several variants of each operator are defined.
4747

48-
| out-of-place | in-place | out-of-place + primal | in-place + primal |
49-
| :-------------------------- | :--------------------------- | :----------------------------------------------- | :----------------------------------------------- |
50-
| [`derivative`](@ref) | [`derivative!`](@ref) | [`value_and_derivative`](@ref) | [`value_and_derivative!`](@ref) |
48+
| out-of-place | in-place | out-of-place + primal | in-place + primal |
49+
| :-------------------------- | :--------------------------- | :----------------------------------------------- | :------------------------------------------------ |
50+
| [`derivative`](@ref) | [`derivative!`](@ref) | [`value_and_derivative`](@ref) | [`value_and_derivative!`](@ref) |
5151
| [`second_derivative`](@ref) | [`second_derivative!`](@ref) | [`value_derivative_and_second_derivative`](@ref) | [`value_derivative_and_second_derivative!`](@ref) |
52-
| [`gradient`](@ref) | [`gradient!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!`](@ref) |
53-
| [`hessian`](@ref) | [`hessian!`](@ref) | NA | NA |
54-
| [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) |
55-
| [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) |
56-
| [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) |
57-
| [`hvp`](@ref) | [`hvp!`](@ref) | NA | NA |
52+
| [`gradient`](@ref) | [`gradient!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!`](@ref) |
53+
| [`hessian`](@ref) | [`hessian!`](@ref) | [`value_gradient_and_hessian`](@ref) | [`value_gradient_and_hessian!`](@ref) NA |
54+
| [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) |
55+
| [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) |
56+
| [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) |
57+
| [`hvp`](@ref) | [`hvp!`](@ref) | NA | NA |
5858

5959
## Mutation and signatures
6060

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ using DifferentiationInterface:
1010
JacobianExtras,
1111
PullbackExtras,
1212
PushforwardExtras,
13-
SecondDerivativeExtras
13+
SecondDerivativeExtras,
14+
maybe_dense_ad
1415
using FastDifferentiation:
1516
derivative,
1617
hessian,

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl

Lines changed: 66 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
## Pushforward
22

3-
struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E2} <: PushforwardExtras
3+
struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E1!} <: PushforwardExtras
44
y_prototype::Y
55
jvp_exe::E1
6-
jvp_exe!::E2
6+
jvp_exe!::E1!
77
end
88

99
function DI.prepare_pushforward(f, ::AutoFastDifferentiation, x, dx)
@@ -70,9 +70,9 @@ end
7070

7171
## Pullback
7272

73-
struct FastDifferentiationOneArgPullbackExtras{E1,E2} <: PullbackExtras
73+
struct FastDifferentiationOneArgPullbackExtras{E1,E1!} <: PullbackExtras
7474
vjp_exe::E1
75-
vjp_exe!::E2
75+
vjp_exe!::E1!
7676
end
7777

7878
function DI.prepare_pullback(f, ::AutoFastDifferentiation, x, dy)
@@ -133,10 +133,10 @@ end
133133

134134
## Derivative
135135

136-
struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E2} <: DerivativeExtras
136+
struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E1!} <: DerivativeExtras
137137
y_prototype::Y
138138
der_exe::E1
139-
der_exe!::E2
139+
der_exe!::E1!
140140
end
141141

142142
function DI.prepare_derivative(f, ::AutoFastDifferentiation, x)
@@ -190,13 +190,12 @@ end
190190

191191
## Gradient
192192

193-
struct FastDifferentiationOneArgGradientExtras{E1,E2} <: GradientExtras
193+
struct FastDifferentiationOneArgGradientExtras{E1,E1!} <: GradientExtras
194194
jac_exe::E1
195-
jac_exe!::E2
195+
jac_exe!::E1!
196196
end
197197

198198
function DI.prepare_gradient(f, backend::AutoFastDifferentiation, x)
199-
y_prototype = f(x)
200199
x_var = make_variables(:x, size(x)...)
201200
y_var = f(x_var)
202201

@@ -241,10 +240,10 @@ end
241240

242241
## Jacobian
243242

244-
struct FastDifferentiationOneArgJacobianExtras{Y,E1,E2} <: JacobianExtras
243+
struct FastDifferentiationOneArgJacobianExtras{Y,E1,E1!} <: JacobianExtras
245244
y_prototype::Y
246245
jac_exe::E1
247-
jac_exe!::E2
246+
jac_exe!::E1!
248247
end
249248

250249
function DI.prepare_jacobian(
@@ -307,34 +306,29 @@ end
307306

308307
## Second derivative
309308

310-
struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,E1,E1!,E2,E2!} <:
309+
struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,D,E2,E2!} <:
311310
SecondDerivativeExtras
312311
y_prototype::Y
313-
der_exe::E1
314-
der_exe!::E1!
312+
derivative_extras::D
315313
der2_exe::E2
316314
der2_exe!::E2!
317315
end
318316

319-
function DI.prepare_second_derivative(f, ::AutoFastDifferentiation, x)
317+
function DI.prepare_second_derivative(f, backend::AutoFastDifferentiation, x)
320318
y_prototype = f(x)
321319
x_var = only(make_variables(:x))
322320
y_var = f(x_var)
323321

324322
x_vec_var = monovec(x_var)
325323
y_vec_var = y_var isa Number ? monovec(y_var) : vec(y_var)
326324

327-
der_vec_var = derivative(y_vec_var, x_var)
328325
der2_vec_var = derivative(y_vec_var, x_var, x_var)
329-
330-
der_exe = make_function(der_vec_var, x_vec_var; in_place=false)
331-
der_exe! = make_function(der_vec_var, x_vec_var; in_place=true)
332-
333326
der2_exe = make_function(der2_vec_var, x_vec_var; in_place=false)
334327
der2_exe! = make_function(der2_vec_var, x_vec_var; in_place=true)
335328

329+
derivative_extras = DI.prepare_derivative(f, backend, x)
336330
return FastDifferentiationAllocatingSecondDerivativeExtras(
337-
y_prototype, der_exe, der_exe!, der2_exe, der2_exe!
331+
y_prototype, derivative_extras, der2_exe, der2_exe!
338332
)
339333
end
340334

@@ -364,20 +358,13 @@ end
364358

365359
function DI.value_derivative_and_second_derivative(
366360
f,
367-
::AutoFastDifferentiation,
361+
backend::AutoFastDifferentiation,
368362
x,
369363
extras::FastDifferentiationAllocatingSecondDerivativeExtras,
370364
)
371-
y = f(x)
372-
if extras.y_prototype isa Number
373-
der = only(extras.der_exe(monovec(x)))
374-
der2 = only(extras.der2_exe(monovec(x)))
375-
return y, der, der2
376-
else
377-
der = reshape(extras.der_exe(monovec(x)), size(extras.y_prototype))
378-
der2 = reshape(extras.der2_exe(monovec(x)), size(extras.y_prototype))
379-
return y, der, der2
380-
end
365+
y, der = DI.value_and_derivative(f, backend, x, extras.derivative_extras)
366+
der2 = DI.second_derivative(f, backend, x, extras)
367+
return y, der, der2
381368
end
382369

383370
function DI.value_derivative_and_second_derivative!(
@@ -388,17 +375,16 @@ function DI.value_derivative_and_second_derivative!(
388375
x,
389376
extras::FastDifferentiationAllocatingSecondDerivativeExtras,
390377
)
391-
y = f(x)
392-
extras.der_exe!(vec(der), monovec(x))
393-
extras.der2_exe!(vec(der2), monovec(x))
378+
y, _ = DI.value_and_derivative!(f, der, backend, x, extras.derivative_extras)
379+
DI.second_derivative!(f, der2, backend, x, extras)
394380
return y, der, der2
395381
end
396382

397383
## HVP
398384

399-
struct FastDifferentiationHVPExtras{E1,E2} <: HVPExtras
400-
hvp_exe::E1
401-
hvp_exe!::E2
385+
struct FastDifferentiationHVPExtras{E2,E2!} <: HVPExtras
386+
hvp_exe::E2
387+
hvp_exe!::E2!
402388
end
403389

404390
function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, v)
@@ -428,24 +414,30 @@ end
428414

429415
## Hessian
430416

431-
struct FastDifferentiationHessianExtras{E1,E2} <: HessianExtras
432-
hess_exe::E1
433-
hess_exe!::E2
417+
struct FastDifferentiationHessianExtras{G,E2,E2!} <: HessianExtras
418+
gradient_extras::G
419+
hess_exe::E2
420+
hess_exe!::E2!
434421
end
435422

436423
function DI.prepare_hessian(
437424
f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x
438425
)
439-
x_vec_var = make_variables(:x, size(x)...)
440-
y_vec_var = f(x_vec_var)
426+
x_var = make_variables(:x, size(x)...)
427+
y_var = f(x_var)
428+
429+
x_vec_var = vec(x_var)
430+
441431
hess_var = if backend isa AutoSparse
442-
sparse_hessian(y_vec_var, vec(x_vec_var))
432+
sparse_hessian(y_var, x_vec_var)
443433
else
444-
hessian(y_vec_var, vec(x_vec_var))
434+
hessian(y_var, x_vec_var)
445435
end
446-
hess_exe = make_function(hess_var, vec(x_vec_var); in_place=false)
447-
hess_exe! = make_function(hess_var, vec(x_vec_var); in_place=true)
448-
return FastDifferentiationHessianExtras(hess_exe, hess_exe!)
436+
hess_exe = make_function(hess_var, x_vec_var; in_place=false)
437+
hess_exe! = make_function(hess_var, x_vec_var; in_place=true)
438+
439+
gradient_extras = DI.prepare_gradient(f, maybe_dense_ad(backend), x)
440+
return FastDifferentiationHessianExtras(gradient_extras, hess_exe, hess_exe!)
449441
end
450442

451443
function DI.hessian(
@@ -467,3 +459,29 @@ function DI.hessian!(
467459
extras.hess_exe!(hess, vec(x))
468460
return hess
469461
end
462+
463+
function DI.value_gradient_and_hessian(
464+
f,
465+
backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}},
466+
x,
467+
extras::FastDifferentiationHessianExtras,
468+
)
469+
y, grad = DI.value_and_gradient(f, maybe_dense_ad(backend), x, extras.gradient_extras)
470+
hess = DI.hessian(f, backend, x, extras)
471+
return y, grad, hess
472+
end
473+
474+
function DI.value_gradient_and_hessian!(
475+
f,
476+
grad,
477+
hess,
478+
backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}},
479+
x,
480+
extras::FastDifferentiationHessianExtras,
481+
)
482+
y, _ = DI.value_and_gradient!(
483+
f, grad, maybe_dense_ad(backend), x, extras.gradient_extras
484+
)
485+
DI.hessian!(f, hess, backend, x, extras)
486+
return y, grad, hess
487+
end

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
## Pushforward
22

3-
struct FastDifferentiationTwoArgPushforwardExtras{E1,E2} <: PushforwardExtras
3+
struct FastDifferentiationTwoArgPushforwardExtras{E1,E1!} <: PushforwardExtras
44
jvp_exe::E1
5-
jvp_exe!::E2
5+
jvp_exe!::E1!
66
end
77

88
function DI.prepare_pushforward(f!, y, ::AutoFastDifferentiation, x, dx)
@@ -80,9 +80,9 @@ end
8080

8181
## Pullback
8282

83-
struct FastDifferentiationTwoArgPullbackExtras{E1,E2} <: PullbackExtras
83+
struct FastDifferentiationTwoArgPullbackExtras{E1,E1!} <: PullbackExtras
8484
vjp_exe::E1
85-
vjp_exe!::E2
85+
vjp_exe!::E1!
8686
end
8787

8888
function DI.prepare_pullback(f!, y, ::AutoFastDifferentiation, x, dy)
@@ -156,9 +156,9 @@ end
156156

157157
## Derivative
158158

159-
struct FastDifferentiationTwoArgDerivativeExtras{E1,E2} <: DerivativeExtras
159+
struct FastDifferentiationTwoArgDerivativeExtras{E1,E1!} <: DerivativeExtras
160160
der_exe::E1
161-
der_exe!::E2
161+
der_exe!::E1!
162162
end
163163

164164
function DI.prepare_derivative(f!, y, ::AutoFastDifferentiation, x)
@@ -216,9 +216,9 @@ end
216216

217217
## Jacobian
218218

219-
struct FastDifferentiationTwoArgJacobianExtras{E1,E2} <: JacobianExtras
219+
struct FastDifferentiationTwoArgJacobianExtras{E1,E1!} <: JacobianExtras
220220
jac_exe::E1
221-
jac_exe!::E2
221+
jac_exe!::E1!
222222
end
223223

224224
function DI.prepare_jacobian(

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,21 +153,39 @@ end
153153

154154
## Hessian
155155

156-
struct FiniteDiffHessianExtras{C} <: HessianExtras
157-
cache::C
156+
struct FiniteDiffHessianExtras{C1,C2} <: HessianExtras
157+
gradient_cache::C1
158+
hessian_cache::C2
158159
end
159160

160161
function DI.prepare_hessian(f, backend::AutoFiniteDiff, x)
161-
cache = HessianCache(x, fdhtype(backend))
162-
return FiniteDiffHessianExtras(cache)
162+
y = f(x)
163+
df = zero(y) .* x
164+
gradient_cache = GradientCache(df, x, fdtype(backend))
165+
hessian_cache = HessianCache(x, fdhtype(backend))
166+
return FiniteDiffHessianExtras(gradient_cache, hessian_cache)
163167
end
164168

165-
# cache cannot be reused because of https://github.com/JuliaDiff/FiniteDiff.jl/issues/185
166-
167169
function DI.hessian(f, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras)
168-
return finite_difference_hessian(f, x, extras.cache)
170+
return finite_difference_hessian(f, x, extras.hessian_cache)
169171
end
170172

171173
function DI.hessian!(f, hess, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras)
172-
return finite_difference_hessian!(hess, f, x, extras.cache)
174+
return finite_difference_hessian!(hess, f, x, extras.hessian_cache)
175+
end
176+
177+
function DI.value_gradient_and_hessian(
178+
f, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras
179+
)
180+
grad = finite_difference_gradient(f, x, extras.gradient_cache)
181+
hess = finite_difference_hessian(f, x, extras.hessian_cache)
182+
return f(x), grad, hess
183+
end
184+
185+
function DI.value_gradient_and_hessian!(
186+
f, grad, hess, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras
187+
)
188+
finite_difference_gradient!(grad, f, x, extras.gradient_cache)
189+
finite_difference_hessian!(hess, f, x, extras.hessian_cache)
190+
return f(x), grad, hess
173191
end

0 commit comments

Comments
 (0)