@@ -2328,6 +2328,7 @@ def _kernel_quantize_fp8_row(
2328
2328
M ,
2329
2329
N ,
2330
2330
K ,
2331
+ K_fp8 , # used when padding
2331
2332
stride_ab ,
2332
2333
stride_am ,
2333
2334
stride_an ,
@@ -2364,7 +2365,8 @@ def _kernel_quantize_fp8_row(
2364
2365
B (int): Size of dimenion 0
2365
2366
M (int): Size of dimenion 1
2366
2367
N (int): Size of dimenion 2
2367
- K (int): Size of dimenion 3
2368
+ K (int): Size of dimenion 3 (input row size)
2369
+ K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
2368
2370
stride_ab (int): Stride of b dimension of A.
2369
2371
stride_am (int): Stride of m dimension of A.
2370
2372
stride_an (int): Stride of n dimension of A.
@@ -2433,21 +2435,32 @@ def _kernel_quantize_fp8_row(
2433
2435
tl .store (A_scale + pid , 1.0 / a_scale )
2434
2436
n_offset = tl .arange (0 , BLOCK_SIZE )
2435
2437
2436
- for _k in range (0 , tl .cdiv (K , BLOCK_SIZE )):
2438
+ # Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
2439
+ for _k in range (0 , tl .cdiv (K_fp8 , BLOCK_SIZE )):
2440
+ # For the first K elements, use A; for the rest, use 0
2441
+ # Compute the valid range for this tile
2442
+ tile_start = _k * BLOCK_SIZE
2443
+ # Calculate masks for both cases
2444
+ mask_in_A = (n_offset + tile_start ) < K_in
2445
+ mask_in_A_fp8 = (n_offset + tile_start ) < K_fp8
2446
+
2447
+ # Load from A if in range, else 0 (we're going all the way to K_fp8)
2437
2448
a = tl .load (
2438
- A + a_offset_base + n_offset * stride_ak ,
2439
- mask = n_offset < K_in ,
2449
+ A + a_offset_base + ( n_offset + tile_start ) * stride_ak ,
2450
+ mask = mask_in_A & mask_in_A_fp8 ,
2440
2451
other = 0.0 ,
2441
2452
)
2453
+ # For elements >= K, a will be 0
2442
2454
a_fp8 = a * a_scale
2443
2455
# Clamp A to fp8 range to make sure there's no overflow.
2444
2456
# This is required for AMD. Nvidia's default saturation
2445
2457
# handles it, but it's nice to have anyway.
2446
2458
a_fp8 = tl .clamp (a_fp8 , - MAX_FP8 , MAX_FP8 ).to (TL_FP8_DTYPE )
2459
+ # For elements >= K, a_fp8 is already 0
2447
2460
tl .store (
2448
- A_fp8 + a_fp8_offset_base + n_offset * stride_ok ,
2461
+ A_fp8 + a_fp8_offset_base + ( n_offset + tile_start ) * stride_ok ,
2449
2462
a_fp8 ,
2450
- mask = n_offset < K ,
2463
+ mask = mask_in_A_fp8 ,
2451
2464
)
2452
2465
n_offset += BLOCK_SIZE
2453
2466
@@ -2456,6 +2469,9 @@ def triton_quantize_fp8_row(
2456
2469
a : Tensor ,
2457
2470
scale_ub : Optional [Tensor ] = None ,
2458
2471
zero_start_index_M : Optional [Tensor ] = None ,
2472
+ pad_rows_to_multiple_of : Optional [
2473
+ int
2474
+ ] = None , # TODO(yyetim) Add a test case, validate padding and 0 setting
2459
2475
) -> Tuple [Tensor , Tensor ]:
2460
2476
"""
2461
2477
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
@@ -2464,6 +2480,7 @@ def triton_quantize_fp8_row(
2464
2480
a (Tensor): higher precision input tensor of 4 dimension.
2465
2481
scale_ub (Tensor): Maximum allowed value for scale.
2466
2482
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
2483
+ pad_rows_to_multiple_of: Pad rows to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
2467
2484
2468
2485
Returns:
2469
2486
torch.Tensor: fp8 scaled tensor.
@@ -2485,7 +2502,18 @@ def triton_quantize_fp8_row(
2485
2502
pt_dtype , tl_dtype , max_fp8 , eps = get_fp8_constants ()
2486
2503
num_rows = a .numel () // a .shape [- 1 ]
2487
2504
a_scale = torch .empty ((num_rows ), dtype = torch .float32 , device = a .device )
2488
- a_fp8 = torch .empty (a .shape , device = a .device , dtype = pt_dtype )
2505
+ # If pad_rows_to_multiple_of is provided, pad the last dimension to be a multiple of it
2506
+ if pad_rows_to_multiple_of is not None :
2507
+ last_dim = a .shape [- 1 ]
2508
+ padded_last_dim = (
2509
+ (last_dim + pad_rows_to_multiple_of - 1 ) // pad_rows_to_multiple_of
2510
+ ) * pad_rows_to_multiple_of
2511
+ a_fp8 = torch .empty (
2512
+ (* a .shape [:- 1 ], padded_last_dim ), device = a .device , dtype = pt_dtype
2513
+ )
2514
+ a_shape = (* a_shape [:- 1 ], padded_last_dim )
2515
+ else :
2516
+ a_fp8 = torch .empty (a .shape , device = a .device , dtype = pt_dtype )
2489
2517
2490
2518
# If input tensor is sufficiently large, we need to use int64 indexing.
2491
2519
use_int64 = a .numel () > (2 ** 31 - 1 )
@@ -2504,6 +2532,7 @@ def triton_quantize_fp8_row(
2504
2532
a .shape [1 ],
2505
2533
a .shape [2 ],
2506
2534
a .shape [3 ],
2535
+ a_fp8 .shape [3 ],
2507
2536
a .stride (0 ),
2508
2537
a .stride (1 ),
2509
2538
a .stride (2 ),
@@ -2908,6 +2937,7 @@ def quantize_fp8_row(
2908
2937
zero_start_index_M : Optional [Tensor ] = None ,
2909
2938
use_triton : bool = True ,
2910
2939
output_device : Optional [torch .device ] = None ,
2940
+ pad_rows_to_multiple_of : Optional [int ] = None ,
2911
2941
) -> Tuple [torch .Tensor , torch .Tensor ]:
2912
2942
"""
2913
2943
Quantize a to fp8 with row-wise scalings and optionally move to output device.
@@ -2928,7 +2958,12 @@ def quantize_fp8_row(
2928
2958
logger .info ("Triton does not support cpu, falling back to torch ops." )
2929
2959
use_triton = False
2930
2960
if use_triton :
2931
- return triton_quantize_fp8_row (a , scale_ub , zero_start_index_M )
2961
+ return triton_quantize_fp8_row (
2962
+ a ,
2963
+ scale_ub ,
2964
+ zero_start_index_M ,
2965
+ pad_rows_to_multiple_of = pad_rows_to_multiple_of ,
2966
+ )
2932
2967
# else use pytorch implementation.
2933
2968
if not output_device :
2934
2969
output_device = a .device
@@ -2958,18 +2993,34 @@ def quantize_fp8_row(
2958
2993
def quantize_fp8_row_meta (
2959
2994
a : Tensor ,
2960
2995
scale_ub : Optional [Tensor ] = None ,
2996
+ zero_start_index_M : Optional [Tensor ] = None ,
2961
2997
use_triton : bool = True ,
2962
2998
output_device : Optional [torch .device ] = None ,
2999
+ pad_rows_to_multiple_of : Optional [int ] = None ,
2963
3000
) -> Tuple [torch .Tensor , torch .Tensor ]:
2964
3001
"""Shape function for torch compile."""
2965
3002
if output_device is None :
2966
3003
output_device = a .device
2967
3004
a_shape = a .shape
2968
- # Flatten to 2D since each row of each potential batch gets a scale.
2969
3005
dtype = get_fp8_constants ()[0 ]
2970
- fake_out = torch .empty (a .shape , device = output_device , dtype = dtype )
2971
- fake_scale = torch .empty (a_shape [:- 1 ], device = output_device , dtype = torch .float32 )
2972
- return fake_out , fake_scale
3006
+ if pad_rows_to_multiple_of is not None :
3007
+ last_dim = a .shape [- 1 ]
3008
+ padded_last_dim = (
3009
+ (last_dim + pad_rows_to_multiple_of - 1 ) // pad_rows_to_multiple_of
3010
+ ) * pad_rows_to_multiple_of
3011
+ fake_out = torch .empty (
3012
+ (* a .shape [:- 1 ], padded_last_dim ), device = output_device , dtype = dtype
3013
+ )
3014
+ fake_scale = torch .empty (
3015
+ a_shape [:- 1 ], device = output_device , dtype = torch .float32
3016
+ )
3017
+ return fake_out , fake_scale
3018
+ else :
3019
+ fake_out = torch .empty (a .shape , device = output_device , dtype = dtype )
3020
+ fake_scale = torch .empty (
3021
+ a_shape [:- 1 ], device = output_device , dtype = torch .float32
3022
+ )
3023
+ return fake_out , fake_scale
2973
3024
2974
3025
2975
3026
@triton .autotune (
0 commit comments