Skip to content

Commit a14cafd

Browse files
authored
fix bug (#10348)
1 parent 4113c5d commit a14cafd

File tree

2 files changed

+65
-25
lines changed

2 files changed

+65
-25
lines changed

paddlenlp/transformers/deepseek_v2/fp8_linear.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -188,18 +188,21 @@ def kitchen_quant(x, backend=None, is_1d_scaled=True, return_transpose=False):
188188

189189

190190
def kitchen_fp8_gemm(x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled, is_b_1d_scaled):
191-
y = kitchen.ops.fp8_gemm_blockwise(
192-
a=x_fp8,
193-
a_decode_scale=x_scale,
194-
b=w_fp8,
195-
b_decode_scale=w_scale,
196-
out_dtype=paddle.bfloat16,
197-
out=None,
198-
accumulate=False,
199-
use_split_accumulator=True,
200-
is_a_1d_scaled=is_a_1d_scaled,
201-
is_b_1d_scaled=is_b_1d_scaled,
202-
)
191+
if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0:
192+
y = kitchen.ops.fp8_gemm_blockwise(
193+
a=x_fp8,
194+
a_decode_scale=x_scale,
195+
b=w_fp8,
196+
b_decode_scale=w_scale,
197+
out_dtype=paddle.bfloat16,
198+
out=None,
199+
accumulate=False,
200+
use_split_accumulator=True,
201+
is_a_1d_scaled=is_a_1d_scaled,
202+
is_b_1d_scaled=is_b_1d_scaled,
203+
)
204+
else:
205+
y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], paddle.bfloat16)
203206
return y
204207

205208

@@ -229,8 +232,15 @@ def forward(ctx, x, weight):
229232
x_t = x.T
230233
# padding
231234
x_t_shape = x_t.shape
232-
if x_t.shape[-1] % 8 != 0:
233-
x_t = paddle.concat([x_t, paddle.zeros([x_t.shape[0], 8 - (x_t.shape[-1] % 8)], dtype=x_t.dtype)], axis=-1)
235+
if x_t.shape[-1] % 128 != 0 or x_t.shape[-1] % 512 != 0:
236+
if (x_t.shape[-1] + 128 - (x_t.shape[-1] % 128)) % 512 != 0:
237+
padding_size = 512
238+
else:
239+
padding_size = 128
240+
x_t = paddle.concat(
241+
[x_t, paddle.zeros([x_t.shape[0], padding_size - (x_t.shape[-1] % padding_size)], dtype=x_t.dtype)],
242+
axis=1,
243+
)
234244
x_t_quant, x_t_scale = kitchen_quant(
235245
x_t.contiguous(), backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
236246
)
@@ -262,9 +272,20 @@ def backward(ctx, dout):
262272
# compute dw = mm(x_t, dout_t)
263273
dout_t = dout.reshape([-1, dout.shape[-1]]).T.contiguous()
264274
# padding
265-
if dout_t.shape[-1] % 8 != 0:
266-
pad_size = 8 - (dout_t.shape[-1] % 8)
267-
dout_t = paddle.concat([dout_t, paddle.zeros([dout_t.shape[0], pad_size], dtype=dout_t.dtype)], axis=-1)
275+
if dout_t.shape[-1] % 128 != 0 or dout_t.shape[-1] % 512 != 0:
276+
if (dout_t.shape[-1] + 128 - (dout_t.shape[-1] % 128)) % 512 != 0:
277+
padding_size = 512
278+
else:
279+
padding_size = 128
280+
dout_t = paddle.concat(
281+
[
282+
dout_t,
283+
paddle.zeros(
284+
[dout_t.shape[0], padding_size - (dout_t.shape[-1] % padding_size)], dtype=dout_t.dtype
285+
),
286+
],
287+
axis=1,
288+
)
268289

269290
dout_t_quant, dout_t_scale = kitchen_quant(
270291
dout_t, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=False
@@ -301,8 +322,15 @@ def backward(ctx, dout):
301322
dx_orig_shape = x.shape
302323
# padding
303324
x = x.reshape([-1, x.shape[-1]])
304-
if x.shape[0] % 8 != 0:
305-
x = paddle.concat([x, paddle.zeros([8 - (x.shape[0] % 8), x.shape[-1]], dtype=x.dtype)], axis=0)
325+
if x.shape[0] % 128 != 0 or x.shape[0] % 512 != 0:
326+
if (x.shape[0] + 128 - (x.shape[0] % 128)) % 512 != 0:
327+
padding_size = 512
328+
else:
329+
padding_size = 128
330+
x = paddle.concat(
331+
[x, paddle.zeros([padding_size - (x.shape[0] % padding_size), x.shape[-1]], dtype=x.dtype)],
332+
axis=0,
333+
)
306334

307335
_, _, x_t_quant, x_t_scale = kitchen_quant(
308336
x, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=True
@@ -325,10 +353,20 @@ def backward(ctx, dout):
325353

326354
# compute dw = mm(x_t, dout_t)
327355
dout_t = dout.reshape([-1, dout.shape[-1]])
328-
329-
if dout_t.shape[0] % 8 != 0:
330-
pad_size = 8 - (dout_t.shape[0] % 8)
331-
dout_t = paddle.concat([dout_t, paddle.zeros([pad_size, dout_t.shape[-1]], dtype=dout_t.dtype)], axis=0)
356+
if dout_t.shape[0] % 128 != 0 or dout_t.shape[0] % 512 != 0:
357+
if (dout_t.shape[0] + 128 - (dout_t.shape[0] % 128)) % 512 != 0:
358+
padding_size = 512
359+
else:
360+
padding_size = 128
361+
dout_t = paddle.concat(
362+
[
363+
dout_t,
364+
paddle.zeros(
365+
[padding_size - (dout_t.shape[0] % padding_size), dout_t.shape[-1]], dtype=dout_t.dtype
366+
),
367+
],
368+
axis=0,
369+
)
332370

333371
_, _, dout_t_quant, dout_t_scale = kitchen_quant(
334372
dout_t, backend=kitchen.ops.Backend.CUBLAS, is_1d_scaled=True, return_transpose=True

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,12 +1361,14 @@ def forward(ctx, x, rms_norm_weight, q_down_weight, kv_down_weight, eps):
13611361

13621362
kv = paddle.matmul(hidden_states, kv_down_weight)
13631363

1364-
ctx.save_for_backward(x, rms_norm_weight, q_down_weight, kv_down_weight, eps)
1364+
ctx.save_for_backward(x, rms_norm_weight, q_down_weight, kv_down_weight)
1365+
ctx.eps = eps
13651366
return q, kv
13661367

13671368
@staticmethod
13681369
def backward(ctx, d_q, d_kv):
1369-
x, rms_norm_weight, q_down_weight, kv_down_weight, eps = ctx.saved_tensor()
1370+
x, rms_norm_weight, q_down_weight, kv_down_weight = ctx.saved_tensor()
1371+
eps = ctx.eps
13701372
hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
13711373

13721374
h_grad_0, d_q_down_weight = _C_ops.matmul_grad(hidden_states, q_down_weight, d_q, False, False)

0 commit comments

Comments
 (0)