diff --git a/setup.py b/setup.py index 5e4779dd..74e0d17f 100644 --- a/setup.py +++ b/setup.py @@ -66,19 +66,28 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: nvcc_cuda_version = parse(output[release_idx].split(",")[0]) return nvcc_cuda_version -# Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures. +# Check for environment variable to specify architectures for GPU-less builds compute_capabilities = set() -device_count = torch.cuda.device_count() -for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 8: - warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}") - continue - compute_capabilities.add(f"{major}.{minor}") + +if os.environ.get('TORCH_CUDA_ARCH_LIST'): + # Parse architectures from environment variable + arch_list = os.environ['TORCH_CUDA_ARCH_LIST'].replace(' ', '').split(';') + for arch in arch_list: + compute_capabilities.add(arch) + print(f"Using TORCH_CUDA_ARCH_LIST: {compute_capabilities}") +else: + # Iterate over all GPUs on the current machine + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 8: + warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}") + continue + compute_capabilities.add(f"{major}.{minor}") nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) if not compute_capabilities: - raise RuntimeError("No GPUs found. Please specify the target GPU architectures or build on a machine with GPUs.") + raise RuntimeError("No GPUs found. Please specify the target GPU architectures via TORCH_CUDA_ARCH_LIST environment variable or build on a machine with GPUs.") else: print(f"Detect GPUs with compute capabilities: {compute_capabilities}")