@@ -35,15 +35,15 @@ def __init__(
35
35
self ,
36
36
dim ,
37
37
dim_out ,
38
- norm_dim = - 1
38
+ norm_dim_in = True
39
39
):
40
40
super ().__init__ ()
41
41
self .linear = nn .Linear (dim , dim_out , bias = False )
42
42
43
43
parametrize .register_parametrization (
44
44
self .linear ,
45
45
'weight' ,
46
- L2Norm (dim = norm_dim )
46
+ L2Norm (dim = - 1 if norm_dim_in else 0 )
47
47
)
48
48
49
49
@property
@@ -66,9 +66,9 @@ def __init__(
66
66
):
67
67
super ().__init__ ()
68
68
dim_inner = dim_head * heads
69
- self .to_q = NormLinear (dim , dim_inner , norm_dim = 0 )
70
- self .to_k = NormLinear (dim , dim_inner , norm_dim = 0 )
71
- self .to_v = NormLinear (dim , dim_inner , norm_dim = 0 )
69
+ self .to_q = NormLinear (dim , dim_inner )
70
+ self .to_k = NormLinear (dim , dim_inner )
71
+ self .to_v = NormLinear (dim , dim_inner )
72
72
73
73
self .rotary_emb = RotaryEmbedding (dim_head )
74
74
self .qk_scale = nn .Parameter (torch .ones (dim_head ) * (dim_head ** - 0.25 ))
@@ -77,7 +77,7 @@ def __init__(
77
77
self .split_heads = Rearrange ('b n (h d) -> b h n d' , h = heads )
78
78
self .merge_heads = Rearrange ('b h n d -> b n (h d)' )
79
79
80
- self .to_out = NormLinear (dim_inner , dim )
80
+ self .to_out = NormLinear (dim_inner , dim , norm_dim_in = False )
81
81
82
82
def forward (
83
83
self ,
@@ -123,13 +123,13 @@ def __init__(
123
123
self .dim = dim
124
124
dim_inner = int (dim * expand_factor * 2 / 3 )
125
125
126
- self .to_hidden = NormLinear (dim , dim_inner , norm_dim = 0 )
127
- self .to_gate = NormLinear (dim , dim_inner , norm_dim = 0 )
126
+ self .to_hidden = NormLinear (dim , dim_inner )
127
+ self .to_gate = NormLinear (dim , dim_inner )
128
128
129
129
self .hidden_scale = nn .Parameter (torch .ones (dim_inner ))
130
130
self .gate_scale = nn .Parameter (torch .ones (dim_inner ))
131
131
132
- self .to_out = NormLinear (dim_inner , dim )
132
+ self .to_out = NormLinear (dim_inner , dim , norm_dim_in = False )
133
133
134
134
def forward (self , x ):
135
135
hidden , gate = self .to_hidden (x ), self .to_gate (x )
@@ -177,7 +177,7 @@ def __init__(
177
177
nn .Parameter (torch .ones (dim ) * residual_lerp_scale_init ),
178
178
]))
179
179
180
- self .to_logits = NormLinear (dim , num_tokens , norm_dim = 0 )
180
+ self .to_logits = NormLinear (dim , num_tokens )
181
181
182
182
self .logit_scale = nn .Parameter (torch .ones (num_tokens ))
183
183
0 commit comments