From 62cd7ddad3b9830be4417a763ba91d3d7ceaaeb4 Mon Sep 17 00:00:00 2001 From: Symon <59005260+4SAnalyticsnModelling@users.noreply.github.com> Date: Wed, 12 Jul 2023 10:53:13 -0600 Subject: [PATCH 1/6] Update rules.jl - add rules for a new optimizer PAdam --- src/rules.jl | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/rules.jl b/src/rules.jl index e994b740..74036a1b 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -543,7 +543,41 @@ function apply!(o::AdaBelief, state, x, dx) return (mt, st, βt .* β), dx′ end +""" +PAdam(η = 1f-2, β = (9f-1, 9.99f-1), ρ = 2.5f-1, eps(typeof(η))) + +The partially adaptive momentum estimation method (PADAM) [https://arxiv.org/pdf/1806.06763v1.pdf] +# Parameters +- Learning rate (`η`): Amount by which gradients are discounted before updating + the weights. +- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the + second (β2) momentum estimate. +- Partially adaptive parameter (`p`): Varies between 0 and 0.5. +- Machine epsilon (`ϵ`): Constant to prevent division by zero + (no need to change default) +""" +struct PAdam{T} <: AbstractRule + eta::T + beta::Tuple{T, T} + rho::T + epsilon::T +end +PAdam(η = 1f-2, β = (9f-1, 9.99f-1), ρ = 2.5f-1, ϵ = eps(typeof(η))) = PAdam{typeof(η)}(η, β, ρ, ϵ) + +init(o::PAdam, x::AbstractArray) = (onevalue(o.epsilon, x), onevalue(o.epsilon, x), onevalue(o.epsilon, x)) + +function apply!(o::PAdam, state, x, dx) + η, β, ρ, ϵ = o.eta, o.beta, o.rho, o.epsilon + mt, vt, v̂t = state + + @.. mt = β[1] * mt + (1 - β[1]) * dx + @.. vt = β[2] * vt + (1 - β[2]) * abs2(dx) + @.. v̂t = max(v̂t, vt) + dx′ = @lazy η * mt / (v̂t ^ ρ + ϵ) + + return (mt, vt, v̂t), dx′ +end """ WeightDecay(γ = 5f-4) From ed70f924a8d1727e8c4667f5ba851ac99670d40b Mon Sep 17 00:00:00 2001 From: Symon <59005260+4SAnalyticsnModelling@users.noreply.github.com> Date: Wed, 12 Jul 2023 10:53:53 -0600 Subject: [PATCH 2/6] Update Optimisers.jl - export new optimizer PAdam --- src/Optimisers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 20fc8aad..da5df27d 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -13,7 +13,7 @@ export destructure include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, - AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, + AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, PAdam, WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion, AccumGrad From c0cb80b5d3c3b11fe1dacd5a7b054f12b2fe0b42 Mon Sep 17 00:00:00 2001 From: Symon <59005260+4SAnalyticsnModelling@users.noreply.github.com> Date: Wed, 12 Jul 2023 10:58:31 -0600 Subject: [PATCH 3/6] Update rules.jl - add test for new PAdam optimizer --- test/rules.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/rules.jl b/test/rules.jl index a10e055f..ec178dfd 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -8,7 +8,7 @@ RULES = [ # All the rules at default settings: Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(), AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(), - AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), + AdamW(), RAdam(), OAdam(), AdaBelief(), PAdam(), Lion(), # A few chained combinations: OptimiserChain(WeightDecay(), Adam(0.001)), OptimiserChain(ClipNorm(), Adam(0.001)), @@ -181,7 +181,7 @@ end empty!(LOG) @testset "$(name(opt))" for opt in [ # The Flux PR had 1e-2 for all. But AdaDelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too: - Adam(1e-2), RMSProp(1e-2), RAdam(1e-2), OAdam(1e-2), AdaGrad(1e-2), AdaDelta(0.9, 1e-5), NAdam(1e-2), AdaBelief(1e-2), + Adam(1e-2), RMSProp(1e-2), RAdam(1e-2), OAdam(1e-2), AdaGrad(1e-2), AdaDelta(0.9, 1e-5), NAdam(1e-2), AdaBelief(1e-2), PAdam(1e-2) # These weren't in Flux PR: Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), AdamW(1e-2), ] @@ -266,4 +266,4 @@ end tree, x4 = Optimisers.update(tree, x3, g4) @test x4 ≈ x3 -end \ No newline at end of file +end From 6be2690672c35d5ae477794611b911b252c6018f Mon Sep 17 00:00:00 2001 From: Symon <59005260+4SAnalyticsnModelling@users.noreply.github.com> Date: Wed, 12 Jul 2023 11:00:00 -0600 Subject: [PATCH 4/6] Update api.md - add the new PAdam optimizer to documentation --- docs/src/api.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/api.md b/docs/src/api.md index 6c021f25..efc886e9 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -17,6 +17,7 @@ Optimisers.AMSGrad Optimisers.NAdam Optimisers.AdamW Optimisers.AdaBelief +Optimisers.PAdam ``` In addition to the main course, you may wish to order some of these condiments: From ede9cea8d7b6c0daa2c1fdea23f67ab131223f1b Mon Sep 17 00:00:00 2001 From: Symon <59005260+4SAnalyticsnModelling@users.noreply.github.com> Date: Wed, 12 Jul 2023 11:20:27 -0600 Subject: [PATCH 5/6] Update rules.jl --- test/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rules.jl b/test/rules.jl index ec178dfd..d0bc41fe 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -181,7 +181,7 @@ end empty!(LOG) @testset "$(name(opt))" for opt in [ # The Flux PR had 1e-2 for all. But AdaDelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too: - Adam(1e-2), RMSProp(1e-2), RAdam(1e-2), OAdam(1e-2), AdaGrad(1e-2), AdaDelta(0.9, 1e-5), NAdam(1e-2), AdaBelief(1e-2), PAdam(1e-2) + Adam(1e-2), RMSProp(1e-2), RAdam(1e-2), OAdam(1e-2), AdaGrad(1e-2), AdaDelta(0.9, 1e-5), NAdam(1e-2), AdaBelief(1e-2), PAdam(1e-2), # These weren't in Flux PR: Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), AdamW(1e-2), ] From 880a2605c7837b951cd7db678a041acb03525a74 Mon Sep 17 00:00:00 2001 From: Symon <59005260+4SAnalyticsnModelling@users.noreply.github.com> Date: Wed, 12 Jul 2023 13:22:15 -0600 Subject: [PATCH 6/6] Update rules.jl - take PAdam out of the complex number testset --- test/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rules.jl b/test/rules.jl index d0bc41fe..a4d3b8f6 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -181,7 +181,7 @@ end empty!(LOG) @testset "$(name(opt))" for opt in [ # The Flux PR had 1e-2 for all. But AdaDelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too: - Adam(1e-2), RMSProp(1e-2), RAdam(1e-2), OAdam(1e-2), AdaGrad(1e-2), AdaDelta(0.9, 1e-5), NAdam(1e-2), AdaBelief(1e-2), PAdam(1e-2), + Adam(1e-2), RMSProp(1e-2), RAdam(1e-2), OAdam(1e-2), AdaGrad(1e-2), AdaDelta(0.9, 1e-5), NAdam(1e-2), AdaBelief(1e-2), # These weren't in Flux PR: Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), AdamW(1e-2), ]