File tree Expand file tree Collapse file tree 3 files changed +15
-8
lines changed Expand file tree Collapse file tree 3 files changed +15
-8
lines changed Original file line number Diff line number Diff line change @@ -16,7 +16,7 @@ class Adopt(Optimizer):
16
16
"""
17
17
the proposed Adam substitute from University of Tokyo
18
18
19
- Algorithm 2 in https://arxiv.org/abs/2411.02853
19
+ Algorithm 3 in https://arxiv.org/abs/2411.02853
20
20
"""
21
21
22
22
def __init__ (
@@ -74,7 +74,7 @@ def step(
74
74
75
75
if len (state ) == 0 :
76
76
state ['steps' ] = 0
77
- state ['m' ] = torch .empty_like (grad )
77
+ state ['m' ] = torch .zeros_like (grad )
78
78
state ['v' ] = grad * grad
79
79
80
80
# get some of the states
@@ -91,9 +91,16 @@ def step(
91
91
92
92
grad_sq = grad * grad
93
93
94
- next_m = grad .div (v .sqrt ().clamp (min = eps )) # they claim that a max(value, eps) performs better than adding the epsilon
94
+ update = grad .div (v .sqrt ().clamp (min = eps )) # they claim that a max(value, eps) performs better than adding the epsilon
95
95
96
- m .lerp_ (next_m , 1. - (beta1 * int (steps > 1 )))
96
+ # clip with t ^ 0.25 as in Algorithm 3
97
+
98
+ clip_value = steps ** 0.25
99
+ update .clamp_ (min = - clip_value , max = clip_value )
100
+
101
+ # update m
102
+
103
+ m .lerp_ (update , 1. - beta1 )
97
104
98
105
# then update parameters
99
106
Original file line number Diff line number Diff line change @@ -17,7 +17,7 @@ class AdoptAtan2(Optimizer):
17
17
the proposed Adam substitute from University of Tokyo
18
18
combined with the proposed atan2 method for ridding of the eps from Google
19
19
20
- Algorithm 2 in https://arxiv.org/abs/2411.02853
20
+ Algorithm 3 in https://arxiv.org/abs/2411.02853
21
21
"""
22
22
23
23
def __init__ (
@@ -77,7 +77,7 @@ def step(
77
77
78
78
if len (state ) == 0 :
79
79
state ['steps' ] = 0
80
- state ['m' ] = torch .empty_like (grad )
80
+ state ['m' ] = torch .zeros_like (grad )
81
81
state ['v' ] = grad * grad
82
82
83
83
# get some of the states
@@ -96,7 +96,7 @@ def step(
96
96
97
97
next_m = grad .atan2 (b * v .sqrt ())
98
98
99
- m .lerp_ (next_m , 1. - ( beta1 * int ( steps > 1 )) )
99
+ m .lerp_ (next_m , 1. - beta1 )
100
100
101
101
# then update parameters
102
102
Original file line number Diff line number Diff line change 1
1
[project ]
2
2
name = " adam-atan2-pytorch"
3
- version = " 0.1.9 "
3
+ version = " 0.1.10 "
4
4
description = " Adam-atan2 for Pytorch"
5
5
authors = [
6
6
{ name = " Phil Wang" , email = " lucidrains@gmail.com" }
You can’t perform that action at this time.
0 commit comments