|
25 | 25 | from torch._inductor.test_case import TestCase as InductorTestCase
|
26 | 26 | from torch.testing._internal import common_utils
|
27 | 27 |
|
28 |
| -from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl |
| 28 | +from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale |
29 | 29 | from torchao.float8.float8_utils import compute_error
|
30 | 30 | from torchao.quantization import (
|
31 | 31 | Float8DynamicActivationFloat8WeightConfig,
|
@@ -630,6 +630,51 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
|
630 | 630 | error = compute_error(ref_output, quant_output)
|
631 | 631 | self.assertGreater(error, 15, f"Quantization SQNR too low: {error}")
|
632 | 632 |
|
| 633 | + def test_preprocess_scale_3d_reshape(self): |
| 634 | + """Test that preprocess_scale correctly handles 3D scale tensors""" |
| 635 | + device = "cpu" # Use CPU for basic functionality test |
| 636 | + |
| 637 | + # Test 1: PerTensor scale (scalar) - should reshape to (1, 1) |
| 638 | + per_tensor_scale = torch.tensor(0.5, device=device) |
| 639 | + result = preprocess_scale(per_tensor_scale, (2, 4, 8)) |
| 640 | + expected_shape = (1, 1) |
| 641 | + self.assertEqual(result.shape, expected_shape) |
| 642 | + self.assertEqual(result.item(), 0.5) |
| 643 | + |
| 644 | + # Test 2: 1D scale tensor with one element - should reshape to (1, 1) |
| 645 | + one_element_scale = torch.tensor([0.3], device=device) |
| 646 | + result = preprocess_scale(one_element_scale, (2, 4, 8)) |
| 647 | + expected_shape = (1, 1) |
| 648 | + self.assertEqual(result.shape, expected_shape) |
| 649 | + self.assertEqual(result.item(), 0.3) |
| 650 | + |
| 651 | + # Test 3: 3D scale tensor for per-row quantization - should flatten first N-1 dims |
| 652 | + # This is the key test for the 3D reshape fix |
| 653 | + scale_3d = torch.randn( |
| 654 | + 2, 4, device=device |
| 655 | + ) # Shape matches first 2 dims of (2, 4, 8) |
| 656 | + result = preprocess_scale(scale_3d, (2, 4, 8)) |
| 657 | + expected_shape = (8, 1) # Flattened (2*4, 1) |
| 658 | + self.assertEqual(result.shape, expected_shape) |
| 659 | + |
| 660 | + # Verify the values are preserved correctly |
| 661 | + expected_values = scale_3d.flatten().unsqueeze(-1) |
| 662 | + self.assertTrue(torch.allclose(result, expected_values)) |
| 663 | + |
| 664 | + # Test 4: 2D scale tensor (already correct shape) - should just add last dimension |
| 665 | + scale_2d = torch.randn(8, device=device) |
| 666 | + result = preprocess_scale(scale_2d, (8, 16)) |
| 667 | + expected_shape = (8, 1) |
| 668 | + self.assertEqual(result.shape, expected_shape) |
| 669 | + |
| 670 | + # Test 5: Edge case with higher dimensions (4D) |
| 671 | + scale_4d = torch.randn( |
| 672 | + 2, 2, 2, device=device |
| 673 | + ) # Shape matches first 3 dims of (2, 2, 2, 8) |
| 674 | + result = preprocess_scale(scale_4d, (2, 2, 2, 8)) |
| 675 | + expected_shape = (8, 1) # Flattened (2*2*2, 1) |
| 676 | + self.assertEqual(result.shape, expected_shape) |
| 677 | + |
633 | 678 |
|
634 | 679 | common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
|
635 | 680 |
|
|
0 commit comments