Skip to content

Commit 1aae565

Browse files
committed
support moe group_wise weight quant
1 parent 8d0a39a commit 1aae565

File tree

7 files changed

+181
-88
lines changed

7 files changed

+181
-88
lines changed

csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ struct MoeFCGemm {
211211
// Only used by device-level operator
212212
GemmCoord* host_problem_sizes;
213213

214+
int group_size;
215+
214216
//
215217
// Methods
216218
//
@@ -220,6 +222,7 @@ struct MoeFCGemm {
220222
Arguments()
221223
: problem_count(0),
222224
threadblock_count(0),
225+
group_size(-1),
223226
ptr_A(nullptr),
224227
ptr_B(nullptr),
225228
weight_scales(nullptr),
@@ -243,10 +246,12 @@ struct MoeFCGemm {
243246
int64_t* total_rows_before_expert,
244247
int64_t gemm_n,
245248
int64_t gemm_k,
249+
int group_size,
246250
GemmCoord* host_problem_sizes = nullptr)
247251
: problem_count(problem_count),
248252
threadblock_count(threadblock_count),
249253
output_op(output_op),
254+
group_size(group_size),
250255
ptr_A(const_cast<ElementA*>(ptr_A)),
251256
ptr_B(const_cast<ElementB*>(ptr_B)),
252257
weight_scales(const_cast<ElementScale*>(weight_scales)),
@@ -280,6 +285,8 @@ struct MoeFCGemm {
280285
ElementC* ptr_C;
281286
ElementC* ptr_D;
282287

288+
int group_size;
289+
283290
//
284291
// Methods
285292
//
@@ -290,7 +297,8 @@ struct MoeFCGemm {
290297
ptr_B(nullptr),
291298
weight_scales(nullptr),
292299
ptr_C(nullptr),
293-
ptr_D(nullptr) {}
300+
ptr_D(nullptr),
301+
group_size(-1) {}
294302

295303
CUTLASS_HOST_DEVICE
296304
Params(Arguments const& args,
@@ -308,7 +316,8 @@ struct MoeFCGemm {
308316
ptr_B(args.ptr_B),
309317
weight_scales(args.weight_scales),
310318
ptr_C(args.ptr_C),
311-
ptr_D(args.ptr_D) {}
319+
ptr_D(args.ptr_D),
320+
group_size(args.group_size) {}
312321

313322
CUTLASS_HOST_DEVICE
314323
void update(Arguments const& args,
@@ -498,7 +507,7 @@ struct MoeFCGemm {
498507
auto CreateMMA = [&]() {
499508
if constexpr (use_dq_gemm<Mma>::value)
500509
return Mma(
501-
shared_storage.main_loop, -1, thread_idx, warp_idx, lane_idx);
510+
shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
502511
else
503512
return Mma(
504513
shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class MoeGemmRunner {
3737
int64_t gemm_k,
3838
int num_experts,
3939
std::string activation_type,
40+
const int32_t weightonly_group_size,
4041
cudaStream_t stream);
4142

4243
void moe_gemm(const T* A,
@@ -48,6 +49,7 @@ class MoeGemmRunner {
4849
int64_t gemm_n,
4950
int64_t gemm_k,
5051
int num_experts,
52+
int group_size,
5153
cudaStream_t stream);
5254

5355
private:
@@ -62,6 +64,7 @@ class MoeGemmRunner {
6264
int64_t gemm_n,
6365
int64_t gemm_k,
6466
int num_experts,
67+
int group_size,
6568
CutlassGemmConfig gemm_config,
6669
cudaStream_t stream,
6770
int* occupancy = nullptr);
@@ -77,6 +80,7 @@ class MoeGemmRunner {
7780
int64_t gemm_n,
7881
int64_t gemm_k,
7982
int num_experts,
83+
int group_size,
8084
cudaStream_t stream);
8185

8286
private:

0 commit comments

Comments
 (0)