Skip to content

Commit 7137f5d

Browse files
committed
the adopt authors have updated paper with clipping for stability
1 parent 4ba86bc commit 7137f5d

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

adam_atan2_pytorch/adopt.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Adopt(Optimizer):
1616
"""
1717
the proposed Adam substitute from University of Tokyo
1818
19-
Algorithm 2 in https://arxiv.org/abs/2411.02853
19+
Algorithm 3 in https://arxiv.org/abs/2411.02853
2020
"""
2121

2222
def __init__(
@@ -74,7 +74,7 @@ def step(
7474

7575
if len(state) == 0:
7676
state['steps'] = 0
77-
state['m'] = torch.empty_like(grad)
77+
state['m'] = torch.zeros_like(grad)
7878
state['v'] = grad * grad
7979

8080
# get some of the states
@@ -91,9 +91,16 @@ def step(
9191

9292
grad_sq = grad * grad
9393

94-
next_m = grad.div(v.sqrt().clamp(min = eps)) # they claim that a max(value, eps) performs better than adding the epsilon
94+
update = grad.div(v.sqrt().clamp(min = eps)) # they claim that a max(value, eps) performs better than adding the epsilon
9595

96-
m.lerp_(next_m, 1. - (beta1 * int(steps > 1)))
96+
# clip with t ^ 0.25 as in Algorithm 3
97+
98+
clip_value = steps ** 0.25
99+
update.clamp_(min = -clip_value, max = clip_value)
100+
101+
# update m
102+
103+
m.lerp_(update, 1. - beta1)
97104

98105
# then update parameters
99106

adam_atan2_pytorch/adopt_atan2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class AdoptAtan2(Optimizer):
1717
the proposed Adam substitute from University of Tokyo
1818
combined with the proposed atan2 method for ridding of the eps from Google
1919
20-
Algorithm 2 in https://arxiv.org/abs/2411.02853
20+
Algorithm 3 in https://arxiv.org/abs/2411.02853
2121
"""
2222

2323
def __init__(
@@ -77,7 +77,7 @@ def step(
7777

7878
if len(state) == 0:
7979
state['steps'] = 0
80-
state['m'] = torch.empty_like(grad)
80+
state['m'] = torch.zeros_like(grad)
8181
state['v'] = grad * grad
8282

8383
# get some of the states
@@ -96,7 +96,7 @@ def step(
9696

9797
next_m = grad.atan2(b * v.sqrt())
9898

99-
m.lerp_(next_m, 1. - (beta1 * int(steps > 1)))
99+
m.lerp_(next_m, 1. - beta1)
100100

101101
# then update parameters
102102

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "adam-atan2-pytorch"
3-
version = "0.1.9"
3+
version = "0.1.10"
44
description = "Adam-atan2 for Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)