@@ -26,13 +26,15 @@ def __init__(
2626 lr = 1e-4 ,
2727 betas : tuple [float , float ] = (0.9 , 0.99 ),
2828 weight_decay = 0. ,
29+ regen_reg_rate = 0. ,
2930 decoupled_wd = True ,
3031 a = 1.27 ,
3132 b = 1.
3233 ):
3334 assert lr > 0.
3435 assert all ([0. <= beta <= 1. for beta in betas ])
3536 assert weight_decay >= 0.
37+ assert not (weight_decay > 0. and regen_reg_rate > 0. )
3638
3739 self ._init_lr = lr
3840 self .decoupled_wd = decoupled_wd
@@ -43,6 +45,7 @@ def __init__(
4345 a = a ,
4446 b = b ,
4547 weight_decay = weight_decay ,
48+ regen_reg_rate = regen_reg_rate
4649 )
4750
4851 super ().__init__ (params , defaults )
@@ -61,13 +64,19 @@ def step(
6164 for group in self .param_groups :
6265 for p in filter (lambda p : exists (p .grad ), group ['params' ]):
6366
64- grad , lr , wd , beta1 , beta2 , a , b , state , init_lr = p .grad , group ['lr' ], group ['weight_decay' ], * group ['betas' ], group ['a' ], group ['b' ], self .state [p ], self ._init_lr
67+ grad , lr , wd , regen_rate , beta1 , beta2 , a , b , state , init_lr = p .grad , group ['lr' ], group ['weight_decay' ], group [ 'regen_reg_rate ' ], * group ['betas' ], group ['a' ], group ['b' ], self .state [p ], self ._init_lr
6568
6669 # maybe decoupled weight decay
6770
6871 if self .decoupled_wd :
6972 wd /= init_lr
7073
74+ # regenerative regularization from Kumar et al. https://arxiv.org/abs/2308.11958
75+
76+ if regen_rate > 0. and 'param_init' in state :
77+ param_init = state ['param_init' ]
78+ p .lerp_ (param_init , lr / init_lr * regen_rate )
79+
7180 # weight decay
7281
7382 if wd > 0. :
@@ -80,6 +89,9 @@ def step(
8089 state ['m' ] = torch .zeros_like (grad )
8190 state ['v' ] = grad * grad
8291
92+ if regen_rate > 0. :
93+ state ['param_init' ] = p .clone ()
94+
8395 # get some of the states
8496
8597 m , v , steps = state ['m' ], state ['v' ], state ['steps' ]
0 commit comments