Skip to content

Commit f6ab117

Browse files
committed
add an atan2 flavor for adopt
1 parent 16bdbcd commit f6ab117

File tree

2 files changed

+118
-1
lines changed

2 files changed

+118
-1
lines changed

adam_atan2_pytorch/adopt_atan2.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from __future__ import annotations
2+
from typing import Callable
3+
4+
import torch
5+
from torch import atan2, sqrt
6+
from torch.optim.optimizer import Optimizer
7+
8+
# functions
9+
10+
def exists(val):
11+
return val is not None
12+
13+
# class
14+
15+
class AdoptAtan2(Optimizer):
16+
"""
17+
the proposed Adam substitute from University of Tokyo
18+
19+
Algorithm 2 in https://arxiv.org/abs/2411.02853
20+
"""
21+
22+
def __init__(
23+
self,
24+
params,
25+
lr = 1e-4,
26+
betas: tuple[float, float] = (0.9, 0.9999),
27+
weight_decay = 0.,
28+
decoupled_wd = True,
29+
a = 1.27,
30+
b = 1.
31+
):
32+
assert lr > 0.
33+
assert all([0. <= beta <= 1. for beta in betas])
34+
assert weight_decay >= 0.
35+
36+
self._init_lr = lr
37+
self.decoupled_wd = decoupled_wd
38+
39+
defaults = dict(
40+
lr = lr,
41+
betas = betas,
42+
a = a,
43+
b = b,
44+
weight_decay = weight_decay,
45+
)
46+
47+
super().__init__(params, defaults)
48+
49+
@torch.no_grad()
50+
def step(
51+
self,
52+
closure: Callable | None = None
53+
):
54+
55+
loss = None
56+
if exists(closure):
57+
with torch.enable_grad():
58+
loss = closure()
59+
60+
for group in self.param_groups:
61+
for p in filter(lambda p: exists(p.grad), group['params']):
62+
63+
grad, lr, wd, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
64+
65+
# maybe decoupled weight decay
66+
67+
if self.decoupled_wd:
68+
wd /= init_lr
69+
70+
# weight decay
71+
72+
if wd > 0.:
73+
p.mul_(1. - lr * wd)
74+
75+
# init state if needed
76+
77+
if len(state) == 0:
78+
state['steps'] = 0
79+
state['m'] = torch.empty_like(grad)
80+
state['v'] = grad * grad
81+
82+
# get some of the states
83+
84+
m, v, steps = state['m'], state['v'], state['steps']
85+
86+
# for the first step do nothing
87+
88+
if steps == 0:
89+
state['steps'] += 1
90+
continue
91+
92+
# logic
93+
94+
steps += 1
95+
96+
# calculate m
97+
98+
grad_sq = grad * grad
99+
100+
next_m = grad.atan2(b * v.sqrt())
101+
102+
if steps > 1:
103+
m.lerp_(next_m, 1. - beta1)
104+
105+
# then update parameters
106+
107+
p.add_(m, alpha = -lr * a)
108+
109+
# update exp grad sq (v)
110+
111+
v.lerp_(grad_sq, 1. - beta2)
112+
113+
# increment steps
114+
115+
state['steps'] = steps
116+
117+
return loss

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "adam-atan2-pytorch"
3-
version = "0.1.4"
3+
version = "0.1.5"
44
description = "Adam-atan2 for Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)