Skip to content

Add H100 to CI for regression #1792

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ jobs:
torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/cpu'
gpu-arch-type: "cpu"
gpu-arch-version: ""
- name: H100
runs-on: linux.aws.h100
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu124'
gpu-arch-type: "cuda"
gpu-arch-version: "12.4"

permissions:
id-token: write
Expand Down
3 changes: 3 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
TORCH_VERSION_AT_LEAST_2_6,
is_fbcode,
is_sm_at_least_89,
is_sm_at_least_90,
)

is_cusparselt_available = (
Expand Down Expand Up @@ -220,6 +221,8 @@ class TestAffineQuantizedBasic(TestCase):
def test_flatten_unflatten(self, device, dtype):
if device == "cuda" and dtype == torch.bfloat16 and is_fbcode():
raise unittest.SkipTest("TODO: Failing for cuda + bfloat16 in fbcode")
if device == "cuda" and dtype == torch.bfloat16 and is_sm_at_least_90():
raise unittest.SkipTest('TODO: Failing on H100')
apply_quant_list = get_quantization_functions(False, True, device)
for apply_quant in apply_quant_list:
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
Expand Down
25 changes: 19 additions & 6 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
quantize_,
)
from torchao.quantization.granularity import (
Granularity,
PerRow,
PerTensor,
)
Expand Down Expand Up @@ -142,7 +143,11 @@ def test_fp8_linear_variants(
)
def test_invalid_granularity(self):
with pytest.raises(ValueError, match="Invalid granularity specification"):
float8_dynamic_activation_float8_weight(granularity="invalid")
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
quantize_(
model,
float8_dynamic_activation_float8_weight(granularity="invalid")
)

@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
Expand All @@ -152,18 +157,26 @@ def test_mismatched_granularity(self):
ValueError,
match="Different granularities for activation and weight are not supported",
):
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
quantize_(
model,
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))
)

@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_unsupported_granularity(self):
class UnsupportedGranularity:
pass

with pytest.raises(ValueError, match="Invalid granularity types"):
float8_dynamic_activation_float8_weight(
granularity=(UnsupportedGranularity(), UnsupportedGranularity())
with pytest.raises(
ValueError,
match="Invalid granularity types:",
):
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
quantize_(
model,
float8_dynamic_activation_float8_weight(granularity=(UnsupportedGranularity(), UnsupportedGranularity()))
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
2 changes: 2 additions & 0 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
to_nf4,
)
from torchao.testing.utils import skip_if_rocm
from torchao.utils import is_sm_at_least_90

bnb_available = False

Expand Down Expand Up @@ -616,6 +617,7 @@ def world_size(self) -> int:
reason="torch >= 2.4 required",
)
@skip_if_lt_x_gpu(2)
@pytest.mark.skipif(is_sm_at_least_90(), reason="Skipping test on SM90+") # TODO: fix
def test_qlora_fsdp2(self):
from torch.distributed._composable.fsdp import CPUOffloadPolicy, OffloadPolicy

Expand Down
23 changes: 6 additions & 17 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,23 +883,12 @@ def test_autoquantizable_flatten_unflatten(self):
)
@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype):
if dtype != torch.bfloat16:
with self.assertRaisesRegex(
AssertionError, "PerRow quantization only works for bfloat16 precision"
):
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float,
device,
25,
test_dtype=dtype,
)
else:
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float,
device,
25,
test_dtype=dtype,
)
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float,
device,
25,
test_dtype=dtype,
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(
Expand Down
3 changes: 3 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
get_available_devices,
is_sm_at_least_90,
)

try:
Expand Down Expand Up @@ -419,6 +420,7 @@ def world_size(self) -> int:
)
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
@skip_if_rocm("ROCm enablement in progress")
@pytest.mark.skipif(is_sm_at_least_90(), reason="Will need more investigation on H100")
def test_fsdp2(self):
optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit]
if torch.cuda.get_device_capability() >= (8, 9):
Expand Down Expand Up @@ -530,6 +532,7 @@ def _test_fsdp2(self, optim_cls):
)
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
@skip_if_rocm("ROCm enablement in progress")
@pytest.mark.skipif(is_sm_at_least_90(), reason="Will need more investigation on H100") # TODO: investigate why this test fails on H100
def test_uneven_shard(self):
in_dim = 512
out_dim = _FSDP_WORLD_SIZE * 16 + 1
Expand Down
5 changes: 4 additions & 1 deletion test/prototype/test_quantized_training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import skipIf
import pytest

from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6, is_sm_at_least_90

if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Requires torch>=2.4", allow_module_level=True)
Expand Down Expand Up @@ -295,6 +296,7 @@ def world_size(self) -> int:
return _FSDP_WORLD_SIZE

@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
@pytest.mark.skipif(is_sm_at_least_90(), reason="Skipping test on SM90+") # TODO: fix
def test_fsdp2_correctness(self):
mp_policy = MixedPrecisionPolicy()

Expand Down Expand Up @@ -387,6 +389,7 @@ def _run_subtest(self, args):
)

@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
@pytest.mark.skipif(is_sm_at_least_90(), reason="Skipping test on SM90+") # TODO: fix
def test_precompute_bitnet_scale(self):
from torchao.prototype.quantized_training.bitnet import (
get_bitnet_scale,
Expand Down
3 changes: 3 additions & 0 deletions test/prototype/test_smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
is_sm_at_least_90,
)

if torch.version.hip is not None:
Expand Down Expand Up @@ -61,6 +62,7 @@ def forward(self, x):
torch._dynamo.config.cache_size_limit = 128


@pytest.mark.skipif(is_sm_at_least_90(), reason="Does not run on H100") # TODO: fix this test on H100
@pytest.mark.parametrize("bias", bias_list)
@pytest.mark.parametrize("alpha", alpha_list)
@pytest.mark.parametrize("quant_mode", quant_mode_list)
Expand Down Expand Up @@ -136,6 +138,7 @@ def forward(self, x):
assert torch.allclose(out, out_ref.to(idtype), atol=atol)


@pytest.mark.skipif(is_sm_at_least_90(), reason="Does not run on H100") # TODO: fix this test on H100
@pytest.mark.parametrize("alpha", alpha_list)
@pytest.mark.parametrize("quant_mode", quant_mode_list)
@pytest.mark.parametrize("device", devices)
Expand Down
3 changes: 3 additions & 0 deletions test/test_rowwise_scaled_linear_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
rowwise_scaled_linear_cutlass_s8s4,
)
from torchao.quantization.utils import group_quantize_tensor_symmetric
from torchao.utils import is_sm_at_least_89, is_sm_at_least_90

ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
Expand Down Expand Up @@ -84,6 +85,7 @@ def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias)
torch.testing.assert_close(output, output_ref)


@pytest.mark.skipif(is_sm_at_least_90(), reason="Does not run on H100")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
Expand All @@ -94,6 +96,7 @@ def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bia
)


@pytest.mark.skipif(is_sm_at_least_90(), reason="Does not run on H100")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
Expand Down
13 changes: 1 addition & 12 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from importlib.metadata import version
from math import gcd
from typing import Any, Callable, Tuple
import warnings

import torch
import torch.nn.utils.parametrize as parametrize
Expand Down Expand Up @@ -558,18 +559,6 @@ class PlainAQTTensorImpl(...):
get_tensor_impl_constructor = classmethod(_get_tensor_impl_constructor)
_get_to_kwargs = _get_to_kwargs

def __tensor_flatten__(self):
raise NotImplementedError("Subclasses must implement __tensor_flatten__")

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
raise NotImplementedError("Subclasses must implement __tensor_unflatten__")

def __repr__(self):
raise NotImplementedError("Subclasses must implement __repr__")

def get_layout(self):
if not hasattr(self, "_layout"):
return None
Expand Down
Loading