99#ifndef FUSED_DEEP_MOE_H
1010#define FUSED_DEEP_MOE_H
1111
12- #include < kernel_operator.h>
1312#include " lib/matmul_intf.h"
13+ #include < kernel_operator.h>
1414
15- #include " ../utils/op_kernel/operator/catlass/catlass/catlass .hpp"
16- #include " ../utils/op_kernel/operator/catlass/catlass /arch/arch.hpp"
17- #include " ../utils/op_kernel/operator/catlass/catlass /layout/layout.hpp"
18- #include " ../utils/op_kernel/operator/catlass/catlass /epilogue/tile/tile_broadcast_mul.hpp"
19- #include " ../utils/op_kernel/operator/catlass/catlass /epilogue/tile/tile_broadcast_one_blk.hpp"
20- #include " ../utils/op_kernel/operator/catlass/catlass /epilogue/tile/tile_swizzle.hpp"
21- #include " ../utils/op_kernel/operator/catlass/catlass /gemm/block/block_swizzle.hpp"
22- #include " ../utils/op_kernel/operator/catlass/catlass /gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp"
23- #include " ../utils/op_kernel/operator/catlass/catlass /gemm/gemm_type.hpp"
15+ #include " ../utils/op_kernel/operator/catlass/act/act .hpp"
16+ #include " ../utils/op_kernel/operator/catlass/act /arch/arch.hpp"
17+ #include " ../utils/op_kernel/operator/catlass/act /layout/layout.hpp"
18+ #include " ../utils/op_kernel/operator/catlass/act /epilogue/tile/tile_broadcast_mul.hpp"
19+ #include " ../utils/op_kernel/operator/catlass/act /epilogue/tile/tile_broadcast_one_blk.hpp"
20+ #include " ../utils/op_kernel/operator/catlass/act /epilogue/tile/tile_swizzle.hpp"
21+ #include " ../utils/op_kernel/operator/catlass/act /gemm/block/block_swizzle.hpp"
22+ #include " ../utils/op_kernel/operator/catlass/act /gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp"
23+ #include " ../utils/op_kernel/operator/catlass/act /gemm/gemm_type.hpp"
2424#include " ../utils/op_kernel/operator/epilogue/dispatch_policy.h"
2525#include " ../utils/op_kernel/operator/gemm/dispatch_policy.h"
2626#include " ../utils/op_kernel/operator/epilogue/block/block_epilogue.h"
2727#include " ../utils/op_kernel/operator/gemm/block/block_mmad.h"
2828#include " ../utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h"
2929
3030#include " ../utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h"
31- #include " ../utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h"
3231
3332#include " fused_deep_moe_tiling.h"
3433#include " fused_deep_moe_base.h"
3534
3635#define ENABLE_GMM2_COMBINE
37- constexpr uint32_t GMM1_HIDDEN_SIZE = 4096 ;
38- constexpr uint32_t TOKEN_LENGTH = 7168 ;
36+ # define GMM1_HIDDEN_SIZE 4096
37+ # define TOKEN_LENGTH 7168
3938
40- using namespace Catlass ;
39+ using namespace Act ;
4140
4241using MmadAtlasA2Custom =
4342 Gemm::MmadAtlasA2PreloadAsyncWithCallback<CUSTOM_PRELOAD_STAGES, CUSTOM_L1_STAGES, CUSTOM_L0A_STAGES,
@@ -60,16 +59,16 @@ using Gmm2DispatchPolicy =
6059
6160template <uint32_t EXEC_FLAG, typename XType_, class L1TileShape_ , class L0TileShape_ , class EpilogueTileShape_ ,
6261 class BlockScheduler_ , class DispatchPolicy_ = MmadAtlasA2Custom>
63- CATLASS_DEVICE void GmmDeqSwigluQuant (GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
64- layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale,
65- layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
66- layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
67- GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale,
68- GM_ADDR gmWorkspace , GM_ADDR gmX , GM_ADDR debugGm , GM_ADDR gmexpertIds ,
69- GM_ADDR gmExpandIdx, GM_ADDR gmEpSendCount, GM_ADDR gmResvered ,
70- uint32_t epRankSize , uint32_t epRankId , uint32_t moeExpertNum ,
71- uint32_t moeExpertNumPerRank , uint32_t sharedExpertNum ,
72- uint32_t sharedExpertRankNum, uint32_t quantMode, uint32_t globalBs, uint32_t bs )
62+ ACT_DEVICE void GmmDeqSwigluQuant (GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
63+ layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale,
64+ layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
65+ layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
66+ GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace ,
67+ GM_ADDR gmX , GM_ADDR debugGm , GM_ADDR gmexpertIds , GM_ADDR gmExpandIdx ,
68+ GM_ADDR gmEpSendCount, GM_ADDR gmResvered, uint32_t epRankSize, uint32_t epRankId ,
69+ uint32_t moeExpertNum , uint32_t moeExpertNumPerRank , uint32_t sharedExpertNum ,
70+ uint32_t sharedExpertRankNum, uint32_t quantMode, uint32_t globalBs , uint32_t bs ,
71+ uint32_t topK )
7372{
7473 using ArchTag = Arch::AtlasA2;
7574 using DispatchPolicy = DispatchPolicy_;
@@ -149,7 +148,8 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
149148 sharedExpertRankNum,
150149 quantMode,
151150 globalBs,
152- bs};
151+ bs,
152+ topK};
153153 // call a kernel
154154 GemmKernel gemm;
155155 gemm (params);
@@ -178,11 +178,11 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
178178
179179template <TemplateMC2TypeClass, class L1TileShape_ , class L0TileShape_ , class EpilogueTileShape_ , class BlockScheduler_ ,
180180 class DispatchPolicy_ = MmadAtlasA2Custom>
181- CATLASS_DEVICE void GmmDeq (GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
182- layout::RowMajor layoutA, GM_ADDR gmB, layout::nZ layoutB, GM_ADDR gmScale,
183- layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
184- layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
185- GM_ADDR gmWorkspace, void *combiner)
181+ ACT_DEVICE void GmmDeq (GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
182+ layout::RowMajor layoutA, GM_ADDR gmB, layout::nZ layoutB, GM_ADDR gmScale,
183+ layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
184+ layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
185+ GM_ADDR gmWorkspace, void *combiner)
186186{
187187 using ArchTag = Arch::AtlasA2;
188188 using DispatchPolicy = DispatchPolicy_;
@@ -196,7 +196,7 @@ CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR
196196 using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
197197
198198 constexpr uint32_t ubStages = 1 ;
199- using EpilogueDispatchPolicy = Catlass:: Epilogue::EpilogueAtlasA2PerTokenDequant<ubStages>;
199+ using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequant<ubStages, EXEC_FLAG >;
200200 using ScaleType = Gemm::GemmType<float , layout::VectorLayout>;
201201 using PerTokenScaleType = Gemm::GemmType<float , layout::VectorLayout>;
202202 using DType = Gemm::GemmType<ExpandXType, layout::RowMajor>;
@@ -214,23 +214,20 @@ CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR
214214 using TileCopy = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, DType>;
215215 using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle;
216216
217- using BlockEpilogue =
218- Catlass::Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy, CType, ScaleType, PerTokenScaleType, DType,
219- TileRowBroadcastMul, TileBroadcastOneBlk, TileOneBlkColumnBroadcastMul,
220- TileCopy, TileScheduler>;
217+ using BlockEpilogue = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy, CType, ScaleType, PerTokenScaleType,
218+ DType, TileRowBroadcastMul, TileBroadcastOneBlk,
219+ TileOneBlkColumnBroadcastMul, TileCopy, TileScheduler>;
221220
222221 using BlockScheduler = BlockScheduler_;
223222
224223 // kernel level
225224 using ElementGroupList = int64_t ;
226- using GemmKernel =
227- Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace<BlockMmad, BlockEpilogue, BlockScheduler,
228- WORKSPACE_STAGES, ElementGroupList>;
225+ using GemmKernel = Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace<
226+ TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
229227
230228 typename GemmKernel::Params params{
231229 problemShape, groupCount, gmGroupList, gmA, layoutA, gmB, layoutB, gmScale,
232- layoutScale, gmPerTokenScale, layoutPerTokenScale, gmD, layoutD, gmWorkspace,
233- };
230+ layoutScale, gmPerTokenScale, layoutPerTokenScale, gmD, layoutD, gmWorkspace, combiner};
234231
235232 // call a kernel
236233 GemmKernel gemm;
@@ -282,6 +279,7 @@ class FusedDeepMoe
282279 uint32_t quantMode_{0 };
283280 uint32_t globalBs_{0 };
284281 uint32_t bs_{0 };
282+ uint32_t maxBs_{0 };
285283 uint32_t topK_{0 };
286284
287285 AscendC::TPipe *tpipe_{nullptr };
@@ -324,12 +322,13 @@ __aicore__ inline void FusedDeepMoe<TemplateMC2TypeFunc>::Init(
324322 globalBs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo .globalBs ;
325323 bs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo .bs ;
326324 topK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo .k ;
325+ maxBs_ = globalBs_ / epRankSize_;
327326
328327 bool isShareExpert = (epRankId_ < sharedExpertRankNum_);
329328 if (isShareExpert) {
330- m_ = bs_ * epRankSize_ / sharedExpertRankNum_;
329+ m_ = maxBs_ * epRankSize_ / sharedExpertRankNum_;
331330 } else {
332- m_ = bs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_);
331+ m_ = maxBs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_);
333332 }
334333
335334 n_ = GMM1_HIDDEN_SIZE;
@@ -421,8 +420,7 @@ __aicore__ inline void FusedDeepMoe<TemplateMC2TypeFunc>::Process()
421420 layoutPerTokenScale1, gmX2, layoutX2, gmPerTokenScale2, layoutPerTokenScale2,
422421 gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount,
423422 gmResvered, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_,
424- sharedExpertNum_, sharedExpertRankNum_, quantMode_, globalBs_, bs_);
425-
423+ sharedExpertNum_, sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_);
426424#ifdef ENABLE_GMM2_COMBINE
427425 AscendC::PipeBarrier<PIPE_ALL>();
428426 Arch::CrossCoreFlag gmm1AivFinished{0 };
0 commit comments