Skip to content

Commit 771341a

Browse files
committed
first add vanilla adopt for running experiments against adam
1 parent 8600496 commit 771341a

File tree

4 files changed

+126
-1
lines changed

4 files changed

+126
-1
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,12 @@ for _ in range(100):
7272
url = {https://api.semanticscholar.org/CorpusID:270380086}
7373
}
7474
```
75+
76+
```bibtex
77+
@inproceedings{Taniguchi2024ADOPTMA,
78+
title = {ADOPT: Modified Adam Can Converge with Any \$\beta\_2\$ with the Optimal Rate},
79+
author = {Shohei Taniguchi and Keno Harada and Gouki Minegishi and Yuta Oshima and Seong Cheol Jeong and Go Nagahara and Tomoshi Iiyama and Masahiro Suzuki and Yusuke Iwasawa and Yutaka Matsuo},
80+
year = {2024},
81+
url = {https://api.semanticscholar.org/CorpusID:273822148}
82+
}
83+
```

adam_atan2_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from adam_atan2_pytorch.adam_atan2 import AdamAtan2
2+
from adam_atan2_pytorch.adopt import Adopt
23

34
Adam = AdamAtan2

adam_atan2_pytorch/adopt.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from __future__ import annotations
2+
from typing import Callable
3+
4+
import torch
5+
from torch import atan2, sqrt
6+
from torch.optim.optimizer import Optimizer
7+
8+
# functions
9+
10+
def exists(val):
11+
return val is not None
12+
13+
# class
14+
15+
class Adopt(Optimizer):
16+
"""
17+
the proposed Adam substitute from University of Tokyo
18+
19+
Algorithm 2 in https://arxiv.org/abs/2411.02853
20+
"""
21+
22+
def __init__(
23+
self,
24+
params,
25+
lr = 1e-4,
26+
betas: tuple[float, float] = (0.9, 0.9999),
27+
eps = 1e-6,
28+
weight_decay = 0.,
29+
decoupled_wd = True
30+
):
31+
assert lr > 0.
32+
assert all([0. <= beta <= 1. for beta in betas])
33+
assert weight_decay >= 0.
34+
35+
self._init_lr = lr
36+
self.decoupled_wd = decoupled_wd
37+
38+
defaults = dict(
39+
lr = lr,
40+
betas = betas,
41+
eps = eps,
42+
weight_decay = weight_decay,
43+
)
44+
45+
super().__init__(params, defaults)
46+
47+
@torch.no_grad()
48+
def step(
49+
self,
50+
closure: Callable | None = None
51+
):
52+
53+
loss = None
54+
if exists(closure):
55+
with torch.enable_grad():
56+
loss = closure()
57+
58+
for group in self.param_groups:
59+
for p in filter(lambda p: exists(p.grad), group['params']):
60+
61+
grad, lr, wd, beta1, beta2, eps, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['eps'], self.state[p], self._init_lr
62+
63+
# maybe decoupled weight decay
64+
65+
if self.decoupled_wd:
66+
wd /= init_lr
67+
68+
# weight decay
69+
70+
if wd > 0.:
71+
p.mul_(1. - lr * wd)
72+
73+
# init state if needed
74+
75+
if len(state) == 0:
76+
state['steps'] = 0
77+
state['m'] = torch.empty_like(grad)
78+
state['v'] = grad * grad
79+
80+
# get some of the states
81+
82+
m, v, steps = state['m'], state['v'], state['steps']
83+
84+
# for the first step do nothing
85+
86+
if steps == 0:
87+
state['steps'] += 1
88+
continue
89+
90+
# logic
91+
92+
steps += 1
93+
94+
# calculate m
95+
96+
grad_sq = grad * grad
97+
98+
next_m = grad.div(v.sqrt().clamp(min = eps)) # they claim that a max(value, eps) performs better than adding the epsilon
99+
100+
if steps > 1:
101+
m.lerp_(next_m, 1. - beta2)
102+
103+
# then update parameters
104+
105+
p.add_(m, alpha = -lr)
106+
107+
# update exp grad sq (v)
108+
109+
v.lerp_(grad_sq, 1. - beta1)
110+
111+
# increment steps
112+
113+
state['steps'] = steps
114+
115+
return loss

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "adam-atan2-pytorch"
3-
version = "0.1.1"
3+
version = "0.1.2"
44
description = "Adam-atan2 for Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)