Skip to content

Commit f92d22b

Browse files
chenxuevian
andcommitted
[Attention][Kernel]moe support for llama4 and mllama4
Co-authored-by: evian <eviantai@u.nus.edu> Signed-off-by: chenxu <chenxu68@huawei.com>
1 parent 1fce70a commit f92d22b

File tree

4 files changed

+26
-6
lines changed

4 files changed

+26
-6
lines changed

vllm_ascend/attention/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,7 @@ def __init__(
668668
blocksparse_params: Optional[Dict[str, Any]] = None,
669669
logits_soft_cap: Optional[float] = None,
670670
attn_type: str = AttentionType.DECODER,
671+
use_irope: bool = False,
671672
) -> None:
672673
self.num_heads = num_heads
673674
self.head_size = head_size

vllm_ascend/attention/attention_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def __init__(
172172
blocksparse_params: Optional[Dict[str, Any]] = None,
173173
logits_soft_cap: Optional[float] = None,
174174
attn_type: str = AttentionType.DECODER,
175+
use_irope: bool = False,
175176
) -> None:
176177
self.num_heads = num_heads
177178
self.head_size = head_size

vllm_ascend/ops/common_fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def forward_oot(
6161
topk_weights=topk_weights,
6262
topk_ids=topk_ids,
6363
top_k=top_k,
64-
expert_map=expert_map)
64+
expert_map=expert_map,
65+
apply_router_weight_on_input=apply_router_weight_on_input)
6566

6667

6768
UnquantizedFusedMoEMethod.forward_oot = forward_oot

vllm_ascend/ops/fused_moe.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def fused_experts(
150150
topk_ids: torch.Tensor,
151151
top_k: int,
152152
expert_map: torch.Tensor = None,
153+
apply_router_weight_on_input: bool = False,
153154
) -> torch.Tensor:
154155
"""
155156
Fused experts with top-k routing.
@@ -188,6 +189,15 @@ def fused_experts(
188189
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
189190
# ], "Only float32, float16, and bfloat16 are supported"
190191

192+
if apply_router_weight_on_input:
193+
assert (topk_weights.dim() == 2
194+
), "`topk_weights` should be in shape (num_tokens, topk)"
195+
_, topk = topk_weights.shape
196+
assert (
197+
topk == 1
198+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
199+
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
200+
191201
if expert_map is not None:
192202
# Generate token indices and flatten
193203
token_indices = (torch.arange(num_tokens,
@@ -289,14 +299,15 @@ def fused_experts(
289299
torch.zeros_like(weighted_down_out)).to(dtype)
290300
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
291301
else:
302+
scales = torch.ones_like(topk_weights) if apply_router_weight_on_input else topk_weights
292303
# TODO: Reorder device memory 2 times here, replace the current
293304
# implementation here when suitable operators become available.
294305
final_hidden_states = torch_npu.npu_moe_finalize_routing(
295306
down_out_list,
296307
skip1=None,
297308
skip2=None,
298309
bias=None,
299-
scales=topk_weights,
310+
scales=scales,
300311
expanded_src_to_dst_row=expanded_row_idx,
301312
export_for_source_row=topk_ids,
302313
)
@@ -363,9 +374,6 @@ def select_experts(
363374
Raises:
364375
ValueError: If an unsupported scoring function is provided.
365376
"""
366-
if custom_routing_function is not None:
367-
raise NotImplementedError(
368-
"Custom routing function is not supported now")
369377

370378
if scoring_func == "softmax":
371379
# NOTE: vLLM use dtype=torch.float here
@@ -402,9 +410,18 @@ def select_experts(
402410
k=top_k,
403411
dim=-1,
404412
sorted=False)
405-
else:
413+
elif custom_routing_function is None:
406414
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
407415
topk_weights = topk_weights.to(hidden_states.dtype)
416+
else:
417+
topk_weights, topk_ids = custom_routing_function(
418+
hidden_states=hidden_states,
419+
gating_output=router_logits,
420+
topk=top_k,
421+
renormalize=renormalize)
422+
# Required by npu_moe_init_routing
423+
topk_ids = topk_ids.to(torch.int32)
424+
return topk_weights, topk_ids
408425

409426
# Required by npu_moe_init_routing
410427
topk_ids = topk_ids.to(torch.int32)

0 commit comments

Comments
 (0)