@@ -55,7 +55,7 @@ def __init__(
5555 muon_params ,
5656 params ,
5757 lr = 1e-4 ,
58- muon_lr = None ,
58+ muon_lr = 1e-3 ,
5959 betas : tuple [float , float ] = (0.9 , 0.99 ),
6060 weight_decay = 0. ,
6161 regen_reg_rate = 0. ,
@@ -64,6 +64,7 @@ def __init__(
6464 a = 1.27 ,
6565 b = 1. ,
6666 muon_steps = 5 ,
67+ muon_beta1 = 0.95 ,
6768 muon_newton_schulz5_coefs = (3.4445 , - 4.7750 , 2.0315 ),
6869 muon_eps = 1e-7 ,
6970 remove_muon_params_from_params = True
@@ -82,9 +83,12 @@ def __init__(
8283
8384 self .decoupled_wd = decoupled_wd
8485
86+ beta1 , beta2 = betas
87+
8588 defaults = dict (
8689 lr = lr ,
87- betas = betas ,
90+ beta1 = beta1 ,
91+ beta2 = beta2 ,
8892 a = a ,
8993 b = b ,
9094 weight_decay = weight_decay ,
@@ -101,7 +105,7 @@ def __init__(
101105
102106 param_groups = [
103107 dict (params = params , lr = lr ),
104- dict (params = muon_params , lr = muon_lr , use_muon = True )
108+ dict (params = muon_params , lr = muon_lr , beta1 = muon_beta1 , use_muon = True )
105109 ]
106110
107111 super ().__init__ (param_groups , defaults )
@@ -123,7 +127,7 @@ def step(
123127
124128 use_muon = group ['use_muon' ]
125129
126- grad , lr , wd , regen_rate , cautious_factor , beta1 , beta2 , a , b , state , init_lr , init_muon_lr = p .grad , group ['lr' ], group ['weight_decay' ], group ['regen_reg_rate' ], group ['cautious_factor' ], * group ['betas ' ], group ['a' ], group ['b' ], self .state [p ], self ._init_lr , self ._init_muon_lr
130+ grad , lr , wd , regen_rate , cautious_factor , beta1 , beta2 , a , b , state , init_lr , init_muon_lr = p .grad , group ['lr' ], group ['weight_decay' ], group ['regen_reg_rate' ], group ['cautious_factor' ], group ['beta1' ], group [ 'beta2 ' ], group ['a' ], group ['b' ], self .state [p ], self ._init_lr , self ._init_muon_lr
127131
128132 param_init_lr = init_lr if not use_muon else init_muon_lr
129133
0 commit comments