15
15
import os
16
16
import shutil
17
17
import subprocess
18
+ from packaging .version import parse , Version
18
19
19
20
import paddle
20
21
from paddle .utils .cpp_extension import CUDAExtension , setup
21
22
22
23
sm_version = int (os .getenv ("CUDA_SM_VERSION" , "0" ))
23
24
25
+ def get_nvcc_cuda_version (cuda_dir : str ) -> Version :
26
+ """Get the CUDA version from nvcc.
27
+
28
+ Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
29
+ """
30
+ nvcc_output = subprocess .check_output ([cuda_dir + "/bin/nvcc" , "-V" ],
31
+ universal_newlines = True )
32
+ output = nvcc_output .split ()
33
+ release_idx = output .index ("release" ) + 1
34
+ nvcc_cuda_version = parse (output [release_idx ].split ("," )[0 ])
35
+ return nvcc_cuda_version
24
36
25
37
def update_git_submodule ():
26
38
try :
@@ -153,6 +165,7 @@ def get_gencode_flags():
153
165
]
154
166
cc = get_sm_version ()
155
167
cuda_version = float (paddle .version .cuda ())
168
+ nvcc_version = get_nvcc_cuda_version (os .environ .get ("CUDA_HOME" , "/usr/local/cuda" ))
156
169
157
170
if cc >= 80 :
158
171
sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu" ]
@@ -178,7 +191,8 @@ def get_gencode_flags():
178
191
"gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu" ,
179
192
]
180
193
181
- if cc >= 80 and cuda_version >= 12.4 :
194
+ if cc >= 80 and nvcc_version >= Version ("12.4" ):
195
+ os .environ .pop ('PADDLE_CUDA_ARCH_LIST' , None )
182
196
nvcc_compile_args += [
183
197
"-std=c++17" ,
184
198
"--use_fast_math" ,
0 commit comments