@@ -66,15 +66,22 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
66
66
nvcc_cuda_version = parse (output [release_idx ].split ("," )[0 ])
67
67
return nvcc_cuda_version
68
68
69
- # Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
70
69
compute_capabilities = set ()
71
- device_count = torch .cuda .device_count ()
72
- for i in range (device_count ):
73
- major , minor = torch .cuda .get_device_capability (i )
74
- if major < 8 :
75
- warnings .warn (f"skipping GPU { i } with compute capability { major } .{ minor } " )
76
- continue
77
- compute_capabilities .add (f"{ major } .{ minor } " )
70
+ cuda_architectures = os .environ .get ("CUDA_ARCHITECTURES" )
71
+ if cuda_architectures is not None :
72
+ for arch in cuda_architectures .split ("," ):
73
+ arch = arch .strip ()
74
+ if arch :
75
+ compute_capabilities .add (arch )
76
+ else :
77
+ #Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
78
+ device_count = torch .cuda .device_count ()
79
+ for i in range (device_count ):
80
+ major , minor = torch .cuda .get_device_capability (i )
81
+ if major < 8 :
82
+ warnings .warn (f"skipping GPU { i } with compute capability { major } .{ minor } " )
83
+ continue
84
+ compute_capabilities .add (f"{ major } .{ minor } " )
78
85
79
86
nvcc_cuda_version = get_nvcc_cuda_version (CUDA_HOME )
80
87
if not compute_capabilities :
@@ -119,54 +126,96 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
119
126
ext_modules = []
120
127
121
128
if HAS_SM80 or HAS_SM86 or HAS_SM89 or HAS_SM90 or HAS_SM120 :
122
- qattn_extension = CUDAExtension (
129
+ sm80_sources = [
130
+ "csrc/qattn/pybind_sm80.cpp" ,
131
+ "csrc/qattn/qk_int_sv_f16_cuda_sm80.cu" ,
132
+ ]
133
+
134
+ qattn_extension_sm80 = CUDAExtension (
123
135
name = "sageattention._qattn_sm80" ,
124
- sources = [
125
- "csrc/qattn/pybind_sm80.cpp" ,
126
- "csrc/qattn/qk_int_sv_f16_cuda_sm80.cu" ,
127
- ],
136
+ sources = sm80_sources ,
128
137
extra_compile_args = {
129
138
"cxx" : CXX_FLAGS ,
130
139
"nvcc" : NVCC_FLAGS ,
131
140
},
132
141
)
133
- ext_modules .append (qattn_extension )
142
+ ext_modules .append (qattn_extension_sm80 )
134
143
135
144
if HAS_SM89 or HAS_SM120 :
136
- qattn_extension = CUDAExtension (
145
+ sm89_sources = [
146
+ "csrc/qattn/pybind_sm89.cpp" ,
147
+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu" ,
148
+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu" ,
149
+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu" ,
150
+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu" ,
151
+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu" ,
152
+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu" ,
153
+ "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu"
154
+ #"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
155
+ ]
156
+
157
+ sm89_nvcc_flags = [flag for flag in NVCC_FLAGS ]
158
+
159
+ filtered_flags = []
160
+ skip_next = False
161
+ for i , flag in enumerate (sm89_nvcc_flags ):
162
+ if skip_next :
163
+ skip_next = False
164
+ continue
165
+ if flag == "-gencode" :
166
+ if i + 1 < len (sm89_nvcc_flags ):
167
+ arch_flag = sm89_nvcc_flags [i + 1 ]
168
+ if "compute_89" in arch_flag or "compute_90" in arch_flag or "compute_120" in arch_flag :
169
+ filtered_flags .append (flag )
170
+ filtered_flags .append (arch_flag )
171
+ skip_next = True
172
+ elif flag not in ["-gencode" ]:
173
+ filtered_flags .append (flag )
174
+
175
+ qattn_extension_sm89 = CUDAExtension (
137
176
name = "sageattention._qattn_sm89" ,
138
- sources = [
139
- "csrc/qattn/pybind_sm89.cpp" ,
140
- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu" ,
141
- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu" ,
142
- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu" ,
143
- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu" ,
144
- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu" ,
145
- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu" ,
146
- "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu"
147
- #"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
148
- ],
177
+ sources = sm89_sources ,
149
178
extra_compile_args = {
150
179
"cxx" : CXX_FLAGS ,
151
- "nvcc" : NVCC_FLAGS ,
180
+ "nvcc" : filtered_flags if filtered_flags else NVCC_FLAGS ,
152
181
},
153
182
)
154
- ext_modules .append (qattn_extension )
183
+ ext_modules .append (qattn_extension_sm89 )
155
184
156
185
if HAS_SM90 :
157
- qattn_extension = CUDAExtension (
186
+ sm90_sources = [
187
+ "csrc/qattn/pybind_sm90.cpp" ,
188
+ "csrc/qattn/qk_int_sv_f8_cuda_sm90.cu" ,
189
+ ]
190
+
191
+ sm90_nvcc_flags = [flag for flag in NVCC_FLAGS ]
192
+
193
+ filtered_flags = []
194
+ skip_next = False
195
+ for i , flag in enumerate (sm90_nvcc_flags ):
196
+ if skip_next :
197
+ skip_next = False
198
+ continue
199
+ if flag == "-gencode" :
200
+ if i + 1 < len (sm90_nvcc_flags ):
201
+ arch_flag = sm90_nvcc_flags [i + 1 ]
202
+ if "compute_90" in arch_flag or "compute_120" in arch_flag :
203
+ filtered_flags .append (flag )
204
+ filtered_flags .append (arch_flag )
205
+ skip_next = True
206
+ elif flag not in ["-gencode" ]:
207
+ filtered_flags .append (flag )
208
+
209
+ qattn_extension_sm90 = CUDAExtension (
158
210
name = "sageattention._qattn_sm90" ,
159
- sources = [
160
- "csrc/qattn/pybind_sm90.cpp" ,
161
- "csrc/qattn/qk_int_sv_f8_cuda_sm90.cu" ,
162
- ],
211
+ sources = sm90_sources ,
163
212
extra_compile_args = {
164
213
"cxx" : CXX_FLAGS ,
165
- "nvcc" : NVCC_FLAGS ,
214
+ "nvcc" : filtered_flags if filtered_flags else NVCC_FLAGS ,
166
215
},
167
216
extra_link_args = ['-lcuda' ],
168
217
)
169
- ext_modules .append (qattn_extension )
218
+ ext_modules .append (qattn_extension_sm90 )
170
219
171
220
# Fused kernels.
172
221
fused_extension = CUDAExtension (
@@ -208,15 +257,23 @@ def compile_new(*args, **kwargs):
208
257
** kwargs ,
209
258
"output_dir" : os .path .join (
210
259
kwargs ["output_dir" ],
211
- self .thread_ext_name_map [ threading .current_thread ().ident ] ),
260
+ self .thread_ext_name_map . get ( threading .current_thread ().ident , "default" ) ),
212
261
})
213
262
self .compiler .compile = compile_new
214
263
self .compiler ._compile_separate_output_dir = True
215
264
self .thread_ext_name_map [threading .current_thread ().ident ] = ext .name
216
- objects = super ().build_extension (ext )
265
+
266
+ original_build_temp = self .build_temp
267
+ self .build_temp = os .path .join (original_build_temp , ext .name .replace ("." , "_" ))
268
+ os .makedirs (self .build_temp , exist_ok = True )
269
+
270
+ try :
271
+ objects = super ().build_extension (ext )
272
+ finally :
273
+ self .build_temp = original_build_temp
274
+
217
275
return objects
218
276
219
-
220
277
setup (
221
278
name = 'sageattention' ,
222
279
version = '2.2.0' ,
0 commit comments