@@ -657,3 +657,196 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
657
657
assert x .t ().dtype == x_reconstructed_t .dtype , (
658
658
f"Transpose dtype mismatch: { x .t ().dtype } vs { x_reconstructed_t .dtype } "
659
659
)
660
+
661
+
662
+ @pytest .mark .parametrize (
663
+ "shape" ,
664
+ [
665
+ (128 , 4 ),
666
+ (256 , 8 ),
667
+ (100 , 3 ),
668
+ (4 , 4 ),
669
+ (50 , 10 ),
670
+ (384 , 12 ),
671
+ ],
672
+ )
673
+ @pytest .mark .parametrize ("use_triton_kernel" , [False , True ])
674
+ @pytest .mark .skipif (
675
+ not TORCH_VERSION_AT_LEAST_2_8 , reason = "torch.compile requires PyTorch 2.8+"
676
+ )
677
+ def test_to_blocked_from_blocked_roundtrip (shape , use_triton_kernel : bool ):
678
+ """
679
+ Test that to_blocked and from_blocked are proper inverses of each other
680
+ for various input shapes that may require padding.
681
+ """
682
+ from torchao .prototype .mx_formats .utils import from_blocked , to_blocked
683
+
684
+ rows , cols = shape
685
+
686
+ # Use CUDA if available, otherwise CPU
687
+ device = "cuda" if torch .cuda .is_available () else "cpu"
688
+
689
+ # Test with random data
690
+ original = torch .randn (rows , cols , device = device , dtype = torch .float32 )
691
+
692
+ # Test both triton and PyTorch implementations
693
+ # Only test triton if we have torch 2.8+ and triton available
694
+ blocked = to_blocked (original , use_triton_kernel = use_triton_kernel )
695
+ reconstructed = from_blocked (blocked , rows , cols )
696
+
697
+ torch .testing .assert_close (
698
+ original ,
699
+ reconstructed ,
700
+ atol = 1e-6 ,
701
+ rtol = 1e-6 ,
702
+ msg = f"Roundtrip failed for shape { shape } with use_triton_kernel={ use_triton_kernel } " ,
703
+ )
704
+
705
+ ones = torch .ones (rows , cols , device = device , dtype = torch .float32 )
706
+ blocked_ones = to_blocked (ones , use_triton_kernel = False )
707
+ reconstructed_ones = from_blocked (blocked_ones , rows , cols )
708
+ torch .testing .assert_close (ones , reconstructed_ones , atol = 1e-6 , rtol = 1e-6 )
709
+
710
+
711
+ @pytest .mark .parametrize ("store_swizzled" , [False , True ])
712
+ @pytest .mark .parametrize (
713
+ "shape" ,
714
+ [
715
+ (32 , 64 ),
716
+ (16 , 32 ),
717
+ (64 , 128 ),
718
+ (384 , 128 ),
719
+ ],
720
+ )
721
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
722
+ def test_nvfp4_swizzled_scales_construction (store_swizzled , shape ):
723
+ """
724
+ Test that NVFP4Tensor can be constructed with swizzled scales and
725
+ that the _swizzled_scales flag is set correctly.
726
+ """
727
+ from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
728
+
729
+ M , K = shape
730
+ data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
731
+
732
+ # Create tensor with specified swizzled storage
733
+ tensor = NVFP4Tensor .to_nvfp4 (data , store_swizzled = store_swizzled )
734
+
735
+ # Verify the flag is set correctly
736
+ assert tensor ._swizzled_scales == store_swizzled
737
+
738
+ # Verify the tensor can be dequantized correctly
739
+ reconstructed = tensor .to_dtype (torch .bfloat16 )
740
+ assert reconstructed .shape == data .shape
741
+
742
+
743
+ @pytest .mark .parametrize (
744
+ "slice_dim,slice_spec" ,
745
+ [
746
+ pytest .param (0 , slice (0 , 16 ), id = "slice_rows[0:16]" ),
747
+ pytest .param (0 , slice (8 , 24 ), id = "slice_rows[8:24]" ),
748
+ pytest .param (1 , slice (0 , 32 ), id = "slice_cols[0:32]" ),
749
+ pytest .param (1 , slice (16 , 48 ), id = "slice_cols[16:48]" ),
750
+ ],
751
+ )
752
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
753
+ def test_nvfp4_swizzled_scales_slicing (slice_dim , slice_spec ):
754
+ """
755
+ Test that slicing works correctly with swizzled scales and maintains
756
+ the swizzled state in the output tensor.
757
+ """
758
+ from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
759
+
760
+ M , K = 32 , 64
761
+ data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
762
+
763
+ # Create tensor with swizzled scales
764
+ tensor = NVFP4Tensor .to_nvfp4 (data , store_swizzled = True )
765
+ assert tensor ._swizzled_scales == True
766
+
767
+ # Perform slice operation
768
+ if slice_dim == 0 :
769
+ sliced_tensor = tensor [slice_spec , :]
770
+ else :
771
+ sliced_tensor = tensor [:, slice_spec ]
772
+
773
+ # Verify sliced tensor maintains swizzled state
774
+ assert sliced_tensor ._swizzled_scales == True
775
+
776
+ # Verify sliced tensor can be dequantized
777
+ sliced_reconstructed = sliced_tensor .to_dtype (torch .bfloat16 )
778
+
779
+ # Compare with direct slicing of original data
780
+ original_reconstructed = tensor .to_dtype (torch .bfloat16 )
781
+ if slice_dim == 0 :
782
+ expected = original_reconstructed [slice_spec , :]
783
+ else :
784
+ expected = original_reconstructed [:, slice_spec ]
785
+
786
+ torch .testing .assert_close (sliced_reconstructed , expected , atol = 1e-6 , rtol = 1e-6 )
787
+
788
+
789
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
790
+ def test_nvfp4_swizzled_scales_serialization ():
791
+ """
792
+ Test that tensor flatten/unflatten preserves the swizzled scales state.
793
+ """
794
+ from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
795
+
796
+ M , K = 32 , 64
797
+ data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
798
+
799
+ # Create tensor with swizzled scales
800
+ original_tensor = NVFP4Tensor .to_nvfp4 (data , store_swizzled = True )
801
+
802
+ # Test serialization
803
+ tensor_list , ctx = original_tensor .__tensor_flatten__ ()
804
+
805
+ # Verify swizzled flag is preserved in context
806
+ assert "_swizzled_scales" in ctx
807
+ assert ctx ["_swizzled_scales" ] == True
808
+
809
+ # Test deserialization
810
+ inner_tensors = {}
811
+ for name in tensor_list :
812
+ inner_tensors [name ] = getattr (original_tensor , name )
813
+
814
+ reconstructed_tensor = NVFP4Tensor .__tensor_unflatten__ (
815
+ inner_tensors , ctx , None , None
816
+ )
817
+
818
+ # Verify the swizzled state is preserved
819
+ assert reconstructed_tensor ._swizzled_scales == True
820
+
821
+ # Verify functionality is preserved
822
+ original_dq = original_tensor .to_dtype (torch .bfloat16 )
823
+ reconstructed_dq = reconstructed_tensor .to_dtype (torch .bfloat16 )
824
+
825
+ torch .testing .assert_close (original_dq , reconstructed_dq , atol = 1e-6 , rtol = 1e-6 )
826
+
827
+
828
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
829
+ def test_nvfp4_swizzled_scales_get_scales_method ():
830
+ """
831
+ Test that the get_scales() method correctly unswizzles scales when needed.
832
+ """
833
+ from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
834
+
835
+ M , K = 32 , 64
836
+ data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
837
+
838
+ # Create tensors with both storage methods
839
+ regular_tensor = NVFP4Tensor .to_nvfp4 (data , store_swizzled = False )
840
+ swizzled_tensor = NVFP4Tensor .to_nvfp4 (data , store_swizzled = True )
841
+
842
+ # Get scales from both tensors
843
+ regular_scales = regular_tensor .get_hp_scales ()
844
+ swizzled_scales = swizzled_tensor .get_hp_scales ()
845
+
846
+ # Scales should be equivalent (within quantization error)
847
+ torch .testing .assert_close (regular_scales , swizzled_scales , atol = 1e-6 , rtol = 1e-6 )
848
+
849
+ # Verify scales have the expected shape
850
+ expected_shape = (M , K // 16 )
851
+ assert regular_scales .shape == expected_shape
852
+ assert swizzled_scales .shape == expected_shape
0 commit comments