Skip to content

Commit bbd42d3

Browse files
committed
fix the conflict
1 parent e789537 commit bbd42d3

File tree

7 files changed

+103
-18
lines changed

7 files changed

+103
-18
lines changed

csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ struct MoeFCGemm {
211211
// Only used by device-level operator
212212
GemmCoord* host_problem_sizes;
213213

214+
int group_size;
215+
214216
//
215217
// Methods
216218
//
@@ -220,6 +222,7 @@ struct MoeFCGemm {
220222
Arguments()
221223
: problem_count(0),
222224
threadblock_count(0),
225+
group_size(-1),
223226
ptr_A(nullptr),
224227
ptr_B(nullptr),
225228
weight_scales(nullptr),
@@ -243,10 +246,12 @@ struct MoeFCGemm {
243246
int64_t* total_rows_before_expert,
244247
int64_t gemm_n,
245248
int64_t gemm_k,
249+
int group_size,
246250
GemmCoord* host_problem_sizes = nullptr)
247251
: problem_count(problem_count),
248252
threadblock_count(threadblock_count),
249253
output_op(output_op),
254+
group_size(group_size),
250255
ptr_A(const_cast<ElementA*>(ptr_A)),
251256
ptr_B(const_cast<ElementB*>(ptr_B)),
252257
weight_scales(const_cast<ElementScale*>(weight_scales)),
@@ -280,6 +285,8 @@ struct MoeFCGemm {
280285
ElementC* ptr_C;
281286
ElementC* ptr_D;
282287

288+
int group_size;
289+
283290
//
284291
// Methods
285292
//
@@ -290,7 +297,8 @@ struct MoeFCGemm {
290297
ptr_B(nullptr),
291298
weight_scales(nullptr),
292299
ptr_C(nullptr),
293-
ptr_D(nullptr) {}
300+
ptr_D(nullptr),
301+
group_size(-1) {}
294302

295303
CUTLASS_HOST_DEVICE
296304
Params(Arguments const& args,
@@ -308,7 +316,8 @@ struct MoeFCGemm {
308316
ptr_B(args.ptr_B),
309317
weight_scales(args.weight_scales),
310318
ptr_C(args.ptr_C),
311-
ptr_D(args.ptr_D) {}
319+
ptr_D(args.ptr_D),
320+
group_size(args.group_size) {}
312321

313322
CUTLASS_HOST_DEVICE
314323
void update(Arguments const& args,
@@ -498,7 +507,7 @@ struct MoeFCGemm {
498507
auto CreateMMA = [&]() {
499508
if constexpr (use_dq_gemm<Mma>::value)
500509
return Mma(
501-
shared_storage.main_loop, -1, thread_idx, warp_idx, lane_idx);
510+
shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
502511
else
503512
return Mma(
504513
shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class MoeGemmRunner {
3737
int64_t gemm_k,
3838
int num_experts,
3939
std::string activation_type,
40+
const int32_t weightonly_group_size,
4041
cudaStream_t stream);
4142

4243
void moe_gemm(const T* A,
@@ -48,6 +49,7 @@ class MoeGemmRunner {
4849
int64_t gemm_n,
4950
int64_t gemm_k,
5051
int num_experts,
52+
int group_size,
5153
cudaStream_t stream);
5254

5355
private:
@@ -62,6 +64,7 @@ class MoeGemmRunner {
6264
int64_t gemm_n,
6365
int64_t gemm_k,
6466
int num_experts,
67+
int group_size,
6568
CutlassGemmConfig gemm_config,
6669
cudaStream_t stream,
6770
int* occupancy = nullptr);
@@ -77,6 +80,7 @@ class MoeGemmRunner {
7780
int64_t gemm_n,
7881
int64_t gemm_k,
7982
int num_experts,
83+
int group_size,
8084
cudaStream_t stream);
8185

8286
private:

csrc/gpu/moe/fused_moe/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
6666
int64_t gemm_n,
6767
int64_t gemm_k,
6868
int num_experts,
69+
int group_size,
6970
CutlassGemmConfig gemm_config,
7071
const int multi_processor_count,
7172
cudaStream_t stream,
@@ -191,7 +192,8 @@ void generic_moe_gemm_kernelLauncher(const T* A,
191192
reinterpret_cast<ElementType*>(C),
192193
total_rows_before_expert,
193194
gemm_n,
194-
gemm_k);
195+
gemm_k,
196+
group_size);
195197

196198
GemmGrouped gemm;
197199

@@ -237,6 +239,7 @@ struct dispatch_stages {
237239
int64_t gemm_n,
238240
int64_t gemm_k,
239241
int num_experts,
242+
int group_size,
240243
CutlassGemmConfig gemm_config,
241244
int multi_processor_count,
242245
cudaStream_t stream,
@@ -271,6 +274,7 @@ struct dispatch_stages<T,
271274
int64_t gemm_n,
272275
int64_t gemm_k,
273276
int num_experts,
277+
int group_size,
274278
CutlassGemmConfig gemm_config,
275279
int multi_processor_count,
276280
cudaStream_t stream,
@@ -290,6 +294,7 @@ struct dispatch_stages<T,
290294
gemm_n,
291295
gemm_k,
292296
num_experts,
297+
group_size,
293298
gemm_config,
294299
multi_processor_count,
295300
stream,
@@ -320,6 +325,7 @@ struct dispatch_stages<T,
320325
int64_t gemm_n,
321326
int64_t gemm_k,
322327
int num_experts,
328+
int group_size,
323329
CutlassGemmConfig gemm_config,
324330
int multi_processor_count,
325331
cudaStream_t stream,
@@ -339,6 +345,7 @@ struct dispatch_stages<T,
339345
gemm_n,
340346
gemm_k,
341347
num_experts,
348+
group_size,
342349
gemm_config,
343350
multi_processor_count,
344351
stream,
@@ -361,6 +368,7 @@ void dispatch_gemm_config(const T* A,
361368
int64_t gemm_n,
362369
int64_t gemm_k,
363370
int num_experts,
371+
int group_size,
364372
CutlassGemmConfig gemm_config,
365373
int multi_processor_count,
366374
cudaStream_t stream,
@@ -382,6 +390,7 @@ void dispatch_gemm_config(const T* A,
382390
gemm_n, \
383391
gemm_k, \
384392
num_experts, \
393+
group_size, \
385394
gemm_config, \
386395
multi_processor_count, \
387396
stream, \
@@ -419,6 +428,7 @@ void dispatch_gemm_config(const T* A,
419428
gemm_n, \
420429
gemm_k, \
421430
num_experts, \
431+
group_size, \
422432
gemm_config, \
423433
multi_processor_count, \
424434
stream, \
@@ -444,6 +454,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
444454
int64_t gemm_n,
445455
int64_t gemm_k,
446456
int num_experts,
457+
int group_size,
447458
CutlassGemmConfig gemm_config,
448459
int sm_version,
449460
int multi_processor_count,
@@ -489,6 +500,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
489500
int64_t gemm_n,
490501
int64_t gemm_k,
491502
int num_experts,
503+
int group_size,
492504
CutlassGemmConfig gemm_config,
493505
int sm_version,
494506
int multi_processor_count,
@@ -555,6 +567,7 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
555567
int64_t gemm_n,
556568
int64_t gemm_k,
557569
int num_experts,
570+
int group_size,
558571
CutlassGemmConfig gemm_config,
559572
int sm_version,
560573
int multi_processor_count,
@@ -602,6 +615,7 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
602615
int64_t gemm_n,
603616
int64_t gemm_k,
604617
int num_experts,
618+
int group_size,
605619
CutlassGemmConfig gemm_config,
606620
cudaStream_t stream,
607621
int* occupancy) {
@@ -617,6 +631,7 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
617631
gemm_n, \
618632
gemm_k, \
619633
num_experts, \
634+
group_size, \
620635
gemm_config, \
621636
sm_, \
622637
multi_processor_count_, \
@@ -647,11 +662,12 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
647662
int64_t gemm_n,
648663
int64_t gemm_k,
649664
int num_experts,
665+
int group_size,
650666
cudaStream_t stream) {
651667
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
652668
static constexpr bool only_simt_configs = std::is_same<T, float>::value;
653669
std::vector<CutlassGemmConfig> candidate_configs =
654-
get_candidate_configs(sm_, -1, is_weight_only, only_simt_configs, true);
670+
get_candidate_configs(sm_, group_size, is_weight_only, only_simt_configs, true);
655671
static constexpr int warm_time = 5;
656672
static constexpr int test_time = 10;
657673
auto& gemmConfigManager = GemmConfigManager::Instance();
@@ -670,7 +686,6 @@ void MoeGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
670686
int profile_total_rows =
671687
std::min(gemmConfigManager.nextPowerOfTwo(total_rows),
672688
gemmConfigManager.getMaxProfileM());
673-
674689
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
675690
for (int i = 0; i < warm_time; i++) {
676691
dispatch_to_arch<EpilogueTag>(A,
@@ -748,6 +763,7 @@ void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(
748763
int64_t gemm_k,
749764
int num_experts,
750765
std::string activation_type,
766+
const int32_t weightonly_group_size,
751767
cudaStream_t stream) {
752768
if (activation_type == "none") {
753769
if (biases) {
@@ -761,6 +777,7 @@ void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(
761777
gemm_n,
762778
gemm_k,
763779
num_experts,
780+
weightonly_group_size,
764781
stream);
765782
} else {
766783
run_gemm<EpilogueOpNoBias>(A,
@@ -773,6 +790,7 @@ void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(
773790
gemm_n,
774791
gemm_k,
775792
num_experts,
793+
weightonly_group_size,
776794
stream);
777795
}
778796
}
@@ -788,6 +806,7 @@ void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A,
788806
int64_t gemm_n,
789807
int64_t gemm_k,
790808
int num_experts,
809+
int group_size,
791810
cudaStream_t stream) {
792811
run_gemm<EpilogueOpNoBias>(A,
793812
B,
@@ -799,5 +818,6 @@ void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A,
799818
gemm_n,
800819
gemm_k,
801820
num_experts,
821+
group_size,
802822
stream);
803823
}

csrc/gpu/moe/fused_moe/fused_moe.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ void FusedMoeKernel(const paddle::Tensor& input,
5757
const paddle::optional<paddle::Tensor>& ffn2_scale,
5858
const paddle::optional<paddle::Tensor>& ffn2_bias,
5959
const std::string& quant_method,
60+
const int weightonly_group_size,
6061
const int moe_topk,
6162
const bool group_moe,
6263
const bool norm_topk_prob,
@@ -86,6 +87,7 @@ void FusedMoeKernel(const paddle::Tensor& input,
8687
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
8788
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
8889
nullptr,
90+
weightonly_group_size,
8991
moe_topk,
9092
group_moe,
9193
norm_topk_prob,
@@ -105,6 +107,7 @@ std::vector<paddle::Tensor> FusedExpertMoe(
105107
const paddle::optional<paddle::Tensor>& ffn2_bias,
106108
const paddle::optional<paddle::Tensor>& ffn2_scale,
107109
const std::string& quant_method,
110+
const int weightonly_group_size,
108111
const int moe_topk,
109112
const bool norm_topk_prob,
110113
const bool group_moe) {
@@ -122,6 +125,7 @@ std::vector<paddle::Tensor> FusedExpertMoe(
122125
ffn2_scale,
123126
ffn2_bias,
124127
quant_method,
128+
weightonly_group_size,
125129
moe_topk,
126130
group_moe,
127131
norm_topk_prob,
@@ -137,6 +141,7 @@ std::vector<paddle::Tensor> FusedExpertMoe(
137141
ffn2_scale,
138142
ffn2_bias,
139143
quant_method,
144+
weightonly_group_size,
140145
moe_topk,
141146
group_moe,
142147
norm_topk_prob,
@@ -184,6 +189,7 @@ PD_BUILD_OP(fused_expert_moe)
184189
paddle::Optional("ffn2_scale")})
185190
.Outputs({"output"})
186191
.Attrs({"quant_method:std::string",
192+
"weightonly_group_size:int",
187193
"moe_topk:int",
188194
"norm_topk_prob:bool",
189195
"group_moe:bool"})

csrc/gpu/moe/fused_moe/moe/fused_moe_helper.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class MoeHelper {
126126
const paddle::Tensor *ffn2_scale,
127127
const paddle::Tensor *ffn2_bias,
128128
const paddle::Tensor *moe_token_type_ids,
129+
const int weightonly_group_size,
129130
const int moe_topk,
130131
const bool group_moe,
131132
const bool norm_topk_prob,
@@ -304,6 +305,7 @@ class MoeHelper {
304305
hidden_size,
305306
num_experts,
306307
"none",
308+
weightonly_group_size,
307309
stream);
308310
} else if (gemm_method_ == "weight_only_int4") {
309311
int4_moe_gemm_runner_->moe_gemm_bias_act(
@@ -319,6 +321,7 @@ class MoeHelper {
319321
hidden_size,
320322
num_experts,
321323
"none",
324+
weightonly_group_size,
322325
stream);
323326
} else {
324327
fp16_moe_gemm_runner_->moe_gemm_bias_act(
@@ -333,6 +336,7 @@ class MoeHelper {
333336
hidden_size,
334337
num_experts,
335338
"none",
339+
weightonly_group_size,
336340
stream);
337341
}
338342

@@ -356,6 +360,7 @@ class MoeHelper {
356360
hidden_size,
357361
inter_size / 2,
358362
num_experts,
363+
weightonly_group_size,
359364
stream);
360365
} else if (gemm_method_ == "weight_only_int4") {
361366
int4_moe_gemm_runner_->moe_gemm(
@@ -369,6 +374,7 @@ class MoeHelper {
369374
hidden_size,
370375
inter_size / 2,
371376
num_experts,
377+
weightonly_group_size,
372378
stream);
373379
} else {
374380
fp16_moe_gemm_runner_->moe_gemm(
@@ -381,6 +387,7 @@ class MoeHelper {
381387
hidden_size,
382388
inter_size / 2,
383389
num_experts,
390+
weightonly_group_size,
384391
stream);
385392
}
386393

0 commit comments

Comments
 (0)