Skip to content

Commit 464669e

Browse files
luanyunduluanyundu
authored andcommitted
Merge branch 'main' into a2_layour
2 parents c156c5d + 2a2aef2 commit 464669e

File tree

180 files changed

+4234
-27796
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

180 files changed

+4234
-27796
lines changed

csrc/deepep/deep_ep.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:
8787
const int notify_send_data_size =
8888
num_experts * EXPERT_DATA_SIZE + server_num + MAX_BATCH_SIZE * (1 + 2 * server_num + num_topk);
8989
/*
90-
The notify send data is constructed by 8 parameters and
91-
the parameters are ordered as follows:
90+
The notify send data is constructed by 8 parameters and the parameters are ordered as follows:
9291
1. the number of the tokens that every expert received from this NPU.
9392
size:[numExpert]
9493
2. The number of tokens received by each server from this NPU (deduplicated).
@@ -217,11 +216,11 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
217216

218217
int send_per_group = 3; // (send_to_expert_num, send_to_expert_offset, send_rank_tokens)
219218

220-
auto send_data = at::zeros({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
219+
auto send_data = torch::empty({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
221220
int64_t send_count = send_per_group * num_local_experts * num_ranks;
222221

223-
auto send_data_offset = at::zeros({num_experts}, at::dtype(at::kInt).device(x.device()));
224-
at::Tensor recv_data = at::zeros({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
222+
auto send_data_offset = torch::empty({num_experts}, at::dtype(at::kInt).device(x.device()));
223+
at::Tensor recv_data = torch::empty({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
225224

226225
// get ep name
227226
char hcom_ep_name[HCOMM_NAME_LEN];
@@ -243,7 +242,7 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
243242

244243
auto options_cpu = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
245244
std::vector<int32_t> local_expert_acc(num_experts, 0);
246-
auto send_token_idx_cpu = at::zeros({num_tokens, num_topk}, options_cpu);
245+
auto send_token_idx_cpu = torch::empty({num_tokens, num_topk}, options_cpu);
247246
auto send_token_idx_ptr = send_token_idx_cpu.data_ptr<int>();
248247

249248
auto topk_idx_cpu = new_topk_idx.to(at::kCPU);
@@ -261,8 +260,8 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
261260

262261
EP_HOST_ASSERT(recv_data.dim() == 1 and recv_data.is_contiguous());
263262
EP_HOST_ASSERT(recv_data.size(0) % num_experts == 0);
264-
at::Tensor recv_offset_cpu = at::zeros({num_experts}, options_cpu);
265-
at::Tensor recv_count_cpu = at::zeros({num_experts}, options_cpu);
263+
at::Tensor recv_offset_cpu = torch::empty({num_experts}, options_cpu);
264+
at::Tensor recv_count_cpu = torch::empty({num_experts}, options_cpu);
266265
auto recv_data_cpu = recv_data.to(at::kCPU);
267266
auto recv_data_ptr = recv_data_cpu.data_ptr<int>();
268267
auto recv_count_ptr = recv_count_cpu.data_ptr<int>();
@@ -303,10 +302,10 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
303302
auto recv_count = recv_count_cpu.to(x.device());
304303

305304
int num_recv_tokens = (total_recv_tokens == 0) ? 1 : total_recv_tokens;
306-
auto expandx_out = use_quant ? at::zeros({num_recv_tokens, hidden}, at::dtype(at::kChar).device(x.device()))
307-
: at::zeros({num_recv_tokens, hidden}, x.options());
308-
auto dynamic_scales_out = at::zeros({num_recv_tokens}, at::dtype(at::kFloat).device(x.device()));
309-
auto expand_idx_out = at::zeros({num_recv_tokens * 3}, at::dtype(at::kInt).device(x.device()));
305+
auto expandx_out = use_quant ? torch::empty({num_recv_tokens, hidden}, at::dtype(at::kChar).device(x.device()))
306+
: torch::empty({num_recv_tokens, hidden}, x.options());
307+
auto dynamic_scales_out = torch::empty({num_recv_tokens}, at::dtype(at::kFloat).device(x.device()));
308+
auto expand_idx_out = torch::empty({num_recv_tokens * 3}, at::dtype(at::kInt).device(x.device()));
310309

311310
EXEC_NPU_CMD(aclnnCamMoeDispatchNormal, new_x, expert_ids, send_data_offset, send_token_idx, recv_offset,
312311
recv_count, hcom_ep_name,

csrc/deepep/ops/op_host/fused_deep_moe_tiling.cpp

Lines changed: 134 additions & 47 deletions
Large diffs are not rendered by default.

csrc/deepep/ops/op_kernel/fused_deep_moe.h

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,35 +9,34 @@
99
#ifndef FUSED_DEEP_MOE_H
1010
#define FUSED_DEEP_MOE_H
1111

12-
#include <kernel_operator.h>
1312
#include "lib/matmul_intf.h"
13+
#include <kernel_operator.h>
1414

15-
#include "../utils/op_kernel/operator/catlass/catlass/catlass.hpp"
16-
#include "../utils/op_kernel/operator/catlass/catlass/arch/arch.hpp"
17-
#include "../utils/op_kernel/operator/catlass/catlass/layout/layout.hpp"
18-
#include "../utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_broadcast_mul.hpp"
19-
#include "../utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_broadcast_one_blk.hpp"
20-
#include "../utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_swizzle.hpp"
21-
#include "../utils/op_kernel/operator/catlass/catlass/gemm/block/block_swizzle.hpp"
22-
#include "../utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp"
23-
#include "../utils/op_kernel/operator/catlass/catlass/gemm/gemm_type.hpp"
15+
#include "../utils/op_kernel/operator/catlass/act/act.hpp"
16+
#include "../utils/op_kernel/operator/catlass/act/arch/arch.hpp"
17+
#include "../utils/op_kernel/operator/catlass/act/layout/layout.hpp"
18+
#include "../utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_mul.hpp"
19+
#include "../utils/op_kernel/operator/catlass/act/epilogue/tile/tile_broadcast_one_blk.hpp"
20+
#include "../utils/op_kernel/operator/catlass/act/epilogue/tile/tile_swizzle.hpp"
21+
#include "../utils/op_kernel/operator/catlass/act/gemm/block/block_swizzle.hpp"
22+
#include "../utils/op_kernel/operator/catlass/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp"
23+
#include "../utils/op_kernel/operator/catlass/act/gemm/gemm_type.hpp"
2424
#include "../utils/op_kernel/operator/epilogue/dispatch_policy.h"
2525
#include "../utils/op_kernel/operator/gemm/dispatch_policy.h"
2626
#include "../utils/op_kernel/operator/epilogue/block/block_epilogue.h"
2727
#include "../utils/op_kernel/operator/gemm/block/block_mmad.h"
2828
#include "../utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h"
2929

3030
#include "../utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h"
31-
#include "../utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h"
3231

3332
#include "fused_deep_moe_tiling.h"
3433
#include "fused_deep_moe_base.h"
3534

3635
#define ENABLE_GMM2_COMBINE
37-
constexpr uint32_t GMM1_HIDDEN_SIZE = 4096;
38-
constexpr uint32_t TOKEN_LENGTH = 7168;
36+
#define GMM1_HIDDEN_SIZE 4096
37+
#define TOKEN_LENGTH 7168
3938

40-
using namespace Catlass;
39+
using namespace Act;
4140

4241
using MmadAtlasA2Custom =
4342
Gemm::MmadAtlasA2PreloadAsyncWithCallback<CUSTOM_PRELOAD_STAGES, CUSTOM_L1_STAGES, CUSTOM_L0A_STAGES,
@@ -60,16 +59,16 @@ using Gmm2DispatchPolicy =
6059

6160
template <uint32_t EXEC_FLAG, typename XType_, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_,
6261
class BlockScheduler_, class DispatchPolicy_ = MmadAtlasA2Custom>
63-
CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
64-
layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale,
65-
layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
66-
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
67-
GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale,
68-
GM_ADDR gmWorkspace, GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds,
69-
GM_ADDR gmExpandIdx, GM_ADDR gmEpSendCount, GM_ADDR gmResvered,
70-
uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum,
71-
uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum,
72-
uint32_t sharedExpertRankNum, uint32_t quantMode, uint32_t globalBs, uint32_t bs)
62+
ACT_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
63+
layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale,
64+
layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
65+
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
66+
GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace,
67+
GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx,
68+
GM_ADDR gmEpSendCount, GM_ADDR gmResvered, uint32_t epRankSize, uint32_t epRankId,
69+
uint32_t moeExpertNum, uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum,
70+
uint32_t sharedExpertRankNum, uint32_t quantMode, uint32_t globalBs, uint32_t bs,
71+
uint32_t topK)
7372
{
7473
using ArchTag = Arch::AtlasA2;
7574
using DispatchPolicy = DispatchPolicy_;
@@ -149,7 +148,8 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
149148
sharedExpertRankNum,
150149
quantMode,
151150
globalBs,
152-
bs};
151+
bs,
152+
topK};
153153
// call a kernel
154154
GemmKernel gemm;
155155
gemm(params);
@@ -178,11 +178,11 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
178178

179179
template <TemplateMC2TypeClass, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_, class BlockScheduler_,
180180
class DispatchPolicy_ = MmadAtlasA2Custom>
181-
CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
182-
layout::RowMajor layoutA, GM_ADDR gmB, layout::nZ layoutB, GM_ADDR gmScale,
183-
layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
184-
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
185-
GM_ADDR gmWorkspace, void *combiner)
181+
ACT_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
182+
layout::RowMajor layoutA, GM_ADDR gmB, layout::nZ layoutB, GM_ADDR gmScale,
183+
layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale,
184+
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
185+
GM_ADDR gmWorkspace, void *combiner)
186186
{
187187
using ArchTag = Arch::AtlasA2;
188188
using DispatchPolicy = DispatchPolicy_;
@@ -196,7 +196,7 @@ CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR
196196
using BlockMmad = Gemm::Block::BlockMmad<DispatchPolicy, L1TileShape, L0TileShape, AType, BType, CType>;
197197

198198
constexpr uint32_t ubStages = 1;
199-
using EpilogueDispatchPolicy = Catlass::Epilogue::EpilogueAtlasA2PerTokenDequant<ubStages>;
199+
using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequant<ubStages, EXEC_FLAG>;
200200
using ScaleType = Gemm::GemmType<float, layout::VectorLayout>;
201201
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
202202
using DType = Gemm::GemmType<ExpandXType, layout::RowMajor>;
@@ -214,23 +214,20 @@ CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR
214214
using TileCopy = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, DType>;
215215
using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle;
216216

217-
using BlockEpilogue =
218-
Catlass::Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy, CType, ScaleType, PerTokenScaleType, DType,
219-
TileRowBroadcastMul, TileBroadcastOneBlk, TileOneBlkColumnBroadcastMul,
220-
TileCopy, TileScheduler>;
217+
using BlockEpilogue = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy, CType, ScaleType, PerTokenScaleType,
218+
DType, TileRowBroadcastMul, TileBroadcastOneBlk,
219+
TileOneBlkColumnBroadcastMul, TileCopy, TileScheduler>;
221220

222221
using BlockScheduler = BlockScheduler_;
223222

224223
// kernel level
225224
using ElementGroupList = int64_t;
226-
using GemmKernel =
227-
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace<BlockMmad, BlockEpilogue, BlockScheduler,
228-
WORKSPACE_STAGES, ElementGroupList>;
225+
using GemmKernel = Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace<
226+
TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
229227

230228
typename GemmKernel::Params params{
231229
problemShape, groupCount, gmGroupList, gmA, layoutA, gmB, layoutB, gmScale,
232-
layoutScale, gmPerTokenScale, layoutPerTokenScale, gmD, layoutD, gmWorkspace,
233-
};
230+
layoutScale, gmPerTokenScale, layoutPerTokenScale, gmD, layoutD, gmWorkspace, combiner};
234231

235232
// call a kernel
236233
GemmKernel gemm;
@@ -282,6 +279,7 @@ class FusedDeepMoe
282279
uint32_t quantMode_{0};
283280
uint32_t globalBs_{0};
284281
uint32_t bs_{0};
282+
uint32_t maxBs_{0};
285283
uint32_t topK_{0};
286284

287285
AscendC::TPipe *tpipe_{nullptr};
@@ -324,12 +322,13 @@ __aicore__ inline void FusedDeepMoe<TemplateMC2TypeFunc>::Init(
324322
globalBs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.globalBs;
325323
bs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs;
326324
topK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k;
325+
maxBs_ = globalBs_ / epRankSize_;
327326

328327
bool isShareExpert = (epRankId_ < sharedExpertRankNum_);
329328
if (isShareExpert) {
330-
m_ = bs_ * epRankSize_ / sharedExpertRankNum_;
329+
m_ = maxBs_ * epRankSize_ / sharedExpertRankNum_;
331330
} else {
332-
m_ = bs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_);
331+
m_ = maxBs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_);
333332
}
334333

335334
n_ = GMM1_HIDDEN_SIZE;
@@ -421,8 +420,7 @@ __aicore__ inline void FusedDeepMoe<TemplateMC2TypeFunc>::Process()
421420
layoutPerTokenScale1, gmX2, layoutX2, gmPerTokenScale2, layoutPerTokenScale2,
422421
gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount,
423422
gmResvered, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_,
424-
sharedExpertNum_, sharedExpertRankNum_, quantMode_, globalBs_, bs_);
425-
423+
sharedExpertNum_, sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_);
426424
#ifdef ENABLE_GMM2_COMBINE
427425
AscendC::PipeBarrier<PIPE_ALL>();
428426
Arch::CrossCoreFlag gmm1AivFinished{0};

csrc/deepep/ops/utils/.DS_Store

-6 KB
Binary file not shown.
-6 KB
Binary file not shown.
-6 KB
Binary file not shown.

csrc/deepep/ops/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
*/
99
#ifndef CAM_MOE_DISTRIBUTE_COMBINE_H
1010
#define CAM_MOE_DISTRIBUTE_COMBINE_H
11+
#define OPT_RANK_OFFSET 512
1112

1213
#include "kernel_operator.h"
1314
#include "kernel_tiling/kernel_tiling.h"
@@ -29,7 +30,6 @@ constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024;
2930
constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024;
3031
constexpr uint16_t SEND_SYNC_EVENT_ID = 9;
3132
constexpr uint16_t RECV_SYNC_EVENT_ID = 10;
32-
constexpr uint32_t OPT_RANK_OFFSET = 512;
3333

3434
template <AscendC::HardEvent event>
3535
__aicore__ inline void SyncFunc()
@@ -246,7 +246,11 @@ __aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::Init(
246246
selfDataStatusTensor[coreIdx_ * UB_ALIGN]);
247247
__asm__ __volatile__("");
248248
dataState_ = selfDataStatusTensor(coreIdx_ * UB_ALIGN);
249-
selfDataStatusTensor(coreIdx_ * UB_ALIGN) = 1 - dataState_;
249+
if (dataState_ == 0) {
250+
selfDataStatusTensor(coreIdx_ * UB_ALIGN) = 1;
251+
} else {
252+
selfDataStatusTensor(coreIdx_ * UB_ALIGN) = 0;
253+
}
250254
__asm__ __volatile__("");
251255
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
252256
selfDataStatusTensor[coreIdx_ * UB_ALIGN]);
@@ -372,15 +376,16 @@ template <TemplateMC2TypeClass>
372376
__aicore__ inline void CamMoeDistributeCombine<TemplateMC2TypeFunc>::AlltoAllBuffInit()
373377
{
374378
tpipe_->Reset();
379+
uint32_t bsMulTopkSizeAligned = Ceil(axisBS_ * axisK_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN; // 防止UB不对齐
375380
tpipe_->InitBuffer(readStateBuf_, UB_ALIGN);
376381
tpipe_->InitBuffer(statusBuf_, sendRankNum_ * UB_ALIGN);
377-
tpipe_->InitBuffer(expertIdsBuf_, axisBS_ * axisK_ * sizeof(int32_t));
378-
tpipe_->InitBuffer(expandScalesBuf_, axisBS_ * axisK_ * sizeof(float));
382+
tpipe_->InitBuffer(expertIdsBuf_, bsMulTopkSizeAligned);
383+
tpipe_->InitBuffer(expandScalesBuf_, bsMulTopkSizeAligned);
379384
tpipe_->InitBuffer(tokenBuf_, axisH_ * sizeof(ExpandXType));
380385
tpipe_->InitBuffer(rowTmpFloatBuf_, axisHFloatSize_); // 7168 * 4 = 28672
381386
tpipe_->InitBuffer(mulBuf_, axisHFloatSize_); // 7168 * 4 = 28672
382387
tpipe_->InitBuffer(sumFloatBuf_, axisHFloatSize_); // 7168 * 4 = 28672
383-
tpipe_->InitBuffer(indexCountsBuf_, axisBS_ * axisK_ * sizeof(int32_t));
388+
tpipe_->InitBuffer(indexCountsBuf_, bsMulTopkSizeAligned);
384389
tpipe_->InitBuffer(moeSumQueue_, BUFFER_NUM, axisHExpandXTypeSize_);
385390
tpipe_->InitBuffer(gatherMaskOutBuf_, epWorldSize_ * sizeof(float));
386391
tpipe_->InitBuffer(gatherTmpBuf_, sizeof(uint32_t)); // 4

0 commit comments

Comments
 (0)