Skip to content

Commit 591d128

Browse files
authored
Fixing extended trace failure for Adam and AdaMax and generalising alpha parameter to accept callable object (scheduler) (#1115)
* adding alpha to state and generalizing alpha to accept a function (scheduler) * removing unused variables * adding alpha to AdaMax state and generalizing alpha to accept a function (scheduler) * updating the docstring for Adam to add description of scheduled alpha constructors * updating the docstring for AdaMax to add description of scheduled alpha constructors * adding tests for scheduled Adam and AdaMax, which covers testing extended_trace=true case * adding default constant alpha case tests for extended_trace=true for Adam and AdaMax
1 parent 711dfec commit 591d128

File tree

3 files changed

+164
-18
lines changed

3 files changed

+164
-18
lines changed

src/multivariate/solvers/first_order/adam.jl

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,37 @@
11
"""
22
# Adam
3-
## Constructor
3+
## Constant `alpha` case (default) constructor:
4+
45
```julia
56
Adam(; alpha=0.0001, beta_mean=0.9, beta_var=0.999, epsilon=1e-8)
67
```
8+
9+
## Scheduled `alpha` case constructor:
10+
11+
Alternative to the above (default) usage where `alpha` is a fixed constant for
12+
all the iterations, the following constructor provides flexibility for `alpha`
13+
to be a callable object (a scheduler) that maps the current iteration count to
14+
a value of `alpha` that is to-be used for the current optimization iteraion's
15+
update step. This helps us in scheduling `alpha` over the iterations as
16+
desired, using the following usage,
17+
18+
```julia
19+
# Let alpha_scheduler be iteration -> alpha value mapping callable object
20+
Adam(; alpha=alpha_scheduler, other_kwargs...)
21+
```
22+
723
## Description
8-
Adam is a gradient based optimizer that choses its search direction by building up estimates of the first two moments of the gradient vector. This makes it suitable for problems with a stochastic objective and thus gradient. The method is introduced in [1] where the related AdaMax method is also introduced, see `?AdaMax` for more information on that method.
24+
Adam is a gradient based optimizer that choses its search direction by building
25+
up estimates of the first two moments of the gradient vector. This makes it
26+
suitable for problems with a stochastic objective and thus gradient. The method
27+
is introduced in [1] where the related AdaMax method is also introduced, see
28+
`?AdaMax` for more information on that method.
929
1030
## References
1131
[1] https://arxiv.org/abs/1412.6980
1232
"""
13-
struct Adam{T, Tm} <: FirstOrderOptimizer
14-
α::T
33+
struct Adam{Tα, T, Tm} <: FirstOrderOptimizer
34+
α::Tα
1535
β₁::T
1636
β₂::T
1737
ϵ::T
@@ -32,20 +52,29 @@ mutable struct AdamState{Tx, T, Tm, Tu, Ti} <: AbstractOptimizerState
3252
s::Tx
3353
m::Tm
3454
u::Tu
55+
alpha::T
3556
iter::Ti
3657
end
3758
function reset!(method, state::AdamState, obj, x)
3859
value_gradient!!(obj, x)
3960
end
61+
62+
function _get_init_params(method::Adam{T}) where T <: Real
63+
method.α, method.β₁, method.β₂
64+
end
65+
66+
function _get_init_params(method::Adam)
67+
method.α(1), method.β₁, method.β₂
68+
end
69+
4070
function initial_state(method::Adam, options, d, initial_x::AbstractArray{T}) where T
4171
initial_x = copy(initial_x)
4272

4373
value_gradient!!(d, initial_x)
44-
α, β₁, β₂ = method.α, method.β₁, method.β₂
74+
α, β₁, β₂ = _get_init_params(method)
4575

4676
m = copy(gradient(d))
4777
u = zero(m)
48-
a = 1 - β₁
4978
iter = 0
5079

5180
AdamState(initial_x, # Maintain current state in state.x
@@ -54,13 +83,29 @@ function initial_state(method::Adam, options, d, initial_x::AbstractArray{T}) wh
5483
similar(initial_x), # Maintain current search direction in state.s
5584
m,
5685
u,
86+
α,
5787
iter)
5888
end
5989

90+
function _update_iter_alpha_in_state!(
91+
state::AdamState, method::Adam{T}) where T <: Real
92+
93+
state.iter = state.iter+1
94+
end
95+
96+
function _update_iter_alpha_in_state!(
97+
state::AdamState, method::Adam)
98+
99+
state.iter = state.iter+1
100+
state.alpha = method.α(state.iter)
101+
end
102+
60103
function update_state!(d, state::AdamState{T}, method::Adam) where T
61-
state.iter = state.iter+1
104+
105+
_update_iter_alpha_in_state!(state, method)
62106
value_gradient!(d, state.x)
63-
α, β₁, β₂, ϵ = method.α, method.β₁, method.β₂, method.ϵ
107+
108+
α, β₁, β₂, ϵ = state.alpha, method.β₁, method.β₂, method.ϵ
64109
a = 1 - β₁
65110
b = 1 - β₂
66111

src/multivariate/solvers/first_order/adamax.jl

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,37 @@
11
"""
22
# AdaMax
3-
## Constructor
3+
## Constant `alpha` case (default) constructor:
4+
45
```julia
56
AdaMax(; alpha=0.002, beta_mean=0.9, beta_var=0.999, epsilon=1e-8)
67
```
7-
## Description
8-
AdaMax is a gradient based optimizer that choses its search direction by building up estimates of the first two moments of the gradient vector. This makes it suitable for problems with a stochastic objective and thus gradient. The method is introduced in [1] where the related Adam method is also introduced, see `?Adam` for more information on that method.
98
9+
## Scheduled `alpha` case constructor:
10+
11+
Alternative to the above (default) usage where `alpha` is a fixed constant for
12+
all the iterations, the following constructor provides flexibility for `alpha`
13+
to be a callable object (a scheduler) that maps the current iteration count to
14+
a value of `alpha` that is to-be used for the current optimization iteraion's
15+
update step. This helps us in scheduling `alpha` over the iterations as
16+
desired, using the following usage,
1017
18+
```julia
19+
# Let alpha_scheduler be iteration -> alpha value mapping callable object
20+
AdaMax(; alpha=alpha_scheduler, other_kwargs...)
21+
```
22+
23+
## Description
24+
AdaMax is a gradient based optimizer that choses its search direction by
25+
building up estimates of the first two moments of the gradient vector. This
26+
makes it suitable for problems with a stochastic objective and thus gradient.
27+
The method is introduced in [1] where the related Adam method is also
28+
introduced, see `?Adam` for more information on that method.
29+
30+
## References
1131
[1] https://arxiv.org/abs/1412.6980
1232
"""
13-
14-
struct AdaMax{T,Tm} <: FirstOrderOptimizer
15-
α::T
33+
struct AdaMax{Tα, T, Tm} <: FirstOrderOptimizer
34+
α::Tα
1635
β₁::T
1736
β₂::T
1837
ϵ::T
@@ -33,20 +52,29 @@ mutable struct AdaMaxState{Tx, T, Tm, Tu, Ti} <: AbstractOptimizerState
3352
s::Tx
3453
m::Tm
3554
u::Tu
55+
alpha::T
3656
iter::Ti
3757
end
3858
function reset!(method, state::AdaMaxState, obj, x)
3959
value_gradient!!(obj, x)
4060
end
61+
62+
function _get_init_params(method::AdaMax{T}) where T <: Real
63+
method.α, method.β₁, method.β₂
64+
end
65+
66+
function _get_init_params(method::AdaMax)
67+
method.α(1), method.β₁, method.β₂
68+
end
69+
4170
function initial_state(method::AdaMax, options, d, initial_x::AbstractArray{T}) where T
4271
initial_x = copy(initial_x)
4372

4473
value_gradient!!(d, initial_x)
45-
α, β₁, β₂ = method.α, method.β₁, method.β₂
74+
α, β₁, β₂ = _get_init_params(method)
4675

4776
m = copy(gradient(d))
4877
u = zero(m)
49-
a = 1 - β₁
5078
iter = 0
5179

5280
AdaMaxState(initial_x, # Maintain current state in state.x
@@ -55,13 +83,27 @@ function initial_state(method::AdaMax, options, d, initial_x::AbstractArray{T})
5583
similar(initial_x), # Maintain current search direction in state.s
5684
m,
5785
u,
86+
α,
5887
iter)
5988
end
6089

90+
function _update_iter_alpha_in_state!(
91+
state::AdaMaxState, method::AdaMax{T}) where T <: Real
92+
93+
state.iter = state.iter+1
94+
end
95+
96+
function _update_iter_alpha_in_state!(
97+
state::AdaMaxState, method::AdaMax)
98+
99+
state.iter = state.iter+1
100+
state.alpha = method.α(state.iter)
101+
end
102+
61103
function update_state!(d, state::AdaMaxState{T}, method::AdaMax) where T
62-
state.iter = state.iter+1
104+
_update_iter_alpha_in_state!(state, method)
63105
value_gradient!(d, state.x)
64-
α, β₁, β₂, ϵ = method.α, method.β₁, method.β₂, method.ϵ
106+
α, β₁, β₂, ϵ = state.alpha, method.β₁, method.β₂, method.ϵ
65107
a = 1 - β₁
66108
m, u = state.m, state.u
67109

test/multivariate/solvers/first_order/adam_adamax.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
skip = skip,
2222
show_name = debug_printing)
2323
end
24+
2425
@testset "AdaMax" begin
2526
f(x) = x[1]^4
2627
function g!(storage, x)
@@ -45,3 +46,61 @@ end
4546
show_name=debug_printing,
4647
iteration_exceptions = (("Trigonometric", 1_000_000,),))
4748
end
49+
50+
@testset "Adam-scheduler" begin
51+
f(x) = x[1]^4
52+
function g!(storage, x)
53+
storage[1] = 4 * x[1]^3
54+
return
55+
end
56+
57+
initial_x = [1.0]
58+
options = Optim.Options(show_trace = debug_printing, allow_f_increases=true, iterations=100_000)
59+
alpha_scheduler(iter) = 0.0001*(1 + 0.99^iter)
60+
results = Optim.optimize(f, g!, initial_x, Adam(alpha=alpha_scheduler), options)
61+
@test norm(Optim.minimum(results)) < 1e-6
62+
@test summary(results) == "Adam"
63+
64+
# verifying the alpha values over iterations and also testing extended_trace
65+
# this way we test both alpha scheduler and the working of
66+
# extended_trace=true option
67+
68+
options = Optim.Options(show_trace = debug_printing, allow_f_increases=true, iterations=1000, extended_trace=true, store_trace=true)
69+
results = Optim.optimize(f, g!, initial_x, Adam(alpha=1e-5), options)
70+
71+
@test prod(map(iter -> results.trace[iter].metadata["Current step size"], 2:results.iterations+1) .== 1e-5)
72+
73+
options = Optim.Options(show_trace = debug_printing, allow_f_increases=true, iterations=1000, extended_trace=true, store_trace=true)
74+
results = Optim.optimize(f, g!, initial_x, Adam(alpha=alpha_scheduler), options)
75+
76+
@test map(iter -> results.trace[iter].metadata["Current step size"], 2:results.iterations+1) == alpha_scheduler.(1:results.iterations)
77+
end
78+
79+
@testset "AdaMax-scheduler" begin
80+
f(x) = x[1]^4
81+
function g!(storage, x)
82+
storage[1] = 4 * x[1]^3
83+
return
84+
end
85+
86+
initial_x = [1.0]
87+
options = Optim.Options(show_trace = debug_printing, allow_f_increases=true, iterations=100_000)
88+
alpha_scheduler(iter) = 0.002*(1 + 0.99^iter)
89+
results = Optim.optimize(f, g!, initial_x, AdaMax(alpha=alpha_scheduler), options)
90+
@test norm(Optim.minimum(results)) < 1e-6
91+
@test summary(results) == "AdaMax"
92+
93+
# verifying the alpha values over iterations and also testing extended_trace
94+
# this way we test both alpha scheduler and the working of
95+
# extended_trace=true option
96+
97+
options = Optim.Options(show_trace = debug_printing, allow_f_increases=true, iterations=1000, extended_trace=true, store_trace=true)
98+
results = Optim.optimize(f, g!, initial_x, AdaMax(alpha=1e-4), options)
99+
100+
@test prod(map(iter -> results.trace[iter].metadata["Current step size"], 2:results.iterations+1) .== 1e-4)
101+
102+
options = Optim.Options(show_trace = debug_printing, allow_f_increases=true, iterations=1000, extended_trace=true, store_trace=true)
103+
results = Optim.optimize(f, g!, initial_x, AdaMax(alpha=alpha_scheduler), options)
104+
105+
@test map(iter -> results.trace[iter].metadata["Current step size"], 2:results.iterations+1) == alpha_scheduler.(1:results.iterations)
106+
end

0 commit comments

Comments
 (0)