Skip to content

[Feat] Sage Attention Kernels Support for sm80, sm89, sm90 #9848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,027 changes: 1,027 additions & 0 deletions csrc/gpu/sage_attn_kernels/sageattn_fused.cu

Large diffs are not rendered by default.

1,571 changes: 1,571 additions & 0 deletions csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu

Large diffs are not rendered by default.

1,108 changes: 1,108 additions & 0 deletions csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu

Large diffs are not rendered by default.

1,832 changes: 1,832 additions & 0 deletions csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu

Large diffs are not rendered by default.

994 changes: 994 additions & 0 deletions csrc/gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu

Large diffs are not rendered by default.

2,741 changes: 2,741 additions & 0 deletions csrc/gpu/sage_attn_kernels/sageattn_utils.cuh

Large diffs are not rendered by default.

27 changes: 26 additions & 1 deletion csrc/setup_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,31 @@ def get_gencode_flags():
"gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",
]

if cc >= 80 and cuda_version >= 12.4:
nvcc_compile_args += [
"-std=c++17",
"--use_fast_math",
"--threads=8",
"-D_GLIBCXX_USE_CXX11_ABI=1",
]
sources += ["./gpu/sage_attn_kernels/sageattn_fused.cu"]
if cc >= 80 and cc < 89:
sources += [
"./gpu/sage_attn_kernels/sageattn_qk_int_sv_f16_kernel_sm80.cu"
]
nvcc_compile_args += ["-gencode", f"arch=compute_80,code=compute_80"]
elif cc >= 89 and cc < 90:
sources += [
"./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm89.cu"
]
nvcc_compile_args += ["-gencode", f"arch=compute_89,code=compute_89"]
elif cc >= 90:
sources += [
"./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_kernel_sm90.cu",
"./gpu/sage_attn_kernels/sageattn_qk_int_sv_f8_dsk_kernel_sm90.cu"
]
nvcc_compile_args += ["-gencode", f"arch=compute_90a,code=compute_90a"]

if cc >= 90 and cuda_version >= 12.0:
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py --cuda_arch 90")
os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_ptr_scale_sm90.py --cuda_arch 90")
Expand All @@ -188,7 +213,7 @@ def get_gencode_flags():
name=ops_name,
ext_modules=CUDAExtension(
sources=sources,
extra_compile_args={"cxx": ["-O3"], "nvcc": nvcc_compile_args},
extra_compile_args={"cxx": ["-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"], "nvcc": nvcc_compile_args},
libraries=["cublasLt"],
library_dirs=[library_path],
),
Expand Down
145 changes: 109 additions & 36 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
transpose_remove_padding,
write_cache_kv,
)

except:
pass

Expand Down Expand Up @@ -2969,18 +2970,42 @@
if kwargs["max_enc_len_this_time"]: # prefill phase
query, key, value = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs)

fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded(
query,
key,
value,
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("max_enc_len_this_time", -1),
kwargs.get("max_enc_len_this_time", -1),
self.softmax_scale,
causal=True,
training=False,
)[0]
from paddlenlp.utils.env import PREFILL_USE_SAGE_ATTN

Check warning on line 2973 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L2973

Added line #L2973 was not covered by tests

if PREFILL_USE_SAGE_ATTN:
from .sageattention import sageattn_qk_int8_pv_fp8_cuda_dsk_sm90

Check warning on line 2976 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L2975-L2976

Added lines #L2975 - L2976 were not covered by tests

query_192 = paddle.unsqueeze(query, axis=0)
key_192 = paddle.unsqueeze(key, axis=0)

Check warning on line 2979 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L2978-L2979

Added lines #L2978 - L2979 were not covered by tests

value_128, _ = paddle.split(value, [128, 64], axis=-1)
value_128 = paddle.unsqueeze(value_128, axis=0)

Check warning on line 2982 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L2981-L2982

Added lines #L2981 - L2982 were not covered by tests

fmha_out_prefill = sageattn_qk_int8_pv_fp8_cuda_dsk_sm90(

Check warning on line 2984 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L2984

Added line #L2984 was not covered by tests
query_192,
key_192,
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
value_128,
is_causal=True,
sm_scale=self.softmax_scale,
tensor_layout="NHD",
)
fmha_out_prefill = paddle.nn.functional.pad(fmha_out_prefill, (0, 192 - 128))
fmha_out_prefill = paddle.squeeze(fmha_out_prefill, axis=0)

Check warning on line 2995 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L2994-L2995

Added lines #L2994 - L2995 were not covered by tests
else:
fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded(

Check warning on line 2997 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L2997

Added line #L2997 was not covered by tests
query,
key,
value,
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("max_enc_len_this_time", -1),
kwargs.get("max_enc_len_this_time", -1),
self.softmax_scale,
causal=True,
training=False,
)[0]

fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim])
fmha_out_prefill = fmha_out_prefill[:, :, : self.config.mla_config.v_head_dim]
Expand Down Expand Up @@ -3302,18 +3327,42 @@
if kwargs["max_enc_len_this_time"]: # prefill phase
query, key, value = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs)

fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded(
query,
key,
value,
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("max_enc_len_this_time", -1),
kwargs.get("max_enc_len_this_time", -1),
self.softmax_scale,
causal=True,
training=False,
)[0]
from paddlenlp.utils.env import PREFILL_USE_SAGE_ATTN

Check warning on line 3330 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L3330

Added line #L3330 was not covered by tests

if PREFILL_USE_SAGE_ATTN:
from .sageattention import sageattn_qk_int8_pv_fp8_cuda_dsk_sm90

Check warning on line 3333 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L3332-L3333

Added lines #L3332 - L3333 were not covered by tests

query_192 = paddle.unsqueeze(query, axis=0)
key_192 = paddle.unsqueeze(key, axis=0)

Check warning on line 3336 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L3335-L3336

Added lines #L3335 - L3336 were not covered by tests

value_128, _ = paddle.split(value, [128, 64], axis=-1)
value_128 = paddle.unsqueeze(value_128, axis=0)

Check warning on line 3339 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L3338-L3339

Added lines #L3338 - L3339 were not covered by tests

fmha_out_prefill = sageattn_qk_int8_pv_fp8_cuda_dsk_sm90(

Check warning on line 3341 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L3341

Added line #L3341 was not covered by tests
query_192,
key_192,
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
value_128,
is_causal=True,
sm_scale=self.softmax_scale,
tensor_layout="NHD",
)
fmha_out_prefill = paddle.nn.functional.pad(fmha_out_prefill, (0, 192 - 128))
fmha_out_prefill = paddle.squeeze(fmha_out_prefill, axis=0)

Check warning on line 3352 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L3351-L3352

Added lines #L3351 - L3352 were not covered by tests
else:
fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded(

Check warning on line 3354 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L3354

Added line #L3354 was not covered by tests
query,
key,
value,
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("max_enc_len_this_time", -1),
kwargs.get("max_enc_len_this_time", -1),
self.softmax_scale,
causal=True,
training=False,
)[0]

fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim])
fmha_out_prefill = fmha_out_prefill[:, :, : self.config.mla_config.v_head_dim]
Expand Down Expand Up @@ -4997,18 +5046,42 @@
if kwargs["max_enc_len_this_time"]: # prefill phase
query, key, value = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs)

fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded(
query,
key,
value,
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("max_enc_len_this_time", -1),
kwargs.get("max_enc_len_this_time", -1),
self.softmax_scale,
causal=True,
training=False,
)[0]
from paddlenlp.utils.env import PREFILL_USE_SAGE_ATTN

Check warning on line 5049 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L5049

Added line #L5049 was not covered by tests

if PREFILL_USE_SAGE_ATTN:
from .sageattention import sageattn_qk_int8_pv_fp8_cuda_dsk_sm90

Check warning on line 5052 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L5051-L5052

Added lines #L5051 - L5052 were not covered by tests

query_192 = paddle.unsqueeze(query, axis=0)
key_192 = paddle.unsqueeze(key, axis=0)

Check warning on line 5055 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L5054-L5055

Added lines #L5054 - L5055 were not covered by tests

value_128, _ = paddle.split(value, [128, 64], axis=-1)
value_128 = paddle.unsqueeze(value_128, axis=0)

Check warning on line 5058 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L5057-L5058

Added lines #L5057 - L5058 were not covered by tests

fmha_out_prefill = sageattn_qk_int8_pv_fp8_cuda_dsk_sm90(

Check warning on line 5060 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L5060

Added line #L5060 was not covered by tests
query_192,
key_192,
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
value_128,
is_causal=True,
sm_scale=self.softmax_scale,
tensor_layout="NHD",
)
fmha_out_prefill = paddle.nn.functional.pad(fmha_out_prefill, (0, 192 - 128))
fmha_out_prefill = paddle.squeeze(fmha_out_prefill, axis=0)

Check warning on line 5071 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L5070-L5071

Added lines #L5070 - L5071 were not covered by tests
else:
fmha_out_prefill = paddle.nn.functional.flash_attention.flash_attn_unpadded(

Check warning on line 5073 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L5073

Added line #L5073 was not covered by tests
query,
key,
value,
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("max_enc_len_this_time", -1),
kwargs.get("max_enc_len_this_time", -1),
self.softmax_scale,
causal=True,
training=False,
)[0]

fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_heads, self.config.mla_config.qk_head_dim])
fmha_out_prefill = fmha_out_prefill[:, :, : self.config.mla_config.v_head_dim]
Expand Down
Loading
Loading