Skip to content

Commit 9fb32cd

Browse files
authored
Replace closures with callable structs in second order (#315)
* Replace v with dx in hvp signature * Replace closures with functors * Fixes
1 parent e93d16a commit 9fb32cd

File tree

2 files changed

+117
-124
lines changed

2 files changed

+117
-124
lines changed
Lines changed: 93 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
## Docstrings
22

33
"""
4-
prepare_hvp(f, backend, x, v) -> extras
4+
prepare_hvp(f, backend, x, dx) -> extras
55
66
Create an `extras` object that can be given to [`hvp`](@ref) and its variants.
77
@@ -11,7 +11,7 @@ Create an `extras` object that can be given to [`hvp`](@ref) and its variants.
1111
function prepare_hvp end
1212

1313
"""
14-
prepare_hvp_same_point(f, backend, x, v) -> extras_same
14+
prepare_hvp_same_point(f, backend, x, dx) -> extras_same
1515
1616
Create an `extras_same` object that can be given to [`hvp`](@ref) and its variants _if they are applied at the same point `x`_.
1717
@@ -21,16 +21,16 @@ Create an `extras_same` object that can be given to [`hvp`](@ref) and its varian
2121
function prepare_hvp_same_point end
2222

2323
"""
24-
hvp(f, backend, x, v, [extras]) -> p
24+
hvp(f, backend, x, dx, [extras]) -> p
2525
26-
Compute the Hessian-vector product of `f` at point `x` with seed `v`.
26+
Compute the Hessian-vector product of `f` at point `x` with seed `dx`.
2727
"""
2828
function hvp end
2929

3030
"""
31-
hvp!(f, p, backend, x, v, [extras]) -> p
31+
hvp!(f, p, backend, x, dx, [extras]) -> p
3232
33-
Compute the Hessian-vector product of `f` at point `x` with seed `v`, overwriting `p`.
33+
Compute the Hessian-vector product of `f` at point `x` with seed `dx`, overwriting `p`.
3434
"""
3535
function hvp! end
3636

@@ -45,181 +45,168 @@ abstract type HVPExtras <: Extras end
4545

4646
struct NoHVPExtras <: HVPExtras end
4747

48-
#=
49-
Source: https://arxiv.org/abs/2403.14606 (section 8.1)
48+
struct InnerGradient{F,B}
49+
f::F
50+
backend::B
51+
end
52+
53+
function (ig::InnerGradient)(x)
54+
@compat (; f, backend) = ig
55+
return gradient(f, backend, x)
56+
end
57+
58+
struct InnerPushforwardFixedSeed{F,B,DX}
59+
f::F
60+
backend::B
61+
dx::DX
62+
end
5063

51-
By order of preference:
52-
- forward on reverse
53-
- reverse on forward
54-
- reverse on reverse
55-
- forward on forward
56-
=#
64+
function (ipfs::InnerPushforwardFixedSeed)(x)
65+
@compat (; f, backend, dx) = ipfs
66+
return pushforward(f, backend, x, dx)
67+
end
5768

58-
struct ForwardOverForwardHVPExtras{C,E} <: HVPExtras
59-
inner_gradient_closure::C
69+
struct ForwardOverForwardHVPExtras{IG<:InnerGradient,E<:PushforwardExtras} <: HVPExtras
70+
inner_gradient::IG
6071
outer_pushforward_extras::E
6172
end
6273

63-
struct ForwardOverReverseHVPExtras{C,E} <: HVPExtras
64-
inner_gradient_closure::C
74+
struct ForwardOverReverseHVPExtras{IG<:InnerGradient,E<:PushforwardExtras} <: HVPExtras
75+
inner_gradient::IG
6576
outer_pushforward_extras::E
6677
end
6778

68-
struct ReverseOverForwardHVPExtras{C,E} <: HVPExtras
69-
inner_pushforward_closure_generator::C
79+
struct ReverseOverForwardHVPExtras{E<:GradientExtras} <: HVPExtras
7080
outer_gradient_extras::E
7181
end
7282

73-
struct ReverseOverReverseHVPExtras{C,E} <: HVPExtras
74-
inner_gradient_closure::C
83+
struct ReverseOverReverseHVPExtras{IG<:InnerGradient,E<:PullbackExtras} <: HVPExtras
84+
inner_gradient::IG
7585
outer_pullback_extras::E
7686
end
7787

78-
function prepare_hvp(f::F, backend::AbstractADType, x, v) where {F}
79-
return prepare_hvp(f, SecondOrder(backend, backend), x, v)
88+
function prepare_hvp(f::F, backend::AbstractADType, x, dx) where {F}
89+
return prepare_hvp(f, SecondOrder(backend, backend), x, dx)
8090
end
8191

82-
function prepare_hvp(f::F, backend::SecondOrder, x, v) where {F}
83-
return prepare_hvp(f, backend, x, v, hvp_mode(backend))
92+
function prepare_hvp(f::F, backend::SecondOrder, x, dx) where {F}
93+
return prepare_hvp(f, backend, x, dx, hvp_mode(backend))
8494
end
8595

86-
function prepare_hvp(f::F, backend::SecondOrder, x, v, ::ForwardOverForward) where {F}
96+
function prepare_hvp(f::F, backend::SecondOrder, x, dx, ::ForwardOverForward) where {F}
8797
# pushforward of many pushforwards in theory, but pushforward of gradient in practice
88-
inner_backend = nested(inner(backend))
89-
inner_gradient_closure(z) = gradient(f, inner_backend, z)
90-
outer_pushforward_extras = prepare_pushforward(
91-
inner_gradient_closure, outer(backend), x, v
92-
)
93-
return ForwardOverForwardHVPExtras(inner_gradient_closure, outer_pushforward_extras)
98+
inner_gradient = InnerGradient(f, nested(inner(backend)))
99+
outer_pushforward_extras = prepare_pushforward(inner_gradient, outer(backend), x, dx)
100+
return ForwardOverForwardHVPExtras(inner_gradient, outer_pushforward_extras)
94101
end
95102

96-
function prepare_hvp(f::F, backend::SecondOrder, x, v, ::ForwardOverReverse) where {F}
103+
function prepare_hvp(f::F, backend::SecondOrder, x, dx, ::ForwardOverReverse) where {F}
97104
# pushforward of gradient
98-
inner_backend = nested(inner(backend))
99-
inner_gradient_closure(z) = gradient(f, inner_backend, z)
100-
outer_pushforward_extras = prepare_pushforward(
101-
inner_gradient_closure, outer(backend), x, v
102-
)
103-
return ForwardOverReverseHVPExtras(inner_gradient_closure, outer_pushforward_extras)
105+
inner_gradient = InnerGradient(f, nested(inner(backend)))
106+
outer_pushforward_extras = prepare_pushforward(inner_gradient, outer(backend), x, dx)
107+
return ForwardOverReverseHVPExtras(inner_gradient, outer_pushforward_extras)
104108
end
105109

106-
function prepare_hvp(f::F, backend::SecondOrder, x, v, ::ReverseOverForward) where {F}
110+
function prepare_hvp(f::F, backend::SecondOrder, x, dx, ::ReverseOverForward) where {F}
107111
# gradient of pushforward
108-
# uses v in the closure
109-
inner_backend = nested(inner(backend))
110-
function inner_pushforward_closure_generator(v)
111-
inner_pushforward_closure(z) = pushforward(f, inner_backend, z, v)
112-
return inner_pushforward_closure
113-
end
114-
outer_gradient_extras = prepare_gradient(
115-
inner_pushforward_closure_generator(v), outer(backend), x
116-
)
117-
return ReverseOverForwardHVPExtras(
118-
inner_pushforward_closure_generator, outer_gradient_extras
119-
)
120-
end
121-
122-
function prepare_hvp(f::F, backend::SecondOrder, x, v, ::ReverseOverReverse) where {F}
112+
# uses dx in the closure so it can't be stored
113+
inner_pushforward = InnerPushforwardFixedSeed(f, nested(inner(backend)), dx)
114+
outer_gradient_extras = prepare_gradient(inner_pushforward, outer(backend), x)
115+
return ReverseOverForwardHVPExtras(outer_gradient_extras)
116+
end
117+
118+
function prepare_hvp(f::F, backend::SecondOrder, x, dx, ::ReverseOverReverse) where {F}
123119
# pullback of the gradient
124-
inner_backend = nested(inner(backend))
125-
inner_gradient_closure(z) = gradient(f, inner_backend, z)
126-
outer_pullback_extras = prepare_pullback(inner_gradient_closure, outer(backend), x, v)
127-
return ReverseOverReverseHVPExtras(inner_gradient_closure, outer_pullback_extras)
120+
inner_gradient = InnerGradient(f, nested(inner(backend)))
121+
outer_pullback_extras = prepare_pullback(inner_gradient, outer(backend), x, dx)
122+
return ReverseOverReverseHVPExtras(inner_gradient, outer_pullback_extras)
128123
end
129124

130125
## Preparation (same point)
131126

132127
function prepare_hvp_same_point(
133-
f::F, backend::AbstractADType, x, v, extras::HVPExtras
128+
f::F, backend::AbstractADType, x, dx, extras::HVPExtras
134129
) where {F}
135130
return extras
136131
end
137132

138-
function prepare_hvp_same_point(f::F, backend::AbstractADType, x, v) where {F}
139-
extras = prepare_hvp(f, backend, x, v)
140-
return prepare_hvp_same_point(f, backend, x, v, extras)
133+
function prepare_hvp_same_point(f::F, backend::AbstractADType, x, dx) where {F}
134+
extras = prepare_hvp(f, backend, x, dx)
135+
return prepare_hvp_same_point(f, backend, x, dx, extras)
141136
end
142137

143138
## One argument
144139

145-
function hvp(f::F, backend::AbstractADType, x, v) where {F}
146-
return hvp(f, backend, x, v, prepare_hvp(f, backend, x, v))
140+
function hvp(f::F, backend::AbstractADType, x, dx) where {F}
141+
return hvp(f, backend, x, dx, prepare_hvp(f, backend, x, dx))
147142
end
148143

149-
function hvp!(f::F, p, backend::AbstractADType, x, v) where {F}
150-
return hvp!(f, p, backend, x, v, prepare_hvp(f, backend, x, v))
144+
function hvp!(f::F, p, backend::AbstractADType, x, dx) where {F}
145+
return hvp!(f, p, backend, x, dx, prepare_hvp(f, backend, x, dx))
151146
end
152147

153-
function hvp(f::F, backend::AbstractADType, x, v, extras::HVPExtras) where {F}
154-
return hvp(f, SecondOrder(backend, backend), x, v, extras)
148+
function hvp(f::F, backend::AbstractADType, x, dx, extras::HVPExtras) where {F}
149+
return hvp(f, SecondOrder(backend, backend), x, dx, extras)
155150
end
156151

157152
function hvp(
158-
f::F, backend::SecondOrder, x, v, extras::ForwardOverForwardHVPExtras
153+
f::F, backend::SecondOrder, x, dx, extras::ForwardOverForwardHVPExtras
159154
) where {F}
160-
@compat (; inner_gradient_closure, outer_pushforward_extras) = extras
161-
return pushforward(
162-
inner_gradient_closure, outer(backend), x, v, outer_pushforward_extras
163-
)
155+
@compat (; inner_gradient, outer_pushforward_extras) = extras
156+
return pushforward(inner_gradient, outer(backend), x, dx, outer_pushforward_extras)
164157
end
165158

166159
function hvp(
167-
f::F, backend::SecondOrder, x, v, extras::ForwardOverReverseHVPExtras
160+
f::F, backend::SecondOrder, x, dx, extras::ForwardOverReverseHVPExtras
168161
) where {F}
169-
@compat (; inner_gradient_closure, outer_pushforward_extras) = extras
170-
return pushforward(
171-
inner_gradient_closure, outer(backend), x, v, outer_pushforward_extras
172-
)
162+
@compat (; inner_gradient, outer_pushforward_extras) = extras
163+
return pushforward(inner_gradient, outer(backend), x, dx, outer_pushforward_extras)
173164
end
174165

175166
function hvp(
176-
f::F, backend::SecondOrder, x, v, extras::ReverseOverForwardHVPExtras
167+
f::F, backend::SecondOrder, x, dx, extras::ReverseOverForwardHVPExtras
177168
) where {F}
178-
@compat (; inner_pushforward_closure_generator, outer_gradient_extras) = extras
179-
inner_pushforward_closure = inner_pushforward_closure_generator(v)
180-
return gradient(inner_pushforward_closure, outer(backend), x, outer_gradient_extras)
169+
@compat (; outer_gradient_extras) = extras
170+
inner_pushforward = InnerPushforwardFixedSeed(f, nested(inner(backend)), dx)
171+
return gradient(inner_pushforward, outer(backend), x, outer_gradient_extras)
181172
end
182173

183174
function hvp(
184-
f::F, backend::SecondOrder, x, v, extras::ReverseOverReverseHVPExtras
175+
f::F, backend::SecondOrder, x, dx, extras::ReverseOverReverseHVPExtras
185176
) where {F}
186-
@compat (; inner_gradient_closure, outer_pullback_extras) = extras
187-
return pullback(inner_gradient_closure, outer(backend), x, v, outer_pullback_extras)
177+
@compat (; inner_gradient, outer_pullback_extras) = extras
178+
return pullback(inner_gradient, outer(backend), x, dx, outer_pullback_extras)
188179
end
189180

190-
function hvp!(f::F, p, backend::AbstractADType, x, v, extras::HVPExtras) where {F}
191-
return hvp!(f, p, SecondOrder(backend, backend), x, v, extras)
181+
function hvp!(f::F, p, backend::AbstractADType, x, dx, extras::HVPExtras) where {F}
182+
return hvp!(f, p, SecondOrder(backend, backend), x, dx, extras)
192183
end
193184

194185
function hvp!(
195-
f::F, p, backend::SecondOrder, x, v, extras::ForwardOverForwardHVPExtras
186+
f::F, p, backend::SecondOrder, x, dx, extras::ForwardOverForwardHVPExtras
196187
) where {F}
197-
@compat (; inner_gradient_closure, outer_pushforward_extras) = extras
198-
return pushforward!(
199-
inner_gradient_closure, p, outer(backend), x, v, outer_pushforward_extras
200-
)
188+
@compat (; inner_gradient, outer_pushforward_extras) = extras
189+
return pushforward!(inner_gradient, p, outer(backend), x, dx, outer_pushforward_extras)
201190
end
202191

203192
function hvp!(
204-
f::F, p, backend::SecondOrder, x, v, extras::ForwardOverReverseHVPExtras
193+
f::F, p, backend::SecondOrder, x, dx, extras::ForwardOverReverseHVPExtras
205194
) where {F}
206-
@compat (; inner_gradient_closure, outer_pushforward_extras) = extras
207-
return pushforward!(
208-
inner_gradient_closure, p, outer(backend), x, v, outer_pushforward_extras
209-
)
195+
@compat (; inner_gradient, outer_pushforward_extras) = extras
196+
return pushforward!(inner_gradient, p, outer(backend), x, dx, outer_pushforward_extras)
210197
end
211198

212199
function hvp!(
213-
f::F, p, backend::SecondOrder, x, v, extras::ReverseOverForwardHVPExtras
200+
f::F, p, backend::SecondOrder, x, dx, extras::ReverseOverForwardHVPExtras
214201
) where {F}
215-
@compat (; inner_pushforward_closure_generator, outer_gradient_extras) = extras
216-
inner_pushforward_closure = inner_pushforward_closure_generator(v)
217-
return gradient!(inner_pushforward_closure, p, outer(backend), x, outer_gradient_extras)
202+
@compat (; outer_gradient_extras) = extras
203+
inner_pushforward = InnerPushforwardFixedSeed(f, nested(inner(backend)), dx)
204+
return gradient!(inner_pushforward, p, outer(backend), x, outer_gradient_extras)
218205
end
219206

220207
function hvp!(
221-
f::F, p, backend::SecondOrder, x, v, extras::ReverseOverReverseHVPExtras
208+
f::F, p, backend::SecondOrder, x, dx, extras::ReverseOverReverseHVPExtras
222209
) where {F}
223-
@compat (; inner_gradient_closure, outer_pullback_extras) = extras
224-
return pullback!(inner_gradient_closure, p, outer(backend), x, v, outer_pullback_extras)
210+
@compat (; inner_gradient, outer_pullback_extras) = extras
211+
return pullback!(inner_gradient, p, outer(backend), x, dx, outer_pullback_extras)
225212
end

DifferentiationInterface/src/second_order/second_derivative.jl

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,19 @@ abstract type SecondDerivativeExtras <: Extras end
4949

5050
struct NoSecondDerivativeExtras <: SecondDerivativeExtras end
5151

52-
struct ClosureSecondDerivativeExtras{C,E} <: SecondDerivativeExtras
53-
inner_derivative_closure::C
52+
struct InnerDerivative{F,B}
53+
f::F
54+
backend::B
55+
end
56+
57+
function (id::InnerDerivative)(x)
58+
@compat (; f, backend) = id
59+
return derivative(f, backend, x)
60+
end
61+
62+
struct ClosureSecondDerivativeExtras{ID<:InnerDerivative,E<:DerivativeExtras} <:
63+
SecondDerivativeExtras
64+
inner_derivative::ID
5465
outer_derivative_extras::E
5566
end
5667

@@ -59,12 +70,9 @@ function prepare_second_derivative(f::F, backend::AbstractADType, x) where {F}
5970
end
6071

6172
function prepare_second_derivative(f::F, backend::SecondOrder, x) where {F}
62-
inner_backend = nested(inner(backend))
63-
inner_derivative_closure(z) = derivative(f, inner_backend, z)
64-
outer_derivative_extras = prepare_derivative(
65-
inner_derivative_closure, outer(backend), x
66-
)
67-
return ClosureSecondDerivativeExtras(inner_derivative_closure, outer_derivative_extras)
73+
inner_derivative = InnerDerivative(f, nested(inner(backend)))
74+
outer_derivative_extras = prepare_derivative(inner_derivative, outer(backend), x)
75+
return ClosureSecondDerivativeExtras(inner_derivative, outer_derivative_extras)
6876
end
6977

7078
## One argument
@@ -100,8 +108,8 @@ end
100108
function second_derivative(
101109
f::F, backend::SecondOrder, x, extras::ClosureSecondDerivativeExtras
102110
) where {F}
103-
@compat (; inner_derivative_closure, outer_derivative_extras) = extras
104-
return derivative(inner_derivative_closure, outer(backend), x, outer_derivative_extras)
111+
@compat (; inner_derivative, outer_derivative_extras) = extras
112+
return derivative(inner_derivative, outer(backend), x, outer_derivative_extras)
105113
end
106114

107115
function value_derivative_and_second_derivative(
@@ -115,10 +123,10 @@ end
115123
function value_derivative_and_second_derivative(
116124
f::F, backend::SecondOrder, x, extras::ClosureSecondDerivativeExtras
117125
) where {F}
118-
@compat (; inner_derivative_closure, outer_derivative_extras) = extras
126+
@compat (; inner_derivative, outer_derivative_extras) = extras
119127
y = f(x)
120128
der, der2 = value_and_derivative(
121-
inner_derivative_closure, outer(backend), x, outer_derivative_extras
129+
inner_derivative, outer(backend), x, outer_derivative_extras
122130
)
123131
return y, der, der2
124132
end
@@ -132,10 +140,8 @@ end
132140
function second_derivative!(
133141
f::F, der2, backend::SecondOrder, x, extras::SecondDerivativeExtras
134142
) where {F}
135-
@compat (; inner_derivative_closure, outer_derivative_extras) = extras
136-
return derivative!(
137-
inner_derivative_closure, der2, outer(backend), x, outer_derivative_extras
138-
)
143+
@compat (; inner_derivative, outer_derivative_extras) = extras
144+
return derivative!(inner_derivative, der2, outer(backend), x, outer_derivative_extras)
139145
end
140146

141147
function value_derivative_and_second_derivative!(
@@ -149,10 +155,10 @@ end
149155
function value_derivative_and_second_derivative!(
150156
f::F, der, der2, backend::SecondOrder, x, extras::SecondDerivativeExtras
151157
) where {F}
152-
@compat (; inner_derivative_closure, outer_derivative_extras) = extras
158+
@compat (; inner_derivative, outer_derivative_extras) = extras
153159
y = f(x)
154160
new_der, _ = value_and_derivative!(
155-
inner_derivative_closure, der2, outer(backend), x, outer_derivative_extras
161+
inner_derivative, der2, outer(backend), x, outer_derivative_extras
156162
)
157163
return y, copyto!(der, new_der), der2
158164
end

0 commit comments

Comments
 (0)