Skip to content

Commit 5e05d83

Browse files
committed
take care of a paper in the continual learning literature that improves on regenerative reg by using wasserstein distance
1 parent 068150a commit 5e05d83

File tree

4 files changed

+155
-3
lines changed

4 files changed

+155
-3
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,14 @@ for _ in range(100):
5959
url = {https://api.semanticscholar.org/CorpusID:261076021}
6060
}
6161
```
62+
63+
```bibtex
64+
@article{Lewandowski2024LearningCB,
65+
title = {Learning Continually by Spectral Regularization},
66+
author = {Alex Lewandowski and Saurabh Kumar and Dale Schuurmans and Andr'as Gyorgy and Marlos C. Machado},
67+
journal = {ArXiv},
68+
year = {2024},
69+
volume = {abs/2406.06811},
70+
url = {https://api.semanticscholar.org/CorpusID:270380086}
71+
}
72+
```

adam_atan2_pytorch/adam_atan2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(
2020
betas: tuple[float, float] = (0.9, 0.99),
2121
weight_decay = 0.,
2222
regen_reg_rate = 0.,
23+
wasserstein_reg = False,
2324
decoupled_wd = False,
2425
a = 1.27,
2526
b = 1.
@@ -39,7 +40,8 @@ def __init__(
3940
a = a,
4041
b = b,
4142
weight_decay = weight_decay,
42-
regen_reg_rate = regen_reg_rate
43+
regen_reg_rate = regen_reg_rate,
44+
wasserstein_reg = wasserstein_reg,
4345
)
4446

4547
super().__init__(params, defaults)
@@ -58,7 +60,7 @@ def step(
5860
for group in self.param_groups:
5961
for p in filter(lambda p: exists(p.grad), group['params']):
6062

61-
grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
63+
grad, lr, wd, regen_rate, wasserstein_reg, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], group['wasserstein_reg'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
6264

6365
# maybe decoupled weight decay
6466

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from __future__ import annotations
2+
from typing import Callable
3+
4+
import torch
5+
from torch import atan2, sqrt
6+
from torch.optim.optimizer import Optimizer
7+
8+
# functions
9+
10+
def exists(val):
11+
return val is not None
12+
13+
# class
14+
15+
class AdamAtan2(Optimizer):
16+
def __init__(
17+
self,
18+
params,
19+
lr = 1e-4,
20+
betas: tuple[float, float] = (0.9, 0.99),
21+
weight_decay = 0.,
22+
regen_reg_rate = 0.,
23+
decoupled_wd = False,
24+
a = 1.27,
25+
b = 1.
26+
):
27+
assert lr > 0.
28+
assert all([0. <= beta <= 1. for beta in betas])
29+
assert weight_decay >= 0.
30+
assert regen_reg_rate >= 0.
31+
assert not (weight_decay > 0. and regen_reg_rate > 0.)
32+
33+
self._init_lr = lr
34+
self.decoupled_wd = decoupled_wd
35+
36+
defaults = dict(
37+
lr = lr,
38+
betas = betas,
39+
a = a,
40+
b = b,
41+
weight_decay = weight_decay,
42+
regen_reg_rate = regen_reg_rate,
43+
)
44+
45+
super().__init__(params, defaults)
46+
47+
@torch.no_grad()
48+
def step(
49+
self,
50+
closure: Callable | None = None
51+
):
52+
53+
loss = None
54+
if exists(closure):
55+
with torch.enable_grad():
56+
loss = closure()
57+
58+
for group in self.param_groups:
59+
for p in filter(lambda p: exists(p.grad), group['params']):
60+
61+
grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
62+
63+
# maybe decoupled weight decay
64+
65+
if self.decoupled_wd:
66+
wd /= init_lr
67+
68+
# weight decay
69+
70+
if wd > 0.:
71+
p.mul_(1. - lr * wd)
72+
73+
# regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958
74+
75+
if regen_rate > 0. and 'param_init' in state:
76+
param_init = state['param_init']
77+
78+
shape = param_init.shape
79+
80+
# wasserstein compares using ordered statistics, iiuc
81+
82+
indices = p.flatten().sort(dim = -1).indices
83+
indices = indices.argsort(dim = -1)
84+
85+
target = param_init.flatten()[indices]
86+
target = target.reshape(shape)
87+
88+
p.lerp_(target, lr / init_lr * regen_rate)
89+
90+
# init state if needed
91+
92+
if len(state) == 0:
93+
state['steps'] = 0
94+
state['exp_avg'] = torch.zeros_like(grad)
95+
state['exp_avg_sq'] = torch.zeros_like(grad)
96+
97+
if regen_rate > 0.:
98+
99+
# wasserstein reg - https://arxiv.org/abs/2406.06811v1
100+
# initial parameters sorted for efficiency
101+
102+
shape = p.shape
103+
p = p.flatten().sort(dim = -1).values
104+
p = p.reshape(shape)
105+
106+
state['param_init'] = p.clone()
107+
108+
# get some of the states
109+
110+
exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps']
111+
112+
steps += 1
113+
114+
# bias corrections
115+
116+
bias_correct1 = 1. - beta1 ** steps
117+
bias_correct2 = 1. - beta2 ** steps
118+
119+
# decay running averages
120+
121+
exp_avg.lerp_(grad, 1. - beta1)
122+
exp_avg_sq.lerp_(grad * grad, 1. - beta2)
123+
124+
# the following line is the proposed change to the update rule
125+
# using atan2 instead of a division with epsilon in denominator
126+
# a * atan2(exp_avg / bias_correct1, b * sqrt(exp_avg_sq / bias_correct2))
127+
128+
den = exp_avg_sq.mul(b * b / bias_correct2).sqrt_()
129+
update = exp_avg.mul(1. / bias_correct1).atan2_(den)
130+
131+
# update parameters
132+
133+
p.add_(update, alpha = -lr * a)
134+
135+
# increment steps
136+
137+
state['steps'] = steps
138+
139+
return loss

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "adam-atan2-pytorch"
3-
version = "0.0.12"
3+
version = "0.0.15"
44
description = "Adam-atan2 for Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)