Skip to content

Commit 0124a0e

Browse files
authored
fix: take absstep into account for FiniteDiff (#812)
* Add tests * Fix * Changelog
1 parent 272eeb5 commit 0124a0e

File tree

4 files changed

+61
-9
lines changed

4 files changed

+61
-9
lines changed

DifferentiationInterface/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
### Fixed
1313

14+
- Take `absstep` into account for FiniteDiff ([#812])
1415
- Make basis work for `CuArray` ([#810])
1516

1617
## [0.7.0]
@@ -39,6 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3940
[0.6.54]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.53...DifferentiationInterface-v0.6.54
4041
[0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53
4142

43+
[#812]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/812
4244
[#810]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/810
4345
[#799]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/799
4446
[#795]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/795

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function DI.prepare_pushforward_nokwarg(
2727
absstep = if isnothing(backend.absstep)
2828
relstep
2929
else
30-
backend.relstep
30+
backend.absstep
3131
end
3232
dir = backend.dir
3333
return FiniteDiffOneArgPushforwardPrep(_sig, cache, relstep, absstep, dir)
@@ -144,7 +144,7 @@ function DI.prepare_derivative_nokwarg(
144144
absstep = if isnothing(backend.absstep)
145145
relstep
146146
else
147-
backend.relstep
147+
backend.absstep
148148
end
149149
dir = backend.dir
150150
return FiniteDiffOneArgDerivativePrep(_sig, cache, relstep, absstep, dir)
@@ -269,7 +269,7 @@ function DI.prepare_gradient_nokwarg(
269269
absstep = if isnothing(backend.absstep)
270270
relstep
271271
else
272-
backend.relstep
272+
backend.absstep
273273
end
274274
dir = backend.dir
275275
return FiniteDiffGradientPrep(_sig, cache, relstep, absstep, dir)
@@ -359,7 +359,7 @@ function DI.prepare_jacobian_nokwarg(
359359
absstep = if isnothing(backend.absstep)
360360
relstep
361361
else
362-
backend.relstep
362+
backend.absstep
363363
end
364364
dir = backend.dir
365365
return FiniteDiffOneArgJacobianPrep(_sig, cache, relstep, absstep, dir)
@@ -465,8 +465,16 @@ function DI.prepare_hessian_nokwarg(
465465
else
466466
backend.relstep
467467
end
468-
absstep_g = isnothing(backend.absstep) ? relstep_g : backend.absstep
469-
absstep_h = isnothing(backend.absstep) ? relstep_h : backend.absstep
468+
absstep_g = if isnothing(backend.absstep)
469+
relstep_g
470+
else
471+
backend.absstep
472+
end
473+
absstep_h = if isnothing(backend.absstep)
474+
relstep_h
475+
else
476+
backend.absstep
477+
end
470478
return FiniteDiffHessianPrep(
471479
_sig, gradient_cache, hessian_cache, relstep_g, absstep_g, relstep_h, absstep_h
472480
)

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function DI.prepare_pushforward_nokwarg(
3131
absstep = if isnothing(backend.absstep)
3232
relstep
3333
else
34-
backend.relstep
34+
backend.absstep
3535
end
3636
dir = backend.dir
3737
return FiniteDiffTwoArgPushforwardPrep(_sig, cache, relstep, absstep, dir)
@@ -175,7 +175,7 @@ function DI.prepare_derivative_nokwarg(
175175
absstep = if isnothing(backend.absstep)
176176
relstep
177177
else
178-
backend.relstep
178+
backend.absstep
179179
end
180180
dir = backend.dir
181181
return FiniteDiffTwoArgDerivativePrep(_sig, cache, relstep, absstep, dir)
@@ -295,7 +295,7 @@ function DI.prepare_jacobian_nokwarg(
295295
absstep = if isnothing(backend.absstep)
296296
relstep
297297
else
298-
backend.relstep
298+
backend.absstep
299299
end
300300
dir = backend.dir
301301
return FiniteDiffTwoArgJacobianPrep(_sig, cache, relstep, absstep, dir)

DifferentiationInterface/test/Back/FiniteDiff/test.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,45 @@ end
7272
logging=LOGGING,
7373
)
7474
end;
75+
76+
@testset "Step size" begin # fix 811
77+
backend = AutoFiniteDiff(; absstep=1000, relstep=0.1)
78+
preps = [
79+
prepare_pushforward(identity, backend, 1.0, (1.0,)),
80+
prepare_pushforward(copyto!, [0.0], backend, [1.0], ([1.0],)),
81+
prepare_derivative(identity, backend, 1.0),
82+
prepare_derivative((y, x) -> y .= x, [0.0], backend, 1.0),
83+
prepare_gradient(sum, backend, [1.0]),
84+
prepare_jacobian(identity, backend, [1.0]),
85+
prepare_jacobian(copyto!, [0.0], backend, [1.0]),
86+
]
87+
for prep in preps
88+
@test prep.relstep == 0.1
89+
@test prep.absstep == 1000
90+
end
91+
prep = prepare_hessian(sum, backend, [1.0])
92+
@test prep.absstep_g == 1000
93+
@test prep.absstep_h == 1000
94+
@test prep.relstep_g == 0.1
95+
@test prep.relstep_h == 0.1
96+
97+
backend = AutoFiniteDiff(; relstep=0.1)
98+
preps = [
99+
prepare_pushforward(identity, backend, 1.0, (1.0,)),
100+
prepare_pushforward(copyto!, [0.0], backend, [1.0], ([1.0],)),
101+
prepare_derivative(identity, backend, 1.0),
102+
prepare_derivative((y, x) -> y .= x, [0.0], backend, 1.0),
103+
prepare_gradient(sum, backend, [1.0]),
104+
prepare_jacobian(identity, backend, [1.0]),
105+
prepare_jacobian(copyto!, [0.0], backend, [1.0]),
106+
]
107+
for prep in preps
108+
@test prep.relstep == 0.1
109+
@test prep.absstep == 0.1
110+
end
111+
prep = prepare_hessian(sum, backend, [1.0])
112+
@test prep.absstep_g == 0.1
113+
@test prep.absstep_h == 0.1
114+
@test prep.relstep_g == 0.1
115+
@test prep.relstep_h == 0.1
116+
end

0 commit comments

Comments
 (0)