Skip to content

Commit 4973f07

Browse files
authored
[Prim][CINN] Fix group_norm_grad decomp sum dtype (#73037)
1 parent 4c802cc commit 4973f07

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2307,15 +2307,18 @@ void group_norm_grad(const Tensor& x,
23072307
auto tmp1 = out_grad_data * (x_data - mean_new) * sqrt_var_1;
23082308

23092309
auto scale_grad_tmp = reshape<T>(
2310-
tmp1.sum(reduce_axis_except_channel, scale->dtype(), false), {-1});
2310+
tmp1.sum(reduce_axis_except_channel, x_data.dtype(), false), {-1});
2311+
scale_grad_tmp = ConvertToOrig<T>(scale_grad_tmp, scale->dtype());
2312+
23112313
set_output<T>(scale_grad_tmp, scale_grad);
23122314
}
23132315
}
23142316

23152317
if (bias_grad) {
23162318
if (bias) {
23172319
auto bias_grad_tmp =
2318-
out_grad_data.sum(reduce_axis_except_channel, bias->dtype(), false);
2320+
out_grad_data.sum(reduce_axis_except_channel, x_data.dtype(), false);
2321+
bias_grad_tmp = ConvertToOrig<T>(bias_grad_tmp, bias->dtype());
23192322

23202323
set_output<T>(reshape<T>(bias_grad_tmp, {-1}), bias_grad);
23212324
}

test/deprecated/legacy_test/test_group_norm_op_deprecated.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ def setUp(self):
586586
self.python_out_sig = ["Y"]
587587
self.data_format = "NHWC"
588588
self.prim_op_type = "comp"
589+
self.channel_last = True
589590

590591
self.dtype = np.uint16
591592
self.shape = (1, 3, 5, 512)

0 commit comments

Comments
 (0)