From 5d2ef2f0a2e1db38bb46bb023bf54016bcf7eb2c Mon Sep 17 00:00:00 2001 From: gaclove Date: Thu, 14 Aug 2025 09:45:19 +0000 Subject: [PATCH] Enhance CUDA architecture support in setup.py by allowing user-defined architectures via environment variable. Refactor GPU capability checks and streamline NVCC flags for SM89 and SM90 extensions. Improve build process by creating separate output directories for extensions. --- setup.py | 146 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 100 insertions(+), 46 deletions(-) diff --git a/setup.py b/setup.py index 5e4779dd..54b95154 100644 --- a/setup.py +++ b/setup.py @@ -66,22 +66,56 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: nvcc_cuda_version = parse(output[release_idx].split(",")[0]) return nvcc_cuda_version -# Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures. +def filter_nvcc_flags_for_arch(nvcc_flags, arch_substrings): + """Filter NVCC flags, only keep gencode flags for specified architectures""" + filtered_flags = [] + skip_next = False + for i, flag in enumerate(nvcc_flags): + if skip_next: + skip_next = False + continue + if flag == "-gencode": + if i + 1 < len(nvcc_flags): + arch_flag = nvcc_flags[i + 1] + if any(sub in arch_flag for sub in arch_substrings): + filtered_flags.append(flag) + filtered_flags.append(arch_flag) + skip_next = True + elif flag not in ["-gencode"]: + filtered_flags.append(flag) + return filtered_flags + compute_capabilities = set() -device_count = torch.cuda.device_count() -for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 8: - warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}") - continue - compute_capabilities.add(f"{major}.{minor}") +cuda_architectures = os.environ.get("CUDA_ARCHITECTURES") +if cuda_architectures is not None: + for arch in cuda_architectures.split(","): + arch = arch.strip() + if arch: + compute_capabilities.add(arch) +else: + #Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures. + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 8: + warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}") + continue + compute_capabilities.add(f"{major}.{minor}") -nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) if not compute_capabilities: raise RuntimeError("No GPUs found. Please specify the target GPU architectures or build on a machine with GPUs.") else: + + unsupported_archs = compute_capabilities - SUPPORTED_ARCHS + if unsupported_archs: + warnings.warn(f"Unsupported GPU architectures detected: {unsupported_archs}. Supported architectures: {SUPPORTED_ARCHS}") + compute_capabilities = compute_capabilities & SUPPORTED_ARCHS + if not compute_capabilities: + raise RuntimeError(f"No supported GPU architectures found. Detected: {compute_capabilities | unsupported_archs}, Supported: {SUPPORTED_ARCHS}") + print(f"Detect GPUs with compute capabilities: {compute_capabilities}") +nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) # Validate the NVCC CUDA version. if nvcc_cuda_version < Version("12.0"): raise RuntimeError("CUDA 12.0 or higher is required to build the package.") @@ -119,54 +153,66 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: ext_modules = [] if HAS_SM80 or HAS_SM86 or HAS_SM89 or HAS_SM90 or HAS_SM120: - qattn_extension = CUDAExtension( + sm80_sources = [ + "csrc/qattn/pybind_sm80.cpp", + "csrc/qattn/qk_int_sv_f16_cuda_sm80.cu", + ] + + qattn_extension_sm80 = CUDAExtension( name="sageattention._qattn_sm80", - sources=[ - "csrc/qattn/pybind_sm80.cpp", - "csrc/qattn/qk_int_sv_f16_cuda_sm80.cu", - ], + sources=sm80_sources, extra_compile_args={ "cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS, }, ) - ext_modules.append(qattn_extension) + ext_modules.append(qattn_extension_sm80) if HAS_SM89 or HAS_SM120: - qattn_extension = CUDAExtension( + sm89_sources = [ + "csrc/qattn/pybind_sm89.cpp", + "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu", + "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu", + "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu", + "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu", + "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu", + "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu", + "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu" + #"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu", + ] + + arch_substrings = ["sm_89", "compute_89", "sm_90a", "compute_90a", "sm_120", "compute_120"] + filtered_flags = filter_nvcc_flags_for_arch(NVCC_FLAGS, arch_substrings) + + qattn_extension_sm89 = CUDAExtension( name="sageattention._qattn_sm89", - sources=[ - "csrc/qattn/pybind_sm89.cpp", - "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu", - "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu", - "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu", - "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu", - "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu", - "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu", - "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu" - #"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu", - ], + sources=sm89_sources, extra_compile_args={ "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, + "nvcc": filtered_flags if filtered_flags else NVCC_FLAGS, }, ) - ext_modules.append(qattn_extension) + ext_modules.append(qattn_extension_sm89) if HAS_SM90: - qattn_extension = CUDAExtension( + sm90_sources = [ + "csrc/qattn/pybind_sm90.cpp", + "csrc/qattn/qk_int_sv_f8_cuda_sm90.cu", + ] + + arch_substrings = ["sm_90a", "compute_90a"] + filtered_flags = filter_nvcc_flags_for_arch(NVCC_FLAGS, arch_substrings) + + qattn_extension_sm90 = CUDAExtension( name="sageattention._qattn_sm90", - sources=[ - "csrc/qattn/pybind_sm90.cpp", - "csrc/qattn/qk_int_sv_f8_cuda_sm90.cu", - ], + sources=sm90_sources, extra_compile_args={ "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, + "nvcc": filtered_flags if filtered_flags else NVCC_FLAGS, }, extra_link_args=['-lcuda'], ) - ext_modules.append(qattn_extension) + ext_modules.append(qattn_extension_sm90) # Fused kernels. fused_extension = CUDAExtension( @@ -208,24 +254,32 @@ def compile_new(*args, **kwargs): **kwargs, "output_dir": os.path.join( kwargs["output_dir"], - self.thread_ext_name_map[threading.current_thread().ident]), + self.thread_ext_name_map.get(threading.current_thread().ident, f"thread_{threading.current_thread().ident}")), }) self.compiler.compile = compile_new self.compiler._compile_separate_output_dir = True self.thread_ext_name_map[threading.current_thread().ident] = ext.name - objects = super().build_extension(ext) - return objects + original_build_temp = self.build_temp + self.build_temp = os.path.join(original_build_temp, ext.name.replace(".", "_")) + os.makedirs(self.build_temp, exist_ok=True) + + try: + objects = super().build_extension(ext) + finally: + self.build_temp = original_build_temp + + return objects setup( - name='sageattention', - version='2.2.0', + name='sageattention', + version='2.2.0', author='SageAttention team', - license='Apache 2.0 License', - description='Accurate and efficient plug-and-play low-bit attention.', - long_description=open('README.md', encoding='utf-8').read(), - long_description_content_type='text/markdown', - url='https://github.com/thu-ml/SageAttention', + license='Apache 2.0 License', + description='Accurate and efficient plug-and-play low-bit attention.', + long_description=open('README.md', encoding='utf-8').read(), + long_description_content_type='text/markdown', + url='https://github.com/thu-ml/SageAttention', packages=find_packages(), python_requires='>=3.9', ext_modules=ext_modules,