Skip to content

Commit 02d93d5

Browse files
yyetimfacebook-github-bot
authored andcommitted
Add ability to pad the rowwise quantized tensors (pytorch#4877)
Summary: X-link: facebookresearch/FBGEMM#1899 Some downstream kernels assume a certain width from quantized tensors. This adds the ability to do this as part of the triton fp8 quantize kernel. Reviewed By: RandySheriff Differential Revision: D82486197
1 parent fd32631 commit 02d93d5

File tree

2 files changed

+135
-9
lines changed

2 files changed

+135
-9
lines changed

fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
quantize_fp8_packed_row,
2727
quantize_fp8_packed_row_raw,
2828
quantize_fp8_row,
29+
quantize_fp8_row_meta,
2930
scale_fp8_row,
3031
)
3132

@@ -48,6 +49,8 @@ def _test_quantize_fp8_row(
4849
use_jagged: bool = False,
4950
use_scale_ub: bool = False,
5051
transpose_inputs: bool = False,
52+
align_rows_to: Optional[int] = None,
53+
expected_padded_size: Optional[int] = None, # only set with align_rows_to
5154
) -> None:
5255
a = torch.randn(shape, dtype=torch.bfloat16, device=device)
5356
inputs = [a]
@@ -91,8 +94,23 @@ def _test_quantize_fp8_row(
9194
zero_start_index_M=zero_start_index_M,
9295
use_triton=use_triton,
9396
output_device=output_device,
97+
align_rows_to=align_rows_to,
9498
)
9599

100+
a_fp8_meta, a_scale_meta = quantize_fp8_row_meta(
101+
input_a,
102+
scale_ub=scale_ub,
103+
zero_start_index_M=zero_start_index_M,
104+
use_triton=use_triton,
105+
output_device=output_device,
106+
align_rows_to=align_rows_to,
107+
)
108+
109+
self.assertEqual(a_fp8.dtype, a_fp8_meta.dtype)
110+
self.assertEqual(a_fp8.shape, a_fp8_meta.shape)
111+
self.assertEqual(a_scale.dtype, a_scale_meta.dtype)
112+
self.assertEqual(a_scale.shape, a_scale_meta.shape)
113+
96114
# Undo scaling.
97115
a_torch = a_fp8.to(torch.bfloat16)
98116
broadcast_shape = list(a_torch.shape[:-1]) + [-1]
@@ -101,6 +119,20 @@ def _test_quantize_fp8_row(
101119

102120
a_torch *= a_scale.view(broadcast_shape)
103121

122+
if align_rows_to is not None:
123+
# Pad input_a's row dimension to expected_padded_size if specified.
124+
assert expected_padded_size is not None
125+
pad_rows = expected_padded_size - input_a.shape[-1]
126+
if pad_rows > 0:
127+
pad_shape = list(input_a.shape)
128+
pad_shape[-1] = pad_rows
129+
pad_tensor = torch.zeros(
130+
pad_shape,
131+
dtype=input_a.dtype,
132+
device=input_a.device,
133+
)
134+
input_a = torch.cat([input_a, pad_tensor], dim=-1)
135+
104136
self.assertTrue(
105137
torch.allclose(
106138
input_a.to(device=output_device),
@@ -112,8 +144,50 @@ def _test_quantize_fp8_row(
112144

113145
for n_col in range(1, 9000, 100):
114146
_test_quantize_fp8_row((2, n_col), True, torch.device("cuda"))
147+
148+
# Test with padding. These go up to 9000 (larger than max BLOCK_SIZE)
149+
150+
# Calculate expected_padded_size from align_rows_to=8.
151+
# Using a different math here, just to make tests different from implementation.
152+
align_rows_to = 8
153+
trailing_beyond_alignment = n_col % align_rows_to
154+
padding_size = (
155+
align_rows_to - trailing_beyond_alignment
156+
if trailing_beyond_alignment > 0
157+
else 0
158+
)
159+
expected_padded_size = n_col + padding_size
160+
_test_quantize_fp8_row(
161+
(2, n_col),
162+
True,
163+
torch.device("cuda"),
164+
align_rows_to=align_rows_to,
165+
expected_padded_size=expected_padded_size,
166+
)
167+
115168
# Test with batched input.
116169
_test_quantize_fp8_row((4, 2, 3), True, torch.device("cuda"))
170+
_test_quantize_fp8_row( # simple padding case
171+
(4, 2, 3),
172+
True,
173+
torch.device("cuda"),
174+
align_rows_to=8,
175+
expected_padded_size=8,
176+
)
177+
_test_quantize_fp8_row( # multiple padding case
178+
(4, 2, 13),
179+
True,
180+
torch.device("cuda"),
181+
align_rows_to=8,
182+
expected_padded_size=16,
183+
)
184+
_test_quantize_fp8_row( # 0 padding case
185+
(4, 2, 8),
186+
True,
187+
torch.device("cuda"),
188+
align_rows_to=8,
189+
expected_padded_size=8,
190+
)
117191
_test_quantize_fp8_row((6, 4, 2, 3), True, torch.device("cuda"))
118192
# Test with non-contiguous input
119193
_test_quantize_fp8_row(
@@ -132,6 +206,14 @@ def _test_quantize_fp8_row(
132206
_test_quantize_fp8_row((6, 4, 2, 3), True, torch.device("cpu"))
133207
# Test with zero_start_index_M
134208
_test_quantize_fp8_row((20, 30), True, torch.device("cuda"), use_jagged=True)
209+
_test_quantize_fp8_row(
210+
(20, 30),
211+
True,
212+
torch.device("cuda"),
213+
use_jagged=True,
214+
align_rows_to=16,
215+
expected_padded_size=32,
216+
)
135217
_test_quantize_fp8_row(
136218
(6, 4, 2, 3), True, torch.device("cuda"), use_jagged=True
137219
)

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2328,6 +2328,7 @@ def _kernel_quantize_fp8_row(
23282328
M,
23292329
N,
23302330
K,
2331+
K_fp8, # used when padding
23312332
stride_ab,
23322333
stride_am,
23332334
stride_an,
@@ -2364,7 +2365,8 @@ def _kernel_quantize_fp8_row(
23642365
B (int): Size of dimenion 0
23652366
M (int): Size of dimenion 1
23662367
N (int): Size of dimenion 2
2367-
K (int): Size of dimenion 3
2368+
K (int): Size of dimenion 3 (input row size)
2369+
K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
23682370
stride_ab (int): Stride of b dimension of A.
23692371
stride_am (int): Stride of m dimension of A.
23702372
stride_an (int): Stride of n dimension of A.
@@ -2433,21 +2435,26 @@ def _kernel_quantize_fp8_row(
24332435
tl.store(A_scale + pid, 1.0 / a_scale)
24342436
n_offset = tl.arange(0, BLOCK_SIZE)
24352437

2436-
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
2438+
# Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
2439+
for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)):
2440+
# Load from A if in range, else 0 (we're going all the way to K_fp8)
24372441
a = tl.load(
24382442
A + a_offset_base + n_offset * stride_ak,
24392443
mask=n_offset < K_in,
24402444
other=0.0,
24412445
)
2446+
# For elements >= K, a will be 0
24422447
a_fp8 = a * a_scale
24432448
# Clamp A to fp8 range to make sure there's no overflow.
24442449
# This is required for AMD. Nvidia's default saturation
24452450
# handles it, but it's nice to have anyway.
24462451
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
2452+
2453+
# Store the full new row in its place (for elements >= K, a_fp8 is already 0)
24472454
tl.store(
24482455
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
24492456
a_fp8,
2450-
mask=n_offset < K,
2457+
mask=n_offset < K_fp8,
24512458
)
24522459
n_offset += BLOCK_SIZE
24532460

@@ -2456,6 +2463,7 @@ def triton_quantize_fp8_row(
24562463
a: Tensor,
24572464
scale_ub: Optional[Tensor] = None,
24582465
zero_start_index_M: Optional[Tensor] = None,
2466+
align_rows_to: Optional[int] = None,
24592467
) -> Tuple[Tensor, Tensor]:
24602468
"""
24612469
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
@@ -2464,6 +2472,7 @@ def triton_quantize_fp8_row(
24642472
a (Tensor): higher precision input tensor of 4 dimension.
24652473
scale_ub (Tensor): Maximum allowed value for scale.
24662474
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
2475+
align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
24672476
24682477
Returns:
24692478
torch.Tensor: fp8 scaled tensor.
@@ -2485,7 +2494,18 @@ def triton_quantize_fp8_row(
24852494
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
24862495
num_rows = a.numel() // a.shape[-1]
24872496
a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
2488-
a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
2497+
# If align_rows_to is provided, pad the last dimension to be a multiple of it
2498+
if align_rows_to is not None:
2499+
last_dim = a.shape[-1]
2500+
padded_last_dim = (
2501+
(last_dim + align_rows_to - 1) // align_rows_to
2502+
) * align_rows_to
2503+
a_fp8 = torch.empty(
2504+
(*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype
2505+
)
2506+
a_shape = torch.Size((*a_shape[:-1], padded_last_dim))
2507+
else:
2508+
a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
24892509

24902510
# If input tensor is sufficiently large, we need to use int64 indexing.
24912511
use_int64 = a.numel() > (2**31 - 1)
@@ -2504,6 +2524,7 @@ def triton_quantize_fp8_row(
25042524
a.shape[1],
25052525
a.shape[2],
25062526
a.shape[3],
2527+
a_fp8.shape[3],
25072528
a.stride(0),
25082529
a.stride(1),
25092530
a.stride(2),
@@ -2908,6 +2929,7 @@ def quantize_fp8_row(
29082929
zero_start_index_M: Optional[Tensor] = None,
29092930
use_triton: bool = True,
29102931
output_device: Optional[torch.device] = None,
2932+
align_rows_to: Optional[int] = None,
29112933
) -> Tuple[torch.Tensor, torch.Tensor]:
29122934
"""
29132935
Quantize a to fp8 with row-wise scalings and optionally move to output device.
@@ -2918,6 +2940,7 @@ def quantize_fp8_row(
29182940
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
29192941
use_triton (bool): Whether to use triton kernel or pytorch.
29202942
output_device (torch.device): Device to optionally move the scaled tensors to.
2943+
align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
29212944
29222945
Returns:
29232946
torch.Tensor: fp8 scaled tensor.
@@ -2928,7 +2951,12 @@ def quantize_fp8_row(
29282951
logger.info("Triton does not support cpu, falling back to torch ops.")
29292952
use_triton = False
29302953
if use_triton:
2931-
return triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
2954+
return triton_quantize_fp8_row(
2955+
a,
2956+
scale_ub,
2957+
zero_start_index_M,
2958+
align_rows_to=align_rows_to,
2959+
)
29322960
# else use pytorch implementation.
29332961
if not output_device:
29342962
output_device = a.device
@@ -2958,18 +2986,34 @@ def quantize_fp8_row(
29582986
def quantize_fp8_row_meta(
29592987
a: Tensor,
29602988
scale_ub: Optional[Tensor] = None,
2989+
zero_start_index_M: Optional[Tensor] = None,
29612990
use_triton: bool = True,
29622991
output_device: Optional[torch.device] = None,
2992+
align_rows_to: Optional[int] = None,
29632993
) -> Tuple[torch.Tensor, torch.Tensor]:
29642994
"""Shape function for torch compile."""
29652995
if output_device is None:
29662996
output_device = a.device
29672997
a_shape = a.shape
2968-
# Flatten to 2D since each row of each potential batch gets a scale.
29692998
dtype = get_fp8_constants()[0]
2970-
fake_out = torch.empty(a.shape, device=output_device, dtype=dtype)
2971-
fake_scale = torch.empty(a_shape[:-1], device=output_device, dtype=torch.float32)
2972-
return fake_out, fake_scale
2999+
if align_rows_to is not None:
3000+
last_dim = a.shape[-1]
3001+
padded_last_dim = (
3002+
(last_dim + align_rows_to - 1) // align_rows_to
3003+
) * align_rows_to
3004+
fake_out = torch.empty(
3005+
(*a.shape[:-1], padded_last_dim), device=output_device, dtype=dtype
3006+
)
3007+
fake_scale = torch.empty(
3008+
a_shape[:-1], device=output_device, dtype=torch.float32
3009+
)
3010+
return fake_out, fake_scale
3011+
else:
3012+
fake_out = torch.empty(a.shape, device=output_device, dtype=dtype)
3013+
fake_scale = torch.empty(
3014+
a_shape[:-1], device=output_device, dtype=torch.float32
3015+
)
3016+
return fake_out, fake_scale
29733017

29743018

29753019
@triton.autotune(

0 commit comments

Comments
 (0)