@@ -2328,6 +2328,7 @@ def _kernel_quantize_fp8_row(
2328
2328
M ,
2329
2329
N ,
2330
2330
K ,
2331
+ K_fp8 , # used when padding
2331
2332
stride_ab ,
2332
2333
stride_am ,
2333
2334
stride_an ,
@@ -2364,7 +2365,8 @@ def _kernel_quantize_fp8_row(
2364
2365
B (int): Size of dimenion 0
2365
2366
M (int): Size of dimenion 1
2366
2367
N (int): Size of dimenion 2
2367
- K (int): Size of dimenion 3
2368
+ K (int): Size of dimenion 3 (input row size)
2369
+ K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
2368
2370
stride_ab (int): Stride of b dimension of A.
2369
2371
stride_am (int): Stride of m dimension of A.
2370
2372
stride_an (int): Stride of n dimension of A.
@@ -2433,21 +2435,32 @@ def _kernel_quantize_fp8_row(
2433
2435
tl .store (A_scale + pid , 1.0 / a_scale )
2434
2436
n_offset = tl .arange (0 , BLOCK_SIZE )
2435
2437
2436
- for _k in range (0 , tl .cdiv (K , BLOCK_SIZE )):
2438
+ # Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
2439
+ for _k in range (0 , tl .cdiv (K_fp8 , BLOCK_SIZE )):
2440
+ # For the first K elements, use A; for the rest, use 0
2441
+ mask_in_A = n_offset < K_in
2442
+
2443
+ # Load from A if in range, else 0 (we're going all the way to K_fp8)
2437
2444
a = tl .load (
2438
2445
A + a_offset_base + n_offset * stride_ak ,
2439
- mask = n_offset < K_in ,
2446
+ mask = mask_in_A ,
2440
2447
other = 0.0 ,
2441
2448
)
2449
+ # For elements >= K, a will be 0
2442
2450
a_fp8 = a * a_scale
2443
2451
# Clamp A to fp8 range to make sure there's no overflow.
2444
2452
# This is required for AMD. Nvidia's default saturation
2445
2453
# handles it, but it's nice to have anyway.
2446
2454
a_fp8 = tl .clamp (a_fp8 , - MAX_FP8 , MAX_FP8 ).to (TL_FP8_DTYPE )
2455
+
2456
+ # Get the mask for A_fp8
2457
+ mask_in_A_fp8 = n_offset < K_fp8
2458
+
2459
+ # Store the full new row in its place (for elements >= K, a_fp8 is already 0)
2447
2460
tl .store (
2448
2461
A_fp8 + a_fp8_offset_base + n_offset * stride_ok ,
2449
2462
a_fp8 ,
2450
- mask = n_offset < K ,
2463
+ mask = mask_in_A_fp8 ,
2451
2464
)
2452
2465
n_offset += BLOCK_SIZE
2453
2466
@@ -2456,6 +2469,9 @@ def triton_quantize_fp8_row(
2456
2469
a : Tensor ,
2457
2470
scale_ub : Optional [Tensor ] = None ,
2458
2471
zero_start_index_M : Optional [Tensor ] = None ,
2472
+ align_rows_to : Optional [
2473
+ int
2474
+ ] = None , # TODO(yyetim) Add a test case, validate padding and 0 setting
2459
2475
) -> Tuple [Tensor , Tensor ]:
2460
2476
"""
2461
2477
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
@@ -2464,6 +2480,7 @@ def triton_quantize_fp8_row(
2464
2480
a (Tensor): higher precision input tensor of 4 dimension.
2465
2481
scale_ub (Tensor): Maximum allowed value for scale.
2466
2482
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
2483
+ align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
2467
2484
2468
2485
Returns:
2469
2486
torch.Tensor: fp8 scaled tensor.
@@ -2485,7 +2502,18 @@ def triton_quantize_fp8_row(
2485
2502
pt_dtype , tl_dtype , max_fp8 , eps = get_fp8_constants ()
2486
2503
num_rows = a .numel () // a .shape [- 1 ]
2487
2504
a_scale = torch .empty ((num_rows ), dtype = torch .float32 , device = a .device )
2488
- a_fp8 = torch .empty (a .shape , device = a .device , dtype = pt_dtype )
2505
+ # If align_rows_to is provided, pad the last dimension to be a multiple of it
2506
+ if align_rows_to is not None :
2507
+ last_dim = a .shape [- 1 ]
2508
+ padded_last_dim = (
2509
+ (last_dim + align_rows_to - 1 ) // align_rows_to
2510
+ ) * align_rows_to
2511
+ a_fp8 = torch .empty (
2512
+ (* a .shape [:- 1 ], padded_last_dim ), device = a .device , dtype = pt_dtype
2513
+ )
2514
+ a_shape = torch .Size ((* a_shape [:- 1 ], padded_last_dim ))
2515
+ else :
2516
+ a_fp8 = torch .empty (a .shape , device = a .device , dtype = pt_dtype )
2489
2517
2490
2518
# If input tensor is sufficiently large, we need to use int64 indexing.
2491
2519
use_int64 = a .numel () > (2 ** 31 - 1 )
@@ -2504,6 +2532,7 @@ def triton_quantize_fp8_row(
2504
2532
a .shape [1 ],
2505
2533
a .shape [2 ],
2506
2534
a .shape [3 ],
2535
+ a_fp8 .shape [3 ],
2507
2536
a .stride (0 ),
2508
2537
a .stride (1 ),
2509
2538
a .stride (2 ),
@@ -2908,6 +2937,7 @@ def quantize_fp8_row(
2908
2937
zero_start_index_M : Optional [Tensor ] = None ,
2909
2938
use_triton : bool = True ,
2910
2939
output_device : Optional [torch .device ] = None ,
2940
+ align_rows_to : Optional [int ] = None ,
2911
2941
) -> Tuple [torch .Tensor , torch .Tensor ]:
2912
2942
"""
2913
2943
Quantize a to fp8 with row-wise scalings and optionally move to output device.
@@ -2918,6 +2948,7 @@ def quantize_fp8_row(
2918
2948
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
2919
2949
use_triton (bool): Whether to use triton kernel or pytorch.
2920
2950
output_device (torch.device): Device to optionally move the scaled tensors to.
2951
+ align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
2921
2952
2922
2953
Returns:
2923
2954
torch.Tensor: fp8 scaled tensor.
@@ -2928,7 +2959,12 @@ def quantize_fp8_row(
2928
2959
logger .info ("Triton does not support cpu, falling back to torch ops." )
2929
2960
use_triton = False
2930
2961
if use_triton :
2931
- return triton_quantize_fp8_row (a , scale_ub , zero_start_index_M )
2962
+ return triton_quantize_fp8_row (
2963
+ a ,
2964
+ scale_ub ,
2965
+ zero_start_index_M ,
2966
+ align_rows_to = align_rows_to ,
2967
+ )
2932
2968
# else use pytorch implementation.
2933
2969
if not output_device :
2934
2970
output_device = a .device
@@ -2958,18 +2994,34 @@ def quantize_fp8_row(
2958
2994
def quantize_fp8_row_meta (
2959
2995
a : Tensor ,
2960
2996
scale_ub : Optional [Tensor ] = None ,
2997
+ zero_start_index_M : Optional [Tensor ] = None ,
2961
2998
use_triton : bool = True ,
2962
2999
output_device : Optional [torch .device ] = None ,
3000
+ align_rows_to : Optional [int ] = None ,
2963
3001
) -> Tuple [torch .Tensor , torch .Tensor ]:
2964
3002
"""Shape function for torch compile."""
2965
3003
if output_device is None :
2966
3004
output_device = a .device
2967
3005
a_shape = a .shape
2968
- # Flatten to 2D since each row of each potential batch gets a scale.
2969
3006
dtype = get_fp8_constants ()[0 ]
2970
- fake_out = torch .empty (a .shape , device = output_device , dtype = dtype )
2971
- fake_scale = torch .empty (a_shape [:- 1 ], device = output_device , dtype = torch .float32 )
2972
- return fake_out , fake_scale
3007
+ if align_rows_to is not None :
3008
+ last_dim = a .shape [- 1 ]
3009
+ padded_last_dim = (
3010
+ (last_dim + align_rows_to - 1 ) // align_rows_to
3011
+ ) * align_rows_to
3012
+ fake_out = torch .empty (
3013
+ (* a .shape [:- 1 ], padded_last_dim ), device = output_device , dtype = dtype
3014
+ )
3015
+ fake_scale = torch .empty (
3016
+ a_shape [:- 1 ], device = output_device , dtype = torch .float32
3017
+ )
3018
+ return fake_out , fake_scale
3019
+ else :
3020
+ fake_out = torch .empty (a .shape , device = output_device , dtype = dtype )
3021
+ fake_scale = torch .empty (
3022
+ a_shape [:- 1 ], device = output_device , dtype = torch .float32
3023
+ )
3024
+ return fake_out , fake_scale
2973
3025
2974
3026
2975
3027
@triton .autotune (
0 commit comments