针对单精度矩阵乘法, 做出一些比较容易想到的优化技巧
在 4060 上实验效果:
===== M = 256, N = 256, K = 1024 =====
cublas | time: 0.038912 ms | max diff: N/A
v1 | time: 1.213440 ms | max diff: 0.000046
v2 | time: 0.249856 ms | max diff: 0.000046
v3 | time: 0.196448 ms | max diff: 0.000046
v4 | time: 0.116736 ms | max diff: 0.000046
v5 | time: 0.132000 ms | max diff: 0.000046
v6 | time: 0.121856 ms | max diff: 0.000046
v7 | time: 0.139264 ms | max diff: 0.000046
v8 | time: 0.155648 ms | max diff: 0.000046
===== M = 512, N = 512, K = 1024 =====
cublas | time: 0.114688 ms | max diff: N/A
v1 | time: 0.908224 ms | max diff: 0.000061
v2 | time: 0.698368 ms | max diff: 0.000061
v3 | time: 0.621568 ms | max diff: 0.000061
v4 | time: 0.149472 ms | max diff: 0.000061
v5 | time: 0.158720 ms | max diff: 0.000061
v6 | time: 0.220160 ms | max diff: 0.000061
v7 | time: 0.186304 ms | max diff: 0.000061
v8 | time: 0.176128 ms | max diff: 0.000061
===== M = 1024, N = 1024, K = 1024 =====
cublas | time: 0.401408 ms | max diff: N/A
v1 | time: 3.567616 ms | max diff: 0.000072
v2 | time: 2.742272 ms | max diff: 0.000072
v3 | time: 2.458624 ms | max diff: 0.000072
v4 | time: 0.491520 ms | max diff: 0.000072
v5 | time: 0.535296 ms | max diff: 0.000072
v6 | time: 0.604160 ms | max diff: 0.000072
v7 | time: 0.578560 ms | max diff: 0.000072
v8 | time: 0.616448 ms | max diff: 0.000072
===== M = 2048, N = 2048, K = 1024 =====
cublas | time: 0.966656 ms | max diff: N/A
v1 | time: 8.936448 ms | max diff: 0.000000
v2 | time: 6.849536 ms | max diff: 0.000000
v3 | time: 6.164480 ms | max diff: 0.000000
v4 | time: 1.184768 ms | max diff: 0.000000
v5 | time: 1.287168 ms | max diff: 0.000000
v6 | time: 1.438720 ms | max diff: 0.000000
v7 | time: 1.427456 ms | max diff: 0.000000
v8 | time: 1.478656 ms | max diff: 0.000000
===== M = 4096, N = 4096, K = 1024 =====
cublas | time: 3.634176 ms | max diff: N/A
v1 | time: 36.820992 ms | max diff: 0.000000
v2 | time: 30.008320 ms | max diff: 0.000000
v3 | time: 27.866112 ms | max diff: 0.000000
v4 | time: 4.813824 ms | max diff: 0.000000
v5 | time: 5.111808 ms | max diff: 0.000000
v6 | time: 5.800960 ms | max diff: 0.000000
v7 | time: 5.719040 ms | max diff: 0.000000
v8 | time: 5.867520 ms | max diff: 0.000000
===== M = 8192, N = 8192, K = 1024 =====
cublas | time: 14.385152 ms | max diff: N/A
v1 | time: 179.564545 ms | max diff: 0.000000
v2 | time: 136.337402 ms | max diff: 0.000000
v3 | time: 112.858109 ms | max diff: 0.000000
v4 | time: 19.419136 ms | max diff: 0.000000
v5 | time: 20.897793 ms | max diff: 0.000000
v6 | time: 26.982401 ms | max diff: 0.000000
v7 | time: 25.618431 ms | max diff: 0.000000
v8 | time: 24.745983 ms | max diff: 0.000000
- 使用 shared memory 加快从 global memory 的数据读取
- 使用向量化
float4
加载 - 使用 register 加快从 shared memory 的数据读取
- 使用 double buffer 隐藏从 gmem 读取数据到 smem 的时间开销
- 使用 padding, swizzle, 更改数据加载方式三种方式来缓解 bank conflict