From 0116592f6eace985216ac3e97b40ecc200aaf101 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 25 Jul 2025 03:07:57 +0000 Subject: [PATCH 1/2] [wip] add torch.compile --- bench/bench_qk_int8_pv_fp16_cuda.py | 1 + sageattention/core.py | 154 ++++++++++++++++++++++++++-- sageattention/sm80_compile.py | 96 +++++++++++++++++ script.sh | 12 +++ tests/test_torch_compile.py | 83 +++++++++++++++ 5 files changed, 335 insertions(+), 11 deletions(-) create mode 100644 sageattention/sm80_compile.py create mode 100644 script.sh create mode 100644 tests/test_torch_compile.py diff --git a/bench/bench_qk_int8_pv_fp16_cuda.py b/bench/bench_qk_int8_pv_fp16_cuda.py index 669b7e4..ef5fac3 100644 --- a/bench/bench_qk_int8_pv_fp16_cuda.py +++ b/bench/bench_qk_int8_pv_fp16_cuda.py @@ -1,4 +1,5 @@ import torch +from torch.testing._internal.optests import fake_check from flash_attn.utils.benchmark import benchmark_forward import sageattention._qattn_sm80 as qattn diff --git a/sageattention/core.py b/sageattention/core.py index 1121f92..8d9ede8 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -52,9 +52,10 @@ from typing import Any, List, Literal, Optional, Tuple, Union import warnings - import subprocess import re + + def get_cuda_version(): try: output = subprocess.check_output(['nvcc', '--version']).decode() @@ -66,6 +67,7 @@ def get_cuda_version(): print("Failed to get CUDA version:", e) return None, None + def get_cuda_arch_versions(): cuda_archs = [] for i in range(torch.cuda.device_count()): @@ -73,6 +75,136 @@ def get_cuda_arch_versions(): cuda_archs.append(f"sm{major}{minor}") return cuda_archs + +@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f16_attn", mutates_args=(), device_types="cuda") +def qk_int8_sv_f16_accum_f16_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + """ + Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation. + """ + return _qattn_sm80.qk_int8_sv_f16_accum_f16_attn( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, qk_quant_gran, sm_scale, return_lse + ) + + +@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f32_attn", mutates_args=(), device_types="cuda") +def qk_int8_sv_f16_accum_f32_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + """ + Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP32 accumulation. + """ + return _qattn_sm80.qk_int8_sv_f16_accum_f32_attn( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, qk_quant_gran, sm_scale, return_lse + ) + + +@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f16_attn_inst_buf", mutates_args=(), device_types="cuda") +def qk_int8_sv_f16_accum_f16_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + """ + Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation. + """ + return _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, qk_quant_gran, sm_scale, return_lse + ) + + +def qk_int8_sv_f16_accum_attn_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + + batch_size = query.size(0) + num_qo_heads = query.size(2) + + if tensor_layout == 0: + qo_len = query.size(1) + else: + qo_len = query.size(2) + + if return_lse: + lse = torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device="cuda") + else: + lse = torch.empty((0)) + return lse + + +torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_attn")(qk_int8_sv_f16_accum_attn_fake) +torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f32_attn")(qk_int8_sv_f16_accum_attn_fake) +torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_attn_inst_buf")(qk_int8_sv_f16_accum_attn_fake) + + +# @torch.library.custom_op("sageattention::qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn", mutates_args=(), device_types="cuda") +# def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# output: torch.Tensor, +# query_scale: torch.Tensor, +# key_scale: torch.Tensor, +# value_scale: torch.Tensor, +# value_mean: torch.Tensor, +# tensor_layout: int, +# is_causal: int, +# qk_quant_gran: int, +# sm_scale: float, +# return_lse: int, +# ) -> torch.Tensor: +# """ +# Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation. +# """ +# return _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( +# query, key, value, output, query_scale, key_scale, value_scale, +# value_mean, tensor_layout, is_causal, qk_quant_gran, sm_scale, +# return_lse +# ) + + def sageattn( q: torch.Tensor, k: torch.Tensor, @@ -151,7 +283,7 @@ def sageattn( else: raise ValueError(f"Unsupported CUDA architecture: {arch}") -@torch.compiler.disable + def sageattn_qk_int8_pv_fp16_triton( q: torch.Tensor, k: torch.Tensor, @@ -294,7 +426,7 @@ def sageattn_qk_int8_pv_fp16_triton( else: return o -@torch.compiler.disable + def sageattn_varlen( q: torch.Tensor, k: torch.Tensor, @@ -411,7 +543,7 @@ def sageattn_varlen( return o -@torch.compiler.disable + def sageattn_qk_int8_pv_fp16_cuda( q: torch.Tensor, k: torch.Tensor, @@ -566,17 +698,17 @@ def sageattn_qk_int8_pv_fp16_cuda( if pv_accum_dtype == 'fp32': v = v.to(torch.float16) - lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp16": if smooth_v: smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) - lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) else: v = v.to(torch.float16) - lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp16+fp32": v = v.to(torch.float16) - lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) else: raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") @@ -587,7 +719,7 @@ def sageattn_qk_int8_pv_fp16_cuda( else: return o -@torch.compiler.disable + def sageattn_qk_int8_pv_fp8_cuda( q: torch.Tensor, k: torch.Tensor, @@ -771,7 +903,7 @@ def sageattn_qk_int8_pv_fp8_cuda( else: return o -@torch.compiler.disable + def sageattn_qk_int8_pv_fp8_cuda_sm90( q: torch.Tensor, k: torch.Tensor, @@ -930,4 +1062,4 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90( if return_lse: return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504 else: - return o \ No newline at end of file + return o diff --git a/sageattention/sm80_compile.py b/sageattention/sm80_compile.py new file mode 100644 index 0000000..4f0b910 --- /dev/null +++ b/sageattention/sm80_compile.py @@ -0,0 +1,96 @@ +from . import _qattn_sm80 +import torch + + +@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f16_attn", mutates_args=(), device_types="cuda") +def qk_int8_sv_f16_accum_f16_attn( + q_int8: torch.Tensor, + k_int8: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + q_scale: torch.Tensor, + k_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + """ + Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation. + """ + return _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + + +@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f32_attn", mutates_args=(), device_types="cuda") +def qk_int8_sv_f16_accum_f32_attn( + q_int8: torch.Tensor, + k_int8: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + q_scale: torch.Tensor, + k_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + """ + Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP32 accumulation. + """ + return _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + + +@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f16_attn_inst_buf", mutates_args=(), device_types="cuda") +def qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8: torch.Tensor, + k_int8: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + q_scale: torch.Tensor, + k_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + """ + Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation. + """ + return _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + + +def qk_int8_sv_f16_accum_attn_fake( + q_int8: torch.Tensor, + k_int8: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + q_scale: torch.Tensor, + k_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + + batch_size = q_int8.size(0) + num_qo_heads = q_int8.size(2) + + if tensor_layout == 0: + qo_len = q_int8.size(1) + else: + qo_len = q_int8.size(2) + + if return_lse: + lse = torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device="cuda") + else: + lse = torch.empty((0)) + return lse + + +torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_attn")(qk_int8_sv_f16_accum_attn_fake) +torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f32_attn")(qk_int8_sv_f16_accum_attn_fake) +torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_attn_inst_buf")(qk_int8_sv_f16_accum_attn_fake) diff --git a/script.sh b/script.sh new file mode 100644 index 0000000..491b856 --- /dev/null +++ b/script.sh @@ -0,0 +1,12 @@ +#!bin/bash + +set -e + +( + export PYTHONBREAKPOINT="pdbp.set_trace" + python setup.py install + ( + cd tests + python -m pytest --tb=line -x + ) +) \ No newline at end of file diff --git a/tests/test_torch_compile.py b/tests/test_torch_compile.py new file mode 100644 index 0000000..06731f4 --- /dev/null +++ b/tests/test_torch_compile.py @@ -0,0 +1,83 @@ +import pytest +import torch +from torch.testing._internal.optests import fake_check + +import sageattention._qattn_sm80 as qattn +from sageattention.core import ( + SM80_ENABLED, + SM89_ENABLED, + SM90_ENABLED, + qk_int8_sv_f16_accum_f32_attn, + qk_int8_sv_f16_accum_f16_attn_inst_buf, + qk_int8_sv_f16_accum_f16_attn, + # qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, + # qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, + # qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, +) + + +def run_fake_check(fn): + def wrapper(*args, **kwargs): + fake_check(fn, args, kwargs) + return wrapper + + +@pytest.mark.skipif(not SM80_ENABLED, reason="SM80 not enabled") +class Test_SM80: + kernels = { + "fp32": qk_int8_sv_f16_accum_f32_attn, + "fp16+fp32": qk_int8_sv_f16_accum_f16_attn_inst_buf, + "fp16": qk_int8_sv_f16_accum_f16_attn + } + + def get_kernel(self, pv_accum_dtype): + return self.kernels[pv_accum_dtype] + + @pytest.mark.parametrize("is_causal", (False, True)) + @pytest.mark.parametrize("seq_len", (1024, 2048,)) + @pytest.mark.parametrize("head", (32,)) + @pytest.mark.parametrize("batch", (4,)) + @pytest.mark.parametrize("headdim", (128,)) + @pytest.mark.parametrize("quant_gran", ("per_warp", "per_thread")) + @pytest.mark.parametrize("pv_accum_dtype", ("fp16", "fp16+fp32", "fp32")) + def test_qk_int8_sv_f16_accum_f16_attn(self, is_causal, seq_len, head, batch, headdim, quant_gran, pv_accum_dtype): + flops = 4 * head * batch * headdim * seq_len * seq_len / (2 if is_causal else 1) + + q = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8, device="cuda") + k = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8, device="cuda") + + vm = torch.randn(batch, head, headdim, dtype=torch.float16, device="cuda") + + WARP_Q = 16 if (headdim == 128 and pv_accum_dtype == "fp16+fp32") else 32 + WARP_K = 64 + + if quant_gran == 'per_warp': + q_scale = torch.randn(batch, head, seq_len // WARP_Q, dtype=torch.float, device="cuda") + k_scale = torch.randn(batch, head, seq_len // WARP_K, dtype=torch.float, device="cuda") + else: + q_scale = torch.randn(batch, head, seq_len // WARP_Q * 8, dtype=torch.float, device="cuda") + k_scale = torch.randn(batch, head, seq_len // WARP_K * 4, dtype=torch.float, device="cuda") + + v = torch.randn(batch, seq_len, head, headdim, dtype=torch.float16, device="cuda") + o = torch.empty(batch, seq_len, head, headdim, dtype=torch.float16, device="cuda") + sm_scale = 1 / (headdim ** 0.5) + _qk_quant_gran = 3 if quant_gran == 'per_thread' else 2 + + kernel = self.get_kernel(pv_accum_dtype) + run_fake_check(kernel)(q, k, v, o, q_scale, k_scale, 0, is_causal, _qk_quant_gran, sm_scale, 0) + + +# @pytest.mark.skipif(not SM89_ENABLED, reason="SM89 not enabled") +# class Test_SM89(_TestTorchCompileBase): +# kernels = { +# "fp32": qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, +# "fp16+fp32": qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, +# "fp16": qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, +# } + + +# @pytest.mark.skipif(not SM90_ENABLED, reason="SM90 not enabled") +# class Test_SM90(_TestTorchCompileBase): +# kernels = { +# "fp16+fp32": qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, +# } \ No newline at end of file From d2808675fd6ae021018422a639420d32b97af0dd Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 25 Jul 2025 04:27:47 +0000 Subject: [PATCH 2/2] Add torch.compile support to SageAttention I'm not so sure about the tests that I added, if it is representative of the common usage of SageAttention --- sageattention/core.py | 155 +++------------------------------- sageattention/sm80_compile.py | 123 +++++++++++++++++++-------- sageattention/sm89_compile.py | 146 ++++++++++++++++++++++++++++++++ sageattention/sm90_compile.py | 94 +++++++++++++++++++++ script.sh | 2 +- setup.py | 2 +- tests/test_torch_compile.py | 139 ++++++++++++++++++------------ 7 files changed, 430 insertions(+), 231 deletions(-) create mode 100644 sageattention/sm89_compile.py create mode 100644 sageattention/sm90_compile.py diff --git a/sageattention/core.py b/sageattention/core.py index 8d9ede8..4399801 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -27,19 +27,19 @@ from .triton.quant_per_thread import per_thread_int8 as per_thread_int8_triton try: - from . import _qattn_sm80 + from . import sm80_compile SM80_ENABLED = True except: SM80_ENABLED = False try: - from . import _qattn_sm89 + from . import sm89_compile SM89_ENABLED = True except: SM89_ENABLED = False try: - from . import _qattn_sm90 + from . import sm90_compile SM90_ENABLED = True except: SM90_ENABLED = False @@ -76,135 +76,6 @@ def get_cuda_arch_versions(): return cuda_archs -@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f16_attn", mutates_args=(), device_types="cuda") -def qk_int8_sv_f16_accum_f16_attn( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - query_scale: torch.Tensor, - key_scale: torch.Tensor, - tensor_layout: int, - is_causal: int, - qk_quant_gran: int, - sm_scale: float, - return_lse: int, -) -> torch.Tensor: - """ - Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation. - """ - return _qattn_sm80.qk_int8_sv_f16_accum_f16_attn( - query, key, value, output, query_scale, key_scale, tensor_layout, - is_causal, qk_quant_gran, sm_scale, return_lse - ) - - -@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f32_attn", mutates_args=(), device_types="cuda") -def qk_int8_sv_f16_accum_f32_attn( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - query_scale: torch.Tensor, - key_scale: torch.Tensor, - tensor_layout: int, - is_causal: int, - qk_quant_gran: int, - sm_scale: float, - return_lse: int, -) -> torch.Tensor: - """ - Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP32 accumulation. - """ - return _qattn_sm80.qk_int8_sv_f16_accum_f32_attn( - query, key, value, output, query_scale, key_scale, tensor_layout, - is_causal, qk_quant_gran, sm_scale, return_lse - ) - - -@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f16_attn_inst_buf", mutates_args=(), device_types="cuda") -def qk_int8_sv_f16_accum_f16_attn_inst_buf( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - query_scale: torch.Tensor, - key_scale: torch.Tensor, - tensor_layout: int, - is_causal: int, - qk_quant_gran: int, - sm_scale: float, - return_lse: int, -) -> torch.Tensor: - """ - Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation. - """ - return _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf( - query, key, value, output, query_scale, key_scale, tensor_layout, - is_causal, qk_quant_gran, sm_scale, return_lse - ) - - -def qk_int8_sv_f16_accum_attn_fake( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - query_scale: torch.Tensor, - key_scale: torch.Tensor, - tensor_layout: int, - is_causal: int, - qk_quant_gran: int, - sm_scale: float, - return_lse: int, -) -> torch.Tensor: - - batch_size = query.size(0) - num_qo_heads = query.size(2) - - if tensor_layout == 0: - qo_len = query.size(1) - else: - qo_len = query.size(2) - - if return_lse: - lse = torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device="cuda") - else: - lse = torch.empty((0)) - return lse - - -torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_attn")(qk_int8_sv_f16_accum_attn_fake) -torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f32_attn")(qk_int8_sv_f16_accum_attn_fake) -torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_attn_inst_buf")(qk_int8_sv_f16_accum_attn_fake) - - -# @torch.library.custom_op("sageattention::qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn", mutates_args=(), device_types="cuda") -# def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( -# query: torch.Tensor, -# key: torch.Tensor, -# value: torch.Tensor, -# output: torch.Tensor, -# query_scale: torch.Tensor, -# key_scale: torch.Tensor, -# value_scale: torch.Tensor, -# value_mean: torch.Tensor, -# tensor_layout: int, -# is_causal: int, -# qk_quant_gran: int, -# sm_scale: float, -# return_lse: int, -# ) -> torch.Tensor: -# """ -# Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation. -# """ -# return _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( -# query, key, value, output, query_scale, key_scale, value_scale, -# value_mean, tensor_layout, is_causal, qk_quant_gran, sm_scale, -# return_lse -# ) - - def sageattn( q: torch.Tensor, k: torch.Tensor, @@ -698,17 +569,17 @@ def sageattn_qk_int8_pv_fp16_cuda( if pv_accum_dtype == 'fp32': v = v.to(torch.float16) - lse = qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm80_compile.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp16": if smooth_v: smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) - lse = qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm80_compile.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) else: v = v.to(torch.float16) - lse = qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm80_compile.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp16+fp32": v = v.to(torch.float16) - lse = qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm80_compile.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) else: raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") @@ -888,13 +759,13 @@ def sageattn_qk_int8_pv_fp8_cuda( if pv_accum_dtype == "fp32": if smooth_v: - lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) else: - lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp32+fp32": - lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp32+fp16": - lse = _qattn_sm89.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm89_compile.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) o = o[..., :head_dim_og] @@ -1053,9 +924,9 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90( if pv_accum_dtype == "fp32": raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") - lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm90_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp32+fp32": - lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm90_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) o = o[..., :head_dim_og] diff --git a/sageattention/sm80_compile.py b/sageattention/sm80_compile.py index 4f0b910..ac5db6e 100644 --- a/sageattention/sm80_compile.py +++ b/sageattention/sm80_compile.py @@ -4,12 +4,12 @@ @torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f16_attn", mutates_args=(), device_types="cuda") def qk_int8_sv_f16_accum_f16_attn( - q_int8: torch.Tensor, - k_int8: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - q_scale: torch.Tensor, - k_scale: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, tensor_layout: int, is_causal: int, qk_quant_gran: int, @@ -19,17 +19,20 @@ def qk_int8_sv_f16_accum_f16_attn( """ Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation. """ - return _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + return _qattn_sm80.qk_int8_sv_f16_accum_f16_attn( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, qk_quant_gran, sm_scale, return_lse + ) @torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f32_attn", mutates_args=(), device_types="cuda") def qk_int8_sv_f16_accum_f32_attn( - q_int8: torch.Tensor, - k_int8: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - q_scale: torch.Tensor, - k_scale: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, tensor_layout: int, is_causal: int, qk_quant_gran: int, @@ -39,17 +42,20 @@ def qk_int8_sv_f16_accum_f32_attn( """ Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP32 accumulation. """ - return _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + return _qattn_sm80.qk_int8_sv_f16_accum_f32_attn( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, qk_quant_gran, sm_scale, return_lse + ) @torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f16_attn_inst_buf", mutates_args=(), device_types="cuda") def qk_int8_sv_f16_accum_f16_attn_inst_buf( - q_int8: torch.Tensor, - k_int8: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - q_scale: torch.Tensor, - k_scale: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, tensor_layout: int, is_causal: int, qk_quant_gran: int, @@ -59,30 +65,57 @@ def qk_int8_sv_f16_accum_f16_attn_inst_buf( """ Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation. """ - return _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse) + return _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, qk_quant_gran, sm_scale, return_lse + ) -def qk_int8_sv_f16_accum_attn_fake( - q_int8: torch.Tensor, - k_int8: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - q_scale: torch.Tensor, - k_scale: torch.Tensor, +@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f16_fuse_v_mean_attn", mutates_args=(), device_types="cuda") +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_mean: torch.Tensor, tensor_layout: int, is_causal: int, qk_quant_gran: int, sm_scale: float, return_lse: int, ) -> torch.Tensor: + """ + Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation. + """ + return _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse + ) - batch_size = q_int8.size(0) - num_qo_heads = q_int8.size(2) + +def sm80_qk_fake_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + batch_size = query.size(0) if tensor_layout == 0: - qo_len = q_int8.size(1) + num_qo_heads = query.size(2) + qo_len = query.size(1) else: - qo_len = q_int8.size(2) + num_qo_heads = query.size(1) + qo_len = query.size(2) if return_lse: lse = torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device="cuda") @@ -90,7 +123,27 @@ def qk_int8_sv_f16_accum_attn_fake( lse = torch.empty((0)) return lse +torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_attn")(sm80_qk_fake_impl) +torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f32_attn")(sm80_qk_fake_impl) +torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_attn_inst_buf")(sm80_qk_fake_impl) + -torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_attn")(qk_int8_sv_f16_accum_attn_fake) -torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f32_attn")(qk_int8_sv_f16_accum_attn_fake) -torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_attn_inst_buf")(qk_int8_sv_f16_accum_attn_fake) +@torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_fuse_v_mean_attn") +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_mean: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return sm80_qk_fake_impl( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, qk_quant_gran, sm_scale, return_lse + ) \ No newline at end of file diff --git a/sageattention/sm89_compile.py b/sageattention/sm89_compile.py new file mode 100644 index 0000000..42e56e2 --- /dev/null +++ b/sageattention/sm89_compile.py @@ -0,0 +1,146 @@ +from . import _qattn_sm89 +import torch + + +@torch.library.custom_op("sageattention_sm89::qk_int8_sv_f8_accum_f32_fuse_v_scale_attn", mutates_args=(), device_types="cuda") +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse + ) + + + +@torch.library.custom_op("sageattention_sm89::qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf", mutates_args=(), device_types="cuda") +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse + ) + + +@torch.library.custom_op("sageattention_sm89::qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf", mutates_args=(), device_types="cuda") +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return _qattn_sm89.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse + ) + + +def sm89_qk_with_key_value( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + batch_size = query.size(0) + + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + + if return_lse: + lse = torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device="cuda") + else: + lse = torch.empty((0)) + return lse + + +torch.library.register_fake("sageattention_sm89::qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")(sm89_qk_with_key_value) +torch.library.register_fake("sageattention_sm89::qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")(sm89_qk_with_key_value) +torch.library.register_fake("sageattention_sm89::qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")(sm89_qk_with_key_value) + + +@torch.library.custom_op("sageattention_sm89::qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn", mutates_args=(), device_types="cuda") +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + value_mean: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + query, key, value, output, query_scale, key_scale, value_scale, + value_mean, tensor_layout, is_causal, qk_quant_gran, sm_scale, + return_lse + ) + + +@torch.library.register_fake("sageattention_sm89::qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn") +def sm89_qk_with_key_value_mean( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + value_mean: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return sm89_qk_with_key_value( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse + ) diff --git a/sageattention/sm90_compile.py b/sageattention/sm90_compile.py new file mode 100644 index 0000000..60847c0 --- /dev/null +++ b/sageattention/sm90_compile.py @@ -0,0 +1,94 @@ +from . import _qattn_sm90 +import torch + + +@torch.library.custom_op("sageattention_sm90::qk_int8_sv_f8_accum_f32_attn_inst_buf", mutates_args=(), device_types="cuda") +def qk_int8_sv_f8_accum_f32_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return _qattn_sm90.qk_int8_sv_f8_accum_f32_attn_inst_buf( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, qk_quant_gran, sm_scale, return_lse + ) + + +@torch.library.register_fake("sageattention_sm90::qk_int8_sv_f8_accum_f32_attn_inst_buf") +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + batch_size = query.size(0) + + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + + if return_lse: + lse = torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device="cuda") + else: + lse = torch.empty((0)) + return lse + + +@torch.library.custom_op("sageattention_sm90::qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf", mutates_args=(), device_types="cuda") +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse + ) + + +@torch.library.register_fake("sageattention_sm90::qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf") +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return qk_int8_sv_f8_accum_f32_attn_inst_buf_fake_impl( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, qk_quant_gran, sm_scale, return_lse + ) diff --git a/script.sh b/script.sh index 491b856..c90017a 100644 --- a/script.sh +++ b/script.sh @@ -7,6 +7,6 @@ set -e python setup.py install ( cd tests - python -m pytest --tb=line -x + python -m pytest --tb=short -rs -sv -x -k SM89 ) ) \ No newline at end of file diff --git a/setup.py b/setup.py index 5e4779d..81084ed 100644 --- a/setup.py +++ b/setup.py @@ -132,7 +132,7 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: ) ext_modules.append(qattn_extension) -if HAS_SM89 or HAS_SM120: +if HAS_SM89 or HAS_SM90 or HAS_SM120: qattn_extension = CUDAExtension( name="sageattention._qattn_sm89", sources=[ diff --git a/tests/test_torch_compile.py b/tests/test_torch_compile.py index 06731f4..81ded3b 100644 --- a/tests/test_torch_compile.py +++ b/tests/test_torch_compile.py @@ -2,20 +2,15 @@ import torch from torch.testing._internal.optests import fake_check -import sageattention._qattn_sm80 as qattn from sageattention.core import ( SM80_ENABLED, SM89_ENABLED, SM90_ENABLED, - qk_int8_sv_f16_accum_f32_attn, - qk_int8_sv_f16_accum_f16_attn_inst_buf, - qk_int8_sv_f16_accum_f16_attn, - # qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, - # qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, - # qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, + sageattn_qk_int8_pv_fp16_cuda, + sageattn_qk_int8_pv_fp8_cuda, + sageattn_qk_int8_pv_fp8_cuda_sm90, ) - def run_fake_check(fn): def wrapper(*args, **kwargs): fake_check(fn, args, kwargs) @@ -23,61 +18,101 @@ def wrapper(*args, **kwargs): @pytest.mark.skipif(not SM80_ENABLED, reason="SM80 not enabled") -class Test_SM80: - kernels = { - "fp32": qk_int8_sv_f16_accum_f32_attn, - "fp16+fp32": qk_int8_sv_f16_accum_f16_attn_inst_buf, - "fp16": qk_int8_sv_f16_accum_f16_attn - } - +class TestSM80: def get_kernel(self, pv_accum_dtype): - return self.kernels[pv_accum_dtype] + return sageattn_qk_int8_pv_fp16_cuda @pytest.mark.parametrize("is_causal", (False, True)) - @pytest.mark.parametrize("seq_len", (1024, 2048,)) + @pytest.mark.parametrize("seq_len", (64, 128)) @pytest.mark.parametrize("head", (32,)) @pytest.mark.parametrize("batch", (4,)) - @pytest.mark.parametrize("headdim", (128,)) + @pytest.mark.parametrize("headdim", (32, 64)) @pytest.mark.parametrize("quant_gran", ("per_warp", "per_thread")) @pytest.mark.parametrize("pv_accum_dtype", ("fp16", "fp16+fp32", "fp32")) - def test_qk_int8_sv_f16_accum_f16_attn(self, is_causal, seq_len, head, batch, headdim, quant_gran, pv_accum_dtype): - flops = 4 * head * batch * headdim * seq_len * seq_len / (2 if is_causal else 1) - - q = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8, device="cuda") - k = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8, device="cuda") - - vm = torch.randn(batch, head, headdim, dtype=torch.float16, device="cuda") - - WARP_Q = 16 if (headdim == 128 and pv_accum_dtype == "fp16+fp32") else 32 - WARP_K = 64 - - if quant_gran == 'per_warp': - q_scale = torch.randn(batch, head, seq_len // WARP_Q, dtype=torch.float, device="cuda") - k_scale = torch.randn(batch, head, seq_len // WARP_K, dtype=torch.float, device="cuda") - else: - q_scale = torch.randn(batch, head, seq_len // WARP_Q * 8, dtype=torch.float, device="cuda") - k_scale = torch.randn(batch, head, seq_len // WARP_K * 4, dtype=torch.float, device="cuda") - - v = torch.randn(batch, seq_len, head, headdim, dtype=torch.float16, device="cuda") - o = torch.empty(batch, seq_len, head, headdim, dtype=torch.float16, device="cuda") + @pytest.mark.parametrize("tensor_layout", ("NHD", "HND")) + @pytest.mark.parametrize("smooth_k", (False, True)) + @pytest.mark.parametrize("smooth_v", (False, True)) + @pytest.mark.parametrize("return_lse", (False, True)) + @pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16)) + def test_SM80(self, is_causal, seq_len, head, batch, headdim, quant_gran, pv_accum_dtype, tensor_layout, smooth_k, smooth_v, return_lse, dtype): + q = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=dtype, device="cuda") + k = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=dtype, device="cuda") + + v = torch.randn(batch, seq_len, head, headdim, dtype=dtype, device="cuda") sm_scale = 1 / (headdim ** 0.5) - _qk_quant_gran = 3 if quant_gran == 'per_thread' else 2 kernel = self.get_kernel(pv_accum_dtype) - run_fake_check(kernel)(q, k, v, o, q_scale, k_scale, 0, is_causal, _qk_quant_gran, sm_scale, 0) + run_fake_check(kernel)(q, k, v, tensor_layout, is_causal, quant_gran, + sm_scale, pv_accum_dtype, smooth_k, smooth_v, + return_lse) + + +@pytest.mark.skipif(not SM89_ENABLED, reason="SM89 not enabled") +class TestSM89: + + def get_kernel(self): + return sageattn_qk_int8_pv_fp8_cuda + + @pytest.mark.parametrize("is_causal", (False, True)) + @pytest.mark.parametrize("seq_len", (64, 128)) + @pytest.mark.parametrize("head", (32,)) + @pytest.mark.parametrize("batch", (4,)) + @pytest.mark.parametrize("headdim", (32, 64)) + @pytest.mark.parametrize("quant_gran", ("per_warp", "per_thread")) + @pytest.mark.parametrize("pv_accum_dtype", ("fp32+fp32", "fp32+fp16", "fp32")) + @pytest.mark.parametrize("tensor_layout", ("NHD", "HND")) + @pytest.mark.parametrize("smooth_k", (False, True)) + @pytest.mark.parametrize("smooth_v", (False, True)) + @pytest.mark.parametrize("return_lse", (False, True)) + @pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16)) + def test_kernels(self, is_causal, seq_len, head, batch, headdim, quant_gran, pv_accum_dtype, tensor_layout, smooth_k, smooth_v, return_lse, dtype): + kernel = self.get_kernel() + + + if tensor_layout == "HND": + q = torch.randint(-128, 127, (batch, head, seq_len, headdim), dtype=dtype, device="cuda") + k = torch.randint(-128, 127, (batch, head, seq_len, headdim), dtype=dtype, device="cuda") + v = torch.randn(batch, head, seq_len, headdim, dtype=dtype, device="cuda") + else: # NHD + q = torch.randint(-128, 127, (batch, seq_len, head, headdim), dtype=dtype, device="cuda") + k = torch.randint(-128, 127, (batch, seq_len, head, headdim), dtype=dtype, device="cuda") + v = torch.randn(batch, seq_len, head, headdim, dtype=dtype, device="cuda") + + sm_scale = 1.0 / (headdim ** 0.5) + + run_fake_check(kernel)(q, k, v, tensor_layout, is_causal, quant_gran, + sm_scale, pv_accum_dtype, smooth_k, smooth_v, + return_lse) + + +@pytest.mark.skipif(not SM90_ENABLED, reason="SM90 not enabled") +class TestSM90: + def get_kernel(self): + return sageattn_qk_int8_pv_fp8_cuda_sm90 + + @pytest.mark.parametrize("is_causal", (False, True)) + @pytest.mark.parametrize("seq_len", (64, 128)) + @pytest.mark.parametrize("head", (32,)) + @pytest.mark.parametrize("batch", (4,)) + @pytest.mark.parametrize("headdim", (32, 64)) + @pytest.mark.parametrize("quant_gran", ("per_warp", "per_thread")) + @pytest.mark.parametrize("pv_accum_dtype", ("fp32+fp32",)) + @pytest.mark.parametrize("tensor_layout", ("NHD", "HND")) + @pytest.mark.parametrize("smooth_k", (False, True)) + @pytest.mark.parametrize("return_lse", (False, True)) + @pytest.mark.parametrize("dtype", (torch.float16, torch.bfloat16)) + def test_kernels(self, is_causal, seq_len, head, batch, headdim, quant_gran, pv_accum_dtype, tensor_layout, smooth_k, return_lse, dtype): + kernel = self.get_kernel() + q = torch.randint(-128, 127, (batch, seq_len, head, headdim), dtype=dtype, device="cuda") + k = torch.randint(-128, 127, (batch, seq_len, head, headdim), dtype=dtype, device="cuda") -# @pytest.mark.skipif(not SM89_ENABLED, reason="SM89 not enabled") -# class Test_SM89(_TestTorchCompileBase): -# kernels = { -# "fp32": qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, -# "fp16+fp32": qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, -# "fp16": qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, -# } + if tensor_layout == "HND": + v = torch.randn(batch, head, seq_len, headdim, dtype=dtype, device="cuda") + else: # NHD + v = torch.randn(batch, seq_len, head, headdim, dtype=dtype, device="cuda") + sm_scale = 1.0 / (headdim ** 0.5) -# @pytest.mark.skipif(not SM90_ENABLED, reason="SM90 not enabled") -# class Test_SM90(_TestTorchCompileBase): -# kernels = { -# "fp16+fp32": qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, -# } \ No newline at end of file + run_fake_check(kernel)(q, k, v, tensor_layout, is_causal, quant_gran, + sm_scale, pv_accum_dtype, smooth_k, return_lse)