Skip to content

Commit 07b4032

Browse files
committed
default muon lr to 1e-3 and beta1 to 0.95
1 parent a4033e8 commit 07b4032

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

adam_atan2_pytorch/muon_adam_atan2.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
muon_params,
5656
params,
5757
lr = 1e-4,
58-
muon_lr = None,
58+
muon_lr = 1e-3,
5959
betas: tuple[float, float] = (0.9, 0.99),
6060
weight_decay = 0.,
6161
regen_reg_rate = 0.,
@@ -64,6 +64,7 @@ def __init__(
6464
a = 1.27,
6565
b = 1.,
6666
muon_steps = 5,
67+
muon_beta1 = 0.95,
6768
muon_newton_schulz5_coefs = (3.4445, -4.7750, 2.0315),
6869
muon_eps = 1e-7,
6970
remove_muon_params_from_params = True
@@ -82,9 +83,12 @@ def __init__(
8283

8384
self.decoupled_wd = decoupled_wd
8485

86+
beta1, beta2 = betas
87+
8588
defaults = dict(
8689
lr = lr,
87-
betas = betas,
90+
beta1 = beta1,
91+
beta2 = beta2,
8892
a = a,
8993
b = b,
9094
weight_decay = weight_decay,
@@ -101,7 +105,7 @@ def __init__(
101105

102106
param_groups = [
103107
dict(params = params, lr = lr),
104-
dict(params = muon_params, lr = muon_lr, use_muon = True)
108+
dict(params = muon_params, lr = muon_lr, beta1 = muon_beta1, use_muon = True)
105109
]
106110

107111
super().__init__(param_groups, defaults)
@@ -123,7 +127,7 @@ def step(
123127

124128
use_muon = group['use_muon']
125129

126-
grad, lr, wd, regen_rate, cautious_factor, beta1, beta2, a, b, state, init_lr, init_muon_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['cautious_factor'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr, self._init_muon_lr
130+
grad, lr, wd, regen_rate, cautious_factor, beta1, beta2, a, b, state, init_lr, init_muon_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['cautious_factor'], group['beta1'], group['beta2'], group['a'], group['b'], self.state[p], self._init_lr, self._init_muon_lr
127131

128132
param_init_lr = init_lr if not use_muon else init_muon_lr
129133

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "adam-atan2-pytorch"
3-
version = "0.2.0"
3+
version = "0.2.1"
44
description = "Adam-atan2 for Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)