Skip to content

Commit 39b8f18

Browse files
yyetimfacebook-github-bot
authored andcommitted
Add ability to pad the rowwise quantized tensors
Summary: X-link: facebookresearch/FBGEMM#1899 Some downstream kernels assume a certain width from quantized tensors. This adds the ability to do this as part of the triton fp8 quantize kernel. Differential Revision: D82486197
1 parent 12a1be8 commit 39b8f18

File tree

2 files changed

+119
-12
lines changed

2 files changed

+119
-12
lines changed

fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
quantize_fp8_packed_row,
2727
quantize_fp8_packed_row_raw,
2828
quantize_fp8_row,
29+
quantize_fp8_row_meta,
2930
scale_fp8_row,
3031
)
3132

@@ -48,6 +49,10 @@ def _test_quantize_fp8_row(
4849
use_jagged: bool = False,
4950
use_scale_ub: bool = False,
5051
transpose_inputs: bool = False,
52+
pad_rows_to_multiple_of: Optional[int] = None,
53+
expected_padded_size: Optional[
54+
int
55+
] = None, # only set with pad_rows_to_multiple_of
5156
) -> None:
5257
a = torch.randn(shape, dtype=torch.bfloat16, device=device)
5358
inputs = [a]
@@ -91,8 +96,23 @@ def _test_quantize_fp8_row(
9196
zero_start_index_M=zero_start_index_M,
9297
use_triton=use_triton,
9398
output_device=output_device,
99+
pad_rows_to_multiple_of=pad_rows_to_multiple_of,
94100
)
95101

102+
a_fp8_meta, a_scale_meta = quantize_fp8_row_meta(
103+
input_a,
104+
scale_ub=scale_ub,
105+
zero_start_index_M=zero_start_index_M,
106+
use_triton=use_triton,
107+
output_device=output_device,
108+
pad_rows_to_multiple_of=pad_rows_to_multiple_of,
109+
)
110+
111+
self.assertEqual(a_fp8.dtype, a_fp8_meta.dtype)
112+
self.assertEqual(a_fp8.shape, a_fp8_meta.shape)
113+
self.assertEqual(a_scale.dtype, a_scale_meta.dtype)
114+
self.assertEqual(a_scale.shape, a_scale_meta.shape)
115+
96116
# Undo scaling.
97117
a_torch = a_fp8.to(torch.bfloat16)
98118
broadcast_shape = list(a_torch.shape[:-1]) + [-1]
@@ -101,6 +121,20 @@ def _test_quantize_fp8_row(
101121

102122
a_torch *= a_scale.view(broadcast_shape)
103123

124+
if pad_rows_to_multiple_of is not None:
125+
# Pad input_a's row dimension to expected_padded_size if specified.
126+
if expected_padded_size is not None:
127+
pad_rows = expected_padded_size - input_a.shape[-1]
128+
if pad_rows > 0:
129+
pad_shape = list(input_a.shape)
130+
pad_shape[-1] = pad_rows
131+
pad_tensor = torch.zeros(
132+
pad_shape,
133+
dtype=input_a.dtype,
134+
device=input_a.device,
135+
)
136+
input_a = torch.cat([input_a, pad_tensor], dim=-1)
137+
104138
self.assertTrue(
105139
torch.allclose(
106140
input_a.to(device=output_device),
@@ -114,6 +148,20 @@ def _test_quantize_fp8_row(
114148
_test_quantize_fp8_row((2, n_col), True, torch.device("cuda"))
115149
# Test with batched input.
116150
_test_quantize_fp8_row((4, 2, 3), True, torch.device("cuda"))
151+
_test_quantize_fp8_row(
152+
(4, 2, 3),
153+
True,
154+
torch.accelerator.current_accelerator("cuda"),
155+
pad_rows_to_multiple_of=8,
156+
expected_padded_size=8,
157+
)
158+
_test_quantize_fp8_row(
159+
(4, 2, 13),
160+
True,
161+
torch.accelerator.current_accelerator("cuda"),
162+
pad_rows_to_multiple_of=8,
163+
expected_padded_size=16,
164+
)
117165
_test_quantize_fp8_row((6, 4, 2, 3), True, torch.device("cuda"))
118166
# Test with non-contiguous input
119167
_test_quantize_fp8_row(
@@ -132,6 +180,14 @@ def _test_quantize_fp8_row(
132180
_test_quantize_fp8_row((6, 4, 2, 3), True, torch.device("cpu"))
133181
# Test with zero_start_index_M
134182
_test_quantize_fp8_row((20, 30), True, torch.device("cuda"), use_jagged=True)
183+
_test_quantize_fp8_row(
184+
(20, 30),
185+
True,
186+
torch.accelerator.current_accelerator("cuda"),
187+
use_jagged=True,
188+
pad_rows_to_multiple_of=16,
189+
expected_padded_size=32,
190+
)
135191
_test_quantize_fp8_row(
136192
(6, 4, 2, 3), True, torch.device("cuda"), use_jagged=True
137193
)

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2328,6 +2328,7 @@ def _kernel_quantize_fp8_row(
23282328
M,
23292329
N,
23302330
K,
2331+
K_fp8, # used when padding
23312332
stride_ab,
23322333
stride_am,
23332334
stride_an,
@@ -2364,7 +2365,8 @@ def _kernel_quantize_fp8_row(
23642365
B (int): Size of dimenion 0
23652366
M (int): Size of dimenion 1
23662367
N (int): Size of dimenion 2
2367-
K (int): Size of dimenion 3
2368+
K (int): Size of dimenion 3 (input row size)
2369+
K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
23682370
stride_ab (int): Stride of b dimension of A.
23692371
stride_am (int): Stride of m dimension of A.
23702372
stride_an (int): Stride of n dimension of A.
@@ -2433,21 +2435,32 @@ def _kernel_quantize_fp8_row(
24332435
tl.store(A_scale + pid, 1.0 / a_scale)
24342436
n_offset = tl.arange(0, BLOCK_SIZE)
24352437

2436-
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
2438+
# Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
2439+
for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)):
2440+
# For the first K elements, use A; for the rest, use 0
2441+
# Compute the valid range for this tile
2442+
tile_start = _k * BLOCK_SIZE
2443+
# Calculate masks for both cases
2444+
mask_in_A = (n_offset + tile_start) < K_in
2445+
mask_in_A_fp8 = (n_offset + tile_start) < K_fp8
2446+
2447+
# Load from A if in range, else 0 (we're going all the way to K_fp8)
24372448
a = tl.load(
2438-
A + a_offset_base + n_offset * stride_ak,
2439-
mask=n_offset < K_in,
2449+
A + a_offset_base + (n_offset + tile_start) * stride_ak,
2450+
mask=mask_in_A & mask_in_A_fp8,
24402451
other=0.0,
24412452
)
2453+
# For elements >= K, a will be 0
24422454
a_fp8 = a * a_scale
24432455
# Clamp A to fp8 range to make sure there's no overflow.
24442456
# This is required for AMD. Nvidia's default saturation
24452457
# handles it, but it's nice to have anyway.
24462458
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
2459+
# For elements >= K, a_fp8 is already 0
24472460
tl.store(
2448-
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
2461+
A_fp8 + a_fp8_offset_base + (n_offset + tile_start) * stride_ok,
24492462
a_fp8,
2450-
mask=n_offset < K,
2463+
mask=mask_in_A_fp8,
24512464
)
24522465
n_offset += BLOCK_SIZE
24532466

@@ -2456,6 +2469,9 @@ def triton_quantize_fp8_row(
24562469
a: Tensor,
24572470
scale_ub: Optional[Tensor] = None,
24582471
zero_start_index_M: Optional[Tensor] = None,
2472+
pad_rows_to_multiple_of: Optional[
2473+
int
2474+
] = None, # TODO(yyetim) Add a test case, validate padding and 0 setting
24592475
) -> Tuple[Tensor, Tensor]:
24602476
"""
24612477
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
@@ -2464,6 +2480,7 @@ def triton_quantize_fp8_row(
24642480
a (Tensor): higher precision input tensor of 4 dimension.
24652481
scale_ub (Tensor): Maximum allowed value for scale.
24662482
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
2483+
pad_rows_to_multiple_of: Pad rows to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
24672484
24682485
Returns:
24692486
torch.Tensor: fp8 scaled tensor.
@@ -2485,7 +2502,18 @@ def triton_quantize_fp8_row(
24852502
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
24862503
num_rows = a.numel() // a.shape[-1]
24872504
a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
2488-
a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
2505+
# If pad_rows_to_multiple_of is provided, pad the last dimension to be a multiple of it
2506+
if pad_rows_to_multiple_of is not None:
2507+
last_dim = a.shape[-1]
2508+
padded_last_dim = (
2509+
(last_dim + pad_rows_to_multiple_of - 1) // pad_rows_to_multiple_of
2510+
) * pad_rows_to_multiple_of
2511+
a_fp8 = torch.empty(
2512+
(*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype
2513+
)
2514+
a_shape = (*a_shape[:-1], padded_last_dim)
2515+
else:
2516+
a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
24892517

24902518
# If input tensor is sufficiently large, we need to use int64 indexing.
24912519
use_int64 = a.numel() > (2**31 - 1)
@@ -2504,6 +2532,7 @@ def triton_quantize_fp8_row(
25042532
a.shape[1],
25052533
a.shape[2],
25062534
a.shape[3],
2535+
a_fp8.shape[3],
25072536
a.stride(0),
25082537
a.stride(1),
25092538
a.stride(2),
@@ -2908,6 +2937,7 @@ def quantize_fp8_row(
29082937
zero_start_index_M: Optional[Tensor] = None,
29092938
use_triton: bool = True,
29102939
output_device: Optional[torch.device] = None,
2940+
pad_rows_to_multiple_of: Optional[int] = None,
29112941
) -> Tuple[torch.Tensor, torch.Tensor]:
29122942
"""
29132943
Quantize a to fp8 with row-wise scalings and optionally move to output device.
@@ -2928,7 +2958,12 @@ def quantize_fp8_row(
29282958
logger.info("Triton does not support cpu, falling back to torch ops.")
29292959
use_triton = False
29302960
if use_triton:
2931-
return triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
2961+
return triton_quantize_fp8_row(
2962+
a,
2963+
scale_ub,
2964+
zero_start_index_M,
2965+
pad_rows_to_multiple_of=pad_rows_to_multiple_of,
2966+
)
29322967
# else use pytorch implementation.
29332968
if not output_device:
29342969
output_device = a.device
@@ -2958,18 +2993,34 @@ def quantize_fp8_row(
29582993
def quantize_fp8_row_meta(
29592994
a: Tensor,
29602995
scale_ub: Optional[Tensor] = None,
2996+
zero_start_index_M: Optional[Tensor] = None,
29612997
use_triton: bool = True,
29622998
output_device: Optional[torch.device] = None,
2999+
pad_rows_to_multiple_of: Optional[int] = None,
29633000
) -> Tuple[torch.Tensor, torch.Tensor]:
29643001
"""Shape function for torch compile."""
29653002
if output_device is None:
29663003
output_device = a.device
29673004
a_shape = a.shape
2968-
# Flatten to 2D since each row of each potential batch gets a scale.
29693005
dtype = get_fp8_constants()[0]
2970-
fake_out = torch.empty(a.shape, device=output_device, dtype=dtype)
2971-
fake_scale = torch.empty(a_shape[:-1], device=output_device, dtype=torch.float32)
2972-
return fake_out, fake_scale
3006+
if pad_rows_to_multiple_of is not None:
3007+
last_dim = a.shape[-1]
3008+
padded_last_dim = (
3009+
(last_dim + pad_rows_to_multiple_of - 1) // pad_rows_to_multiple_of
3010+
) * pad_rows_to_multiple_of
3011+
fake_out = torch.empty(
3012+
(*a.shape[:-1], padded_last_dim), device=output_device, dtype=dtype
3013+
)
3014+
fake_scale = torch.empty(
3015+
a_shape[:-1], device=output_device, dtype=torch.float32
3016+
)
3017+
return fake_out, fake_scale
3018+
else:
3019+
fake_out = torch.empty(a.shape, device=output_device, dtype=dtype)
3020+
fake_scale = torch.empty(
3021+
a_shape[:-1], device=output_device, dtype=torch.float32
3022+
)
3023+
return fake_out, fake_scale
29733024

29743025

29753026
@triton.autotune(

0 commit comments

Comments
 (0)