Skip to content

Commit 068150a

Browse files
committed
add regenerative regularization for better continual plasticity
1 parent 665ec96 commit 068150a

File tree

4 files changed

+48
-12
lines changed

4 files changed

+48
-12
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,12 @@ for _ in range(100):
5050
url = {https://api.semanticscholar.org/CorpusID:271051056}
5151
}
5252
```
53+
54+
```bibtex
55+
@inproceedings{Kumar2023MaintainingPI,
56+
title = {Maintaining Plasticity in Continual Learning via Regenerative Regularization},
57+
author = {Saurabh Kumar and Henrik Marklund and Benjamin Van Roy},
58+
year = {2023},
59+
url = {https://api.semanticscholar.org/CorpusID:261076021}
60+
}
61+
```

adam_atan2_pytorch/adam_atan2.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import Tuple, Callable
2+
from typing import Callable
33

44
import torch
55
from torch import atan2, sqrt
@@ -17,15 +17,18 @@ def __init__(
1717
self,
1818
params,
1919
lr = 1e-4,
20-
betas: Tuple[float, float] = (0.9, 0.99),
20+
betas: tuple[float, float] = (0.9, 0.99),
2121
weight_decay = 0.,
22+
regen_reg_rate = 0.,
2223
decoupled_wd = False,
2324
a = 1.27,
2425
b = 1.
2526
):
2627
assert lr > 0.
2728
assert all([0. <= beta <= 1. for beta in betas])
2829
assert weight_decay >= 0.
30+
assert regen_reg_rate >= 0.
31+
assert not (weight_decay > 0. and regen_reg_rate > 0.)
2932

3033
self._init_lr = lr
3134
self.decoupled_wd = decoupled_wd
@@ -35,7 +38,8 @@ def __init__(
3538
betas = betas,
3639
a = a,
3740
b = b,
38-
weight_decay = weight_decay
41+
weight_decay = weight_decay,
42+
regen_reg_rate = regen_reg_rate
3943
)
4044

4145
super().__init__(params, defaults)
@@ -54,7 +58,7 @@ def step(
5458
for group in self.param_groups:
5559
for p in filter(lambda p: exists(p.grad), group['params']):
5660

57-
grad, lr, wd, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
61+
grad, lr, wd, regen_rate, beta1, beta2, a, b, state, init_lr = p.grad, group['lr'], group['weight_decay'], group['regen_reg_rate'], *group['betas'], group['a'], group['b'], self.state[p], self._init_lr
5862

5963
# maybe decoupled weight decay
6064

@@ -64,7 +68,13 @@ def step(
6468
# weight decay
6569

6670
if wd > 0.:
67-
p.mul_(1. - lr / init_lr * wd)
71+
p.mul_(1. - lr * wd)
72+
73+
# regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958
74+
75+
if regen_rate > 0. and 'param_init' in state:
76+
param_init = state['param_init']
77+
p.lerp_(param_init, lr / init_lr * regen_rate)
6878

6979
# init state if needed
7080

@@ -73,6 +83,9 @@ def step(
7383
state['exp_avg'] = torch.zeros_like(grad)
7484
state['exp_avg_sq'] = torch.zeros_like(grad)
7585

86+
if regen_rate > 0.:
87+
state['param_init'] = p.clone()
88+
7689
# get some of the states
7790

7891
exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps']

adam_atan2_pytorch/foreach.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import Tuple, List, Callable
2+
from typing import Callable
33

44
import torch
55
from torch import atan2, sqrt, Tensor
@@ -18,7 +18,7 @@ def default(*args):
1818

1919
# slow foreach atan2
2020

21-
def slow_foreach_atan2_(nums: List[Tensor], dens: List[Tensor]):
21+
def slow_foreach_atan2_(nums: list[Tensor], dens: list[Tensor]):
2222
for num, den, in zip(nums, dens):
2323
num.atan2_(den)
2424

@@ -29,8 +29,9 @@ def __init__(
2929
self,
3030
params,
3131
lr = 1e-4,
32-
betas: Tuple[float, float] = (0.9, 0.99),
32+
betas: tuple[float, float] = (0.9, 0.99),
3333
weight_decay = 0.,
34+
regen_reg_rate = 0.,
3435
decoupled_wd = False,
3536
a = 1.27,
3637
b = 1.,
@@ -39,6 +40,8 @@ def __init__(
3940
assert lr > 0.
4041
assert all([0. <= beta <= 1. for beta in betas])
4142
assert weight_decay >= 0.
43+
assert regen_reg_rate >= 0.
44+
assert not (weight_decay > 0. and regen_reg_rate > 0.)
4245
assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'lerp', 'sqrt')]), 'this version of torch does not have the prerequisite foreach functions'
4346

4447
self._init_lr = lr
@@ -55,7 +58,8 @@ def __init__(
5558
betas = betas,
5659
a = a,
5760
b = b,
58-
weight_decay = weight_decay
61+
weight_decay = weight_decay,
62+
regen_reg_rate = regen_reg_rate
5963
)
6064

6165
super().__init__(params, defaults)
@@ -74,13 +78,16 @@ def step(
7478

7579
for group in self.param_groups:
7680

77-
wd, lr, beta1, beta2, a, b = group['weight_decay'], group['lr'], *group['betas'], group['a'], group['b']
81+
wd, regen_rate, lr, beta1, beta2, a, b = group['weight_decay'], group['regen_reg_rate'], group['lr'], *group['betas'], group['a'], group['b']
7882

7983
has_weight_decay = wd > 0
8084

85+
has_regenerative_reg = regen_rate > 0
86+
8187
# accumulate List[Tensor] for foreach inplace updates
8288

8389
params = []
90+
params_init = []
8491
grads = []
8592
grad_squared = []
8693
exp_avgs = []
@@ -101,10 +108,11 @@ def step(
101108
state['steps'] = 0
102109
state['exp_avg'] = torch.zeros_like(grad)
103110
state['exp_avg_sq'] = torch.zeros_like(grad)
111+
state['param_init'] = p.clone()
104112

105113
# get some of the states
106114

107-
exp_avg, exp_avg_sq, steps = state['exp_avg'], state['exp_avg_sq'], state['steps']
115+
exp_avg, exp_avg_sq, param_init, steps = state['exp_avg'], state['exp_avg_sq'], state['param_init'], state['steps']
108116

109117
steps += 1
110118

@@ -116,6 +124,7 @@ def step(
116124
# append to list
117125

118126
params.append(p)
127+
params_init.append(param_init)
119128
grads.append(grad)
120129
grad_squared.append(grad * grad)
121130
exp_avgs.append(exp_avg)
@@ -130,6 +139,11 @@ def step(
130139
if has_weight_decay:
131140
torch._foreach_mul_(params, 1. - lr * wd)
132141

142+
# regenerative regularization
143+
144+
if has_regenerative_reg:
145+
torch._foreach_lerp_(params, params_init, lr / init_lr * regen_rate)
146+
133147
# decay running averages
134148

135149
torch._foreach_lerp_(exp_avgs, grads, 1. - beta1)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "adam-atan2-pytorch"
3-
version = "0.0.10"
3+
version = "0.0.12"
44
description = "Adam-atan2 for Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)