Skip to content

Commit 70f2b85

Browse files
authored
Enhance test_autoquant_compile to support ROCm (#2100)
* Enhance test_autoquant_compile to support ROCm and improve device checks - Added checks for ROCm availability alongside CUDA. - Improved device capability checks for CUDA to ensure compatibility with bfloat16 and specific tensor shapes. - Updated skip conditions for unsupported devices and older PyTorch versions. * lint * Refactor device checks in test_autoquant_compile for improved clarity - Simplified the logic for checking supported devices by consolidating CUDA and ROCm checks. - Enhanced readability of the device capability validation for CUDA. - Updated skip conditions for unsupported devices to ensure accurate test execution. * Fix formatting in device check condition for consistency in test_autoquant_compile * Refactor device check formatting in test_autoquant_compile for improved readability - Adjusted the formatting of the device check condition to enhance clarity. - Consolidated the logic for checking supported devices while maintaining functionality.
1 parent 9452640 commit 70f2b85

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

test/integration/test_integration.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,15 +1602,27 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
16021602
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.")
16031603
def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
16041604
undo_recommended_configs()
1605-
if device != "cuda" or not torch.cuda.is_available():
1605+
1606+
is_supported_device = device == "cuda" and (
1607+
torch.cuda.is_available() or torch.version.hip is not None
1608+
)
1609+
1610+
if not is_supported_device:
16061611
self.skipTest(f"autoquant currently does not support {device}")
1607-
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
1608-
if dtype == torch.bfloat16:
1609-
self.skipTest("bfloat16 requires sm80+")
1610-
if m1 == 1 or m2 == 1:
1611-
self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+")
1612-
# This test fails on v0.4.0 and torch 2.4, so skipping for now.
1613-
if m1 == 1 or m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5:
1612+
1613+
# Check CUDA-specific requirements if running on CUDA
1614+
if (
1615+
is_supported_device and torch.version.hip is None
1616+
): # Only apply to CUDA, not ROCm
1617+
device_capability = torch.cuda.get_device_capability()
1618+
if device_capability < (8, 0):
1619+
if dtype == torch.bfloat16:
1620+
self.skipTest("bfloat16 requires sm80+")
1621+
if m1 == 1 or m2 == 1:
1622+
self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+")
1623+
1624+
# Skip certain shapes on older PyTorch versions
1625+
if (m1 == 1 or m2 == 1) and not TORCH_VERSION_AT_LEAST_2_5:
16141626
self.skipTest(f"Shape {(m1, m2, k, n)} requires torch version > 2.4")
16151627
model = (
16161628
torch.nn.Sequential(

0 commit comments

Comments
 (0)