1
1
# # Pushforward
2
2
3
- struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E2 } <: PushforwardExtras
3
+ struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E1! } <: PushforwardExtras
4
4
y_prototype:: Y
5
5
jvp_exe:: E1
6
- jvp_exe!:: E2
6
+ jvp_exe!:: E1!
7
7
end
8
8
9
9
function DI. prepare_pushforward (f, :: AutoFastDifferentiation , x, dx)
70
70
71
71
# # Pullback
72
72
73
- struct FastDifferentiationOneArgPullbackExtras{E1,E2 } <: PullbackExtras
73
+ struct FastDifferentiationOneArgPullbackExtras{E1,E1! } <: PullbackExtras
74
74
vjp_exe:: E1
75
- vjp_exe!:: E2
75
+ vjp_exe!:: E1!
76
76
end
77
77
78
78
function DI. prepare_pullback (f, :: AutoFastDifferentiation , x, dy)
@@ -133,10 +133,10 @@ end
133
133
134
134
# # Derivative
135
135
136
- struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E2 } <: DerivativeExtras
136
+ struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E1! } <: DerivativeExtras
137
137
y_prototype:: Y
138
138
der_exe:: E1
139
- der_exe!:: E2
139
+ der_exe!:: E1!
140
140
end
141
141
142
142
function DI. prepare_derivative (f, :: AutoFastDifferentiation , x)
@@ -190,13 +190,12 @@ end
190
190
191
191
# # Gradient
192
192
193
- struct FastDifferentiationOneArgGradientExtras{E1,E2 } <: GradientExtras
193
+ struct FastDifferentiationOneArgGradientExtras{E1,E1! } <: GradientExtras
194
194
jac_exe:: E1
195
- jac_exe!:: E2
195
+ jac_exe!:: E1!
196
196
end
197
197
198
198
function DI. prepare_gradient (f, backend:: AutoFastDifferentiation , x)
199
- y_prototype = f (x)
200
199
x_var = make_variables (:x , size (x)... )
201
200
y_var = f (x_var)
202
201
@@ -241,10 +240,10 @@ end
241
240
242
241
# # Jacobian
243
242
244
- struct FastDifferentiationOneArgJacobianExtras{Y,E1,E2 } <: JacobianExtras
243
+ struct FastDifferentiationOneArgJacobianExtras{Y,E1,E1! } <: JacobianExtras
245
244
y_prototype:: Y
246
245
jac_exe:: E1
247
- jac_exe!:: E2
246
+ jac_exe!:: E1!
248
247
end
249
248
250
249
function DI. prepare_jacobian (
@@ -307,34 +306,29 @@ end
307
306
308
307
# # Second derivative
309
308
310
- struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,E1,E1! ,E2,E2!} < :
309
+ struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,D ,E2,E2!} < :
311
310
SecondDerivativeExtras
312
311
y_prototype:: Y
313
- der_exe:: E1
314
- der_exe!:: E1!
312
+ derivative_extras:: D
315
313
der2_exe:: E2
316
314
der2_exe!:: E2!
317
315
end
318
316
319
- function DI. prepare_second_derivative (f, :: AutoFastDifferentiation , x)
317
+ function DI. prepare_second_derivative (f, backend :: AutoFastDifferentiation , x)
320
318
y_prototype = f (x)
321
319
x_var = only (make_variables (:x ))
322
320
y_var = f (x_var)
323
321
324
322
x_vec_var = monovec (x_var)
325
323
y_vec_var = y_var isa Number ? monovec (y_var) : vec (y_var)
326
324
327
- der_vec_var = derivative (y_vec_var, x_var)
328
325
der2_vec_var = derivative (y_vec_var, x_var, x_var)
329
-
330
- der_exe = make_function (der_vec_var, x_vec_var; in_place= false )
331
- der_exe! = make_function (der_vec_var, x_vec_var; in_place= true )
332
-
333
326
der2_exe = make_function (der2_vec_var, x_vec_var; in_place= false )
334
327
der2_exe! = make_function (der2_vec_var, x_vec_var; in_place= true )
335
328
329
+ derivative_extras = DI. prepare_derivative (f, backend, x)
336
330
return FastDifferentiationAllocatingSecondDerivativeExtras (
337
- y_prototype, der_exe, der_exe! , der2_exe, der2_exe!
331
+ y_prototype, derivative_extras , der2_exe, der2_exe!
338
332
)
339
333
end
340
334
@@ -364,20 +358,13 @@ end
364
358
365
359
function DI. value_derivative_and_second_derivative (
366
360
f,
367
- :: AutoFastDifferentiation ,
361
+ backend :: AutoFastDifferentiation ,
368
362
x,
369
363
extras:: FastDifferentiationAllocatingSecondDerivativeExtras ,
370
364
)
371
- y = f (x)
372
- if extras. y_prototype isa Number
373
- der = only (extras. der_exe (monovec (x)))
374
- der2 = only (extras. der2_exe (monovec (x)))
375
- return y, der, der2
376
- else
377
- der = reshape (extras. der_exe (monovec (x)), size (extras. y_prototype))
378
- der2 = reshape (extras. der2_exe (monovec (x)), size (extras. y_prototype))
379
- return y, der, der2
380
- end
365
+ y, der = DI. value_and_derivative (f, backend, x, extras. derivative_extras)
366
+ der2 = DI. second_derivative (f, backend, x, extras)
367
+ return y, der, der2
381
368
end
382
369
383
370
function DI. value_derivative_and_second_derivative! (
@@ -388,17 +375,16 @@ function DI.value_derivative_and_second_derivative!(
388
375
x,
389
376
extras:: FastDifferentiationAllocatingSecondDerivativeExtras ,
390
377
)
391
- y = f (x)
392
- extras. der_exe! (vec (der), monovec (x))
393
- extras. der2_exe! (vec (der2), monovec (x))
378
+ y, _ = DI. value_and_derivative! (f, der, backend, x, extras. derivative_extras)
379
+ DI. second_derivative! (f, der2, backend, x, extras)
394
380
return y, der, der2
395
381
end
396
382
397
383
# # HVP
398
384
399
- struct FastDifferentiationHVPExtras{E1 ,E2} <: HVPExtras
400
- hvp_exe:: E1
401
- hvp_exe!:: E2
385
+ struct FastDifferentiationHVPExtras{E2 ,E2! } <: HVPExtras
386
+ hvp_exe:: E2
387
+ hvp_exe!:: E2!
402
388
end
403
389
404
390
function DI. prepare_hvp (f, :: AutoFastDifferentiation , x, v)
@@ -428,24 +414,30 @@ end
428
414
429
415
# # Hessian
430
416
431
- struct FastDifferentiationHessianExtras{E1,E2} <: HessianExtras
432
- hess_exe:: E1
433
- hess_exe!:: E2
417
+ struct FastDifferentiationHessianExtras{G,E2,E2!} <: HessianExtras
418
+ gradient_extras:: G
419
+ hess_exe:: E2
420
+ hess_exe!:: E2!
434
421
end
435
422
436
423
function DI. prepare_hessian (
437
424
f, backend:: Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}} , x
438
425
)
439
- x_vec_var = make_variables (:x , size (x)... )
440
- y_vec_var = f (x_vec_var)
426
+ x_var = make_variables (:x , size (x)... )
427
+ y_var = f (x_var)
428
+
429
+ x_vec_var = vec (x_var)
430
+
441
431
hess_var = if backend isa AutoSparse
442
- sparse_hessian (y_vec_var, vec ( x_vec_var) )
432
+ sparse_hessian (y_var, x_vec_var)
443
433
else
444
- hessian (y_vec_var, vec ( x_vec_var) )
434
+ hessian (y_var, x_vec_var)
445
435
end
446
- hess_exe = make_function (hess_var, vec (x_vec_var); in_place= false )
447
- hess_exe! = make_function (hess_var, vec (x_vec_var); in_place= true )
448
- return FastDifferentiationHessianExtras (hess_exe, hess_exe!)
436
+ hess_exe = make_function (hess_var, x_vec_var; in_place= false )
437
+ hess_exe! = make_function (hess_var, x_vec_var; in_place= true )
438
+
439
+ gradient_extras = DI. prepare_gradient (f, maybe_dense_ad (backend), x)
440
+ return FastDifferentiationHessianExtras (gradient_extras, hess_exe, hess_exe!)
449
441
end
450
442
451
443
function DI. hessian (
@@ -467,3 +459,29 @@ function DI.hessian!(
467
459
extras. hess_exe! (hess, vec (x))
468
460
return hess
469
461
end
462
+
463
+ function DI. value_gradient_and_hessian (
464
+ f,
465
+ backend:: Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}} ,
466
+ x,
467
+ extras:: FastDifferentiationHessianExtras ,
468
+ )
469
+ y, grad = DI. value_and_gradient (f, maybe_dense_ad (backend), x, extras. gradient_extras)
470
+ hess = DI. hessian (f, backend, x, extras)
471
+ return y, grad, hess
472
+ end
473
+
474
+ function DI. value_gradient_and_hessian! (
475
+ f,
476
+ grad,
477
+ hess,
478
+ backend:: Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}} ,
479
+ x,
480
+ extras:: FastDifferentiationHessianExtras ,
481
+ )
482
+ y, _ = DI. value_and_gradient! (
483
+ f, grad, maybe_dense_ad (backend), x, extras. gradient_extras
484
+ )
485
+ DI. hessian! (f, hess, backend, x, extras)
486
+ return y, grad, hess
487
+ end
0 commit comments