1
+ from functools import partial
2
+
1
3
import torch
2
4
from torch import nn
3
5
from torch .nn import Module , ModuleList
4
6
import torch .nn .functional as F
5
- import torch .nn .utils .parametrize as parametrize
7
+ from torch .nn .utils .parametrize import register_parametrization
6
8
7
9
from einops import rearrange
8
10
from einops .layers .torch import Rearrange
@@ -35,16 +37,33 @@ def __init__(
35
37
self ,
36
38
dim ,
37
39
dim_out ,
38
- norm_dim_in = True
40
+ norm_dim_in = True ,
41
+ parametrize = True
39
42
):
40
43
super ().__init__ ()
41
44
self .linear = nn .Linear (dim , dim_out , bias = False )
42
45
43
- parametrize .register_parametrization (
44
- self .linear ,
45
- 'weight' ,
46
- L2Norm (dim = - 1 if norm_dim_in else 0 )
47
- )
46
+ self .parametrize = parametrize
47
+ self .l2norm = L2Norm (dim = - 1 if norm_dim_in else 0 )
48
+
49
+ if parametrize :
50
+ register_parametrization (
51
+ self .linear ,
52
+ 'weight' ,
53
+ self .l2norm
54
+ )
55
+
56
+ self .norm_weights_ ()
57
+
58
+ @torch .no_grad ()
59
+ def norm_weights_ (self ):
60
+ if self .parametrize :
61
+ normed = self .weight
62
+ original = self .linear .parametrizations .weight .original
63
+
64
+ original .copy_ (normed )
65
+ else :
66
+ self .weight .copy_ (self .l2norm (self .weight ))
48
67
49
68
@property
50
69
def weight (self ):
@@ -62,13 +81,16 @@ def __init__(
62
81
* ,
63
82
dim_head = 64 ,
64
83
heads = 8 ,
65
- norm_qk = True
84
+ norm_qk = True ,
85
+ manual_norm_weights = False
66
86
):
67
87
super ().__init__ ()
88
+ NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights )
89
+
68
90
dim_inner = dim_head * heads
69
- self .to_q = NormLinear (dim , dim_inner )
70
- self .to_k = NormLinear (dim , dim_inner )
71
- self .to_v = NormLinear (dim , dim_inner )
91
+ self .to_q = NormLinear_ (dim , dim_inner )
92
+ self .to_k = NormLinear_ (dim , dim_inner )
93
+ self .to_v = NormLinear_ (dim , dim_inner )
72
94
73
95
self .rotary_emb = RotaryEmbedding (dim_head )
74
96
self .qk_scale = nn .Parameter (torch .ones (dim_head ) * (dim_head ** 0.25 ))
@@ -77,7 +99,7 @@ def __init__(
77
99
self .split_heads = Rearrange ('b n (h d) -> b h n d' , h = heads )
78
100
self .merge_heads = Rearrange ('b h n d -> b n (h d)' )
79
101
80
- self .to_out = NormLinear (dim_inner , dim , norm_dim_in = False )
102
+ self .to_out = NormLinear_ (dim_inner , dim , norm_dim_in = False )
81
103
82
104
def forward (
83
105
self ,
@@ -117,19 +139,22 @@ def __init__(
117
139
self ,
118
140
dim ,
119
141
* ,
120
- expand_factor = 4
142
+ expand_factor = 4 ,
143
+ manual_norm_weights = False
121
144
):
122
145
super ().__init__ ()
146
+ NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights )
147
+
123
148
self .dim = dim
124
149
dim_inner = int (dim * expand_factor * 2 / 3 )
125
150
126
- self .to_hidden = NormLinear (dim , dim_inner )
127
- self .to_gate = NormLinear (dim , dim_inner )
151
+ self .to_hidden = NormLinear_ (dim , dim_inner )
152
+ self .to_gate = NormLinear_ (dim , dim_inner )
128
153
129
154
self .hidden_scale = nn .Parameter (torch .ones (dim_inner ))
130
155
self .gate_scale = nn .Parameter (torch .ones (dim_inner ))
131
156
132
- self .to_out = NormLinear (dim_inner , dim , norm_dim_in = False )
157
+ self .to_out = NormLinear_ (dim_inner , dim , norm_dim_in = False )
133
158
134
159
def forward (self , x ):
135
160
hidden , gate = self .to_hidden (x ), self .to_gate (x )
@@ -154,30 +179,33 @@ def __init__(
154
179
attn_norm_qk = True , # they say the query/key normalization is optional
155
180
ff_expand_factor = 4. ,
156
181
ce_ignore_index = - 1 ,
157
- residual_lerp_scale_init = None
182
+ residual_lerp_scale_init = None ,
183
+ manual_norm_weights = False
158
184
):
159
185
super ().__init__ ()
186
+ NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights )
187
+
160
188
self .dim = dim
161
189
162
190
residual_lerp_scale_init = default (residual_lerp_scale_init , 1. / depth )
163
191
164
- self .token_embed = NormLinear (dim , num_tokens )
192
+ self .token_embed = NormLinear_ (dim , num_tokens )
165
193
166
194
self .layers = ModuleList ([])
167
195
self .residual_lerp_scales = nn .ParameterList ([])
168
196
169
197
for _ in range (depth ):
170
198
self .layers .append (ModuleList ([
171
- Attention (dim , dim_head = dim_head , heads = heads , norm_qk = attn_norm_qk ),
172
- FeedForward (dim , expand_factor = ff_expand_factor ),
199
+ Attention (dim , dim_head = dim_head , heads = heads , norm_qk = attn_norm_qk , manual_norm_weights = manual_norm_weights ),
200
+ FeedForward (dim , expand_factor = ff_expand_factor , manual_norm_weights = manual_norm_weights ),
173
201
]))
174
202
175
203
self .residual_lerp_scales .append (nn .ParameterList ([
176
204
nn .Parameter (torch .ones (dim ) * residual_lerp_scale_init ),
177
205
nn .Parameter (torch .ones (dim ) * residual_lerp_scale_init ),
178
206
]))
179
207
180
- self .to_logits = NormLinear (dim , num_tokens )
208
+ self .to_logits = NormLinear_ (dim , num_tokens )
181
209
182
210
self .logit_scale = nn .Parameter (torch .ones (num_tokens ))
183
211
@@ -189,10 +217,7 @@ def norm_weights_(self):
189
217
if not isinstance (module , NormLinear ):
190
218
continue
191
219
192
- normed = module .weight
193
- original = module .linear .parametrizations .weight .original
194
-
195
- original .copy_ (normed )
220
+ module .norm_weights_ ()
196
221
197
222
def forward (
198
223
self ,
0 commit comments