Skip to content

Commit a4033e8

Browse files
committed
add an adam atan2 + muon
1 parent 8f14cf5 commit a4033e8

File tree

4 files changed

+226
-1
lines changed

4 files changed

+226
-1
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,12 @@ for _ in range(100):
9090
url = {https://api.semanticscholar.org/CorpusID:274234738}
9191
}
9292
```
93+
94+
```bibtex
95+
@misc{jordan2024muon,
96+
author = {Keller Jordan and Yuchen Jin and Vlado Boza and Jiacheng You and Franz Cesista and Laker Newhouse and Jeremy Bernstein},
97+
title = {Muon: An optimizer for hidden layers in neural networks},
98+
year = {2024},
99+
url = {https://kellerjordan.github.io/posts/muon/}
100+
}
101+
```

adam_atan2_pytorch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from adam_atan2_pytorch.adam_atan2 import AdamAtan2
22
from adam_atan2_pytorch.adopt_atan2 import AdoptAtan2
33

4+
from adam_atan2_pytorch.muon_adam_atan2 import MuonAdamAtan2
5+
46
Adam = AdamAtan2
57
Adopt = AdoptAtan2
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
from __future__ import annotations
2+
from typing import Callable
3+
4+
import torch
5+
from torch import atan2, sqrt
6+
from torch.optim.optimizer import Optimizer
7+
8+
# functions
9+
10+
def exists(val):
11+
return val is not None
12+
13+
def default(val, d):
14+
return val if exists(val) else d
15+
16+
# muon related
17+
18+
def newtonschulz5(
19+
t,
20+
steps = 5,
21+
eps = 1e-7,
22+
coefs = (3.4445, -4.7750, 2.0315)
23+
):
24+
if t.ndim <= 3:
25+
return t
26+
27+
shape = t.shape
28+
should_transpose = shape[-2] > shape[-1]
29+
30+
if should_transpose:
31+
t = t.transpose(-1, -2)
32+
33+
t, packed_shape = pack([t], '* i j')
34+
t = t / t.norm(dim = (-1, -2), keepdim = True).clamp(min = eps)
35+
36+
a, b, c = coefs
37+
38+
for _ in range(steps):
39+
A = t @ t.transpose(-1, -2)
40+
B = b * A + c * A @ A
41+
t = a * t + B @ t
42+
43+
t, = unpack(t, packed_shape, '* i j')
44+
45+
if should_transpose:
46+
t = t.transpose(-1, -2)
47+
48+
return t
49+
50+
# class
51+
52+
class MuonAdamAtan2(Optimizer):
53+
def __init__(
54+
self,
55+
muon_params,
56+
params,
57+
lr = 1e-4,
58+
muon_lr = None,
59+
betas: tuple[float, float] = (0.9, 0.99),
60+
weight_decay = 0.,
61+
regen_reg_rate = 0.,
62+
decoupled_wd = False,
63+
cautious_factor = 1., # set to 0. for zeroing out any updates not in same direction as gradient as in https://arxiv.org/abs/2411.16085
64+
a = 1.27,
65+
b = 1.,
66+
muon_steps = 5,
67+
muon_newton_schulz5_coefs = (3.4445, -4.7750, 2.0315),
68+
muon_eps = 1e-7,
69+
remove_muon_params_from_params = True
70+
):
71+
assert lr > 0.
72+
assert all([0. <= beta <= 1. for beta in betas])
73+
assert weight_decay >= 0.
74+
assert regen_reg_rate >= 0.
75+
assert not (weight_decay > 0. and regen_reg_rate > 0.)
76+
assert 0. <= cautious_factor <= 1.
77+
78+
self._init_lr = lr
79+
80+
muon_lr = default(muon_lr, lr)
81+
self._init_muon_lr = muon_lr
82+
83+
self.decoupled_wd = decoupled_wd
84+
85+
defaults = dict(
86+
lr = lr,
87+
betas = betas,
88+
a = a,
89+
b = b,
90+
weight_decay = weight_decay,
91+
regen_reg_rate = regen_reg_rate,
92+
cautious_factor = cautious_factor,
93+
use_muon = False,
94+
muon_steps = muon_steps,
95+
muon_newton_schulz5_coefs = muon_newton_schulz5_coefs,
96+
muon_eps = muon_eps,
97+
)
98+
99+
if remove_muon_params_from_params:
100+
params = list(set(params) - set(muon_params))
101+
102+
param_groups = [
103+
dict(params = params, lr = lr),
104+
dict(params = muon_params, lr = muon_lr, use_muon = True)
105+
]
106+
107+
super().__init__(param_groups, defaults)
108+
109+
@torch.no_grad()
110+
def step(
111+
self,
112+
closure: Callable | None = None
113+
):
114+
115+
loss = None
116+
if exists(closure):
117+
with torch.enable_grad():
118+
loss = closure()
119+
120+
for group in self.param_groups:
121+
122+
for p in filter(lambda p: exists(p.grad), group['params']):
123+
124+
use_muon = group['use_muon']
125+
126+
grad, lr, wd, regen_rate, cautious_factor, beta1, beta2, a, b, state, init_lr, init_muon_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['cautious_factor'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr, self._init_muon_lr
127+
128+
param_init_lr = init_lr if not use_muon else init_muon_lr
129+
130+
# maybe decoupled weight decay
131+
132+
if self.decoupled_wd:
133+
wd /= param_init_lr
134+
135+
# weight decay
136+
137+
if wd > 0.:
138+
p.mul_(1. - lr * wd)
139+
140+
# regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958
141+
142+
if regen_rate > 0. and 'param_init' in state:
143+
param_init = state['param_init']
144+
p.lerp_(param_init, lr / init_lr * regen_rate)
145+
146+
# init state if needed
147+
148+
if len(state) == 0:
149+
state['steps'] = 0
150+
state['exp_avg'] = torch.zeros_like(grad)
151+
152+
if not use_muon:
153+
state['exp_avg_sq'] = torch.zeros_like(grad)
154+
155+
if regen_rate > 0.:
156+
state['param_init'] = p.clone()
157+
158+
# get some of the states
159+
160+
exp_avg, steps = state['exp_avg'], state['steps']
161+
162+
steps += 1
163+
164+
# bias corrections
165+
166+
bias_correct1 = 1. - beta1 ** steps
167+
168+
if not use_muon:
169+
exp_avg_sq = state['exp_avg_sq']
170+
bias_correct2 = 1. - beta2 ** steps
171+
172+
# decay running averages
173+
174+
exp_avg.lerp_(grad, 1. - beta1)
175+
176+
if not use_muon:
177+
exp_avg_sq.lerp_(grad * grad, 1. - beta2)
178+
179+
# the following line is the proposed change to the update rule
180+
# using atan2 instead of a division with epsilon in denominator
181+
# a * atan2(exp_avg / bias_correct1, b * sqrt(exp_avg_sq / bias_correct2))
182+
183+
den = exp_avg_sq.mul(b * b / bias_correct2).sqrt_()
184+
update = exp_avg.mul(1. / bias_correct1).atan2_(den)
185+
186+
# maybe cautious update - algorithm 2 in https://arxiv.org/abs/2411.16085
187+
else:
188+
189+
muon_steps, muon_coefs, muon_eps = group['muon_steps'], group['muon_newton_schulz5_coefs'], group['muon_eps']
190+
191+
# Muon from Keller Jordan
192+
# https://kellerjordan.github.io/posts/muon/
193+
194+
update = newtonschulz5(
195+
exp_avg,
196+
steps = muon_steps,
197+
coefs = muon_coefs,
198+
eps = muon_eps
199+
)
200+
201+
if cautious_factor < 1.:
202+
align_mask = (update * grad) > 0
203+
scale = torch.where(align_mask, torch.ones_like(grad), cautious_factor)
204+
update *= (scale / scale.mean().clamp(min = 1e-5))
205+
206+
# update parameters
207+
208+
p.add_(update, alpha = -lr * a)
209+
210+
# increment steps
211+
212+
state['steps'] = steps
213+
214+
return loss

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "adam-atan2-pytorch"
3-
version = "0.1.18"
3+
version = "0.2.0"
44
description = "Adam-atan2 for Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)