Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bench/bench_qk_int8_pv_fp16_cuda.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch.testing._internal.optests import fake_check
from flash_attn.utils.benchmark import benchmark_forward

import sageattention._qattn_sm80 as qattn
Expand Down
43 changes: 23 additions & 20 deletions sageattention/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@
from .triton.quant_per_thread import per_thread_int8 as per_thread_int8_triton

try:
from . import _qattn_sm80
from . import sm80_compile
SM80_ENABLED = True
except:
SM80_ENABLED = False

try:
from . import _qattn_sm89
from . import sm89_compile
SM89_ENABLED = True
except:
SM89_ENABLED = False

try:
from . import _qattn_sm90
from . import sm90_compile
SM90_ENABLED = True
except:
SM90_ENABLED = False
Expand All @@ -52,9 +52,10 @@
from typing import Any, List, Literal, Optional, Tuple, Union
import warnings


import subprocess
import re


def get_cuda_version():
try:
output = subprocess.check_output(['nvcc', '--version']).decode()
Expand All @@ -66,13 +67,15 @@ def get_cuda_version():
print("Failed to get CUDA version:", e)
return None, None


def get_cuda_arch_versions():
cuda_archs = []
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
cuda_archs.append(f"sm{major}{minor}")
return cuda_archs


def sageattn(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -151,7 +154,7 @@ def sageattn(
else:
raise ValueError(f"Unsupported CUDA architecture: {arch}")

@torch.compiler.disable

def sageattn_qk_int8_pv_fp16_triton(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -294,7 +297,7 @@ def sageattn_qk_int8_pv_fp16_triton(
else:
return o

@torch.compiler.disable

def sageattn_varlen(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -411,7 +414,7 @@ def sageattn_varlen(

return o

@torch.compiler.disable

def sageattn_qk_int8_pv_fp16_cuda(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -566,17 +569,17 @@ def sageattn_qk_int8_pv_fp16_cuda(

if pv_accum_dtype == 'fp32':
v = v.to(torch.float16)
lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
lse = sm80_compile.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
elif pv_accum_dtype == "fp16":
if smooth_v:
smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
lse = sm80_compile.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
else:
v = v.to(torch.float16)
lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
lse = sm80_compile.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
elif pv_accum_dtype == "fp16+fp32":
v = v.to(torch.float16)
lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
lse = sm80_compile.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
else:
raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}")

Expand All @@ -587,7 +590,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
else:
return o

@torch.compiler.disable

def sageattn_qk_int8_pv_fp8_cuda(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -756,13 +759,13 @@ def sageattn_qk_int8_pv_fp8_cuda(

if pv_accum_dtype == "fp32":
if smooth_v:
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
else:
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
elif pv_accum_dtype == "fp32+fp32":
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
elif pv_accum_dtype == "fp32+fp16":
lse = _qattn_sm89.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
lse = sm89_compile.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)

o = o[..., :head_dim_og]

Expand All @@ -771,7 +774,7 @@ def sageattn_qk_int8_pv_fp8_cuda(
else:
return o

@torch.compiler.disable

def sageattn_qk_int8_pv_fp8_cuda_sm90(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -921,13 +924,13 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(

if pv_accum_dtype == "fp32":
raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.")
lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
lse = sm90_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
elif pv_accum_dtype == "fp32+fp32":
lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
lse = sm90_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)

o = o[..., :head_dim_og]

if return_lse:
return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
else:
return o
return o
149 changes: 149 additions & 0 deletions sageattention/sm80_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from . import _qattn_sm80
import torch


@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f16_attn", mutates_args=(), device_types="cuda")
def qk_int8_sv_f16_accum_f16_attn(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
query_scale: torch.Tensor,
key_scale: torch.Tensor,
tensor_layout: int,
is_causal: int,
qk_quant_gran: int,
sm_scale: float,
return_lse: int,
) -> torch.Tensor:
"""
Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation.
"""
return _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(
query, key, value, output, query_scale, key_scale, tensor_layout,
is_causal, qk_quant_gran, sm_scale, return_lse
)


@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f32_attn", mutates_args=(), device_types="cuda")
def qk_int8_sv_f16_accum_f32_attn(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
query_scale: torch.Tensor,
key_scale: torch.Tensor,
tensor_layout: int,
is_causal: int,
qk_quant_gran: int,
sm_scale: float,
return_lse: int,
) -> torch.Tensor:
"""
Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP32 accumulation.
"""
return _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(
query, key, value, output, query_scale, key_scale, tensor_layout,
is_causal, qk_quant_gran, sm_scale, return_lse
)


@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f16_attn_inst_buf", mutates_args=(), device_types="cuda")
def qk_int8_sv_f16_accum_f16_attn_inst_buf(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
query_scale: torch.Tensor,
key_scale: torch.Tensor,
tensor_layout: int,
is_causal: int,
qk_quant_gran: int,
sm_scale: float,
return_lse: int,
) -> torch.Tensor:
"""
Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation.
"""
return _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(
query, key, value, output, query_scale, key_scale, tensor_layout,
is_causal, qk_quant_gran, sm_scale, return_lse
)


@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f16_fuse_v_mean_attn", mutates_args=(), device_types="cuda")
def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
query_scale: torch.Tensor,
key_scale: torch.Tensor,
value_mean: torch.Tensor,
tensor_layout: int,
is_causal: int,
qk_quant_gran: int,
sm_scale: float,
return_lse: int,
) -> torch.Tensor:
"""
Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation.
"""
return _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
query, key, value, output, query_scale, key_scale, value_mean,
tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse
)


def sm80_qk_fake_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
query_scale: torch.Tensor,
key_scale: torch.Tensor,
tensor_layout: int,
is_causal: int,
qk_quant_gran: int,
sm_scale: float,
return_lse: int,
) -> torch.Tensor:
batch_size = query.size(0)

if tensor_layout == 0:
num_qo_heads = query.size(2)
qo_len = query.size(1)
else:
num_qo_heads = query.size(1)
qo_len = query.size(2)

if return_lse:
lse = torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device="cuda")
else:
lse = torch.empty((0))
return lse

torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_attn")(sm80_qk_fake_impl)
torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f32_attn")(sm80_qk_fake_impl)
torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_attn_inst_buf")(sm80_qk_fake_impl)


@torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_fuse_v_mean_attn")
def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
query_scale: torch.Tensor,
key_scale: torch.Tensor,
value_mean: torch.Tensor,
tensor_layout: int,
is_causal: int,
qk_quant_gran: int,
sm_scale: float,
return_lse: int,
) -> torch.Tensor:
return sm80_qk_fake_impl(
query, key, value, output, query_scale, key_scale, tensor_layout,
is_causal, qk_quant_gran, sm_scale, return_lse
)
Loading