Skip to content

Commit 83663b8

Browse files
authored
Fix Per Tensor 3d rehsape (#2293)
stack-info: PR: #2293, branch: drisspg/stack/64
1 parent 4c06318 commit 83663b8

File tree

3 files changed

+70
-22
lines changed

3 files changed

+70
-22
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from torch._inductor.test_case import TestCase as InductorTestCase
2626
from torch.testing._internal import common_utils
2727

28-
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
28+
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
2929
from torchao.float8.float8_utils import compute_error
3030
from torchao.quantization import (
3131
Float8DynamicActivationFloat8WeightConfig,
@@ -630,6 +630,51 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
630630
error = compute_error(ref_output, quant_output)
631631
self.assertGreater(error, 15, f"Quantization SQNR too low: {error}")
632632

633+
def test_preprocess_scale_3d_reshape(self):
634+
"""Test that preprocess_scale correctly handles 3D scale tensors"""
635+
device = "cpu" # Use CPU for basic functionality test
636+
637+
# Test 1: PerTensor scale (scalar) - should reshape to (1, 1)
638+
per_tensor_scale = torch.tensor(0.5, device=device)
639+
result = preprocess_scale(per_tensor_scale, (2, 4, 8))
640+
expected_shape = (1, 1)
641+
self.assertEqual(result.shape, expected_shape)
642+
self.assertEqual(result.item(), 0.5)
643+
644+
# Test 2: 1D scale tensor with one element - should reshape to (1, 1)
645+
one_element_scale = torch.tensor([0.3], device=device)
646+
result = preprocess_scale(one_element_scale, (2, 4, 8))
647+
expected_shape = (1, 1)
648+
self.assertEqual(result.shape, expected_shape)
649+
self.assertEqual(result.item(), 0.3)
650+
651+
# Test 3: 3D scale tensor for per-row quantization - should flatten first N-1 dims
652+
# This is the key test for the 3D reshape fix
653+
scale_3d = torch.randn(
654+
2, 4, device=device
655+
) # Shape matches first 2 dims of (2, 4, 8)
656+
result = preprocess_scale(scale_3d, (2, 4, 8))
657+
expected_shape = (8, 1) # Flattened (2*4, 1)
658+
self.assertEqual(result.shape, expected_shape)
659+
660+
# Verify the values are preserved correctly
661+
expected_values = scale_3d.flatten().unsqueeze(-1)
662+
self.assertTrue(torch.allclose(result, expected_values))
663+
664+
# Test 4: 2D scale tensor (already correct shape) - should just add last dimension
665+
scale_2d = torch.randn(8, device=device)
666+
result = preprocess_scale(scale_2d, (8, 16))
667+
expected_shape = (8, 1)
668+
self.assertEqual(result.shape, expected_shape)
669+
670+
# Test 5: Edge case with higher dimensions (4D)
671+
scale_4d = torch.randn(
672+
2, 2, 2, device=device
673+
) # Shape matches first 3 dims of (2, 2, 2, 8)
674+
result = preprocess_scale(scale_4d, (2, 2, 2, 8))
675+
expected_shape = (8, 1) # Flattened (2*2*2, 1)
676+
self.assertEqual(result.shape, expected_shape)
677+
633678

634679
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
635680

torchao/dtypes/floatx/float8_layout.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -370,10 +370,18 @@ def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
370370
return check_aqt(input_tensor) and check_aqt(weight_tensor)
371371

372372

373-
def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]):
374-
"""Ensures input tensor is correctly formated for _scaled_mm"""
373+
def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int, ...]):
374+
"""Ensures input tensor is correctly formatted for _scaled_mm"""
375+
376+
# For PerTensor quantization, scale should be a scalar or have shape [1]
377+
if input_scale.numel() == 1:
378+
# Already a scalar, ensure it has the right shape for _scaled_mm
379+
return input_scale.reshape(1, 1)
380+
381+
# For per-row/block quantization, we need to handle the reshaping
375382
input_scale = input_scale.unsqueeze(-1)
376383

384+
# Match: #input_data.reshape(-1, input_data.shape[-1])
377385
if input_scale.dim() > 2:
378386
input_scale = input_scale.reshape(-1, input_scale.shape[-1])
379387

@@ -388,31 +396,28 @@ def _linear_fp8_act_fp8_weight_impl(
388396
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
389397
scaled_mm_config = weight_tensor._layout.mm_config
390398
assert scaled_mm_config is not None
391-
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
399+
assert not weight_tensor.tensor_impl.transposed, "Weight tensor must be contiguous"
392400

393-
# Weight tensor preprocessing
394-
w_tensor_impl = weight_tensor.tensor_impl
395-
assert not w_tensor_impl.transposed, "Weight tensor must be contiguous"
396-
w_data = w_tensor_impl.float8_data
397-
w_scale = w_tensor_impl.scale
401+
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
398402

399-
# Input tensor preprocessing
400-
inpt_data = input_tensor.tensor_impl.float8_data
403+
# Extract tensor data and scales
404+
inpt_data = input_tensor.tensor_impl.float8_data.reshape(
405+
-1, input_tensor.tensor_impl.float8_data.shape[-1]
406+
)
407+
w_data = weight_tensor.tensor_impl.float8_data
401408
input_scale = input_tensor.tensor_impl.scale
402-
# Handle case where input tensor is more than 2D
403-
inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1])
404-
# Handle rowwise case
409+
w_scale = weight_tensor.tensor_impl.scale
410+
411+
# Handle rowwise scaling
405412
if _is_rowwise_scaled(weight_tensor):
406413
assert _is_rowwise_scaled(input_tensor), (
407414
"Input tensor must be rowwise block size"
408415
)
409-
w_scale = w_scale.T
410-
input_scale = preprocess_scale(input_scale, input_tensor.shape)
416+
w_scale = w_scale.transpose(-1, -2)
411417

412-
# Preprocess data
418+
input_scale = preprocess_scale(input_scale, input_tensor.shape)
413419
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
414420

415-
# Perform the computation
416421
return addmm_float8_unwrapped_inference(
417422
inpt_data,
418423
input_scale,

torchao/float8/inference.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,8 @@ def addmm_float8_unwrapped_inference(
9494
out_dtype=output_dtype,
9595
use_fast_accum=use_fast_accum,
9696
)
97-
output += bias
98-
return output
99-
output = torch._scaled_mm(
97+
return output + bias
98+
return torch._scaled_mm(
10099
a_data,
101100
b_data,
102101
scale_a=a_scale,
@@ -106,7 +105,6 @@ def addmm_float8_unwrapped_inference(
106105
out_dtype=output_dtype,
107106
use_fast_accum=use_fast_accum,
108107
)
109-
return output
110108

111109

112110
def _is_rowwise_scaled(x) -> bool:

0 commit comments

Comments
 (0)