Skip to content

Commit c70e309

Browse files
committed
[Prim][CINN] Fix group norm grad decomp sum type
1 parent 317a3de commit c70e309

File tree

1 file changed

+5
-2
lines changed
  • paddle/fluid/primitive/decomp_rule/decomp_vjp

1 file changed

+5
-2
lines changed

paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2307,15 +2307,18 @@ void group_norm_grad(const Tensor& x,
23072307
auto tmp1 = out_grad_data * (x_data - mean_new) * sqrt_var_1;
23082308

23092309
auto scale_grad_tmp = reshape<T>(
2310-
tmp1.sum(reduce_axis_except_channel, scale->dtype(), false), {-1});
2310+
tmp1.sum(reduce_axis_except_channel, x_data.dtype(), false), {-1});
2311+
scale_grad_tmp = ConvertToOrig<T>(scale_grad_tmp, scale->dtype());
2312+
23112313
set_output<T>(scale_grad_tmp, scale_grad);
23122314
}
23132315
}
23142316

23152317
if (bias_grad) {
23162318
if (bias) {
23172319
auto bias_grad_tmp =
2318-
out_grad_data.sum(reduce_axis_except_channel, bias->dtype(), false);
2320+
out_grad_data.sum(reduce_axis_except_channel, x_data.dtype(), false);
2321+
bias_grad_tmp = ConvertToOrig<T>(bias_grad_tmp, bias->dtype());
23192322

23202323
set_output<T>(reshape<T>(bias_grad_tmp, {-1}), bias_grad);
23212324
}

0 commit comments

Comments
 (0)