Skip to content

Commit 91e714b

Browse files
yyetimfacebook-github-bot
authored andcommitted
Add ability to pad the rowwise quantized tensors (#4877)
Summary: Pull Request resolved: #4877 X-link: facebookresearch/FBGEMM#1899 Some downstream kernels assume a certain width from quantized tensors. This adds the ability to do this as part of the triton fp8 quantize kernel. Reviewed By: RandySheriff Differential Revision: D82486197
1 parent fd32631 commit 91e714b

File tree

2 files changed

+132
-10
lines changed

2 files changed

+132
-10
lines changed

fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
quantize_fp8_packed_row,
2727
quantize_fp8_packed_row_raw,
2828
quantize_fp8_row,
29+
quantize_fp8_row_meta,
2930
scale_fp8_row,
3031
)
3132

@@ -48,6 +49,8 @@ def _test_quantize_fp8_row(
4849
use_jagged: bool = False,
4950
use_scale_ub: bool = False,
5051
transpose_inputs: bool = False,
52+
align_rows_to: Optional[int] = None,
53+
expected_padded_size: Optional[int] = None, # only set with align_rows_to
5154
) -> None:
5255
a = torch.randn(shape, dtype=torch.bfloat16, device=device)
5356
inputs = [a]
@@ -91,8 +94,23 @@ def _test_quantize_fp8_row(
9194
zero_start_index_M=zero_start_index_M,
9295
use_triton=use_triton,
9396
output_device=output_device,
97+
align_rows_to=align_rows_to,
9498
)
9599

100+
a_fp8_meta, a_scale_meta = quantize_fp8_row_meta(
101+
input_a,
102+
scale_ub=scale_ub,
103+
zero_start_index_M=zero_start_index_M,
104+
use_triton=use_triton,
105+
output_device=output_device,
106+
align_rows_to=align_rows_to,
107+
)
108+
109+
self.assertEqual(a_fp8.dtype, a_fp8_meta.dtype)
110+
self.assertEqual(a_fp8.shape, a_fp8_meta.shape)
111+
self.assertEqual(a_scale.dtype, a_scale_meta.dtype)
112+
self.assertEqual(a_scale.shape, a_scale_meta.shape)
113+
96114
# Undo scaling.
97115
a_torch = a_fp8.to(torch.bfloat16)
98116
broadcast_shape = list(a_torch.shape[:-1]) + [-1]
@@ -101,6 +119,20 @@ def _test_quantize_fp8_row(
101119

102120
a_torch *= a_scale.view(broadcast_shape)
103121

122+
if align_rows_to is not None:
123+
# Pad input_a's row dimension to expected_padded_size if specified.
124+
assert expected_padded_size is not None
125+
pad_rows = expected_padded_size - input_a.shape[-1]
126+
if pad_rows > 0:
127+
pad_shape = list(input_a.shape)
128+
pad_shape[-1] = pad_rows
129+
pad_tensor = torch.zeros(
130+
pad_shape,
131+
dtype=input_a.dtype,
132+
device=input_a.device,
133+
)
134+
input_a = torch.cat([input_a, pad_tensor], dim=-1)
135+
104136
self.assertTrue(
105137
torch.allclose(
106138
input_a.to(device=output_device),
@@ -112,8 +144,38 @@ def _test_quantize_fp8_row(
112144

113145
for n_col in range(1, 9000, 100):
114146
_test_quantize_fp8_row((2, n_col), True, torch.device("cuda"))
147+
148+
# Test with padding. These go up to 9000 (larger than max BLOCK_SIZE)
149+
150+
# Calculate expected_padded_size from align_rows_to=8.
151+
# Using a different math here, just to make tests different from implementation.
152+
align_rows_to = 8
153+
truncated_size = align_rows_to - n_col % align_rows_to
154+
expected_padded_size = n_col + truncated_size
155+
_test_quantize_fp8_row(
156+
(2, n_col),
157+
True,
158+
torch.device("cuda"),
159+
align_rows_to=align_rows_to,
160+
expected_padded_size=expected_padded_size,
161+
)
162+
115163
# Test with batched input.
116164
_test_quantize_fp8_row((4, 2, 3), True, torch.device("cuda"))
165+
_test_quantize_fp8_row(
166+
(4, 2, 3),
167+
True,
168+
torch.device("cuda"),
169+
align_rows_to=8,
170+
expected_padded_size=8,
171+
)
172+
_test_quantize_fp8_row(
173+
(4, 2, 13),
174+
True,
175+
torch.device("cuda"),
176+
align_rows_to=8,
177+
expected_padded_size=16,
178+
)
117179
_test_quantize_fp8_row((6, 4, 2, 3), True, torch.device("cuda"))
118180
# Test with non-contiguous input
119181
_test_quantize_fp8_row(
@@ -132,6 +194,14 @@ def _test_quantize_fp8_row(
132194
_test_quantize_fp8_row((6, 4, 2, 3), True, torch.device("cpu"))
133195
# Test with zero_start_index_M
134196
_test_quantize_fp8_row((20, 30), True, torch.device("cuda"), use_jagged=True)
197+
_test_quantize_fp8_row(
198+
(20, 30),
199+
True,
200+
torch.device("cuda"),
201+
use_jagged=True,
202+
align_rows_to=16,
203+
expected_padded_size=32,
204+
)
135205
_test_quantize_fp8_row(
136206
(6, 4, 2, 3), True, torch.device("cuda"), use_jagged=True
137207
)

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2328,6 +2328,7 @@ def _kernel_quantize_fp8_row(
23282328
M,
23292329
N,
23302330
K,
2331+
K_fp8, # used when padding
23312332
stride_ab,
23322333
stride_am,
23332334
stride_an,
@@ -2364,7 +2365,8 @@ def _kernel_quantize_fp8_row(
23642365
B (int): Size of dimenion 0
23652366
M (int): Size of dimenion 1
23662367
N (int): Size of dimenion 2
2367-
K (int): Size of dimenion 3
2368+
K (int): Size of dimenion 3 (input row size)
2369+
K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
23682370
stride_ab (int): Stride of b dimension of A.
23692371
stride_am (int): Stride of m dimension of A.
23702372
stride_an (int): Stride of n dimension of A.
@@ -2433,21 +2435,32 @@ def _kernel_quantize_fp8_row(
24332435
tl.store(A_scale + pid, 1.0 / a_scale)
24342436
n_offset = tl.arange(0, BLOCK_SIZE)
24352437

2436-
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
2438+
# Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
2439+
for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)):
2440+
# For the first K elements, use A; for the rest, use 0
2441+
mask_in_A = n_offset < K_in
2442+
2443+
# Load from A if in range, else 0 (we're going all the way to K_fp8)
24372444
a = tl.load(
24382445
A + a_offset_base + n_offset * stride_ak,
2439-
mask=n_offset < K_in,
2446+
mask=mask_in_A,
24402447
other=0.0,
24412448
)
2449+
# For elements >= K, a will be 0
24422450
a_fp8 = a * a_scale
24432451
# Clamp A to fp8 range to make sure there's no overflow.
24442452
# This is required for AMD. Nvidia's default saturation
24452453
# handles it, but it's nice to have anyway.
24462454
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
2455+
2456+
# Get the mask for A_fp8
2457+
mask_in_A_fp8 = n_offset < K_fp8
2458+
2459+
# Store the full new row in its place (for elements >= K, a_fp8 is already 0)
24472460
tl.store(
24482461
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
24492462
a_fp8,
2450-
mask=n_offset < K,
2463+
mask=mask_in_A_fp8,
24512464
)
24522465
n_offset += BLOCK_SIZE
24532466

@@ -2456,6 +2469,9 @@ def triton_quantize_fp8_row(
24562469
a: Tensor,
24572470
scale_ub: Optional[Tensor] = None,
24582471
zero_start_index_M: Optional[Tensor] = None,
2472+
align_rows_to: Optional[
2473+
int
2474+
] = None, # TODO(yyetim) Add a test case, validate padding and 0 setting
24592475
) -> Tuple[Tensor, Tensor]:
24602476
"""
24612477
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
@@ -2464,6 +2480,7 @@ def triton_quantize_fp8_row(
24642480
a (Tensor): higher precision input tensor of 4 dimension.
24652481
scale_ub (Tensor): Maximum allowed value for scale.
24662482
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
2483+
align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
24672484
24682485
Returns:
24692486
torch.Tensor: fp8 scaled tensor.
@@ -2485,7 +2502,18 @@ def triton_quantize_fp8_row(
24852502
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
24862503
num_rows = a.numel() // a.shape[-1]
24872504
a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
2488-
a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
2505+
# If align_rows_to is provided, pad the last dimension to be a multiple of it
2506+
if align_rows_to is not None:
2507+
last_dim = a.shape[-1]
2508+
padded_last_dim = (
2509+
(last_dim + align_rows_to - 1) // align_rows_to
2510+
) * align_rows_to
2511+
a_fp8 = torch.empty(
2512+
(*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype
2513+
)
2514+
a_shape = torch.Size((*a_shape[:-1], padded_last_dim))
2515+
else:
2516+
a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
24892517

24902518
# If input tensor is sufficiently large, we need to use int64 indexing.
24912519
use_int64 = a.numel() > (2**31 - 1)
@@ -2504,6 +2532,7 @@ def triton_quantize_fp8_row(
25042532
a.shape[1],
25052533
a.shape[2],
25062534
a.shape[3],
2535+
a_fp8.shape[3],
25072536
a.stride(0),
25082537
a.stride(1),
25092538
a.stride(2),
@@ -2908,6 +2937,7 @@ def quantize_fp8_row(
29082937
zero_start_index_M: Optional[Tensor] = None,
29092938
use_triton: bool = True,
29102939
output_device: Optional[torch.device] = None,
2940+
align_rows_to: Optional[int] = None,
29112941
) -> Tuple[torch.Tensor, torch.Tensor]:
29122942
"""
29132943
Quantize a to fp8 with row-wise scalings and optionally move to output device.
@@ -2918,6 +2948,7 @@ def quantize_fp8_row(
29182948
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
29192949
use_triton (bool): Whether to use triton kernel or pytorch.
29202950
output_device (torch.device): Device to optionally move the scaled tensors to.
2951+
align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
29212952
29222953
Returns:
29232954
torch.Tensor: fp8 scaled tensor.
@@ -2928,7 +2959,12 @@ def quantize_fp8_row(
29282959
logger.info("Triton does not support cpu, falling back to torch ops.")
29292960
use_triton = False
29302961
if use_triton:
2931-
return triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
2962+
return triton_quantize_fp8_row(
2963+
a,
2964+
scale_ub,
2965+
zero_start_index_M,
2966+
align_rows_to=align_rows_to,
2967+
)
29322968
# else use pytorch implementation.
29332969
if not output_device:
29342970
output_device = a.device
@@ -2958,18 +2994,34 @@ def quantize_fp8_row(
29582994
def quantize_fp8_row_meta(
29592995
a: Tensor,
29602996
scale_ub: Optional[Tensor] = None,
2997+
zero_start_index_M: Optional[Tensor] = None,
29612998
use_triton: bool = True,
29622999
output_device: Optional[torch.device] = None,
3000+
align_rows_to: Optional[int] = None,
29633001
) -> Tuple[torch.Tensor, torch.Tensor]:
29643002
"""Shape function for torch compile."""
29653003
if output_device is None:
29663004
output_device = a.device
29673005
a_shape = a.shape
2968-
# Flatten to 2D since each row of each potential batch gets a scale.
29693006
dtype = get_fp8_constants()[0]
2970-
fake_out = torch.empty(a.shape, device=output_device, dtype=dtype)
2971-
fake_scale = torch.empty(a_shape[:-1], device=output_device, dtype=torch.float32)
2972-
return fake_out, fake_scale
3007+
if align_rows_to is not None:
3008+
last_dim = a.shape[-1]
3009+
padded_last_dim = (
3010+
(last_dim + align_rows_to - 1) // align_rows_to
3011+
) * align_rows_to
3012+
fake_out = torch.empty(
3013+
(*a.shape[:-1], padded_last_dim), device=output_device, dtype=dtype
3014+
)
3015+
fake_scale = torch.empty(
3016+
a_shape[:-1], device=output_device, dtype=torch.float32
3017+
)
3018+
return fake_out, fake_scale
3019+
else:
3020+
fake_out = torch.empty(a.shape, device=output_device, dtype=dtype)
3021+
fake_scale = torch.empty(
3022+
a_shape[:-1], device=output_device, dtype=torch.float32
3023+
)
3024+
return fake_out, fake_scale
29733025

29743026

29753027
@triton.autotune(

0 commit comments

Comments
 (0)