-
Notifications
You must be signed in to change notification settings - Fork 462
[Refactor][MOE] remove redundant code. #2597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
3fa0c35
8d3331e
dcde977
ce1db63
93d0cdb
b6ccdea
bd34a97
b2902a6
9e5b59e
6eea19b
3a2ec56
d2195c7
297ea45
29d32ea
98d3ce4
5cbb7a6
ab13da5
3d1cb62
dc4b209
3137dfc
85694b3
5d0528f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,20 +18,23 @@ | |
from typing import Any, Callable, Optional | ||
|
||
import torch | ||
import torch_npu | ||
from vllm.config import CompilationLevel, get_current_vllm_config | ||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group | ||
from vllm.forward_context import get_forward_context | ||
from vllm.model_executor.layers.fused_moe.layer import ( | ||
FusedMoE, UnquantizedFusedMoEMethod) | ||
from vllm.model_executor.layers.fused_moe.config import \ | ||
FusedMoEParallelConfig # isort: skip | ||
|
||
from vllm_ascend.ascend_config import get_ascend_config | ||
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, | ||
DummyCommImpl, | ||
MC2CommImpl, | ||
MoECommMethod) | ||
from vllm_ascend.distributed.parallel_state import get_mc2_group | ||
from vllm_ascend.ops.fused_moe import apply_mlp, fused_experts_moge | ||
from vllm_ascend.ops.layers.experts_selector import select_experts | ||
from vllm_ascend.ops.layers.moe_mlp import unquant_apply_mlp | ||
from vllm_ascend.utils import is_310p | ||
|
||
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ | ||
|
@@ -81,7 +84,7 @@ def fused_experts( | |
|
||
permuted_hidden_states, expert_tokens, group_list_type = moe_comm_method.permute( | ||
hidden_states, topk_ids, topk_weights, expert_map, num_experts) | ||
mlp_output = apply_mlp( | ||
mlp_output = unquant_apply_mlp( | ||
permuted_hidden_states, | ||
w1, | ||
w2, | ||
|
@@ -93,6 +96,97 @@ def fused_experts( | |
return hidden_states | ||
|
||
|
||
def fused_experts_moge( | ||
hidden_states: torch.Tensor, | ||
w1: torch.Tensor, | ||
w2: torch.Tensor, | ||
moe_parallel_config: FusedMoEParallelConfig, | ||
topk_weights: torch.Tensor, | ||
topk_ids: torch.Tensor, | ||
top_k: int, | ||
global_num_experts: int, | ||
expert_map: torch.Tensor = None, | ||
apply_router_weight_on_input: bool = False, | ||
) -> torch.Tensor: | ||
""" | ||
|
||
Args: | ||
hidden_states: Hidden states of shape (num_tokens, hidden_size). | ||
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). | ||
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). | ||
topk_weights: Routing weights of shape (num_tokens, top_k). | ||
topk_ids: Selected expert IDs of shape (num_tokens, top_k). | ||
top_k: Number of experts to select. | ||
expert_map: Expert mapping of shape (num_experts,). | ||
|
||
Returns: | ||
hidden_states: Hidden states after routing. | ||
""" | ||
ep_size = moe_parallel_config.ep_size | ||
local_num_experts = global_num_experts // ep_size | ||
local_num_group = top_k // ep_size | ||
|
||
if apply_router_weight_on_input: | ||
assert (topk_weights.dim() == 2 | ||
), "`topk_weights` should be in shape (num_tokens, topk)" | ||
_, topk = topk_weights.shape | ||
assert ( | ||
topk == 1 | ||
), "Only support topk=1 when `apply_router_weight_on_input` is True" | ||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) | ||
|
||
bsz, _ = hidden_states.shape | ||
flatten_topk_ids = topk_ids.view(-1) | ||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) | ||
sorted_topk_ids = sorted_topk_ids.to(torch.int32) | ||
sorted_hidden_states = hidden_states.index_select( | ||
0, sorted_topk_ids // local_num_group) | ||
|
||
experts_id = torch.arange(0, | ||
local_num_experts, | ||
dtype=topk_ids.dtype, | ||
device=topk_ids.device) | ||
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to( | ||
torch.float32).sum(0) | ||
topk_scales = topk_weights.view(-1).index_select( | ||
0, sorted_topk_ids).unsqueeze(-1) | ||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) | ||
|
||
w1 = w1.transpose(1, 2) | ||
gate_up_out = torch_npu.npu_grouped_matmul( | ||
x=[sorted_hidden_states], | ||
weight=[w1], | ||
split_item=2, | ||
group_list_type=0, | ||
group_type=0, | ||
group_list=group_list, | ||
)[0] | ||
|
||
if is_310p(): | ||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( | ||
torch.float16) | ||
else: | ||
gate_up_out = torch_npu.npu_swiglu(gate_up_out) | ||
gate_up_out *= topk_scales | ||
|
||
w2 = w2.transpose(1, 2) | ||
down_out_list = torch_npu.npu_grouped_matmul( | ||
x=[gate_up_out], | ||
weight=[w2], | ||
split_item=2, | ||
group_list_type=0, | ||
group_type=0, | ||
group_list=group_list, | ||
)[0] | ||
|
||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32) | ||
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids) | ||
final_hidden_states = unsorted_hidden_states.reshape( | ||
bsz, top_k // ep_size, -1).sum(1) | ||
|
||
return final_hidden_states | ||
|
||
Comment on lines
+98
to
+187
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function |
||
|
||
def unquantized_fused_moe_init_func(self, *args, **kwargs): | ||
original_unquantized_fused_moe_init_func(self, *args, **kwargs) | ||
vllm_config = get_current_vllm_config() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do not remove release note
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok