Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/pr-test-npu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
run: |
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_low_latency.py

- name: Run test deepep eplb
- name: Run test fused deep moe
timeout-minutes: 10
env:
HCCL_BUFFSIZE: 2000
Expand Down Expand Up @@ -121,7 +121,7 @@ jobs:
run: |
python3 $GITHUB_WORKSPACE/tests/python/deepep/test_low_latency.py

- name: Run test deepep eplb
- name: Run test fused deep moe
timeout-minutes: 10
env:
HCCL_BUFFSIZE: 2000
Expand Down
57 changes: 3 additions & 54 deletions csrc/deepep/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,46 +608,6 @@ std::vector<at::Tensor> Buffer::fused_deep_moe(const at::Tensor &x, const at::Te
EP_HOST_ASSERT(expert_scales_optional.dim() == 2);

this->is_padding = false;
at::Tensor new_x = x;
this->new_topk_idx = expert_ids;
at::Tensor new_scales = expert_scales_optional;

if (expert_ids.size(0) < PADDING_SIZE) {
this->is_padding = true;
this->padding_cnt = PADDING_SIZE - expert_ids.size(0);

std::vector<at::Tensor> x_blocks;
std::vector<at::Tensor> idx_blocks;

if (expert_ids.size(0) != 0) {
x_blocks.emplace_back(x);
idx_blocks.emplace_back(expert_ids);
} else {
this->ori_x = x.clone(); // store the original input when the batch is completely empty
}

int topk = static_cast<int>(expert_ids.size(1));
for (int i = 0; i < this->padding_cnt; i++) {
at::Tensor tmp_x = torch::ones({1, x.size(1)}, x.options());
at::Tensor tmp_idx =
torch::randperm(num_experts, expert_ids.options()).slice(0, 0, topk).reshape({1, topk});
x_blocks.emplace_back(tmp_x);
idx_blocks.emplace_back(tmp_idx);
}
new_x = torch::cat(x_blocks, 0);
this->new_topk_idx = torch::cat(idx_blocks, 0);

// padding expert_scales_optional
std::vector<at::Tensor> scales_blocks;
if (this->padding_cnt != PADDING_SIZE) {
scales_blocks.emplace_back(expert_scales_optional);
}
for (int i = 0; i < this->padding_cnt; i++) {
at::Tensor tmp_scales = torch::zeros({1, expert_scales_optional.size(1)}, expert_scales_optional.options());
scales_blocks.emplace_back(tmp_scales);
}
new_scales = torch::cat(scales_blocks, 0);
}

char hcom_ep_name[128];
if (!moe_all_to_all_group_name.empty()) {
Expand All @@ -657,10 +617,9 @@ std::vector<at::Tensor> Buffer::fused_deep_moe(const at::Tensor &x, const at::Te
}

int64_t global_bs = std::max(new_topk_idx.size(0), num_max_dispatch_tokens_per_rank) * num_ranks;

auto x_shape = x.sizes();
int h = x_shape[1];
int bs = this->new_topk_idx.size(0);
int bs = expert_ids.size(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

You've correctly updated this line to use expert_ids instead of this->new_topk_idx. A similar change is required on line 619 for the global_bs calculation. Since the padding logic that set new_topk_idx was removed, it now holds a stale value within this function, which will lead to incorrect behavior.

Please update line 619 to use expert_ids as well:

int64_t global_bs = std::max(expert_ids.size(0), num_max_dispatch_tokens_per_rank) * num_ranks;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, int64_t global_bs = std::max(expert_ids.size(0), num_max_dispatch_tokens_per_rank) * num_ranks;


at::Tensor output = at::empty({bs, h}, x.options());

Expand All @@ -670,24 +629,14 @@ std::vector<at::Tensor> Buffer::fused_deep_moe(const at::Tensor &x, const at::Te

EXEC_NPU_CMD(aclnnFusedDeepMoe,
// input
new_x, this->new_topk_idx, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight,
gmm2_weight_scale, static_cast<const std::nullptr_t &>(nullptr), new_scales,
x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight,
gmm2_weight_scale, static_cast<const std::nullptr_t &>(nullptr), expert_scales_optional,
// attr
hcom_ep_name, num_ranks, rank, num_experts, shared_expert_num, shared_expert_rank_num, quant_mode,
global_bs,
// output
output, ep_recv_count);

// ---------- unpadding ----------
if (this->is_padding) {
if (expert_ids.size(0) == 0) {
output = this->ori_x;
} else {
output = output.slice(0, 0, PADDING_SIZE - this->padding_cnt);
}
this->is_padding = false;
}

return {output, ep_recv_count};
}
} // namespace deep_ep
Loading