@@ -2328,6 +2328,7 @@ def _kernel_quantize_fp8_row(
2328
2328
M ,
2329
2329
N ,
2330
2330
K ,
2331
+ K_fp8 , # used when padding
2331
2332
stride_ab ,
2332
2333
stride_am ,
2333
2334
stride_an ,
@@ -2364,7 +2365,8 @@ def _kernel_quantize_fp8_row(
2364
2365
B (int): Size of dimenion 0
2365
2366
M (int): Size of dimenion 1
2366
2367
N (int): Size of dimenion 2
2367
- K (int): Size of dimenion 3
2368
+ K (int): Size of dimenion 3 (input row size)
2369
+ K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
2368
2370
stride_ab (int): Stride of b dimension of A.
2369
2371
stride_am (int): Stride of m dimension of A.
2370
2372
stride_an (int): Stride of n dimension of A.
@@ -2433,21 +2435,26 @@ def _kernel_quantize_fp8_row(
2433
2435
tl .store (A_scale + pid , 1.0 / a_scale )
2434
2436
n_offset = tl .arange (0 , BLOCK_SIZE )
2435
2437
2436
- for _k in range (0 , tl .cdiv (K , BLOCK_SIZE )):
2438
+ # Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
2439
+ for _k in range (0 , tl .cdiv (K_fp8 , BLOCK_SIZE )):
2440
+ # Load from A if in range, else 0 (we're going all the way to K_fp8)
2437
2441
a = tl .load (
2438
2442
A + a_offset_base + n_offset * stride_ak ,
2439
2443
mask = n_offset < K_in ,
2440
2444
other = 0.0 ,
2441
2445
)
2446
+ # For elements >= K, a will be 0
2442
2447
a_fp8 = a * a_scale
2443
2448
# Clamp A to fp8 range to make sure there's no overflow.
2444
2449
# This is required for AMD. Nvidia's default saturation
2445
2450
# handles it, but it's nice to have anyway.
2446
2451
a_fp8 = tl .clamp (a_fp8 , - MAX_FP8 , MAX_FP8 ).to (TL_FP8_DTYPE )
2452
+
2453
+ # Store the full new row in its place (for elements >= K, a_fp8 is already 0)
2447
2454
tl .store (
2448
2455
A_fp8 + a_fp8_offset_base + n_offset * stride_ok ,
2449
2456
a_fp8 ,
2450
- mask = n_offset < K ,
2457
+ mask = n_offset < K_fp8 ,
2451
2458
)
2452
2459
n_offset += BLOCK_SIZE
2453
2460
@@ -2456,6 +2463,7 @@ def triton_quantize_fp8_row(
2456
2463
a : Tensor ,
2457
2464
scale_ub : Optional [Tensor ] = None ,
2458
2465
zero_start_index_M : Optional [Tensor ] = None ,
2466
+ align_rows_to : Optional [int ] = None ,
2459
2467
) -> Tuple [Tensor , Tensor ]:
2460
2468
"""
2461
2469
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
@@ -2464,6 +2472,7 @@ def triton_quantize_fp8_row(
2464
2472
a (Tensor): higher precision input tensor of 4 dimension.
2465
2473
scale_ub (Tensor): Maximum allowed value for scale.
2466
2474
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
2475
+ align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
2467
2476
2468
2477
Returns:
2469
2478
torch.Tensor: fp8 scaled tensor.
@@ -2485,7 +2494,18 @@ def triton_quantize_fp8_row(
2485
2494
pt_dtype , tl_dtype , max_fp8 , eps = get_fp8_constants ()
2486
2495
num_rows = a .numel () // a .shape [- 1 ]
2487
2496
a_scale = torch .empty ((num_rows ), dtype = torch .float32 , device = a .device )
2488
- a_fp8 = torch .empty (a .shape , device = a .device , dtype = pt_dtype )
2497
+ # If align_rows_to is provided, pad the last dimension to be a multiple of it
2498
+ if align_rows_to is not None :
2499
+ last_dim = a .shape [- 1 ]
2500
+ padded_last_dim = (
2501
+ (last_dim + align_rows_to - 1 ) // align_rows_to
2502
+ ) * align_rows_to
2503
+ a_fp8 = torch .empty (
2504
+ (* a .shape [:- 1 ], padded_last_dim ), device = a .device , dtype = pt_dtype
2505
+ )
2506
+ a_shape = torch .Size ((* a_shape [:- 1 ], padded_last_dim ))
2507
+ else :
2508
+ a_fp8 = torch .empty (a .shape , device = a .device , dtype = pt_dtype )
2489
2509
2490
2510
# If input tensor is sufficiently large, we need to use int64 indexing.
2491
2511
use_int64 = a .numel () > (2 ** 31 - 1 )
@@ -2504,6 +2524,7 @@ def triton_quantize_fp8_row(
2504
2524
a .shape [1 ],
2505
2525
a .shape [2 ],
2506
2526
a .shape [3 ],
2527
+ a_fp8 .shape [3 ],
2507
2528
a .stride (0 ),
2508
2529
a .stride (1 ),
2509
2530
a .stride (2 ),
@@ -2908,6 +2929,7 @@ def quantize_fp8_row(
2908
2929
zero_start_index_M : Optional [Tensor ] = None ,
2909
2930
use_triton : bool = True ,
2910
2931
output_device : Optional [torch .device ] = None ,
2932
+ align_rows_to : Optional [int ] = None ,
2911
2933
) -> Tuple [torch .Tensor , torch .Tensor ]:
2912
2934
"""
2913
2935
Quantize a to fp8 with row-wise scalings and optionally move to output device.
@@ -2918,6 +2940,7 @@ def quantize_fp8_row(
2918
2940
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
2919
2941
use_triton (bool): Whether to use triton kernel or pytorch.
2920
2942
output_device (torch.device): Device to optionally move the scaled tensors to.
2943
+ align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
2921
2944
2922
2945
Returns:
2923
2946
torch.Tensor: fp8 scaled tensor.
@@ -2928,7 +2951,12 @@ def quantize_fp8_row(
2928
2951
logger .info ("Triton does not support cpu, falling back to torch ops." )
2929
2952
use_triton = False
2930
2953
if use_triton :
2931
- return triton_quantize_fp8_row (a , scale_ub , zero_start_index_M )
2954
+ return triton_quantize_fp8_row (
2955
+ a ,
2956
+ scale_ub ,
2957
+ zero_start_index_M ,
2958
+ align_rows_to = align_rows_to ,
2959
+ )
2932
2960
# else use pytorch implementation.
2933
2961
if not output_device :
2934
2962
output_device = a .device
@@ -2958,18 +2986,34 @@ def quantize_fp8_row(
2958
2986
def quantize_fp8_row_meta (
2959
2987
a : Tensor ,
2960
2988
scale_ub : Optional [Tensor ] = None ,
2989
+ zero_start_index_M : Optional [Tensor ] = None ,
2961
2990
use_triton : bool = True ,
2962
2991
output_device : Optional [torch .device ] = None ,
2992
+ align_rows_to : Optional [int ] = None ,
2963
2993
) -> Tuple [torch .Tensor , torch .Tensor ]:
2964
2994
"""Shape function for torch compile."""
2965
2995
if output_device is None :
2966
2996
output_device = a .device
2967
2997
a_shape = a .shape
2968
- # Flatten to 2D since each row of each potential batch gets a scale.
2969
2998
dtype = get_fp8_constants ()[0 ]
2970
- fake_out = torch .empty (a .shape , device = output_device , dtype = dtype )
2971
- fake_scale = torch .empty (a_shape [:- 1 ], device = output_device , dtype = torch .float32 )
2972
- return fake_out , fake_scale
2999
+ if align_rows_to is not None :
3000
+ last_dim = a .shape [- 1 ]
3001
+ padded_last_dim = (
3002
+ (last_dim + align_rows_to - 1 ) // align_rows_to
3003
+ ) * align_rows_to
3004
+ fake_out = torch .empty (
3005
+ (* a .shape [:- 1 ], padded_last_dim ), device = output_device , dtype = dtype
3006
+ )
3007
+ fake_scale = torch .empty (
3008
+ a_shape [:- 1 ], device = output_device , dtype = torch .float32
3009
+ )
3010
+ return fake_out , fake_scale
3011
+ else :
3012
+ fake_out = torch .empty (a .shape , device = output_device , dtype = dtype )
3013
+ fake_scale = torch .empty (
3014
+ a_shape [:- 1 ], device = output_device , dtype = torch .float32
3015
+ )
3016
+ return fake_out , fake_scale
2973
3017
2974
3018
2975
3019
@triton .autotune (
0 commit comments