Skip to content

Commit 370702a

Browse files
authored
update setup_cuda.py (#10493)
1 parent d3ee14f commit 370702a

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

csrc/setup_cuda.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,24 @@
1515
import os
1616
import shutil
1717
import subprocess
18+
from packaging.version import parse, Version
1819

1920
import paddle
2021
from paddle.utils.cpp_extension import CUDAExtension, setup
2122

2223
sm_version = int(os.getenv("CUDA_SM_VERSION", "0"))
2324

25+
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
26+
"""Get the CUDA version from nvcc.
27+
28+
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
29+
"""
30+
nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
31+
universal_newlines=True)
32+
output = nvcc_output.split()
33+
release_idx = output.index("release") + 1
34+
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
35+
return nvcc_cuda_version
2436

2537
def update_git_submodule():
2638
try:
@@ -153,6 +165,7 @@ def get_gencode_flags():
153165
]
154166
cc = get_sm_version()
155167
cuda_version = float(paddle.version.cuda())
168+
nvcc_version = get_nvcc_cuda_version(os.environ.get("CUDA_HOME", "/usr/local/cuda"))
156169

157170
if cc >= 80:
158171
sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu"]
@@ -178,7 +191,8 @@ def get_gencode_flags():
178191
"gpu/fp8_gemm_with_cutlass/fp8_fp8_fp8_dual_gemm.cu",
179192
]
180193

181-
if cc >= 80 and cuda_version >= 12.4:
194+
if cc >= 80 and nvcc_version >= Version("12.4"):
195+
os.environ.pop('PADDLE_CUDA_ARCH_LIST', None)
182196
nvcc_compile_args += [
183197
"-std=c++17",
184198
"--use_fast_math",

0 commit comments

Comments
 (0)