@@ -71,6 +71,7 @@ def use_cutlass_fp8_gemm():
71
71
transpose_remove_padding ,
72
72
write_cache_kv ,
73
73
)
74
+
74
75
except :
75
76
pass
76
77
@@ -2969,18 +2970,42 @@ def compute_mla_absorb(
2969
2970
if kwargs ["max_enc_len_this_time" ]: # prefill phase
2970
2971
query , key , value = self .compute_qkv_linear (ln_out , i , latent_cache = latent_cache , ** kwargs )
2971
2972
2972
- fmha_out_prefill = paddle .nn .functional .flash_attention .flash_attn_unpadded (
2973
- query ,
2974
- key ,
2975
- value ,
2976
- kwargs .get ("cu_seqlens_q" , None ),
2977
- kwargs .get ("cu_seqlens_k" , None ),
2978
- kwargs .get ("max_enc_len_this_time" , - 1 ),
2979
- kwargs .get ("max_enc_len_this_time" , - 1 ),
2980
- self .softmax_scale ,
2981
- causal = True ,
2982
- training = False ,
2983
- )[0 ]
2973
+ from paddlenlp .utils .env import PREFILL_USE_SAGE_ATTN
2974
+
2975
+ if PREFILL_USE_SAGE_ATTN :
2976
+ from .sageattention import sageattn_qk_int8_pv_fp8_cuda_dsk_sm90
2977
+
2978
+ query_192 = paddle .unsqueeze (query , axis = 0 )
2979
+ key_192 = paddle .unsqueeze (key , axis = 0 )
2980
+
2981
+ value_128 , _ = paddle .split (value , [128 , 64 ], axis = - 1 )
2982
+ value_128 = paddle .unsqueeze (value_128 , axis = 0 )
2983
+
2984
+ fmha_out_prefill = sageattn_qk_int8_pv_fp8_cuda_dsk_sm90 (
2985
+ query_192 ,
2986
+ key_192 ,
2987
+ kwargs .get ("cu_seqlens_q" , None ),
2988
+ kwargs .get ("cu_seqlens_k" , None ),
2989
+ value_128 ,
2990
+ is_causal = True ,
2991
+ sm_scale = self .softmax_scale ,
2992
+ tensor_layout = "NHD" ,
2993
+ )
2994
+ fmha_out_prefill = paddle .nn .functional .pad (fmha_out_prefill , (0 , 192 - 128 ))
2995
+ fmha_out_prefill = paddle .squeeze (fmha_out_prefill , axis = 0 )
2996
+ else :
2997
+ fmha_out_prefill = paddle .nn .functional .flash_attention .flash_attn_unpadded (
2998
+ query ,
2999
+ key ,
3000
+ value ,
3001
+ kwargs .get ("cu_seqlens_q" , None ),
3002
+ kwargs .get ("cu_seqlens_k" , None ),
3003
+ kwargs .get ("max_enc_len_this_time" , - 1 ),
3004
+ kwargs .get ("max_enc_len_this_time" , - 1 ),
3005
+ self .softmax_scale ,
3006
+ causal = True ,
3007
+ training = False ,
3008
+ )[0 ]
2984
3009
2985
3010
fmha_out_prefill = fmha_out_prefill .reshape ([- 1 , self .num_heads , self .config .mla_config .qk_head_dim ])
2986
3011
fmha_out_prefill = fmha_out_prefill [:, :, : self .config .mla_config .v_head_dim ]
@@ -3302,18 +3327,42 @@ def compute_mla_absorb(
3302
3327
if kwargs ["max_enc_len_this_time" ]: # prefill phase
3303
3328
query , key , value = self .compute_qkv_linear (ln_out , i , latent_cache = latent_cache , ** kwargs )
3304
3329
3305
- fmha_out_prefill = paddle .nn .functional .flash_attention .flash_attn_unpadded (
3306
- query ,
3307
- key ,
3308
- value ,
3309
- kwargs .get ("cu_seqlens_q" , None ),
3310
- kwargs .get ("cu_seqlens_k" , None ),
3311
- kwargs .get ("max_enc_len_this_time" , - 1 ),
3312
- kwargs .get ("max_enc_len_this_time" , - 1 ),
3313
- self .softmax_scale ,
3314
- causal = True ,
3315
- training = False ,
3316
- )[0 ]
3330
+ from paddlenlp .utils .env import PREFILL_USE_SAGE_ATTN
3331
+
3332
+ if PREFILL_USE_SAGE_ATTN :
3333
+ from .sageattention import sageattn_qk_int8_pv_fp8_cuda_dsk_sm90
3334
+
3335
+ query_192 = paddle .unsqueeze (query , axis = 0 )
3336
+ key_192 = paddle .unsqueeze (key , axis = 0 )
3337
+
3338
+ value_128 , _ = paddle .split (value , [128 , 64 ], axis = - 1 )
3339
+ value_128 = paddle .unsqueeze (value_128 , axis = 0 )
3340
+
3341
+ fmha_out_prefill = sageattn_qk_int8_pv_fp8_cuda_dsk_sm90 (
3342
+ query_192 ,
3343
+ key_192 ,
3344
+ kwargs .get ("cu_seqlens_q" , None ),
3345
+ kwargs .get ("cu_seqlens_k" , None ),
3346
+ value_128 ,
3347
+ is_causal = True ,
3348
+ sm_scale = self .softmax_scale ,
3349
+ tensor_layout = "NHD" ,
3350
+ )
3351
+ fmha_out_prefill = paddle .nn .functional .pad (fmha_out_prefill , (0 , 192 - 128 ))
3352
+ fmha_out_prefill = paddle .squeeze (fmha_out_prefill , axis = 0 )
3353
+ else :
3354
+ fmha_out_prefill = paddle .nn .functional .flash_attention .flash_attn_unpadded (
3355
+ query ,
3356
+ key ,
3357
+ value ,
3358
+ kwargs .get ("cu_seqlens_q" , None ),
3359
+ kwargs .get ("cu_seqlens_k" , None ),
3360
+ kwargs .get ("max_enc_len_this_time" , - 1 ),
3361
+ kwargs .get ("max_enc_len_this_time" , - 1 ),
3362
+ self .softmax_scale ,
3363
+ causal = True ,
3364
+ training = False ,
3365
+ )[0 ]
3317
3366
3318
3367
fmha_out_prefill = fmha_out_prefill .reshape ([- 1 , self .num_heads , self .config .mla_config .qk_head_dim ])
3319
3368
fmha_out_prefill = fmha_out_prefill [:, :, : self .config .mla_config .v_head_dim ]
@@ -4997,18 +5046,42 @@ def compute_mla_absorb(
4997
5046
if kwargs ["max_enc_len_this_time" ]: # prefill phase
4998
5047
query , key , value = self .compute_qkv_linear (ln_out , i , latent_cache = latent_cache , ** kwargs )
4999
5048
5000
- fmha_out_prefill = paddle .nn .functional .flash_attention .flash_attn_unpadded (
5001
- query ,
5002
- key ,
5003
- value ,
5004
- kwargs .get ("cu_seqlens_q" , None ),
5005
- kwargs .get ("cu_seqlens_k" , None ),
5006
- kwargs .get ("max_enc_len_this_time" , - 1 ),
5007
- kwargs .get ("max_enc_len_this_time" , - 1 ),
5008
- self .softmax_scale ,
5009
- causal = True ,
5010
- training = False ,
5011
- )[0 ]
5049
+ from paddlenlp .utils .env import PREFILL_USE_SAGE_ATTN
5050
+
5051
+ if PREFILL_USE_SAGE_ATTN :
5052
+ from .sageattention import sageattn_qk_int8_pv_fp8_cuda_dsk_sm90
5053
+
5054
+ query_192 = paddle .unsqueeze (query , axis = 0 )
5055
+ key_192 = paddle .unsqueeze (key , axis = 0 )
5056
+
5057
+ value_128 , _ = paddle .split (value , [128 , 64 ], axis = - 1 )
5058
+ value_128 = paddle .unsqueeze (value_128 , axis = 0 )
5059
+
5060
+ fmha_out_prefill = sageattn_qk_int8_pv_fp8_cuda_dsk_sm90 (
5061
+ query_192 ,
5062
+ key_192 ,
5063
+ kwargs .get ("cu_seqlens_q" , None ),
5064
+ kwargs .get ("cu_seqlens_k" , None ),
5065
+ value_128 ,
5066
+ is_causal = True ,
5067
+ sm_scale = self .softmax_scale ,
5068
+ tensor_layout = "NHD" ,
5069
+ )
5070
+ fmha_out_prefill = paddle .nn .functional .pad (fmha_out_prefill , (0 , 192 - 128 ))
5071
+ fmha_out_prefill = paddle .squeeze (fmha_out_prefill , axis = 0 )
5072
+ else :
5073
+ fmha_out_prefill = paddle .nn .functional .flash_attention .flash_attn_unpadded (
5074
+ query ,
5075
+ key ,
5076
+ value ,
5077
+ kwargs .get ("cu_seqlens_q" , None ),
5078
+ kwargs .get ("cu_seqlens_k" , None ),
5079
+ kwargs .get ("max_enc_len_this_time" , - 1 ),
5080
+ kwargs .get ("max_enc_len_this_time" , - 1 ),
5081
+ self .softmax_scale ,
5082
+ causal = True ,
5083
+ training = False ,
5084
+ )[0 ]
5012
5085
5013
5086
fmha_out_prefill = fmha_out_prefill .reshape ([- 1 , self .num_heads , self .config .mla_config .qk_head_dim ])
5014
5087
fmha_out_prefill = fmha_out_prefill [:, :, : self .config .mla_config .v_head_dim ]
0 commit comments