6
6
import torch .nn .functional as F
7
7
from torch .nn .utils .parametrize import register_parametrization
8
8
9
- from einops import rearrange
9
+ from einops import rearrange , einsum
10
10
from einops .layers .torch import Rearrange
11
11
12
12
from rotary_embedding_torch import RotaryEmbedding
@@ -180,7 +180,8 @@ def __init__(
180
180
ff_expand_factor = 4. ,
181
181
ce_ignore_index = - 1 ,
182
182
residual_lerp_scale_init = None ,
183
- manual_norm_weights = False
183
+ manual_norm_weights = False ,
184
+ tied_embedding = False
184
185
):
185
186
super ().__init__ ()
186
187
NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights )
@@ -205,7 +206,7 @@ def __init__(
205
206
nn .Parameter (torch .ones (dim ) * residual_lerp_scale_init ),
206
207
]))
207
208
208
- self .to_logits = NormLinear_ (dim , num_tokens )
209
+ self .to_logits = NormLinear_ (dim , num_tokens ) if not tied_embedding else None
209
210
210
211
self .logit_scale = nn .Parameter (torch .ones (num_tokens ))
211
212
@@ -228,7 +229,8 @@ def forward(
228
229
if return_loss :
229
230
ids , labels = ids [:, :- 1 ], ids [:, 1 :]
230
231
231
- tokens = self .token_embed .weight [ids ]
232
+ token_embed = self .token_embed .weight
233
+ tokens = token_embed [ids ]
232
234
233
235
for (attn , ff ), (attn_alpha , ff_alpha ) in zip (self .layers , self .residual_lerp_scales ):
234
236
@@ -238,7 +240,12 @@ def forward(
238
240
ff_out = l2norm (ff (tokens ))
239
241
tokens = l2norm (tokens .lerp (ff_out , ff_alpha ))
240
242
241
- logits = self .to_logits (tokens )
243
+ if exists (self .to_logits ):
244
+ logits = self .to_logits (tokens )
245
+ else :
246
+ # tied embeddings
247
+ logits = einsum (tokens , token_embed , 'b n d, c d -> b n c' )
248
+
242
249
logits = logits * self .logit_scale * (self .dim ** 0.5 )
243
250
244
251
if not return_loss :
0 commit comments