Skip to content

Commit c3c6695

Browse files
phlrainA-nnonymous
andauthored
optimize linear keepx quant (#10332)
Co-authored-by: Pan Zhaowu <panzhaowu@baidu.com>
1 parent 0e158fb commit c3c6695

File tree

1 file changed

+30
-5
lines changed

1 file changed

+30
-5
lines changed

paddlenlp/transformers/deepseek_v2/fp8_linear.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def forward(ctx, x, w1, w2):
395395
)
396396

397397
# ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) =====
398-
w2_fp8, w2_sacle, w2_t_fp8, w2_t_scale = kitchen_quant(
398+
_, _, w2_t_fp8, w2_t_scale = kitchen_quant(
399399
w2, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=False, return_transpose=True
400400
)
401401
o3 = paddle.empty([o2_fp8.shape[0], w2_t_fp8.shape[0]], dtype=o2.dtype)
@@ -426,8 +426,7 @@ def forward(ctx, x, w1, w2):
426426
# w1_fp8,
427427
# w1_sacle,
428428
o1,
429-
w2_fp8,
430-
w2_sacle,
429+
w2,
431430
paddle.to_tensor(x_orig_shape, dtype="int64", place=paddle.CPUPlace()),
432431
)
433432
return o3
@@ -438,9 +437,13 @@ def backward(ctx, do3):
438437
do3_orig_shape = do3.shape
439438
do3 = do3.reshape([-1, do3_orig_shape[-1]])
440439

441-
x_t_fp8, x_t_scale, w1, o1, w2_fp8, w2_sacle, x_orig_shape = ctx.saved_tensor()
440+
x_t_fp8, x_t_scale, w1, o1, w2, x_orig_shape = ctx.saved_tensor()
442441
x_orig_shape = x_orig_shape.numpy()
443442

443+
w2_fp8, w2_scale = kitchen_quant(
444+
w2, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=False, return_transpose=False
445+
)
446+
444447
# ===== [recompute] o2 = swiglu(o1) =====
445448
# TODO: [Fusion] swiglu + transpose + quant
446449
o2 = swiglu(o1)
@@ -454,7 +457,7 @@ def backward(ctx, do3):
454457
do3, backend=kitchen.ops.Backend.CUTLASS, is_1d_scaled=True, return_transpose=False
455458
)
456459
do2 = paddle.empty([do3_fp8.shape[0], w2_fp8.shape[0]], do3.dtype)
457-
deep_gemm.gemm_fp8_fp8_bf16_nt((do3_fp8, do3_scale), (w2_fp8, w2_sacle), do2)
460+
deep_gemm.gemm_fp8_fp8_bf16_nt((do3_fp8, do3_scale), (w2_fp8, w2_scale), do2)
458461

459462
# ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8)
460463
if o2_t.shape[-1] % 128 != 0 or o2_t.shape[-1] % 512 != 0:
@@ -549,3 +552,25 @@ def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size
549552

550553
def forward(self, x):
551554
return Fuse_FFN_FP8_Func.apply(x, self.w1, self.w2)
555+
556+
557+
class FusedFP8DeepseekV2MLP(paddle.nn.Layer):
558+
def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None, is_moe=False):
559+
super().__init__()
560+
self.config = config
561+
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
562+
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
563+
564+
self.w1 = self.create_parameter(
565+
shape=[4 * self.hidden_size, self.intermediate_size * 2],
566+
dtype="bfloat16",
567+
is_bias=False,
568+
)
569+
self.w2 = self.create_parameter(
570+
shape=[4 * self.intermediate_size, self.hidden_size],
571+
dtype="bfloat16",
572+
is_bias=False,
573+
)
574+
575+
def forward(self, x):
576+
return Fuse_FFN_FP8_Func.apply(x, self.w1, self.w2)

0 commit comments

Comments
 (0)