-
Notifications
You must be signed in to change notification settings - Fork 206
resolve multi-CUDA_ARCHITECTURES compilation conflicts #241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR resolves CUDA compilation conflicts when building for multiple GPU architectures by isolating build artifacts and adding validation for CUDA_ARCHITECTURES environment variable.
- Enables building for multiple CUDA architectures by reading from CUDA_ARCHITECTURES environment variable
- Isolates build artifacts per architecture in separate subdirectories to prevent conflicts
- Filters NVCC flags to only include relevant architecture-specific flags for each extension
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
if cuda_architectures is not None: | ||
for arch in cuda_architectures.split(","): | ||
arch = arch.strip() | ||
if arch: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code doesn't validate that each architecture value is a valid decimal format as mentioned in the PR description. Consider adding validation to ensure each arch value matches the expected pattern (e.g., regex check for decimal format).
if arch: | |
if arch: | |
if not re.match(r"^\d+\.\d+$", arch): | |
raise ValueError(f"Invalid architecture value '{arch}' in CUDA_ARCHITECTURES. Expected decimal format like '8.0', '8.6', etc.") |
Copilot uses AI. Check for mistakes.
skip_next = True | ||
elif flag not in ["-gencode"]: | ||
filtered_flags.append(flag) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The flag filtering logic is duplicated between sm89 and sm90 extensions. Consider extracting this into a helper function to reduce code duplication and improve maintainability.
sm89_arch_list = ["sm_89", "compute_89", "sm_90a", "compute_90a", "sm_120", "compute_120"] | |
filtered_flags = filter_nvcc_flags(NVCC_FLAGS, sm89_arch_list) |
Copilot uses AI. Check for mistakes.
setup.py
Outdated
filtered_flags.append(arch_flag) | ||
skip_next = True | ||
elif flag not in ["-gencode"]: | ||
filtered_flags.append(flag) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is duplicate code from the sm89 extension filtering logic. The same flag filtering pattern should be extracted into a reusable function.
filtered_flags.append(flag) | |
def filter_nvcc_flags_for_arch(nvcc_flags, arch_substrings): | |
filtered_flags = [] | |
skip_next = False | |
for i, flag in enumerate(nvcc_flags): | |
if skip_next: | |
skip_next = False | |
continue | |
if flag == "-gencode": | |
if i + 1 < len(nvcc_flags): | |
arch_flag = nvcc_flags[i + 1] | |
if any(sub in arch_flag for sub in arch_substrings): | |
filtered_flags.append(flag) | |
filtered_flags.append(arch_flag) | |
skip_next = True | |
elif flag not in ["-gencode"]: | |
filtered_flags.append(flag) | |
return filtered_flags | |
filtered_flags = filter_nvcc_flags_for_arch(NVCC_FLAGS, ["sm_90a", "compute_90a"]) |
Copilot uses AI. Check for mistakes.
setup.py
Outdated
@@ -208,24 +259,32 @@ def compile_new(*args, **kwargs): | |||
**kwargs, | |||
"output_dir": os.path.join( | |||
kwargs["output_dir"], | |||
self.thread_ext_name_map[threading.current_thread().ident]), | |||
self.thread_ext_name_map.get(threading.current_thread().ident, "default")), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using 'default' as a fallback directory name could lead to conflicts if multiple threads don't have mapped extension names. Consider using a more unique identifier like thread ID or timestamp.
self.thread_ext_name_map.get(threading.current_thread().ident, "default")), | |
self.thread_ext_name_map.get( | |
threading.current_thread().ident, | |
f"thread_{threading.current_thread().ident}" | |
)), |
Copilot uses AI. Check for mistakes.
testing pip install -v --no-cache-dir .
CUDA_ARCHITECTURES="9.0,12.0" pip install -v --no-cache-dir .
CUDA_ARCHITECTURES="8.9,9.0" pip install -v --no-cache-dir . |
…d architectures via environment variable. Refactor GPU capability checks and streamline NVCC flags for SM89 and SM90 extensions. Improve build process by creating separate output directories for extensions.
fix(build): resolve multi-CUDA_ARCHITECTURES compilation conflicts
build/sm_{arch}/
subdirectories