@@ -47,9 +47,7 @@ struct ReverseDiffGradientPrep{T} <: GradientPrep
4747 tape:: T
4848end
4949
50- function DI. prepare_gradient(
51- f, :: AutoReverseDiff{Compile} , x:: AbstractArray
52- ) where {Compile}
50+ function DI. prepare_gradient(f, :: AutoReverseDiff{Compile} , x) where {Compile}
5351 tape = GradientTape(f, x)
5452 if Compile
5553 tape = compile(tape)
@@ -58,11 +56,7 @@ function DI.prepare_gradient(
5856end
5957
6058function DI. value_and_gradient!(
61- f,
62- grad:: AbstractArray ,
63- prep:: ReverseDiffGradientPrep ,
64- :: AutoReverseDiff ,
65- x:: AbstractArray ,
59+ f, grad:: AbstractArray , prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x
6660)
6761 y = f(x) # TODO : remove once ReverseDiff#251 is fixed
6862 result = MutableDiffResult(y, (grad,))
@@ -71,23 +65,19 @@ function DI.value_and_gradient!(
7165end
7266
7367function DI. value_and_gradient(
74- f, prep:: ReverseDiffGradientPrep , backend:: AutoReverseDiff , x:: AbstractArray
68+ f, prep:: ReverseDiffGradientPrep , backend:: AutoReverseDiff , x
7569)
7670 grad = similar(x)
7771 return DI. value_and_gradient!(f, grad, prep, backend, x)
7872end
7973
8074function DI. gradient!(
81- _f,
82- grad:: AbstractArray ,
83- prep:: ReverseDiffGradientPrep ,
84- :: AutoReverseDiff ,
85- x:: AbstractArray ,
75+ _f, grad, prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x:: AbstractArray
8676)
8777 return gradient!(grad, prep. tape, x)
8878end
8979
90- function DI. gradient(_f, prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x:: AbstractArray )
80+ function DI. gradient(_f, prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x)
9181 return gradient!(prep. tape, x)
9282end
9383
@@ -97,9 +87,7 @@ struct ReverseDiffOneArgJacobianPrep{T} <: JacobianPrep
9787 tape:: T
9888end
9989
100- function DI. prepare_jacobian(
101- f, :: AutoReverseDiff{Compile} , x:: AbstractArray
102- ) where {Compile}
90+ function DI. prepare_jacobian(f, :: AutoReverseDiff{Compile} , x) where {Compile}
10391 tape = JacobianTape(f, x)
10492 if Compile
10593 tape = compile(tape)
@@ -108,37 +96,23 @@ function DI.prepare_jacobian(
10896end
10997
11098function DI. value_and_jacobian!(
111- f,
112- jac:: AbstractMatrix ,
113- prep:: ReverseDiffOneArgJacobianPrep ,
114- :: AutoReverseDiff ,
115- x:: AbstractArray ,
99+ f, jac, prep:: ReverseDiffOneArgJacobianPrep , :: AutoReverseDiff , x
116100)
117101 y = f(x)
118102 result = MutableDiffResult(y, (jac,))
119103 result = jacobian!(result, prep. tape, x)
120104 return DiffResults. value(result), DiffResults. derivative(result)
121105end
122106
123- function DI. value_and_jacobian(
124- f, prep:: ReverseDiffOneArgJacobianPrep , :: AutoReverseDiff , x:: AbstractArray
125- )
107+ function DI. value_and_jacobian(f, prep:: ReverseDiffOneArgJacobianPrep , :: AutoReverseDiff , x)
126108 return f(x), jacobian!(prep. tape, x)
127109end
128110
129- function DI. jacobian!(
130- _f,
131- jac:: AbstractMatrix ,
132- prep:: ReverseDiffOneArgJacobianPrep ,
133- :: AutoReverseDiff ,
134- x:: AbstractArray ,
135- )
111+ function DI. jacobian!(_f, jac, prep:: ReverseDiffOneArgJacobianPrep , :: AutoReverseDiff , x)
136112 return jacobian!(jac, prep. tape, x)
137113end
138114
139- function DI. jacobian(
140- f, prep:: ReverseDiffOneArgJacobianPrep , :: AutoReverseDiff , x:: AbstractArray
141- )
115+ function DI. jacobian(f, prep:: ReverseDiffOneArgJacobianPrep , :: AutoReverseDiff , x)
142116 return jacobian!(prep. tape, x)
143117end
144118
@@ -148,35 +122,24 @@ struct ReverseDiffHessianPrep{T} <: HessianPrep
148122 tape:: T
149123end
150124
151- function DI. prepare_hessian(f, :: AutoReverseDiff{Compile} , x:: AbstractArray ) where {Compile}
125+ function DI. prepare_hessian(f, :: AutoReverseDiff{Compile} , x) where {Compile}
152126 tape = HessianTape(f, x)
153127 if Compile
154128 tape = compile(tape)
155129 end
156130 return ReverseDiffHessianPrep(tape)
157131end
158132
159- function DI. hessian!(
160- _f,
161- hess:: AbstractMatrix ,
162- prep:: ReverseDiffHessianPrep ,
163- :: AutoReverseDiff ,
164- x:: AbstractArray ,
165- )
133+ function DI. hessian!(_f, hess, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x)
166134 return hessian!(hess, prep. tape, x)
167135end
168136
169- function DI. hessian(_f, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x:: AbstractArray )
137+ function DI. hessian(_f, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x)
170138 return hessian!(prep. tape, x)
171139end
172140
173141function DI. value_gradient_and_hessian!(
174- f,
175- grad,
176- hess:: AbstractMatrix ,
177- prep:: ReverseDiffHessianPrep ,
178- :: AutoReverseDiff ,
179- x:: AbstractArray ,
142+ f, grad, hess, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x
180143)
181144 y = f(x) # TODO : remove once ReverseDiff#251 is fixed
182145 result = MutableDiffResult(y, (grad, hess))
@@ -187,10 +150,11 @@ function DI.value_gradient_and_hessian!(
187150end
188151
189152function DI. value_gradient_and_hessian(
190- f, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x:: AbstractArray
153+ f, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x
191154)
192- y = f(x) # TODO : remove once ReverseDiff#251 is fixed
193- result = MutableDiffResult(y, (similar(x), similar(x, length(x), length(x))))
155+ result = MutableDiffResult(
156+ one(eltype(x)), (similar(x), similar(x, length(x), length(x)))
157+ )
194158 result = hessian!(result, prep. tape, x)
195159 return (
196160 DiffResults. value(result), DiffResults. gradient(result), DiffResults. hessian(result)
0 commit comments