Skip to content

Commit 5c059bf

Browse files
committed
Store NVFP4 block scales in swwizzled layout on tensor
stack-info: PR: #2438, branch: drisspg/stack/80
1 parent faf788a commit 5c059bf

File tree

4 files changed

+313
-20
lines changed

4 files changed

+313
-20
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,196 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
657657
assert x.t().dtype == x_reconstructed_t.dtype, (
658658
f"Transpose dtype mismatch: {x.t().dtype} vs {x_reconstructed_t.dtype}"
659659
)
660+
661+
662+
@pytest.mark.parametrize(
663+
"shape",
664+
[
665+
(128, 4),
666+
(256, 8),
667+
(100, 3),
668+
(4, 4),
669+
(50, 10),
670+
(384, 12),
671+
],
672+
)
673+
@pytest.mark.parametrize("use_triton_kernel", [False, True])
674+
@pytest.mark.skipif(
675+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
676+
)
677+
def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
678+
"""
679+
Test that to_blocked and from_blocked are proper inverses of each other
680+
for various input shapes that may require padding.
681+
"""
682+
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
683+
684+
rows, cols = shape
685+
686+
# Use CUDA if available, otherwise CPU
687+
device = "cuda" if torch.cuda.is_available() else "cpu"
688+
689+
# Test with random data
690+
original = torch.randn(rows, cols, device=device, dtype=torch.float32)
691+
692+
# Test both triton and PyTorch implementations
693+
# Only test triton if we have torch 2.8+ and triton available
694+
blocked = to_blocked(original, use_triton_kernel=use_triton_kernel)
695+
reconstructed = from_blocked(blocked, rows, cols)
696+
697+
torch.testing.assert_close(
698+
original,
699+
reconstructed,
700+
atol=1e-6,
701+
rtol=1e-6,
702+
msg=f"Roundtrip failed for shape {shape} with use_triton_kernel={use_triton_kernel}",
703+
)
704+
705+
ones = torch.ones(rows, cols, device=device, dtype=torch.float32)
706+
blocked_ones = to_blocked(ones, use_triton_kernel=False)
707+
reconstructed_ones = from_blocked(blocked_ones, rows, cols)
708+
torch.testing.assert_close(ones, reconstructed_ones, atol=1e-6, rtol=1e-6)
709+
710+
711+
@pytest.mark.parametrize("store_swizzled", [False, True])
712+
@pytest.mark.parametrize(
713+
"shape",
714+
[
715+
(32, 64),
716+
(16, 32),
717+
(64, 128),
718+
(384, 128),
719+
],
720+
)
721+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
722+
def test_nvfp4_swizzled_scales_construction(store_swizzled, shape):
723+
"""
724+
Test that NVFP4Tensor can be constructed with swizzled scales and
725+
that the _swizzled_scales flag is set correctly.
726+
"""
727+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
728+
729+
M, K = shape
730+
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
731+
732+
# Create tensor with specified swizzled storage
733+
tensor = NVFP4Tensor.to_nvfp4(data, store_swizzled=store_swizzled)
734+
735+
# Verify the flag is set correctly
736+
assert tensor._swizzled_scales == store_swizzled
737+
738+
# Verify the tensor can be dequantized correctly
739+
reconstructed = tensor.to_dtype(torch.bfloat16)
740+
assert reconstructed.shape == data.shape
741+
742+
743+
@pytest.mark.parametrize(
744+
"slice_dim,slice_spec",
745+
[
746+
pytest.param(0, slice(0, 16), id="slice_rows[0:16]"),
747+
pytest.param(0, slice(8, 24), id="slice_rows[8:24]"),
748+
pytest.param(1, slice(0, 32), id="slice_cols[0:32]"),
749+
pytest.param(1, slice(16, 48), id="slice_cols[16:48]"),
750+
],
751+
)
752+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
753+
def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
754+
"""
755+
Test that slicing works correctly with swizzled scales and maintains
756+
the swizzled state in the output tensor.
757+
"""
758+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
759+
760+
M, K = 32, 64
761+
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
762+
763+
# Create tensor with swizzled scales
764+
tensor = NVFP4Tensor.to_nvfp4(data, store_swizzled=True)
765+
assert tensor._swizzled_scales == True
766+
767+
# Perform slice operation
768+
if slice_dim == 0:
769+
sliced_tensor = tensor[slice_spec, :]
770+
else:
771+
sliced_tensor = tensor[:, slice_spec]
772+
773+
# Verify sliced tensor maintains swizzled state
774+
assert sliced_tensor._swizzled_scales == True
775+
776+
# Verify sliced tensor can be dequantized
777+
sliced_reconstructed = sliced_tensor.to_dtype(torch.bfloat16)
778+
779+
# Compare with direct slicing of original data
780+
original_reconstructed = tensor.to_dtype(torch.bfloat16)
781+
if slice_dim == 0:
782+
expected = original_reconstructed[slice_spec, :]
783+
else:
784+
expected = original_reconstructed[:, slice_spec]
785+
786+
torch.testing.assert_close(sliced_reconstructed, expected, atol=1e-6, rtol=1e-6)
787+
788+
789+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
790+
def test_nvfp4_swizzled_scales_serialization():
791+
"""
792+
Test that tensor flatten/unflatten preserves the swizzled scales state.
793+
"""
794+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
795+
796+
M, K = 32, 64
797+
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
798+
799+
# Create tensor with swizzled scales
800+
original_tensor = NVFP4Tensor.to_nvfp4(data, store_swizzled=True)
801+
802+
# Test serialization
803+
tensor_list, ctx = original_tensor.__tensor_flatten__()
804+
805+
# Verify swizzled flag is preserved in context
806+
assert "_swizzled_scales" in ctx
807+
assert ctx["_swizzled_scales"] == True
808+
809+
# Test deserialization
810+
inner_tensors = {}
811+
for name in tensor_list:
812+
inner_tensors[name] = getattr(original_tensor, name)
813+
814+
reconstructed_tensor = NVFP4Tensor.__tensor_unflatten__(
815+
inner_tensors, ctx, None, None
816+
)
817+
818+
# Verify the swizzled state is preserved
819+
assert reconstructed_tensor._swizzled_scales == True
820+
821+
# Verify functionality is preserved
822+
original_dq = original_tensor.to_dtype(torch.bfloat16)
823+
reconstructed_dq = reconstructed_tensor.to_dtype(torch.bfloat16)
824+
825+
torch.testing.assert_close(original_dq, reconstructed_dq, atol=1e-6, rtol=1e-6)
826+
827+
828+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
829+
def test_nvfp4_swizzled_scales_get_scales_method():
830+
"""
831+
Test that the get_scales() method correctly unswizzles scales when needed.
832+
"""
833+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
834+
835+
M, K = 32, 64
836+
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
837+
838+
# Create tensors with both storage methods
839+
regular_tensor = NVFP4Tensor.to_nvfp4(data, store_swizzled=False)
840+
swizzled_tensor = NVFP4Tensor.to_nvfp4(data, store_swizzled=True)
841+
842+
# Get scales from both tensors
843+
regular_scales = regular_tensor.get_hp_scales()
844+
swizzled_scales = swizzled_tensor.get_hp_scales()
845+
846+
# Scales should be equivalent (within quantization error)
847+
torch.testing.assert_close(regular_scales, swizzled_scales, atol=1e-6, rtol=1e-6)
848+
849+
# Verify scales have the expected shape
850+
expected_shape = (M, K // 16)
851+
assert regular_scales.shape == expected_shape
852+
assert swizzled_scales.shape == expected_shape

torchao/prototype/mx_formats/mx_subclass.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,11 @@ def _nvfp4_inference_linear_transform(
184184

185185
weight = module.weight
186186

187+
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
188+
raise RuntimeError(
189+
f"NVFP4 only supports weight shape divisible by 16, got {weight.shape}"
190+
)
191+
187192
if module.bias is not None and weight.dtype == torch.float32:
188193
raise RuntimeError(
189194
"Bias is not supported when module weight is in fp32 (out_dtype=Float32). "
@@ -193,8 +198,8 @@ def _nvfp4_inference_linear_transform(
193198
quantized_weight = NVFP4Tensor.to_nvfp4(
194199
weight,
195200
mm_config=config.mm_config,
201+
store_swizzled=True,
196202
)
197-
198203
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
199204
module.extra_repr = types.MethodType(_linear_extra_repr, module)
200205
return module

0 commit comments

Comments
 (0)