@@ -1131,13 +1131,13 @@ def compute_qkv_linear(self, ln_out, i, latent_cache=None, **kwargs):
1131
1131
qkv_out = paddle .add (qkv_out , self .qkv_biases [i ])
1132
1132
return qkv_out
1133
1133
1134
- def compute_qkv (self , src , residual_input , i ):
1134
+ def compute_qkv (self , src , residual_input , i , ** kwargs ):
1135
1135
ln_out = self .compute_layernorm_before_qkv (src , i )
1136
1136
1137
1137
if self .config .mla_config .use_absorb ():
1138
1138
qkv_out = ln_out
1139
1139
else :
1140
- qkv_out = self .compute_qkv_linear (ln_out , i )
1140
+ qkv_out = self .compute_qkv_linear (ln_out , i , ** kwargs )
1141
1141
1142
1142
return qkv_out , residual_input
1143
1143
@@ -1523,7 +1523,7 @@ def forward(
1523
1523
1524
1524
residual_input = src
1525
1525
for i in range (self .num_layers ):
1526
- qkv_out , residual_input = self .compute_qkv (src , residual_input , i )
1526
+ qkv_out , residual_input = self .compute_qkv (src , residual_input , i , ** kwargs )
1527
1527
fmha_out = self .compute_attn (
1528
1528
time_step ,
1529
1529
qkv_out ,
@@ -1596,7 +1596,7 @@ class FusedMultiTransformerPostLayernorm(FusedMultiTransformerBase):
1596
1596
def __init__ (self , config : FusedMultiTransformerConfig ):
1597
1597
super ().__init__ (config )
1598
1598
1599
- def compute_qkv (self , src , residual_input , i ):
1599
+ def compute_qkv (self , src , residual_input , i , ** kwargs ):
1600
1600
qkv_out = self .compute_qkv_linear (src , i )
1601
1601
return qkv_out , src
1602
1602
@@ -2055,9 +2055,7 @@ def compute_qkv_linear(self, ln_out, i, latent_cache=None, **kwargs):
2055
2055
epsilon = self ._epsilon ,
2056
2056
begin_norm_axis = 1 ,
2057
2057
)[0 ]
2058
- query_pe , key_pe = self .config .rotary_emb (
2059
- self .position_ids [0 : kwargs .get ("seq_lens_encoder" , None ).sum ()], query_pe , key_pe
2060
- )
2058
+ query_pe , key_pe = self .config .rotary_emb (self .position_ids , query_pe , key_pe )
2061
2059
2062
2060
if self .config .mla_config .use_absorb ():
2063
2061
from paddlenlp_ops import prefill_mla_write_cache
@@ -2689,7 +2687,7 @@ def compute_layernorm_before_qkv(self, src, i):
2689
2687
2690
2688
return ln_out
2691
2689
2692
- def compute_qkv_linear (self , ln_out , i ):
2690
+ def compute_qkv_linear (self , ln_out , i , ** kwargs ):
2693
2691
if self .config .mla_config .use_mla ():
2694
2692
raise NotImplementedError ("Not support MLA yet." )
2695
2693
else :
@@ -5140,7 +5138,7 @@ def compute_layernorm_before_qkv(self, src, i):
5140
5138
5141
5139
return ln_out
5142
5140
5143
- def compute_qkv_linear (self , ln_out , i ):
5141
+ def compute_qkv_linear (self , ln_out , i , ** kwargs ):
5144
5142
if self .config .mla_config .use_mla ():
5145
5143
raise NotImplementedError ("Not support MLA yet." )
5146
5144
else :
0 commit comments