Skip to content

Commit fa551e5

Browse files
authored
Rename p to dg for Hessian-vector product (#318)
1 parent b41db2f commit fa551e5

File tree

3 files changed

+24
-24
lines changed

3 files changed

+24
-24
lines changed

DifferentiationInterface/src/second_order/hvp.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ Create an `extras_same` object that can be given to [`hvp`](@ref) and its varian
2121
function prepare_hvp_same_point end
2222

2323
"""
24-
hvp(f, backend, x, dx, [extras]) -> p
24+
hvp(f, backend, x, dx, [extras]) -> dg
2525
2626
Compute the Hessian-vector product of `f` at point `x` with seed `dx`.
2727
"""
2828
function hvp end
2929

3030
"""
31-
hvp!(f, p, backend, x, dx, [extras]) -> p
31+
hvp!(f, dg, backend, x, dx, [extras]) -> dg
3232
33-
Compute the Hessian-vector product of `f` at point `x` with seed `dx`, overwriting `p`.
33+
Compute the Hessian-vector product of `f` at point `x` with seed `dx`, overwriting `dg`.
3434
"""
3535
function hvp! end
3636

@@ -141,8 +141,8 @@ function hvp(f::F, backend::AbstractADType, x, dx) where {F}
141141
return hvp(f, backend, x, dx, prepare_hvp(f, backend, x, dx))
142142
end
143143

144-
function hvp!(f::F, p, backend::AbstractADType, x, dx) where {F}
145-
return hvp!(f, p, backend, x, dx, prepare_hvp(f, backend, x, dx))
144+
function hvp!(f::F, dg, backend::AbstractADType, x, dx) where {F}
145+
return hvp!(f, dg, backend, x, dx, prepare_hvp(f, backend, x, dx))
146146
end
147147

148148
function hvp(f::F, backend::AbstractADType, x, dx, extras::HVPExtras) where {F}
@@ -178,35 +178,35 @@ function hvp(
178178
return pullback(inner_gradient, outer(backend), x, dx, outer_pullback_extras)
179179
end
180180

181-
function hvp!(f::F, p, backend::AbstractADType, x, dx, extras::HVPExtras) where {F}
182-
return hvp!(f, p, SecondOrder(backend, backend), x, dx, extras)
181+
function hvp!(f::F, dg, backend::AbstractADType, x, dx, extras::HVPExtras) where {F}
182+
return hvp!(f, dg, SecondOrder(backend, backend), x, dx, extras)
183183
end
184184

185185
function hvp!(
186-
f::F, p, backend::SecondOrder, x, dx, extras::ForwardOverForwardHVPExtras
186+
f::F, dg, backend::SecondOrder, x, dx, extras::ForwardOverForwardHVPExtras
187187
) where {F}
188188
@compat (; inner_gradient, outer_pushforward_extras) = extras
189-
return pushforward!(inner_gradient, p, outer(backend), x, dx, outer_pushforward_extras)
189+
return pushforward!(inner_gradient, dg, outer(backend), x, dx, outer_pushforward_extras)
190190
end
191191

192192
function hvp!(
193-
f::F, p, backend::SecondOrder, x, dx, extras::ForwardOverReverseHVPExtras
193+
f::F, dg, backend::SecondOrder, x, dx, extras::ForwardOverReverseHVPExtras
194194
) where {F}
195195
@compat (; inner_gradient, outer_pushforward_extras) = extras
196-
return pushforward!(inner_gradient, p, outer(backend), x, dx, outer_pushforward_extras)
196+
return pushforward!(inner_gradient, dg, outer(backend), x, dx, outer_pushforward_extras)
197197
end
198198

199199
function hvp!(
200-
f::F, p, backend::SecondOrder, x, dx, extras::ReverseOverForwardHVPExtras
200+
f::F, dg, backend::SecondOrder, x, dx, extras::ReverseOverForwardHVPExtras
201201
) where {F}
202202
@compat (; outer_gradient_extras) = extras
203203
inner_pushforward = InnerPushforwardFixedSeed(f, nested(inner(backend)), dx)
204-
return gradient!(inner_pushforward, p, outer(backend), x, outer_gradient_extras)
204+
return gradient!(inner_pushforward, dg, outer(backend), x, outer_gradient_extras)
205205
end
206206

207207
function hvp!(
208-
f::F, p, backend::SecondOrder, x, dx, extras::ReverseOverReverseHVPExtras
208+
f::F, dg, backend::SecondOrder, x, dx, extras::ReverseOverReverseHVPExtras
209209
) where {F}
210210
@compat (; inner_gradient, outer_pullback_extras) = extras
211-
return pullback!(inner_gradient, p, outer(backend), x, dx, outer_pullback_extras)
211+
return pullback!(inner_gradient, dg, outer(backend), x, dx, outer_pullback_extras)
212212
end

DifferentiationInterfaceTest/src/tests/benchmark.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ function run_benchmark!(
946946
# benchmark
947947
extras = prepare_hvp(f, ba, x, dx)
948948
bench0 = @be prepare_hvp(f, ba, x, dx) samples = 1 evals = 1
949-
bench1 = @be (p=mysimilar(x), ext=deepcopy(extras)) hvp!(f, _.p, ba, x, dx, _.ext) evals = 1
949+
bench1 = @be (dg=mysimilar(x), ext=deepcopy(extras)) hvp!(f, _.dg, ba, x, dx, _.ext) evals = 1
950950
# count
951951
cc = CallCounter(f)
952952
extras = prepare_hvp(cc, ba, x, dx)

DifferentiationInterfaceTest/src/tests/correctness.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ function test_correctness(
910910
ref_backend,
911911
)
912912
@compat (; f, x, dx) = new_scen = deepcopy(scen)
913-
p_true = if ref_backend isa AbstractADType
913+
dg_true = if ref_backend isa AbstractADType
914914
hvp(f, ref_backend, x, dx)
915915
else
916916
new_scen.ref(x, dx)
@@ -921,14 +921,14 @@ function test_correctness(
921921
(prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)),),
922922
(prepare_hvp_same_point(f, ba, x, mycopy_random(dx)),),
923923
])
924-
p1 = hvp(f, ba, x, dx, extras_tup...)
924+
dg1 = hvp(f, ba, x, dx, extras_tup...)
925925

926926
let ()(x, y) = isapprox(x, y; atol, rtol)
927927
@testset "Extras type" begin
928928
@test isempty(extras_tup) || only(extras_tup) isa HVPExtras
929929
end
930930
@testset "HVP value" begin
931-
@test p1 p_true
931+
@test dg1 dg_true
932932
end
933933
end
934934
end
@@ -945,7 +945,7 @@ function test_correctness(
945945
ref_backend,
946946
)
947947
@compat (; f, x, dx) = new_scen = deepcopy(scen)
948-
p_true = if ref_backend isa AbstractADType
948+
dg_true = if ref_backend isa AbstractADType
949949
hvp(f, ref_backend, x, dx)
950950
else
951951
new_scen.ref(x, dx)
@@ -956,16 +956,16 @@ function test_correctness(
956956
(prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)),),
957957
(prepare_hvp_same_point(f, ba, x, mycopy_random(dx)),),
958958
])
959-
p1_in = mysimilar(x)
960-
p1 = hvp!(f, p1_in, ba, x, dx, extras_tup...)
959+
dg1_in = mysimilar(x)
960+
dg1 = hvp!(f, dg1_in, ba, x, dx, extras_tup...)
961961

962962
let ()(x, y) = isapprox(x, y; atol, rtol)
963963
@testset "Extras type" begin
964964
@test isempty(extras_tup) || only(extras_tup) isa HVPExtras
965965
end
966966
@testset "HVP value" begin
967-
@test p1_in p_true
968-
@test p1 p_true
967+
@test dg1_in dg_true
968+
@test dg1 dg_true
969969
end
970970
end
971971
end

0 commit comments

Comments
 (0)