diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index d5b197c8ce38..13a27b15e0f2 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -194,9 +194,7 @@ def get_gencode_flags(): if cc >= 80 and nvcc_version >= Version("12.4"): os.environ.pop('PADDLE_CUDA_ARCH_LIST', None) nvcc_compile_args += [ - "-std=c++17", "--use_fast_math", - "--threads=8", "-D_GLIBCXX_USE_CXX11_ABI=1", ] sources += ["./gpu/sage_attn_kernels/sageattn_fused.cu"] @@ -235,7 +233,7 @@ def get_gencode_flags(): ext_modules=CUDAExtension( sources=sources, extra_compile_args={ - "cxx": ["-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"], + "cxx": ["-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16", "--threads=8"], "nvcc": nvcc_compile_args, }, libraries=["cublasLt"],