@@ -395,7 +395,7 @@ def forward(ctx, x, w1, w2):
395
395
)
396
396
397
397
# ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) =====
398
- w2_fp8 , w2_sacle , w2_t_fp8 , w2_t_scale = kitchen_quant (
398
+ _ , _ , w2_t_fp8 , w2_t_scale = kitchen_quant (
399
399
w2 , backend = kitchen .ops .Backend .CUBLAS , is_1d_scaled = False , return_transpose = True
400
400
)
401
401
o3 = paddle .empty ([o2_fp8 .shape [0 ], w2_t_fp8 .shape [0 ]], dtype = o2 .dtype )
@@ -426,8 +426,7 @@ def forward(ctx, x, w1, w2):
426
426
# w1_fp8,
427
427
# w1_sacle,
428
428
o1 ,
429
- w2_fp8 ,
430
- w2_sacle ,
429
+ w2 ,
431
430
paddle .to_tensor (x_orig_shape , dtype = "int64" , place = paddle .CPUPlace ()),
432
431
)
433
432
return o3
@@ -438,9 +437,13 @@ def backward(ctx, do3):
438
437
do3_orig_shape = do3 .shape
439
438
do3 = do3 .reshape ([- 1 , do3_orig_shape [- 1 ]])
440
439
441
- x_t_fp8 , x_t_scale , w1 , o1 , w2_fp8 , w2_sacle , x_orig_shape = ctx .saved_tensor ()
440
+ x_t_fp8 , x_t_scale , w1 , o1 , w2 , x_orig_shape = ctx .saved_tensor ()
442
441
x_orig_shape = x_orig_shape .numpy ()
443
442
443
+ w2_fp8 , w2_scale = kitchen_quant (
444
+ w2 , backend = kitchen .ops .Backend .CUBLAS , is_1d_scaled = False , return_transpose = False
445
+ )
446
+
444
447
# ===== [recompute] o2 = swiglu(o1) =====
445
448
# TODO: [Fusion] swiglu + transpose + quant
446
449
o2 = swiglu (o1 )
@@ -454,7 +457,7 @@ def backward(ctx, do3):
454
457
do3 , backend = kitchen .ops .Backend .CUTLASS , is_1d_scaled = True , return_transpose = False
455
458
)
456
459
do2 = paddle .empty ([do3_fp8 .shape [0 ], w2_fp8 .shape [0 ]], do3 .dtype )
457
- deep_gemm .gemm_fp8_fp8_bf16_nt ((do3_fp8 , do3_scale ), (w2_fp8 , w2_sacle ), do2 )
460
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((do3_fp8 , do3_scale ), (w2_fp8 , w2_scale ), do2 )
458
461
459
462
# ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8)
460
463
if o2_t .shape [- 1 ] % 128 != 0 or o2_t .shape [- 1 ] % 512 != 0 :
@@ -549,3 +552,25 @@ def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size
549
552
550
553
def forward (self , x ):
551
554
return Fuse_FFN_FP8_Func .apply (x , self .w1 , self .w2 )
555
+
556
+
557
+ class FusedFP8DeepseekV2MLP (paddle .nn .Layer ):
558
+ def __init__ (self , config : DeepseekV2Config , hidden_size = None , intermediate_size = None , is_moe = False ):
559
+ super ().__init__ ()
560
+ self .config = config
561
+ self .hidden_size = config .hidden_size if hidden_size is None else hidden_size
562
+ self .intermediate_size = config .intermediate_size if intermediate_size is None else intermediate_size
563
+
564
+ self .w1 = self .create_parameter (
565
+ shape = [4 * self .hidden_size , self .intermediate_size * 2 ],
566
+ dtype = "bfloat16" ,
567
+ is_bias = False ,
568
+ )
569
+ self .w2 = self .create_parameter (
570
+ shape = [4 * self .intermediate_size , self .hidden_size ],
571
+ dtype = "bfloat16" ,
572
+ is_bias = False ,
573
+ )
574
+
575
+ def forward (self , x ):
576
+ return Fuse_FFN_FP8_Func .apply (x , self .w1 , self .w2 )
0 commit comments