Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 100 additions & 46 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,56 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
return nvcc_cuda_version

# Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
def filter_nvcc_flags_for_arch(nvcc_flags, arch_substrings):
"""Filter NVCC flags, only keep gencode flags for specified architectures"""
filtered_flags = []
skip_next = False
for i, flag in enumerate(nvcc_flags):
if skip_next:
skip_next = False
continue
if flag == "-gencode":
if i + 1 < len(nvcc_flags):
arch_flag = nvcc_flags[i + 1]
if any(sub in arch_flag for sub in arch_substrings):
filtered_flags.append(flag)
filtered_flags.append(arch_flag)
skip_next = True
elif flag not in ["-gencode"]:
filtered_flags.append(flag)
return filtered_flags

compute_capabilities = set()
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
if major < 8:
warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}")
continue
compute_capabilities.add(f"{major}.{minor}")
cuda_architectures = os.environ.get("CUDA_ARCHITECTURES")
if cuda_architectures is not None:
for arch in cuda_architectures.split(","):
arch = arch.strip()
if arch:
Copy link
Preview

Copilot AI Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code doesn't validate that each architecture value is a valid decimal format as mentioned in the PR description. Consider adding validation to ensure each arch value matches the expected pattern (e.g., regex check for decimal format).

Suggested change
if arch:
if arch:
if not re.match(r"^\d+\.\d+$", arch):
raise ValueError(f"Invalid architecture value '{arch}' in CUDA_ARCHITECTURES. Expected decimal format like '8.0', '8.6', etc.")

Copilot uses AI. Check for mistakes.

compute_capabilities.add(arch)
else:
#Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
if major < 8:
warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}")
continue
compute_capabilities.add(f"{major}.{minor}")

nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if not compute_capabilities:
raise RuntimeError("No GPUs found. Please specify the target GPU architectures or build on a machine with GPUs.")
else:

unsupported_archs = compute_capabilities - SUPPORTED_ARCHS
if unsupported_archs:
warnings.warn(f"Unsupported GPU architectures detected: {unsupported_archs}. Supported architectures: {SUPPORTED_ARCHS}")
compute_capabilities = compute_capabilities & SUPPORTED_ARCHS
if not compute_capabilities:
raise RuntimeError(f"No supported GPU architectures found. Detected: {compute_capabilities | unsupported_archs}, Supported: {SUPPORTED_ARCHS}")

print(f"Detect GPUs with compute capabilities: {compute_capabilities}")

nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
# Validate the NVCC CUDA version.
if nvcc_cuda_version < Version("12.0"):
raise RuntimeError("CUDA 12.0 or higher is required to build the package.")
Expand Down Expand Up @@ -119,54 +153,66 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
ext_modules = []

if HAS_SM80 or HAS_SM86 or HAS_SM89 or HAS_SM90 or HAS_SM120:
qattn_extension = CUDAExtension(
sm80_sources = [
"csrc/qattn/pybind_sm80.cpp",
"csrc/qattn/qk_int_sv_f16_cuda_sm80.cu",
]

qattn_extension_sm80 = CUDAExtension(
name="sageattention._qattn_sm80",
sources=[
"csrc/qattn/pybind_sm80.cpp",
"csrc/qattn/qk_int_sv_f16_cuda_sm80.cu",
],
sources=sm80_sources,
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(qattn_extension)
ext_modules.append(qattn_extension_sm80)

if HAS_SM89 or HAS_SM120:
qattn_extension = CUDAExtension(
sm89_sources = [
"csrc/qattn/pybind_sm89.cpp",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu"
#"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
]

arch_substrings = ["sm_89", "compute_89", "sm_90a", "compute_90a", "sm_120", "compute_120"]
filtered_flags = filter_nvcc_flags_for_arch(NVCC_FLAGS, arch_substrings)

Copy link
Preview

Copilot AI Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flag filtering logic is duplicated between sm89 and sm90 extensions. Consider extracting this into a helper function to reduce code duplication and improve maintainability.

Suggested change
sm89_arch_list = ["sm_89", "compute_89", "sm_90a", "compute_90a", "sm_120", "compute_120"]
filtered_flags = filter_nvcc_flags(NVCC_FLAGS, sm89_arch_list)

Copilot uses AI. Check for mistakes.

qattn_extension_sm89 = CUDAExtension(
name="sageattention._qattn_sm89",
sources=[
"csrc/qattn/pybind_sm89.cpp",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu"
#"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
],
sources=sm89_sources,
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
"nvcc": filtered_flags if filtered_flags else NVCC_FLAGS,
},
)
ext_modules.append(qattn_extension)
ext_modules.append(qattn_extension_sm89)

if HAS_SM90:
qattn_extension = CUDAExtension(
sm90_sources = [
"csrc/qattn/pybind_sm90.cpp",
"csrc/qattn/qk_int_sv_f8_cuda_sm90.cu",
]

arch_substrings = ["sm_90a", "compute_90a"]
filtered_flags = filter_nvcc_flags_for_arch(NVCC_FLAGS, arch_substrings)

qattn_extension_sm90 = CUDAExtension(
name="sageattention._qattn_sm90",
sources=[
"csrc/qattn/pybind_sm90.cpp",
"csrc/qattn/qk_int_sv_f8_cuda_sm90.cu",
],
sources=sm90_sources,
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
"nvcc": filtered_flags if filtered_flags else NVCC_FLAGS,
},
extra_link_args=['-lcuda'],
)
ext_modules.append(qattn_extension)
ext_modules.append(qattn_extension_sm90)

# Fused kernels.
fused_extension = CUDAExtension(
Expand Down Expand Up @@ -208,24 +254,32 @@ def compile_new(*args, **kwargs):
**kwargs,
"output_dir": os.path.join(
kwargs["output_dir"],
self.thread_ext_name_map[threading.current_thread().ident]),
self.thread_ext_name_map.get(threading.current_thread().ident, f"thread_{threading.current_thread().ident}")),
})
self.compiler.compile = compile_new
self.compiler._compile_separate_output_dir = True
self.thread_ext_name_map[threading.current_thread().ident] = ext.name
objects = super().build_extension(ext)
return objects

original_build_temp = self.build_temp
self.build_temp = os.path.join(original_build_temp, ext.name.replace(".", "_"))
os.makedirs(self.build_temp, exist_ok=True)

try:
objects = super().build_extension(ext)
finally:
self.build_temp = original_build_temp

return objects

setup(
name='sageattention',
version='2.2.0',
name='sageattention',
version='2.2.0',
author='SageAttention team',
license='Apache 2.0 License',
description='Accurate and efficient plug-and-play low-bit attention.',
long_description=open('README.md', encoding='utf-8').read(),
long_description_content_type='text/markdown',
url='https://github.com/thu-ml/SageAttention',
license='Apache 2.0 License',
description='Accurate and efficient plug-and-play low-bit attention.',
long_description=open('README.md', encoding='utf-8').read(),
long_description_content_type='text/markdown',
url='https://github.com/thu-ml/SageAttention',
packages=find_packages(),
python_requires='>=3.9',
ext_modules=ext_modules,
Expand Down