Skip to content

Commit 665ec96

Browse files
committed
default decoupled weight decay to false
1 parent 32062de commit 665ec96

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

adam_atan2_pytorch/adam_atan2.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
lr = 1e-4,
2020
betas: Tuple[float, float] = (0.9, 0.99),
2121
weight_decay = 0.,
22+
decoupled_wd = False,
2223
a = 1.27,
2324
b = 1.
2425
):
@@ -27,6 +28,7 @@ def __init__(
2728
assert weight_decay >= 0.
2829

2930
self._init_lr = lr
31+
self.decoupled_wd = decoupled_wd
3032

3133
defaults = dict(
3234
lr = lr,
@@ -54,7 +56,12 @@ def step(
5456

5557
grad, lr, wd, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
5658

57-
# decoupled weight decay
59+
# maybe decoupled weight decay
60+
61+
if self.decoupled_wd:
62+
wd /= init_lr
63+
64+
# weight decay
5865

5966
if wd > 0.:
6067
p.mul_(1. - lr / init_lr * wd)

adam_atan2_pytorch/foreach.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
lr = 1e-4,
3232
betas: Tuple[float, float] = (0.9, 0.99),
3333
weight_decay = 0.,
34+
decoupled_wd = False,
3435
a = 1.27,
3536
b = 1.,
3637
foreach_atan2_fn: Callable | None = None
@@ -41,6 +42,7 @@ def __init__(
4142
assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'lerp', 'sqrt')]), 'this version of torch does not have the prerequisite foreach functions'
4243

4344
self._init_lr = lr
45+
self.decoupled_wd = decoupled_wd
4446

4547
self._foreach_atan2_ = default(
4648
foreach_atan2_fn,
@@ -74,6 +76,8 @@ def step(
7476

7577
wd, lr, beta1, beta2, a, b = group['weight_decay'], group['lr'], *group['betas'], group['a'], group['b']
7678

79+
has_weight_decay = wd > 0
80+
7781
# accumulate List[Tensor] for foreach inplace updates
7882

7983
params = []
@@ -86,9 +90,9 @@ def step(
8690

8791
grad, state = p.grad, self.state[p]
8892

89-
# decoupled weight decay
93+
# maybe decoupled weight decay
9094

91-
if wd > 0.:
95+
if self.decoupled_wd and has_weight_decay:
9296
wd /= init_lr
9397

9498
# init state if needed
@@ -123,7 +127,7 @@ def step(
123127

124128
# weight decay
125129

126-
if wd > 0.:
130+
if has_weight_decay:
127131
torch._foreach_mul_(params, 1. - lr * wd)
128132

129133
# decay running averages

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "adam-atan2-pytorch"
3-
version = "0.0.9"
3+
version = "0.0.10"
44
description = "Adam-atan2 for Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)