Skip to content

Commit a8daff3

Browse files
committed
Enhance CUDA architecture support in setup.py by allowing user-defined architectures via environment variable. Refactor GPU capability checks and streamline NVCC flags for SM89 and SM90 extensions. Improve build process by creating separate output directories for extensions.
1 parent 798c791 commit a8daff3

File tree

1 file changed

+95
-38
lines changed

1 file changed

+95
-38
lines changed

setup.py

Lines changed: 95 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,22 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
6666
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
6767
return nvcc_cuda_version
6868

69-
# Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
7069
compute_capabilities = set()
71-
device_count = torch.cuda.device_count()
72-
for i in range(device_count):
73-
major, minor = torch.cuda.get_device_capability(i)
74-
if major < 8:
75-
warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}")
76-
continue
77-
compute_capabilities.add(f"{major}.{minor}")
70+
cuda_architectures = os.environ.get("CUDA_ARCHITECTURES")
71+
if cuda_architectures is not None:
72+
for arch in cuda_architectures.split(","):
73+
arch = arch.strip()
74+
if arch:
75+
compute_capabilities.add(arch)
76+
else:
77+
#Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
78+
device_count = torch.cuda.device_count()
79+
for i in range(device_count):
80+
major, minor = torch.cuda.get_device_capability(i)
81+
if major < 8:
82+
warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}")
83+
continue
84+
compute_capabilities.add(f"{major}.{minor}")
7885

7986
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
8087
if not compute_capabilities:
@@ -119,54 +126,96 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
119126
ext_modules = []
120127

121128
if HAS_SM80 or HAS_SM86 or HAS_SM89 or HAS_SM90 or HAS_SM120:
122-
qattn_extension = CUDAExtension(
129+
sm80_sources = [
130+
"csrc/qattn/pybind_sm80.cpp",
131+
"csrc/qattn/qk_int_sv_f16_cuda_sm80.cu",
132+
]
133+
134+
qattn_extension_sm80 = CUDAExtension(
123135
name="sageattention._qattn_sm80",
124-
sources=[
125-
"csrc/qattn/pybind_sm80.cpp",
126-
"csrc/qattn/qk_int_sv_f16_cuda_sm80.cu",
127-
],
136+
sources=sm80_sources,
128137
extra_compile_args={
129138
"cxx": CXX_FLAGS,
130139
"nvcc": NVCC_FLAGS,
131140
},
132141
)
133-
ext_modules.append(qattn_extension)
142+
ext_modules.append(qattn_extension_sm80)
134143

135144
if HAS_SM89 or HAS_SM120:
136-
qattn_extension = CUDAExtension(
145+
sm89_sources = [
146+
"csrc/qattn/pybind_sm89.cpp",
147+
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu",
148+
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu",
149+
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu",
150+
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu",
151+
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu",
152+
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu",
153+
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu"
154+
#"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
155+
]
156+
157+
sm89_nvcc_flags = [flag for flag in NVCC_FLAGS]
158+
159+
filtered_flags = []
160+
skip_next = False
161+
for i, flag in enumerate(sm89_nvcc_flags):
162+
if skip_next:
163+
skip_next = False
164+
continue
165+
if flag == "-gencode":
166+
if i + 1 < len(sm89_nvcc_flags):
167+
arch_flag = sm89_nvcc_flags[i + 1]
168+
if "compute_89" in arch_flag or "compute_90" in arch_flag or "compute_120" in arch_flag:
169+
filtered_flags.append(flag)
170+
filtered_flags.append(arch_flag)
171+
skip_next = True
172+
elif flag not in ["-gencode"]:
173+
filtered_flags.append(flag)
174+
175+
qattn_extension_sm89 = CUDAExtension(
137176
name="sageattention._qattn_sm89",
138-
sources=[
139-
"csrc/qattn/pybind_sm89.cpp",
140-
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu",
141-
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu",
142-
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu",
143-
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu",
144-
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu",
145-
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu",
146-
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu"
147-
#"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
148-
],
177+
sources=sm89_sources,
149178
extra_compile_args={
150179
"cxx": CXX_FLAGS,
151-
"nvcc": NVCC_FLAGS,
180+
"nvcc": filtered_flags if filtered_flags else NVCC_FLAGS,
152181
},
153182
)
154-
ext_modules.append(qattn_extension)
183+
ext_modules.append(qattn_extension_sm89)
155184

156185
if HAS_SM90:
157-
qattn_extension = CUDAExtension(
186+
sm90_sources = [
187+
"csrc/qattn/pybind_sm90.cpp",
188+
"csrc/qattn/qk_int_sv_f8_cuda_sm90.cu",
189+
]
190+
191+
sm90_nvcc_flags = [flag for flag in NVCC_FLAGS]
192+
193+
filtered_flags = []
194+
skip_next = False
195+
for i, flag in enumerate(sm90_nvcc_flags):
196+
if skip_next:
197+
skip_next = False
198+
continue
199+
if flag == "-gencode":
200+
if i + 1 < len(sm90_nvcc_flags):
201+
arch_flag = sm90_nvcc_flags[i + 1]
202+
if "compute_90" in arch_flag or "compute_120" in arch_flag:
203+
filtered_flags.append(flag)
204+
filtered_flags.append(arch_flag)
205+
skip_next = True
206+
elif flag not in ["-gencode"]:
207+
filtered_flags.append(flag)
208+
209+
qattn_extension_sm90 = CUDAExtension(
158210
name="sageattention._qattn_sm90",
159-
sources=[
160-
"csrc/qattn/pybind_sm90.cpp",
161-
"csrc/qattn/qk_int_sv_f8_cuda_sm90.cu",
162-
],
211+
sources=sm90_sources,
163212
extra_compile_args={
164213
"cxx": CXX_FLAGS,
165-
"nvcc": NVCC_FLAGS,
214+
"nvcc": filtered_flags if filtered_flags else NVCC_FLAGS,
166215
},
167216
extra_link_args=['-lcuda'],
168217
)
169-
ext_modules.append(qattn_extension)
218+
ext_modules.append(qattn_extension_sm90)
170219

171220
# Fused kernels.
172221
fused_extension = CUDAExtension(
@@ -208,15 +257,23 @@ def compile_new(*args, **kwargs):
208257
**kwargs,
209258
"output_dir": os.path.join(
210259
kwargs["output_dir"],
211-
self.thread_ext_name_map[threading.current_thread().ident]),
260+
self.thread_ext_name_map.get(threading.current_thread().ident, "default")),
212261
})
213262
self.compiler.compile = compile_new
214263
self.compiler._compile_separate_output_dir = True
215264
self.thread_ext_name_map[threading.current_thread().ident] = ext.name
216-
objects = super().build_extension(ext)
265+
266+
original_build_temp = self.build_temp
267+
self.build_temp = os.path.join(original_build_temp, ext.name.replace(".", "_"))
268+
os.makedirs(self.build_temp, exist_ok=True)
269+
270+
try:
271+
objects = super().build_extension(ext)
272+
finally:
273+
self.build_temp = original_build_temp
274+
217275
return objects
218276

219-
220277
setup(
221278
name='sageattention',
222279
version='2.2.0',

0 commit comments

Comments
 (0)