Skip to content

Commit 2a2aef2

Browse files
oagniqgnatoagniqgnat
andauthored
Fix the severe performance degradation issue of the top9 dispatch in normal mode compared to top8. (#117)
Co-authored-by: oagniqgnat <tangqingao@huawei.com>
1 parent dd005e3 commit 2a2aef2

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

csrc/deepep/deep_ep.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:
7777
auto device = new_topk_idx.device();
7878
auto num_tokens_per_expert = at::zeros({num_experts}, at::dtype(at::kInt).device(device));
7979
auto num_tokens_per_rank = at::zeros({num_ranks}, at::dtype(at::kInt).device(device));
80-
auto is_token_in_rank = at::zeros({num_tokens, num_ranks}, at::dtype(at::kInt).device(device));
80+
auto is_token_in_rank = torch::empty({num_tokens, num_ranks}, at::dtype(at::kInt).device(device));
8181

8282
EXEC_NPU_CMD(aclnnDispatchLayout, new_topk_idx, num_tokens, num_ranks, num_experts, num_topk, num_tokens_per_rank,
8383
num_tokens_per_expert, is_token_in_rank);
@@ -183,11 +183,11 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
183183

184184
int send_per_group = 3; // (send_to_expert_num, send_to_expert_offset, send_rank_tokens)
185185

186-
auto send_data = at::zeros({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
186+
auto send_data = torch::empty({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
187187
int64_t send_count = send_per_group * num_local_experts * num_ranks;
188188

189-
auto send_data_offset = at::zeros({num_experts}, at::dtype(at::kInt).device(x.device()));
190-
at::Tensor recv_data = at::zeros({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
189+
auto send_data_offset = torch::empty({num_experts}, at::dtype(at::kInt).device(x.device()));
190+
at::Tensor recv_data = torch::empty({num_experts * send_per_group}, at::dtype(at::kInt).device(x.device()));
191191

192192
// get ep name
193193
char hcom_ep_name[HCOMM_NAME_LEN];
@@ -209,7 +209,7 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
209209

210210
auto options_cpu = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
211211
std::vector<int32_t> local_expert_acc(num_experts, 0);
212-
auto send_token_idx_cpu = at::zeros({num_tokens, num_topk}, options_cpu);
212+
auto send_token_idx_cpu = torch::empty({num_tokens, num_topk}, options_cpu);
213213
auto send_token_idx_ptr = send_token_idx_cpu.data_ptr<int>();
214214

215215
auto topk_idx_cpu = new_topk_idx.to(at::kCPU);
@@ -227,8 +227,8 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
227227

228228
EP_HOST_ASSERT(recv_data.dim() == 1 and recv_data.is_contiguous());
229229
EP_HOST_ASSERT(recv_data.size(0) % num_experts == 0);
230-
at::Tensor recv_offset_cpu = at::zeros({num_experts}, options_cpu);
231-
at::Tensor recv_count_cpu = at::zeros({num_experts}, options_cpu);
230+
at::Tensor recv_offset_cpu = torch::empty({num_experts}, options_cpu);
231+
at::Tensor recv_count_cpu = torch::empty({num_experts}, options_cpu);
232232
auto recv_data_cpu = recv_data.to(at::kCPU);
233233
auto recv_data_ptr = recv_data_cpu.data_ptr<int>();
234234
auto recv_count_ptr = recv_count_cpu.data_ptr<int>();
@@ -269,10 +269,10 @@ Buffer::intranode_dispatch(const at::Tensor &x, const std::optional<at::Tensor>
269269
auto recv_count = recv_count_cpu.to(x.device());
270270

271271
int num_recv_tokens = (total_recv_tokens == 0) ? 1 : total_recv_tokens;
272-
auto expandx_out = use_quant ? at::zeros({num_recv_tokens, hidden}, at::dtype(at::kChar).device(x.device()))
273-
: at::zeros({num_recv_tokens, hidden}, x.options());
274-
auto dynamic_scales_out = at::zeros({num_recv_tokens}, at::dtype(at::kFloat).device(x.device()));
275-
auto expand_idx_out = at::zeros({num_recv_tokens * 3}, at::dtype(at::kInt).device(x.device()));
272+
auto expandx_out = use_quant ? torch::empty({num_recv_tokens, hidden}, at::dtype(at::kChar).device(x.device()))
273+
: torch::empty({num_recv_tokens, hidden}, x.options());
274+
auto dynamic_scales_out = torch::empty({num_recv_tokens}, at::dtype(at::kFloat).device(x.device()));
275+
auto expand_idx_out = torch::empty({num_recv_tokens * 3}, at::dtype(at::kInt).device(x.device()));
276276

277277
EXEC_NPU_CMD(aclnnCamMoeDispatchNormal, new_x, expert_ids, send_data_offset, send_token_idx, recv_offset,
278278
recv_count, hcom_ep_name,

tests/python/deepep/test_intranode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def test_diagnose(
245245
for current_x in filter(lambda elem: elem is not None, (x_pure_rand, x)):
246246
if local_rank == 0:
247247
print(
248-
f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, with top-k ...',
248+
f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, with top-k {num_topk} ...',
249249
flush=True,
250250
)
251251
dispatch_args = {

0 commit comments

Comments
 (0)