Skip to content

Commit f2e2357

Browse files
author
root
committed
support swapAB for m_grouped_fp8_gemm_nt_masked
1 parent 51d1e9c commit f2e2357

File tree

7 files changed

+457
-7
lines changed

7 files changed

+457
-7
lines changed

csrc/jit_kernels/heuristics/common.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,17 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
152152

153153
// Select M/N block sizes
154154
// TODO: support `% 16 == 8` block size on SM90
155-
const auto& block_ms = gemm_type == GemmType::MGroupedContiguous ?
155+
156+
std::vector<int> block_ms = gemm_type == GemmType::MGroupedContiguous ?
156157
std::vector{get_mk_alignment_for_contiguous_layout()} : std::vector{64, 128, 256};
157158
std::vector<int> block_ns;
158159
for (int i = 16; i <= 256; i += 16)
159160
block_ns.push_back(i);
161+
if(get_env<int>("ENABLE_SWAPAB")){
162+
block_ms = std::vector{32}; // 32, 64
163+
block_ns = std::vector{256}; // 64, 128, 256
164+
}
165+
160166

161167
// K block size is selected in a fixed manner
162168
const auto& block_k = 128 / static_cast<int>(c10::elementSize(ab_dtype));

csrc/jit_kernels/heuristics/sm90.hpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,15 @@ struct SM90ArchSpec {
4242

4343
// Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k`
4444
// Or too many register spills
45-
if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192))
46-
return false;
4745

46+
if(get_env<int>("ENABLE_SWAPAB")){
47+
if (block_n != 64 and block_n != 128 and block_n != 256)
48+
return false;
49+
}else{
50+
if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192))
51+
return false;
52+
}
53+
4854
// Avoid bank conflicts for FP32 output
4955
if (cd_dtype == torch::kFloat and block_n % 16 == 0)
5056
return false;
@@ -77,7 +83,13 @@ struct SM90ArchSpec {
7783

7884
static ThreadConfig get_thread_config(const KernelType& kernel_type,
7985
const int& block_m, const int& block_n) {
80-
return ThreadConfig::sm90(128, (block_m == 64 ? 1 : 2) * 128);
86+
int tile = 64;
87+
if(get_env<int>("ENABLE_SWAPAB")){
88+
tile = block_n;
89+
}else{
90+
tile = block_m;
91+
}
92+
return ThreadConfig::sm90(128, (tile > 64 ? 2 : 1) * 128);
8193
}
8294

8395
static int get_smem_cd_size(const KernelType& kernel_type,
@@ -102,7 +114,8 @@ struct SM90ArchSpec {
102114

103115
static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k,
104116
const int& block_m, const int& block_n, const int& block_k) {
105-
const auto& use_uniform_sfb = block_k % block_n == 0 ? 1 : 2;
117+
const auto& use_uniform_sfb = get_env<int>("ENABLE_SWAPAB") ? (block_n / 64):(block_k % block_n == 0 ? 1 : 2);
118+
106119
return align<int>(ceil_div(k, block_k) * static_cast<int>(sizeof(float)) * use_uniform_sfb, 8);
107120
}
108121

csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,19 @@ class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime<SM90FP8Gemm1D2DRuntime>
2929
};
3030

3131
static std::string generate_impl(const Args& args) {
32+
33+
const char* kernel_name =
34+
get_env<int>("ENABLE_SWAPAB") ?
35+
"swapAB_sm90_fp8_gemm_1d2d_impl" :
36+
"sm90_fp8_gemm_1d2d_impl";
37+
3238
return fmt::format(R"(
3339
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
3440
3541
using namespace deep_gemm;
3642
3743
static void __instantiate_kernel() {{
38-
auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d2d_impl<
44+
auto ptr = reinterpret_cast<void*>(&{}<
3945
{}, {}, {},
4046
{},
4147
{}, {}, {},
@@ -47,6 +53,7 @@ static void __instantiate_kernel() {{
4753
>);
4854
}};
4955
)",
56+
kernel_name,
5057
// TODO: add CD dtype
5158
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
5259
args.num_groups,

deep_gemm/include/deep_gemm/common/sm90_utils.cuh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,17 @@ struct SM90_U32x2_STSM_N {
144144
}
145145
};
146146

147+
template <typename dtype_t>
148+
struct SM90_U32x2_STSM_T
149+
{
150+
__device__ __forceinline__ static void copy(dtype_t src_0, dtype_t src_1, void* smem_dst)
151+
{
152+
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
153+
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16.trans [%0], {%1, %2};\n" ::"l"(smem_dst), "r"(src[0]),
154+
"r"(src[1]));
155+
}
156+
};
157+
147158
__forceinline__ __device__ void warpgroup_arrive() {
148159
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
149160
}

deep_gemm/include/deep_gemm/common/utils.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ __device__ __forceinline__ float ld_shared(const float* ptr) {
122122
return ret;
123123
}
124124

125+
__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) {
126+
float2 ret;
127+
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr));
128+
return ret;
129+
}
130+
125131
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
126132
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
127133
}

0 commit comments

Comments
 (0)