84
84
function DI. value_and_gradient! (
85
85
f, grad, prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x
86
86
)
87
- y = f (x) # TODO : remove once ReverseDiff#251 is fixed
88
- result = MutableDiffResult (y, (grad,))
87
+ y = f (x) # TODO : ReverseDiff#251
88
+ result = DiffResult (y, (grad,))
89
89
result = gradient! (result, prep. tape, x)
90
- return DiffResults. value (result), DiffResults. derivative (result)
90
+ y = DR. value (result)
91
+ grad === DR. gradient (result) || copyto! (grad, DR. gradient (result))
92
+ return y, grad
91
93
end
92
94
93
95
function DI. value_and_gradient (
@@ -115,10 +117,12 @@ function DI.value_and_gradient!(
115
117
f, grad, :: NoGradientPrep , :: AutoReverseDiff , x, contexts:: Vararg{Context,C}
116
118
) where {C}
117
119
fc = with_contexts (f, contexts... )
118
- y = fc (x) # TODO : remove once ReverseDiff#251 is fixed
119
- result = MutableDiffResult (y, (grad,))
120
+ y = fc (x) # TODO : ReverseDiff#251
121
+ result = DiffResult (y, (grad,))
120
122
result = gradient! (result, fc, x)
121
- return DiffResults. value (result), DiffResults. derivative (result)
123
+ y = DR. value (result)
124
+ grad === DR. gradient (result) || copyto! (grad, DR. gradient (result))
125
+ return y, grad
122
126
end
123
127
124
128
function DI. value_and_gradient (
@@ -162,9 +166,11 @@ function DI.value_and_jacobian!(
162
166
f, jac, prep:: ReverseDiffOneArgJacobianPrep , :: AutoReverseDiff , x
163
167
)
164
168
y = f (x)
165
- result = MutableDiffResult (y, (jac,))
169
+ result = DiffResult (y, (jac,))
166
170
result = jacobian! (result, prep. tape, x)
167
- return DiffResults. value (result), DiffResults. derivative (result)
171
+ y = DR. value (result)
172
+ jac === DR. jacobian (result) || copyto! (jac, DR. jacobian (result))
173
+ return y, jac
168
174
end
169
175
170
176
function DI. value_and_jacobian (f, prep:: ReverseDiffOneArgJacobianPrep , :: AutoReverseDiff , x)
@@ -190,9 +196,11 @@ function DI.value_and_jacobian!(
190
196
) where {C}
191
197
fc = with_contexts (f, contexts... )
192
198
y = fc (x)
193
- result = MutableDiffResult (y, (jac,))
199
+ result = DiffResult (y, (jac,))
194
200
result = jacobian! (result, fc, x)
195
- return DiffResults. value (result), DiffResults. derivative (result)
201
+ y = DR. value (result)
202
+ jac === DR. jacobian (result) || copyto! (jac, DR. jacobian (result))
203
+ return y, jac
196
204
end
197
205
198
206
function DI. value_and_jacobian (
@@ -220,46 +228,49 @@ end
220
228
221
229
# ## Without contexts
222
230
223
- struct ReverseDiffHessianPrep{T} <: HessianPrep
224
- tape:: T
231
+ struct ReverseDiffHessianGradientPrep{GT,HT} <: HessianPrep
232
+ gradient_tape:: GT
233
+ hessian_tape:: HT
225
234
end
226
235
227
236
function DI. prepare_hessian (f, :: AutoReverseDiff{Compile} , x) where {Compile}
228
- tape = HessianTape (f, x)
237
+ gradient_tape = GradientTape (f, x)
238
+ hessian_tape = HessianTape (f, x)
229
239
if Compile
230
- tape = compile (tape)
240
+ gradient_tape = compile (gradient_tape)
241
+ hessian_tape = compile (hessian_tape)
231
242
end
232
- return ReverseDiffHessianPrep (tape )
243
+ return ReverseDiffHessianGradientPrep (gradient_tape, hessian_tape )
233
244
end
234
245
235
- function DI. hessian! (_f, hess, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x)
236
- return hessian! (hess, prep. tape , x)
246
+ function DI. hessian! (_f, hess, prep:: ReverseDiffHessianGradientPrep , :: AutoReverseDiff , x)
247
+ return hessian! (hess, prep. hessian_tape , x)
237
248
end
238
249
239
- function DI. hessian (_f, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x)
240
- return hessian! (prep. tape , x)
250
+ function DI. hessian (_f, prep:: ReverseDiffHessianGradientPrep , :: AutoReverseDiff , x)
251
+ return hessian! (prep. hessian_tape , x)
241
252
end
242
253
243
254
function DI. value_gradient_and_hessian! (
244
- f, grad, hess, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x
255
+ f, grad, hess, prep:: ReverseDiffHessianGradientPrep , :: AutoReverseDiff , x
245
256
)
246
- y = f (x) # TODO : remove once ReverseDiff#251 is fixed
247
- result = MutableDiffResult (y, (grad, hess))
248
- result = hessian! (result, prep. tape, x)
249
- return (
250
- DiffResults. value (result), DiffResults. gradient (result), DiffResults. hessian (result)
251
- )
257
+ y = f (x) # TODO : ReverseDiff#251
258
+ result = DiffResult (y, (grad, hess))
259
+ result = hessian! (result, prep. hessian_tape, x)
260
+ y = DR. value (result)
261
+ # grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
262
+ grad = gradient! (grad, prep. gradient_tape, x) # TODO : ReverseDiff#251
263
+ hess === DR. hessian (result) || copyto! (hess, DR. hessian (result))
264
+ return y, grad, hess
252
265
end
253
266
254
267
function DI. value_gradient_and_hessian (
255
- f, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x
268
+ f, prep:: ReverseDiffHessianGradientPrep , :: AutoReverseDiff , x
256
269
)
257
270
y = f (x) # TODO : remove once ReverseDiff#251 is fixed
258
- result = MutableDiffResult (y, (similar (x), similar (x, length (x), length (x))))
259
- result = hessian! (result, prep. tape, x)
260
- return (
261
- DiffResults. value (result), DiffResults. gradient (result), DiffResults. hessian (result)
262
- )
271
+ result = DiffResult (y, (similar (x), similar (x, length (x), length (x))))
272
+ result = hessian! (result, prep. hessian_tape, x)
273
+ return (DR. value (result), DR. gradient (result), DR. hessian (result))
263
274
end
264
275
265
276
# ## With contexts
@@ -286,22 +297,22 @@ function DI.value_gradient_and_hessian!(
286
297
f, grad, hess, :: NoHessianPrep , :: AutoReverseDiff , x, contexts:: Vararg{Context,C}
287
298
) where {C}
288
299
fc = with_contexts (f, contexts... )
289
- y = fc (x) # TODO : remove once ReverseDiff#251 is fixed
290
- result = MutableDiffResult (y, (grad, hess))
300
+ y = fc (x) # TODO : ReverseDiff#251
301
+ result = DiffResult (y, (grad, hess))
291
302
result = hessian! (result, fc, x)
292
- return (
293
- DiffResults. value (result), DiffResults. gradient (result), DiffResults. hessian (result)
294
- )
303
+ y = DR. value (result)
304
+ # grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
305
+ grad = gradient! (grad, fc, x) # TODO : ReverseDiff#251
306
+ hess === DR. hessian (result) || copyto! (hess, DR. hessian (result))
307
+ return y, grad, hess
295
308
end
296
309
297
310
function DI. value_gradient_and_hessian (
298
311
f, :: NoHessianPrep , :: AutoReverseDiff , x, contexts:: Vararg{Context,C}
299
312
) where {C}
300
313
fc = with_contexts (f, contexts... )
301
- y = fc (x) # TODO : remove once ReverseDiff#251 is fixed
302
- result = MutableDiffResult (y, ( similar (x), similar (x, length (x), length (x))) )
314
+ y = fc (x) # TODO : ReverseDiff#251
315
+ result = HessianResult (x )
303
316
result = hessian! (result, fc, x)
304
- return (
305
- DiffResults. value (result), DiffResults. gradient (result), DiffResults. hessian (result)
306
- )
317
+ return (DR. value (result), DR. gradient (result), DR. hessian (result))
307
318
end
0 commit comments