File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed
paddle/fluid/primitive/decomp_rule/decomp_vjp Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -2307,15 +2307,18 @@ void group_norm_grad(const Tensor& x,
2307
2307
auto tmp1 = out_grad_data * (x_data - mean_new) * sqrt_var_1;
2308
2308
2309
2309
auto scale_grad_tmp = reshape<T>(
2310
- tmp1.sum (reduce_axis_except_channel, scale->dtype (), false ), {-1 });
2310
+ tmp1.sum (reduce_axis_except_channel, x_data.dtype (), false ), {-1 });
2311
+ scale_grad_tmp = ConvertToOrig<T>(scale_grad_tmp, scale->dtype ());
2312
+
2311
2313
set_output<T>(scale_grad_tmp, scale_grad);
2312
2314
}
2313
2315
}
2314
2316
2315
2317
if (bias_grad) {
2316
2318
if (bias) {
2317
2319
auto bias_grad_tmp =
2318
- out_grad_data.sum (reduce_axis_except_channel, bias->dtype (), false );
2320
+ out_grad_data.sum (reduce_axis_except_channel, x_data.dtype (), false );
2321
+ bias_grad_tmp = ConvertToOrig<T>(bias_grad_tmp, bias->dtype ());
2319
2322
2320
2323
set_output<T>(reshape<T>(bias_grad_tmp, {-1 }), bias_grad);
2321
2324
}
You can’t perform that action at this time.
0 commit comments