Skip to content

Add fused_stack_transpose_quant kernel (optional transpose) #10649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 28, 2025

Conversation

lshpku
Copy link

@lshpku lshpku commented May 23, 2025

PR types

Function optimization

PR changes

APIs

Description

实现了fused_stack_quantfused_stack_transpose_quant两个融合算子,定义如下:

/**
 * Stack tensors in X and do quantization on both dim[-1] and dim[-2].
 *
 * Inputs:
 *   X    : N tensors of [M, K], bfloat16
 *
 * Outputs:
 *   out  : [N * M, K], float8_e4m3fn
 *   scale: [N * M / 128, K / 128], float
 *
 * Requirements:
 *   1) N <= 65535
 *   2) M % 128 == 0
 *   3) K % 128 == 0
 */
std::vector<paddle::Tensor> fused_stack_quant(
    const std::vector<paddle::Tensor>& X);
/**
 * Stack tensors in X, transpose dim[-1] and dim[-2], and do quantization 
 * on both dim[-1] and dim[-2].
 *
 * Inputs:
 *   X    : N tensors of [M, K], bfloat16
 *
 * Outputs:
 *   out  : [N * K, M], float8_e4m3fn
 *   scale: [N * K / 128, M / 128], float
 *
 * Requirements:
 *   1) N <= 65535
 *   2) M % 128 == 0
 *   3) K % 128 == 0
 */
std::vector<paddle::Tensor> fused_stack_transpose_quant(
    const std::vector<paddle::Tensor>& X);

(两者唯一区别是,带transpose的版本会在stack后先交换 M、K 维度,再进行后续quant操作)

性能测试

在A100-40G上做了初步测试,由于A100不支持fp8,因此在cast fp32 to fp8的时候用了int8代替,基本可以反映H卡上的性能

测试对象 输入个数与shape 用时(ns) 带宽(GBps) 带宽利用率
fused_stack_quant 4 * [7168, 4096] 258,841 1361 87.5%
fused_stack_transpose_quant 4 * [7168, 4096] 270,972 1300 83.6%

Pcard-85711

Copy link

paddle-bot bot commented May 23, 2025

Thanks for your contribution!

@lshpku lshpku changed the title Add fused_stack_transpose_quant kernel (optional transpose) Add fused_stack_(transpose_)quant kernel May 26, 2025
@lshpku lshpku force-pushed the fused-stack-transpose-quant branch from ae277da to 7b72a18 Compare May 28, 2025 02:52
@lshpku lshpku changed the title Add fused_stack_(transpose_)quant kernel Add fused_stack_transpose_quant kernel (optional transpose) May 28, 2025
@phlrain phlrain self-requested a review May 28, 2025 06:34
@phlrain phlrain merged commit 6c206e1 into PaddlePaddle:dsv3_dev May 28, 2025
2 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants