Skip to content

Commit 409fc40

Browse files
authored
[Inference] fix the bug for dynamic/static inference and dy2st for deepseekv2 in weight only int4/int8 (#10491)
* fix the bug in export deepseek_v2 * fix the deepseekv2 dynamic weight only int8 * fix * fix * update
1 parent f0c5aa2 commit 409fc40

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,13 +1131,13 @@ def compute_qkv_linear(self, ln_out, i, latent_cache=None, **kwargs):
11311131
qkv_out = paddle.add(qkv_out, self.qkv_biases[i])
11321132
return qkv_out
11331133

1134-
def compute_qkv(self, src, residual_input, i):
1134+
def compute_qkv(self, src, residual_input, i, **kwargs):
11351135
ln_out = self.compute_layernorm_before_qkv(src, i)
11361136

11371137
if self.config.mla_config.use_absorb():
11381138
qkv_out = ln_out
11391139
else:
1140-
qkv_out = self.compute_qkv_linear(ln_out, i)
1140+
qkv_out = self.compute_qkv_linear(ln_out, i, **kwargs)
11411141

11421142
return qkv_out, residual_input
11431143

@@ -1523,7 +1523,7 @@ def forward(
15231523

15241524
residual_input = src
15251525
for i in range(self.num_layers):
1526-
qkv_out, residual_input = self.compute_qkv(src, residual_input, i)
1526+
qkv_out, residual_input = self.compute_qkv(src, residual_input, i, **kwargs)
15271527
fmha_out = self.compute_attn(
15281528
time_step,
15291529
qkv_out,
@@ -1596,7 +1596,7 @@ class FusedMultiTransformerPostLayernorm(FusedMultiTransformerBase):
15961596
def __init__(self, config: FusedMultiTransformerConfig):
15971597
super().__init__(config)
15981598

1599-
def compute_qkv(self, src, residual_input, i):
1599+
def compute_qkv(self, src, residual_input, i, **kwargs):
16001600
qkv_out = self.compute_qkv_linear(src, i)
16011601
return qkv_out, src
16021602

@@ -2055,9 +2055,7 @@ def compute_qkv_linear(self, ln_out, i, latent_cache=None, **kwargs):
20552055
epsilon=self._epsilon,
20562056
begin_norm_axis=1,
20572057
)[0]
2058-
query_pe, key_pe = self.config.rotary_emb(
2059-
self.position_ids[0 : kwargs.get("seq_lens_encoder", None).sum()], query_pe, key_pe
2060-
)
2058+
query_pe, key_pe = self.config.rotary_emb(self.position_ids, query_pe, key_pe)
20612059

20622060
if self.config.mla_config.use_absorb():
20632061
from paddlenlp_ops import prefill_mla_write_cache
@@ -2689,7 +2687,7 @@ def compute_layernorm_before_qkv(self, src, i):
26892687

26902688
return ln_out
26912689

2692-
def compute_qkv_linear(self, ln_out, i):
2690+
def compute_qkv_linear(self, ln_out, i, **kwargs):
26932691
if self.config.mla_config.use_mla():
26942692
raise NotImplementedError("Not support MLA yet.")
26952693
else:
@@ -5140,7 +5138,7 @@ def compute_layernorm_before_qkv(self, src, i):
51405138

51415139
return ln_out
51425140

5143-
def compute_qkv_linear(self, ln_out, i):
5141+
def compute_qkv_linear(self, ln_out, i, **kwargs):
51445142
if self.config.mla_config.use_mla():
51455143
raise NotImplementedError("Not support MLA yet.")
51465144
else:

0 commit comments

Comments
 (0)