1
1
# # Docstrings
2
2
3
3
"""
4
- prepare_hvp(f, backend, x, v ) -> extras
4
+ prepare_hvp(f, backend, x, dx ) -> extras
5
5
6
6
Create an `extras` object that can be given to [`hvp`](@ref) and its variants.
7
7
@@ -11,7 +11,7 @@ Create an `extras` object that can be given to [`hvp`](@ref) and its variants.
11
11
function prepare_hvp end
12
12
13
13
"""
14
- prepare_hvp_same_point(f, backend, x, v ) -> extras_same
14
+ prepare_hvp_same_point(f, backend, x, dx ) -> extras_same
15
15
16
16
Create an `extras_same` object that can be given to [`hvp`](@ref) and its variants _if they are applied at the same point `x`_.
17
17
@@ -21,16 +21,16 @@ Create an `extras_same` object that can be given to [`hvp`](@ref) and its varian
21
21
function prepare_hvp_same_point end
22
22
23
23
"""
24
- hvp(f, backend, x, v , [extras]) -> p
24
+ hvp(f, backend, x, dx , [extras]) -> p
25
25
26
- Compute the Hessian-vector product of `f` at point `x` with seed `v `.
26
+ Compute the Hessian-vector product of `f` at point `x` with seed `dx `.
27
27
"""
28
28
function hvp end
29
29
30
30
"""
31
- hvp!(f, p, backend, x, v , [extras]) -> p
31
+ hvp!(f, p, backend, x, dx , [extras]) -> p
32
32
33
- Compute the Hessian-vector product of `f` at point `x` with seed `v `, overwriting `p`.
33
+ Compute the Hessian-vector product of `f` at point `x` with seed `dx `, overwriting `p`.
34
34
"""
35
35
function hvp! end
36
36
@@ -45,181 +45,168 @@ abstract type HVPExtras <: Extras end
45
45
46
46
struct NoHVPExtras <: HVPExtras end
47
47
48
- #=
49
- Source: https://arxiv.org/abs/2403.14606 (section 8.1)
48
+ struct InnerGradient{F,B}
49
+ f:: F
50
+ backend:: B
51
+ end
52
+
53
+ function (ig:: InnerGradient )(x)
54
+ @compat (; f, backend) = ig
55
+ return gradient (f, backend, x)
56
+ end
57
+
58
+ struct InnerPushforwardFixedSeed{F,B,DX}
59
+ f:: F
60
+ backend:: B
61
+ dx:: DX
62
+ end
50
63
51
- By order of preference:
52
- - forward on reverse
53
- - reverse on forward
54
- - reverse on reverse
55
- - forward on forward
56
- =#
64
+ function (ipfs:: InnerPushforwardFixedSeed )(x)
65
+ @compat (; f, backend, dx) = ipfs
66
+ return pushforward (f, backend, x, dx)
67
+ end
57
68
58
- struct ForwardOverForwardHVPExtras{C,E } <: HVPExtras
59
- inner_gradient_closure :: C
69
+ struct ForwardOverForwardHVPExtras{IG <: InnerGradient ,E <: PushforwardExtras } <: HVPExtras
70
+ inner_gradient :: IG
60
71
outer_pushforward_extras:: E
61
72
end
62
73
63
- struct ForwardOverReverseHVPExtras{C,E } <: HVPExtras
64
- inner_gradient_closure :: C
74
+ struct ForwardOverReverseHVPExtras{IG <: InnerGradient ,E <: PushforwardExtras } <: HVPExtras
75
+ inner_gradient :: IG
65
76
outer_pushforward_extras:: E
66
77
end
67
78
68
- struct ReverseOverForwardHVPExtras{C,E} <: HVPExtras
69
- inner_pushforward_closure_generator:: C
79
+ struct ReverseOverForwardHVPExtras{E<: GradientExtras } <: HVPExtras
70
80
outer_gradient_extras:: E
71
81
end
72
82
73
- struct ReverseOverReverseHVPExtras{C,E } <: HVPExtras
74
- inner_gradient_closure :: C
83
+ struct ReverseOverReverseHVPExtras{IG <: InnerGradient ,E <: PullbackExtras } <: HVPExtras
84
+ inner_gradient :: IG
75
85
outer_pullback_extras:: E
76
86
end
77
87
78
- function prepare_hvp (f:: F , backend:: AbstractADType , x, v ) where {F}
79
- return prepare_hvp (f, SecondOrder (backend, backend), x, v )
88
+ function prepare_hvp (f:: F , backend:: AbstractADType , x, dx ) where {F}
89
+ return prepare_hvp (f, SecondOrder (backend, backend), x, dx )
80
90
end
81
91
82
- function prepare_hvp (f:: F , backend:: SecondOrder , x, v ) where {F}
83
- return prepare_hvp (f, backend, x, v , hvp_mode (backend))
92
+ function prepare_hvp (f:: F , backend:: SecondOrder , x, dx ) where {F}
93
+ return prepare_hvp (f, backend, x, dx , hvp_mode (backend))
84
94
end
85
95
86
- function prepare_hvp (f:: F , backend:: SecondOrder , x, v , :: ForwardOverForward ) where {F}
96
+ function prepare_hvp (f:: F , backend:: SecondOrder , x, dx , :: ForwardOverForward ) where {F}
87
97
# pushforward of many pushforwards in theory, but pushforward of gradient in practice
88
- inner_backend = nested (inner (backend))
89
- inner_gradient_closure (z) = gradient (f, inner_backend, z)
90
- outer_pushforward_extras = prepare_pushforward (
91
- inner_gradient_closure, outer (backend), x, v
92
- )
93
- return ForwardOverForwardHVPExtras (inner_gradient_closure, outer_pushforward_extras)
98
+ inner_gradient = InnerGradient (f, nested (inner (backend)))
99
+ outer_pushforward_extras = prepare_pushforward (inner_gradient, outer (backend), x, dx)
100
+ return ForwardOverForwardHVPExtras (inner_gradient, outer_pushforward_extras)
94
101
end
95
102
96
- function prepare_hvp (f:: F , backend:: SecondOrder , x, v , :: ForwardOverReverse ) where {F}
103
+ function prepare_hvp (f:: F , backend:: SecondOrder , x, dx , :: ForwardOverReverse ) where {F}
97
104
# pushforward of gradient
98
- inner_backend = nested (inner (backend))
99
- inner_gradient_closure (z) = gradient (f, inner_backend, z)
100
- outer_pushforward_extras = prepare_pushforward (
101
- inner_gradient_closure, outer (backend), x, v
102
- )
103
- return ForwardOverReverseHVPExtras (inner_gradient_closure, outer_pushforward_extras)
105
+ inner_gradient = InnerGradient (f, nested (inner (backend)))
106
+ outer_pushforward_extras = prepare_pushforward (inner_gradient, outer (backend), x, dx)
107
+ return ForwardOverReverseHVPExtras (inner_gradient, outer_pushforward_extras)
104
108
end
105
109
106
- function prepare_hvp (f:: F , backend:: SecondOrder , x, v , :: ReverseOverForward ) where {F}
110
+ function prepare_hvp (f:: F , backend:: SecondOrder , x, dx , :: ReverseOverForward ) where {F}
107
111
# gradient of pushforward
108
- # uses v in the closure
109
- inner_backend = nested (inner (backend))
110
- function inner_pushforward_closure_generator (v)
111
- inner_pushforward_closure (z) = pushforward (f, inner_backend, z, v)
112
- return inner_pushforward_closure
113
- end
114
- outer_gradient_extras = prepare_gradient (
115
- inner_pushforward_closure_generator (v), outer (backend), x
116
- )
117
- return ReverseOverForwardHVPExtras (
118
- inner_pushforward_closure_generator, outer_gradient_extras
119
- )
120
- end
121
-
122
- function prepare_hvp (f:: F , backend:: SecondOrder , x, v, :: ReverseOverReverse ) where {F}
112
+ # uses dx in the closure so it can't be stored
113
+ inner_pushforward = InnerPushforwardFixedSeed (f, nested (inner (backend)), dx)
114
+ outer_gradient_extras = prepare_gradient (inner_pushforward, outer (backend), x)
115
+ return ReverseOverForwardHVPExtras (outer_gradient_extras)
116
+ end
117
+
118
+ function prepare_hvp (f:: F , backend:: SecondOrder , x, dx, :: ReverseOverReverse ) where {F}
123
119
# pullback of the gradient
124
- inner_backend = nested (inner (backend))
125
- inner_gradient_closure (z) = gradient (f, inner_backend, z)
126
- outer_pullback_extras = prepare_pullback (inner_gradient_closure, outer (backend), x, v)
127
- return ReverseOverReverseHVPExtras (inner_gradient_closure, outer_pullback_extras)
120
+ inner_gradient = InnerGradient (f, nested (inner (backend)))
121
+ outer_pullback_extras = prepare_pullback (inner_gradient, outer (backend), x, dx)
122
+ return ReverseOverReverseHVPExtras (inner_gradient, outer_pullback_extras)
128
123
end
129
124
130
125
# # Preparation (same point)
131
126
132
127
function prepare_hvp_same_point (
133
- f:: F , backend:: AbstractADType , x, v , extras:: HVPExtras
128
+ f:: F , backend:: AbstractADType , x, dx , extras:: HVPExtras
134
129
) where {F}
135
130
return extras
136
131
end
137
132
138
- function prepare_hvp_same_point (f:: F , backend:: AbstractADType , x, v ) where {F}
139
- extras = prepare_hvp (f, backend, x, v )
140
- return prepare_hvp_same_point (f, backend, x, v , extras)
133
+ function prepare_hvp_same_point (f:: F , backend:: AbstractADType , x, dx ) where {F}
134
+ extras = prepare_hvp (f, backend, x, dx )
135
+ return prepare_hvp_same_point (f, backend, x, dx , extras)
141
136
end
142
137
143
138
# # One argument
144
139
145
- function hvp (f:: F , backend:: AbstractADType , x, v ) where {F}
146
- return hvp (f, backend, x, v , prepare_hvp (f, backend, x, v ))
140
+ function hvp (f:: F , backend:: AbstractADType , x, dx ) where {F}
141
+ return hvp (f, backend, x, dx , prepare_hvp (f, backend, x, dx ))
147
142
end
148
143
149
- function hvp! (f:: F , p, backend:: AbstractADType , x, v ) where {F}
150
- return hvp! (f, p, backend, x, v , prepare_hvp (f, backend, x, v ))
144
+ function hvp! (f:: F , p, backend:: AbstractADType , x, dx ) where {F}
145
+ return hvp! (f, p, backend, x, dx , prepare_hvp (f, backend, x, dx ))
151
146
end
152
147
153
- function hvp (f:: F , backend:: AbstractADType , x, v , extras:: HVPExtras ) where {F}
154
- return hvp (f, SecondOrder (backend, backend), x, v , extras)
148
+ function hvp (f:: F , backend:: AbstractADType , x, dx , extras:: HVPExtras ) where {F}
149
+ return hvp (f, SecondOrder (backend, backend), x, dx , extras)
155
150
end
156
151
157
152
function hvp (
158
- f:: F , backend:: SecondOrder , x, v , extras:: ForwardOverForwardHVPExtras
153
+ f:: F , backend:: SecondOrder , x, dx , extras:: ForwardOverForwardHVPExtras
159
154
) where {F}
160
- @compat (; inner_gradient_closure, outer_pushforward_extras) = extras
161
- return pushforward (
162
- inner_gradient_closure, outer (backend), x, v, outer_pushforward_extras
163
- )
155
+ @compat (; inner_gradient, outer_pushforward_extras) = extras
156
+ return pushforward (inner_gradient, outer (backend), x, dx, outer_pushforward_extras)
164
157
end
165
158
166
159
function hvp (
167
- f:: F , backend:: SecondOrder , x, v , extras:: ForwardOverReverseHVPExtras
160
+ f:: F , backend:: SecondOrder , x, dx , extras:: ForwardOverReverseHVPExtras
168
161
) where {F}
169
- @compat (; inner_gradient_closure, outer_pushforward_extras) = extras
170
- return pushforward (
171
- inner_gradient_closure, outer (backend), x, v, outer_pushforward_extras
172
- )
162
+ @compat (; inner_gradient, outer_pushforward_extras) = extras
163
+ return pushforward (inner_gradient, outer (backend), x, dx, outer_pushforward_extras)
173
164
end
174
165
175
166
function hvp (
176
- f:: F , backend:: SecondOrder , x, v , extras:: ReverseOverForwardHVPExtras
167
+ f:: F , backend:: SecondOrder , x, dx , extras:: ReverseOverForwardHVPExtras
177
168
) where {F}
178
- @compat (; inner_pushforward_closure_generator, outer_gradient_extras) = extras
179
- inner_pushforward_closure = inner_pushforward_closure_generator (v )
180
- return gradient (inner_pushforward_closure , outer (backend), x, outer_gradient_extras)
169
+ @compat (; outer_gradient_extras) = extras
170
+ inner_pushforward = InnerPushforwardFixedSeed (f, nested ( inner (backend)), dx )
171
+ return gradient (inner_pushforward , outer (backend), x, outer_gradient_extras)
181
172
end
182
173
183
174
function hvp (
184
- f:: F , backend:: SecondOrder , x, v , extras:: ReverseOverReverseHVPExtras
175
+ f:: F , backend:: SecondOrder , x, dx , extras:: ReverseOverReverseHVPExtras
185
176
) where {F}
186
- @compat (; inner_gradient_closure , outer_pullback_extras) = extras
187
- return pullback (inner_gradient_closure , outer (backend), x, v , outer_pullback_extras)
177
+ @compat (; inner_gradient , outer_pullback_extras) = extras
178
+ return pullback (inner_gradient , outer (backend), x, dx , outer_pullback_extras)
188
179
end
189
180
190
- function hvp! (f:: F , p, backend:: AbstractADType , x, v , extras:: HVPExtras ) where {F}
191
- return hvp! (f, p, SecondOrder (backend, backend), x, v , extras)
181
+ function hvp! (f:: F , p, backend:: AbstractADType , x, dx , extras:: HVPExtras ) where {F}
182
+ return hvp! (f, p, SecondOrder (backend, backend), x, dx , extras)
192
183
end
193
184
194
185
function hvp! (
195
- f:: F , p, backend:: SecondOrder , x, v , extras:: ForwardOverForwardHVPExtras
186
+ f:: F , p, backend:: SecondOrder , x, dx , extras:: ForwardOverForwardHVPExtras
196
187
) where {F}
197
- @compat (; inner_gradient_closure, outer_pushforward_extras) = extras
198
- return pushforward! (
199
- inner_gradient_closure, p, outer (backend), x, v, outer_pushforward_extras
200
- )
188
+ @compat (; inner_gradient, outer_pushforward_extras) = extras
189
+ return pushforward! (inner_gradient, p, outer (backend), x, dx, outer_pushforward_extras)
201
190
end
202
191
203
192
function hvp! (
204
- f:: F , p, backend:: SecondOrder , x, v , extras:: ForwardOverReverseHVPExtras
193
+ f:: F , p, backend:: SecondOrder , x, dx , extras:: ForwardOverReverseHVPExtras
205
194
) where {F}
206
- @compat (; inner_gradient_closure, outer_pushforward_extras) = extras
207
- return pushforward! (
208
- inner_gradient_closure, p, outer (backend), x, v, outer_pushforward_extras
209
- )
195
+ @compat (; inner_gradient, outer_pushforward_extras) = extras
196
+ return pushforward! (inner_gradient, p, outer (backend), x, dx, outer_pushforward_extras)
210
197
end
211
198
212
199
function hvp! (
213
- f:: F , p, backend:: SecondOrder , x, v , extras:: ReverseOverForwardHVPExtras
200
+ f:: F , p, backend:: SecondOrder , x, dx , extras:: ReverseOverForwardHVPExtras
214
201
) where {F}
215
- @compat (; inner_pushforward_closure_generator, outer_gradient_extras) = extras
216
- inner_pushforward_closure = inner_pushforward_closure_generator (v )
217
- return gradient! (inner_pushforward_closure , p, outer (backend), x, outer_gradient_extras)
202
+ @compat (; outer_gradient_extras) = extras
203
+ inner_pushforward = InnerPushforwardFixedSeed (f, nested ( inner (backend)), dx )
204
+ return gradient! (inner_pushforward , p, outer (backend), x, outer_gradient_extras)
218
205
end
219
206
220
207
function hvp! (
221
- f:: F , p, backend:: SecondOrder , x, v , extras:: ReverseOverReverseHVPExtras
208
+ f:: F , p, backend:: SecondOrder , x, dx , extras:: ReverseOverReverseHVPExtras
222
209
) where {F}
223
- @compat (; inner_gradient_closure , outer_pullback_extras) = extras
224
- return pullback! (inner_gradient_closure , p, outer (backend), x, v , outer_pullback_extras)
210
+ @compat (; inner_gradient , outer_pullback_extras) = extras
211
+ return pullback! (inner_gradient , p, outer (backend), x, dx , outer_pullback_extras)
225
212
end
0 commit comments