Skip to content

Commit 2ed5768

Browse files
authored
[BENCHMARK] Update cutlass's gemm configuration file (#4840)
Following on the PR [#4720](#4720), this PR updates the GEMM configuration file used by the CUTLASS provider to improve performance. As mentioned in this [comment](#4720 (comment)), this updated configuration is not the official CUTLASS one and does not deliver the best known performance for GEMM in CUTLASS. **Note:** Work to integrate the best known performance configuration is already being tracked in the issue [#4775](#4775). Signed-off-by: Jefferson Le Quellec <jefferson.lequellec@codeplay.com>
1 parent ce03fe5 commit 2ed5768

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

benchmarks/cutlass_kernel/gemm/input_gemm.in

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
# OLD SHAPES : `cutlass-sycl/benchmarks/device/pvc/input_files/input_gemm.in`
2+
13
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=5120 --n=13824
24
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4 --k=4096 --n=12288
35
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192
46
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192
57
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768
8+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=28672 --n=8192
69
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192
7-
PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=28672 --n=8192
810
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=3072 --k=4096 --n=3072
911
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096
1012
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192
@@ -17,5 +19,23 @@ PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=
1719
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192
1820
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128
1921
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096
20-
PvcGemmBF16BF16FP32_RRR_3 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128
21-
PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128
22+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128
23+
PvcGemmBF16BF16FP32_RRR_3 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128
24+
PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=128 --n=16384
25+
26+
# NEW SHAPES : `cutlass-sycl/benchmarks/device/pvc/input_files/input_pytorch_2.in`
27+
28+
PvcGemmBF16BF16FP32_RCR_16 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=4096 --n=1024
29+
PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=4096 --n=4096
30+
PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=14336 --n=4096
31+
PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=4096 --n=6144
32+
PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=4096 --n=14336
33+
PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=4096 --n=28672
34+
PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=4096 --n=128256
35+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8 --k=4096 --n=1024
36+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8 --k=4096 --n=4096
37+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8 --k=14336 --n=4096
38+
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8 --k=4096 --n=6144
39+
PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=1 --m=8 --k=4096 --n=14336
40+
PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=1 --m=8 --k=4096 --n=28672
41+
PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=1 --m=8 --k=4096 --n=128256

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,13 @@ def xetla_func_with_acc_allocation():
430430
name = 'gemm'
431431
func = getattr(cutlass_kernel, name)
432432

433+
# Special case where the b matrix needs to be transposed (see: `./cutlass_kernel/gemm/input_gemm.in`)
434+
if (B, M, N, K) == (1, 1, 1024, 4096):
435+
_, b_shape = get_shapes(B, M, N, K, transpose_a=False, transpose_b=True)
436+
b = torch.reshape(b, b_shape)
437+
torch_b = b
438+
torch_b = torch.transpose(torch_b, -2, -1)
439+
433440
def cutlass_invoker():
434441
if B == 1:
435442
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)

0 commit comments

Comments
 (0)