@@ -188,18 +188,21 @@ def kitchen_quant(x, backend=None, is_1d_scaled=True, return_transpose=False):
188
188
189
189
190
190
def kitchen_fp8_gemm (x_fp8 , x_scale , w_fp8 , w_scale , is_a_1d_scaled , is_b_1d_scaled ):
191
- y = kitchen .ops .fp8_gemm_blockwise (
192
- a = x_fp8 ,
193
- a_decode_scale = x_scale ,
194
- b = w_fp8 ,
195
- b_decode_scale = w_scale ,
196
- out_dtype = paddle .bfloat16 ,
197
- out = None ,
198
- accumulate = False ,
199
- use_split_accumulator = True ,
200
- is_a_1d_scaled = is_a_1d_scaled ,
201
- is_b_1d_scaled = is_b_1d_scaled ,
202
- )
191
+ if numpy .prod (x_fp8 .shape ) != 0 and numpy .prod (w_fp8 .shape ) != 0 :
192
+ y = kitchen .ops .fp8_gemm_blockwise (
193
+ a = x_fp8 ,
194
+ a_decode_scale = x_scale ,
195
+ b = w_fp8 ,
196
+ b_decode_scale = w_scale ,
197
+ out_dtype = paddle .bfloat16 ,
198
+ out = None ,
199
+ accumulate = False ,
200
+ use_split_accumulator = True ,
201
+ is_a_1d_scaled = is_a_1d_scaled ,
202
+ is_b_1d_scaled = is_b_1d_scaled ,
203
+ )
204
+ else :
205
+ y = paddle .zeros ([x_fp8 .shape [0 ], w_fp8 .shape [0 ]], paddle .bfloat16 )
203
206
return y
204
207
205
208
@@ -229,8 +232,15 @@ def forward(ctx, x, weight):
229
232
x_t = x .T
230
233
# padding
231
234
x_t_shape = x_t .shape
232
- if x_t .shape [- 1 ] % 8 != 0 :
233
- x_t = paddle .concat ([x_t , paddle .zeros ([x_t .shape [0 ], 8 - (x_t .shape [- 1 ] % 8 )], dtype = x_t .dtype )], axis = - 1 )
235
+ if x_t .shape [- 1 ] % 128 != 0 or x_t .shape [- 1 ] % 512 != 0 :
236
+ if (x_t .shape [- 1 ] + 128 - (x_t .shape [- 1 ] % 128 )) % 512 != 0 :
237
+ padding_size = 512
238
+ else :
239
+ padding_size = 128
240
+ x_t = paddle .concat (
241
+ [x_t , paddle .zeros ([x_t .shape [0 ], padding_size - (x_t .shape [- 1 ] % padding_size )], dtype = x_t .dtype )],
242
+ axis = 1 ,
243
+ )
234
244
x_t_quant , x_t_scale = kitchen_quant (
235
245
x_t .contiguous (), backend = kitchen .ops .Backend .CUBLAS , is_1d_scaled = True , return_transpose = False
236
246
)
@@ -262,9 +272,20 @@ def backward(ctx, dout):
262
272
# compute dw = mm(x_t, dout_t)
263
273
dout_t = dout .reshape ([- 1 , dout .shape [- 1 ]]).T .contiguous ()
264
274
# padding
265
- if dout_t .shape [- 1 ] % 8 != 0 :
266
- pad_size = 8 - (dout_t .shape [- 1 ] % 8 )
267
- dout_t = paddle .concat ([dout_t , paddle .zeros ([dout_t .shape [0 ], pad_size ], dtype = dout_t .dtype )], axis = - 1 )
275
+ if dout_t .shape [- 1 ] % 128 != 0 or dout_t .shape [- 1 ] % 512 != 0 :
276
+ if (dout_t .shape [- 1 ] + 128 - (dout_t .shape [- 1 ] % 128 )) % 512 != 0 :
277
+ padding_size = 512
278
+ else :
279
+ padding_size = 128
280
+ dout_t = paddle .concat (
281
+ [
282
+ dout_t ,
283
+ paddle .zeros (
284
+ [dout_t .shape [0 ], padding_size - (dout_t .shape [- 1 ] % padding_size )], dtype = dout_t .dtype
285
+ ),
286
+ ],
287
+ axis = 1 ,
288
+ )
268
289
269
290
dout_t_quant , dout_t_scale = kitchen_quant (
270
291
dout_t , backend = kitchen .ops .Backend .CUBLAS , is_1d_scaled = True , return_transpose = False
@@ -301,8 +322,15 @@ def backward(ctx, dout):
301
322
dx_orig_shape = x .shape
302
323
# padding
303
324
x = x .reshape ([- 1 , x .shape [- 1 ]])
304
- if x .shape [0 ] % 8 != 0 :
305
- x = paddle .concat ([x , paddle .zeros ([8 - (x .shape [0 ] % 8 ), x .shape [- 1 ]], dtype = x .dtype )], axis = 0 )
325
+ if x .shape [0 ] % 128 != 0 or x .shape [0 ] % 512 != 0 :
326
+ if (x .shape [0 ] + 128 - (x .shape [0 ] % 128 )) % 512 != 0 :
327
+ padding_size = 512
328
+ else :
329
+ padding_size = 128
330
+ x = paddle .concat (
331
+ [x , paddle .zeros ([padding_size - (x .shape [0 ] % padding_size ), x .shape [- 1 ]], dtype = x .dtype )],
332
+ axis = 0 ,
333
+ )
306
334
307
335
_ , _ , x_t_quant , x_t_scale = kitchen_quant (
308
336
x , backend = kitchen .ops .Backend .CUBLAS , is_1d_scaled = True , return_transpose = True
@@ -325,10 +353,20 @@ def backward(ctx, dout):
325
353
326
354
# compute dw = mm(x_t, dout_t)
327
355
dout_t = dout .reshape ([- 1 , dout .shape [- 1 ]])
328
-
329
- if dout_t .shape [0 ] % 8 != 0 :
330
- pad_size = 8 - (dout_t .shape [0 ] % 8 )
331
- dout_t = paddle .concat ([dout_t , paddle .zeros ([pad_size , dout_t .shape [- 1 ]], dtype = dout_t .dtype )], axis = 0 )
356
+ if dout_t .shape [0 ] % 128 != 0 or dout_t .shape [0 ] % 512 != 0 :
357
+ if (dout_t .shape [0 ] + 128 - (dout_t .shape [0 ] % 128 )) % 512 != 0 :
358
+ padding_size = 512
359
+ else :
360
+ padding_size = 128
361
+ dout_t = paddle .concat (
362
+ [
363
+ dout_t ,
364
+ paddle .zeros (
365
+ [padding_size - (dout_t .shape [0 ] % padding_size ), dout_t .shape [- 1 ]], dtype = dout_t .dtype
366
+ ),
367
+ ],
368
+ axis = 0 ,
369
+ )
332
370
333
371
_ , _ , dout_t_quant , dout_t_scale = kitchen_quant (
334
372
dout_t , backend = kitchen .ops .Backend .CUBLAS , is_1d_scaled = True , return_transpose = True
0 commit comments