Skip to content

Commit b36d306

Browse files
authored
[Feat] Sage Attention Kernels Support for sm80, sm89, sm90 (#9848)
* add sage attn sm90 kernels * fix * add ds sageattn kernel * update kernels * update setup_cuda.py * update dsk MLA kernel * clean PR branch * fix sa usage * bugfix * modify, for static mode inference SA * add license info * add license info for py file * modify license info * modify license info * bsz=1 assert * fix kernel * move to import line * merge develop & support wint8&fp8
1 parent ed7f01d commit b36d306

10 files changed

+10041
-37
lines changed

csrc/gpu/sage_attn_kernels/sageattn_fused.cu

Lines changed: 1027 additions & 0 deletions
Large diffs are not rendered by default.

csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu

Lines changed: 1571 additions & 0 deletions
Large diffs are not rendered by default.

csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu

Lines changed: 1108 additions & 0 deletions
Large diffs are not rendered by default.

csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu

Lines changed: 1832 additions & 0 deletions
Large diffs are not rendered by default.

csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu

Lines changed: 994 additions & 0 deletions
Large diffs are not rendered by default.

csrc/gpu/sage_attn_kernels/sageattn_utils.cuh

Lines changed: 2741 additions & 0 deletions
Large diffs are not rendered by default.

csrc/setup_cuda.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,31 @@ def get_gencode_flags():
168168
"gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",
169169
]
170170

171+
if cc >= 80 and cuda_version >= 12.4:
172+
nvcc_compile_args += [
173+
"-std=c++17",
174+
"--use_fast_math",
175+
"--threads=8",
176+
"-D_GLIBCXX_USE_CXX11_ABI=1",
177+
]
178+
sources += ["./gpu/sage_attn_kernels/sageattn_fused.cu"]
179+
if cc >= 80 and cc < 89:
180+
sources += [
181+
"./gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu"
182+
]
183+
nvcc_compile_args += ["-gencode", f"arch=compute_80,code=compute_80"]
184+
elif cc >= 89 and cc < 90:
185+
sources += [
186+
"./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu"
187+
]
188+
nvcc_compile_args += ["-gencode", f"arch=compute_89,code=compute_89"]
189+
elif cc >= 90:
190+
sources += [
191+
"./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu",
192+
"./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu"
193+
]
194+
nvcc_compile_args += ["-gencode", f"arch=compute_90a,code=compute_90a"]
195+
171196
if cc >= 90 and cuda_version >= 12.0:
172197
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py --cuda_arch 90")
173198
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_ptr_scale_sm90.py --cuda_arch 90")
@@ -188,7 +213,7 @@ def get_gencode_flags():
188213
name=ops_name,
189214
ext_modules=CUDAExtension(
190215
sources=sources,
191-
extra_compile_args={"cxx": ["-O3"], "nvcc": nvcc_compile_args},
216+
extra_compile_args={"cxx": ["-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"], "nvcc": nvcc_compile_args},
192217
libraries=["cublasLt"],
193218
library_dirs=[library_path],
194219
),

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 109 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def use_cutlass_fp8_gemm():
7171
transpose_remove_padding,
7272
write_cache_kv,
7373
)
74+
7475
except:
7576
pass
7677

@@ -2969,18 +2970,42 @@ def compute_mla_absorb(
29692970
if kwargs["max_enc_len_this_time"]: # prefill phase
29702971
query, key, value = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs)
29712972

2972-
fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded(
2973-
query,
2974-
key,
2975-
value,
2976-
kwargs.get("cu_seqlens_q", None),
2977-
kwargs.get("cu_seqlens_k", None),
2978-
kwargs.get("max_enc_len_this_time", -1),
2979-
kwargs.get("max_enc_len_this_time", -1),
2980-
self.softmax_scale,
2981-
causal=True,
2982-
training=False,
2983-
)[0]
2973+
from paddlenlp.utils.env import PREFILL_USE_SAGE_ATTN
2974+
2975+
if PREFILL_USE_SAGE_ATTN:
2976+
from .sageattention import sageattn_qk_int8_pv_fp8_cuda_dsk_sm90
2977+
2978+
query_192 = paddle.unsqueeze(query, axis=0)
2979+
key_192 = paddle.unsqueeze(key, axis=0)
2980+
2981+
value_128, _ = paddle.split(value, [128, 64], axis=-1)
2982+
value_128 = paddle.unsqueeze(value_128, axis=0)
2983+
2984+
fmha_out_prefill = sageattn_qk_int8_pv_fp8_cuda_dsk_sm90(
2985+
query_192,
2986+
key_192,
2987+
kwargs.get("cu_seqlens_q", None),
2988+
kwargs.get("cu_seqlens_k", None),
2989+
value_128,
2990+
is_causal=True,
2991+
sm_scale=self.softmax_scale,
2992+
tensor_layout="NHD",
2993+
)
2994+
fmha_out_prefill = paddle.nn.functional.pad(fmha_out_prefill, (0, 192 - 128))
2995+
fmha_out_prefill = paddle.squeeze(fmha_out_prefill, axis=0)
2996+
else:
2997+
fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded(
2998+
query,
2999+
key,
3000+
value,
3001+
kwargs.get("cu_seqlens_q", None),
3002+
kwargs.get("cu_seqlens_k", None),
3003+
kwargs.get("max_enc_len_this_time", -1),
3004+
kwargs.get("max_enc_len_this_time", -1),
3005+
self.softmax_scale,
3006+
causal=True,
3007+
training=False,
3008+
)[0]
29843009

29853010
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim])
29863011
fmha_out_prefill = fmha_out_prefill[:, :, : self.config.mla_config.v_head_dim]
@@ -3302,18 +3327,42 @@ def compute_mla_absorb(
33023327
if kwargs["max_enc_len_this_time"]: # prefill phase
33033328
query, key, value = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs)
33043329

3305-
fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded(
3306-
query,
3307-
key,
3308-
value,
3309-
kwargs.get("cu_seqlens_q", None),
3310-
kwargs.get("cu_seqlens_k", None),
3311-
kwargs.get("max_enc_len_this_time", -1),
3312-
kwargs.get("max_enc_len_this_time", -1),
3313-
self.softmax_scale,
3314-
causal=True,
3315-
training=False,
3316-
)[0]
3330+
from paddlenlp.utils.env import PREFILL_USE_SAGE_ATTN
3331+
3332+
if PREFILL_USE_SAGE_ATTN:
3333+
from .sageattention import sageattn_qk_int8_pv_fp8_cuda_dsk_sm90
3334+
3335+
query_192 = paddle.unsqueeze(query, axis=0)
3336+
key_192 = paddle.unsqueeze(key, axis=0)
3337+
3338+
value_128, _ = paddle.split(value, [128, 64], axis=-1)
3339+
value_128 = paddle.unsqueeze(value_128, axis=0)
3340+
3341+
fmha_out_prefill = sageattn_qk_int8_pv_fp8_cuda_dsk_sm90(
3342+
query_192,
3343+
key_192,
3344+
kwargs.get("cu_seqlens_q", None),
3345+
kwargs.get("cu_seqlens_k", None),
3346+
value_128,
3347+
is_causal=True,
3348+
sm_scale=self.softmax_scale,
3349+
tensor_layout="NHD",
3350+
)
3351+
fmha_out_prefill = paddle.nn.functional.pad(fmha_out_prefill, (0, 192 - 128))
3352+
fmha_out_prefill = paddle.squeeze(fmha_out_prefill, axis=0)
3353+
else:
3354+
fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded(
3355+
query,
3356+
key,
3357+
value,
3358+
kwargs.get("cu_seqlens_q", None),
3359+
kwargs.get("cu_seqlens_k", None),
3360+
kwargs.get("max_enc_len_this_time", -1),
3361+
kwargs.get("max_enc_len_this_time", -1),
3362+
self.softmax_scale,
3363+
causal=True,
3364+
training=False,
3365+
)[0]
33173366

33183367
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim])
33193368
fmha_out_prefill = fmha_out_prefill[:, :, : self.config.mla_config.v_head_dim]
@@ -4997,18 +5046,42 @@ def compute_mla_absorb(
49975046
if kwargs["max_enc_len_this_time"]: # prefill phase
49985047
query, key, value = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs)
49995048

5000-
fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded(
5001-
query,
5002-
key,
5003-
value,
5004-
kwargs.get("cu_seqlens_q", None),
5005-
kwargs.get("cu_seqlens_k", None),
5006-
kwargs.get("max_enc_len_this_time", -1),
5007-
kwargs.get("max_enc_len_this_time", -1),
5008-
self.softmax_scale,
5009-
causal=True,
5010-
training=False,
5011-
)[0]
5049+
from paddlenlp.utils.env import PREFILL_USE_SAGE_ATTN
5050+
5051+
if PREFILL_USE_SAGE_ATTN:
5052+
from .sageattention import sageattn_qk_int8_pv_fp8_cuda_dsk_sm90
5053+
5054+
query_192 = paddle.unsqueeze(query, axis=0)
5055+
key_192 = paddle.unsqueeze(key, axis=0)
5056+
5057+
value_128, _ = paddle.split(value, [128, 64], axis=-1)
5058+
value_128 = paddle.unsqueeze(value_128, axis=0)
5059+
5060+
fmha_out_prefill = sageattn_qk_int8_pv_fp8_cuda_dsk_sm90(
5061+
query_192,
5062+
key_192,
5063+
kwargs.get("cu_seqlens_q", None),
5064+
kwargs.get("cu_seqlens_k", None),
5065+
value_128,
5066+
is_causal=True,
5067+
sm_scale=self.softmax_scale,
5068+
tensor_layout="NHD",
5069+
)
5070+
fmha_out_prefill = paddle.nn.functional.pad(fmha_out_prefill, (0, 192 - 128))
5071+
fmha_out_prefill = paddle.squeeze(fmha_out_prefill, axis=0)
5072+
else:
5073+
fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded(
5074+
query,
5075+
key,
5076+
value,
5077+
kwargs.get("cu_seqlens_q", None),
5078+
kwargs.get("cu_seqlens_k", None),
5079+
kwargs.get("max_enc_len_this_time", -1),
5080+
kwargs.get("max_enc_len_this_time", -1),
5081+
self.softmax_scale,
5082+
causal=True,
5083+
training=False,
5084+
)[0]
50125085

50135086
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim])
50145087
fmha_out_prefill = fmha_out_prefill[:, :, : self.config.mla_config.v_head_dim]

0 commit comments

Comments
 (0)