Skip to content

[Prim][CINN] Fix group_norm_grad decomp sum dtype #73037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

lshpku
Copy link
Contributor

@lshpku lshpku commented May 30, 2025

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

修复group_norm_grad组合算子拆解中sum的dtype设置错误的问题,即半精度下应当基于float类型进行sum,解决test_group_norm_op_deprecated单测的精度问题(这个单测在CI里不跑,只能手动验证)

参考layer_norm_grad:

auto scale_grad_tmp = (x_sub_mean_mul_sqrt_var_1 * out_grad_cast)
.sum(un_normalized_axis, x_cast.dtype(), true);
scale_grad_tmp = reshape<T>(scale_grad_tmp, {-1});
scale_grad_tmp = ConvertToOrig<T>(scale_grad_tmp, scale_ptr->dtype());
set_output<T>(scale_grad_tmp, scale_grad);

检查了下另外几个batch_norm/layer_norm/instance_norm的反向都是对的,为啥只有group_norm_grad是错的

Pcard-85711

Copy link

paddle-bot bot commented May 30, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@lshpku lshpku changed the title [Prim][CINN] Fix group norm grad decomp sum type [Prim][CINN] Fix group_norm_grad decomp sum type May 30, 2025
@lshpku lshpku changed the title [Prim][CINN] Fix group_norm_grad decomp sum type [Prim][CINN] Fix group_norm_grad decomp sum dtype May 30, 2025
@lshpku lshpku force-pushed the fix-group-norm-grad-decomp branch from c70e309 to c81847f Compare May 30, 2025 11:51
Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wanghuancoder wanghuancoder merged commit 4973f07 into PaddlePaddle:develop Jun 4, 2025
52 of 53 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants