From 8cb0ffee05176e9c0485586c544bc0f4abce54dc Mon Sep 17 00:00:00 2001 From: NNUCJ <616151263@qq.com> Date: Fri, 18 Jul 2025 14:20:03 +0800 Subject: [PATCH 1/9] Add super kernel in moe Signed-off-by: NNUCJ <616151263@qq.com> --- vllm_ascend/ascend_config.py | 12 ++ vllm_ascend/models/deepseek_v2.py | 1 + vllm_ascend/ops/fused_moe.py | 27 +++-- vllm_ascend/quantization/w8a8_dynamic.py | 140 ++++++++++++----------- vllm_ascend/utils.py | 5 + 5 files changed, 111 insertions(+), 74 deletions(-) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 8ea67994ea..12e6ee34f8 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -63,6 +63,8 @@ def __init__(self, torchair_graph_config): self.enable_view_optimize = torchair_graph_config.get( "enable_view_optimize", True) self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False) + self.enable_super_kernel = torchair_graph_config.get( + "enable_super_kernel", False) if not isinstance(self.graph_batch_sizes, list): raise TypeError("graph_batch_sizes must be list[int]") @@ -95,6 +97,16 @@ def __init__(self, torchair_graph_config): raise RuntimeError( "enable_kv_nz is valid only when Torchair graph mode is enabled" ) + if self.enable_super_kernel: + raise RuntimeError( + "enable_super_kernel is valid only when Torchair graph mode and enable_multistream_moe is enabled" + ) + + if not self.enable_multistream_moe: + if self.enable_super_kernel: + raise RuntimeError( + "enable_super_kernel is valid only when Torchair graph mode and enable_multistream_moe is enabled" + ) class AscendSchedulerConfig: diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index fb1ed6f11b..b8c3534e89 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -242,6 +242,7 @@ def __init__( config.n_routed_experts, bias=False, quant_config=None, + params_dtype=torch.float32, prefix=f"{prefix}.gate") if config.topk_method == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 37edb9767a..9a688988e1 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -48,7 +48,7 @@ MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_ascend_soc_version, npu_stream_switch, - npu_wait_tensor) + npu_wait_tensor, super_kernel) VLLM_ASCEND_MOE_ALL2ALL_BUFFER: bool = envs_ascend.VLLM_ASCEND_MOE_ALL2ALL_BUFFER @@ -1123,6 +1123,7 @@ def __init__( AscendFusedMoE.moe_counter += 1 self.moe_instance_id = AscendFusedMoE.moe_counter + self.prefix = prefix if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -1179,6 +1180,9 @@ def __init__( self.enable_multistream_moe = ( ascend_config.torchair_graph_config.enable_multistream_moe and self.torchair_graph_enabled) + self.enable_super_kernel = ( + ascend_config.torchair_graph_config.super_kernel + and self.enable_multistream_moe) if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -1264,6 +1268,7 @@ def forward( forward_context = get_forward_context() fused_moe_state = get_forward_context().fused_moe_state + is_prefill = get_forward_context().with_prefill # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. quantized_x_for_share, dynamic_scale_for_share = None, None from vllm_ascend.quantization.w8a8_dynamic import \ @@ -1271,13 +1276,17 @@ def forward( if self.enable_multistream_moe: assert gate is not None - router_logits, _ = gate(hidden_states) - if (isinstance(self.quant_method.quant_method, - AscendW8A8DynamicFusedMoEMethod) - and fused_moe_state == FusedMoEState.MC2): - with npu_stream_switch("moe_secondary", 0): - quantized_x_for_share, dynamic_scale_for_share = ( - torch_npu.npu_dynamic_quant(hidden_states)) + with super_kernel(self.prefix, + "stream-fusion=1", + enabled=not is_prefill + and self.enable_super_kernel): + router_logits, _ = gate(hidden_states.float()) + if (isinstance(self.quant_method.quant_method, + AscendW8A8DynamicFusedMoEMethod) + and fused_moe_state == FusedMoEState.MC2): + with npu_stream_switch("moe_secondary", 0): + quantized_x_for_share, dynamic_scale_for_share = ( + torch_npu.npu_dynamic_quant(hidden_states)) if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: @@ -1354,6 +1363,8 @@ def forward( dynamic_scale_for_share=dynamic_scale_for_share, mc2_mask=mc2_mask, token_dispatcher=self.token_dispatcher, + prefix=self.prefix, + enable_super_kernel=self.enable_super_kernel, ) if shared_experts: diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 6e672728e2..7c33032565 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -30,7 +30,8 @@ from vllm_ascend.ops.fused_moe import select_experts from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, dispose_tensor, get_ascend_soc_version, - npu_stream_switch, npu_wait_tensor) + npu_stream_switch, npu_wait_tensor, + super_kernel) CHUNK_SIZE: int = ascend_envs.VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE @@ -853,77 +854,84 @@ def apply( shared_experts: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None, + prefix: str = "", + enable_super_kernel: bool = False, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ 1] == global_num_experts, "Number of global experts mismatch" - - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if global_num_experts == 256: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk当前写8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; 第三个输出是否输出 - # y2_flag=False, # old api; 第三个输出是否输出 - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) - - fused_moe_state = get_forward_context().fused_moe_state - shared_gate_up, shared_dequant_scale = None, None - if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(quantized_x_for_share, router_logits) - share_up_out, _ = shared_experts.gate_up_proj( - (quantized_x_for_share, dynamic_scale_for_share)) - shared_gate_up, shared_dequant_scale = share_up_out[ - 0], share_up_out[1] - - # this is a naive implementation for experts load balance so as - # to avoid accumulating too much tokens on a single rank. - # currently it is only activated when doing profile runs. - if enable_force_load_balance: - topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) - - topk_weights = topk_weights.to(x.dtype) + with super_kernel(prefix, + "stream-fusion=1", + enabled=enable_super_kernel): + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if global_num_experts == 256: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits.float(), + k=top_k, # topk当前写8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; 第三个输出是否输出 + # y2_flag=False, # old api; 第三个输出是否输出 + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + fused_moe_state = get_forward_context().fused_moe_state + shared_gate_up, shared_dequant_scale = None, None + if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(quantized_x_for_share, router_logits) + share_up_out, _ = shared_experts.gate_up_proj( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_gate_up, shared_dequant_scale = share_up_out[ + 0], share_up_out[1] + + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + + topk_weights = topk_weights.to(x.dtype) if fused_moe_state == FusedMoEState.MC2: - return fused_experts_with_mc2( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale_fp32, - w2_scale=layer.w2_weight_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - moe_all_to_all_group_name=self.moe_all_to_all_group_name, - log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, - shared_experts=shared_experts, - is_torchair=self.torchair_graph_enabled, - quantized_x_for_share=shared_gate_up, - dynamic_scale_for_share=shared_dequant_scale, - mc2_mask=kwargs.get("mc2_mask", None)) + with super_kernel(prefix, + "stream-fusion=1", + enabled=enable_super_kernel): + return fused_experts_with_mc2( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale_fp32, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + moe_all_to_all_group_name=self.moe_all_to_all_group_name, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled, + quantized_x_for_share=shared_gate_up, + dynamic_scale_for_share=shared_dequant_scale, + mc2_mask=kwargs.get("mc2_mask", None)) elif fused_moe_state == FusedMoEState.MC2_PREFILL: return fused_prefill_experts_with_mc2( hidden_states=x, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index c63594edc1..a02c7428be 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -30,6 +30,7 @@ import torchair # type: ignore[import] # noqa: F401 from packaging.version import InvalidVersion, Version from torch_npu.npu.streams import Event +from torchair.scope import super_kernel as _super_kernel from vllm.logger import logger import vllm_ascend.envs as envs @@ -296,6 +297,10 @@ def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True): return _npu_stream_switch(tag, priority) if enabled else nullcontext() +def super_kernel(prefix: str, stream: str, enabled: bool = True): + return _super_kernel(prefix, stream) if enabled else nullcontext() + + def npu_wait_tensor(self: torch.Tensor, dependency: torch.Tensor, *, From a50ed5b8a98d5c67eff2f08141bc842e9f31068c Mon Sep 17 00:00:00 2001 From: David9857 <30687415+David9857@users.noreply.github.com> Date: Fri, 18 Jul 2025 11:52:45 +0800 Subject: [PATCH 2/9] [0.9.1]optmize rope in qwen2 (#1782) ### What this PR does / why we need it? Optimize rope by extracting index_select from layers into model, which can reduce (layer_num -1) * 2 Gather ops in each prefill/decode stage. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass cos and sin and set skip_index_select=True to enable this optimization. As shown in the following code : `q, k = self.rotary_emb(positions, q, k, cos=cos, sin=sin, skip_index_select=True)` **Performance results:** **origin:** Successful requests: 400 Benchmark duration (s): 243.10 Total input tokens: 1200000 Total generated tokens: 60000 Request throughput (req/s): 1.65 Output token throughput (tok/s): 246.81 Total Token throughput (tok/s): 5183.02 **optimized:** Successful requests: 400 Benchmark duration (s): 237.42 Total input tokens: 1200000 Total generated tokens: 60000 Request throughput (req/s): 1.68 Output token throughput (tok/s): 252.72 Total Token throughput (tok/s): 5307.03 Signed-off-by: David9857 <985700846@qq.com> --- vllm_ascend/models/qwen2.py | 135 +++++++++++++++++++++++++--- vllm_ascend/models/qwen3.py | 2 +- vllm_ascend/ops/rotary_embedding.py | 10 +-- 3 files changed, 129 insertions(+), 18 deletions(-) diff --git a/vllm_ascend/models/qwen2.py b/vllm_ascend/models/qwen2.py index b024898cdd..8efee8b37f 100644 --- a/vllm_ascend/models/qwen2.py +++ b/vllm_ascend/models/qwen2.py @@ -1,11 +1,12 @@ from collections.abc import Iterable -from typing import Optional, Union +from typing import Any, Optional, Union import torch import torch.nn.functional as F import vllm.envs as envs from torch import nn from transformers import Qwen2Config +from vllm.attention import AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -14,11 +15,14 @@ tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter) from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP -from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2Model +from vllm.model_executor.models.qwen2 import (Qwen2Attention, Qwen2MLP, + Qwen2Model) from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -48,7 +52,59 @@ def maybe_pad_and_reduce_scatter( return hidden_states -class CustomQwen2DecoderLayer(Qwen2DecoderLayer): +class CustomQwen2Attention(Qwen2Attention): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_position=max_position, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=prefix, + attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config) + + def forward(self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + cos: Optional[torch.Tensor] = None, + sin: Optional[torch.Tensor] = None) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if type(self.rotary_emb) is RotaryEmbedding: + # We optimized RotaryEmbedding by moving index_select of cos & sin outside. + # if cos & sin are provided, set is_cos_sin_cached to True to skip index_select. + q, k = self.rotary_emb(positions, + q, + k, + cos=cos, + sin=sin, + is_cos_sin_cached=True) + else: + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class CustomQwen2DecoderLayer(nn.Module): def __init__( self, @@ -57,10 +113,49 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - super().__init__(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix) + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) + + # By default, Qwen2 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = CustomQwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() self.self_attn.o_proj.reduce_results = False @@ -73,6 +168,8 @@ def forward( residual: Optional[torch.Tensor], flashcomm_v1_enabled: bool, pad_size: int, + cos: Optional[torch.Tensor] = None, + sin: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -89,10 +186,10 @@ def forward( if flashcomm_v1_enabled: hidden_states = all_gather_and_maybe_unpad( hidden_states, pad_size) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states, + cos=cos, + sin=sin) if flashcomm_v1_enabled: hidden_states = maybe_pad_and_reduce_scatter( hidden_states, pad_size) @@ -133,6 +230,7 @@ def __init__( prefix=prefix, decoder_layer_type=decoder_layer_type) self.tp_size = get_tensor_model_parallel_world_size() + self.cos_sin_cache = self.layers[0].self_attn.rotary_emb.cos_sin_cache def forward( self, @@ -163,6 +261,19 @@ def forward( num_tokens = hidden_states.size(0) pad_size = (self.tp_size - (num_tokens % self.tp_size)) % self.tp_size + + # Generate cos and sin outside layers to avoid repeated calculation. + cos, sin = None, None + if type(self.layers[0].self_attn.rotary_emb) is RotaryEmbedding: + cos_sin = self.cos_sin_cache.index_select(0, positions) + last_dim = cos_sin.size()[-1] + cos, sin = cos_sin.reshape(-1, 2, + last_dim // 2).repeat(1, 1, + 2).chunk(2, + dim=-2) + cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view( + 1, -1, 1, last_dim).contiguous() + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, @@ -170,6 +281,8 @@ def forward( residual, flashcomm_v1_enabled, pad_size, + cos=cos, + sin=sin, ) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm_ascend/models/qwen3.py b/vllm_ascend/models/qwen3.py index a45e040bcb..82eaa57676 100644 --- a/vllm_ascend/models/qwen3.py +++ b/vllm_ascend/models/qwen3.py @@ -170,7 +170,7 @@ def forward( k, cos=cos, sin=sin, - skip_index_select=True) + is_cos_sin_cached=True) attn_output = self.attn(q, k, v) pad_size = 0 if self.enable_fc == 2: diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index c4c568d4d0..60802bebc5 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -39,7 +39,7 @@ def rope_forward_oot( cos: torch.Tensor = None, sin: torch.Tensor = None, is_neox_style_override: Optional[bool] = None, - skip_index_select: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + is_cos_sin_cached: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: import torch_npu query_shape, key_shape = query.shape, key.shape if self.cos_sin_cache.device != query.device: @@ -64,16 +64,14 @@ def rope_forward_oot( raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") else: - if skip_index_select and neox_style and self.head_size == self.rotary_dim: - # TODO: Remove the contiguous in the future. - # BSNH + if is_cos_sin_cached and neox_style and self.head_size == self.rotary_dim and self.head_size == 128: + # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. + # This method requires head_size and rotary_dim equal 128 and neox_style is True query = query.contiguous().view(1, query.shape[0], -1, self.head_size) key = key.contiguous().view(1, key.shape[0], -1, self.head_size) - # requires head_size=128 and neox_style=True torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin) else: - # TODO: Remove the contiguous in the future. query = query.contiguous().view(query.shape[0], -1) key = key.contiguous().view(key.shape[0], -1) torch_npu._npu_rotary_embedding( From cbd7cf840fb0f9626b24acc79c4132a31b573a26 Mon Sep 17 00:00:00 2001 From: weijinqian0 <1184188277@qq.com> Date: Fri, 18 Jul 2025 12:06:30 +0800 Subject: [PATCH 3/9] =?UTF-8?q?[BUGFIX][v0.9.1]=20fix=20enable=5Fmultistre?= =?UTF-8?q?am=5Fmoe=20bug=20when=20DBO=20is=20enabled=20(=E2=80=A6=20(#182?= =?UTF-8?q?7)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [BUGFIX][v0.9.1] fix enable_multistream_moe bug when DBO is enabled Code merging introduces issues --------- Signed-off-by: weijinqian_v1 Co-authored-by: weijinqian_v1 --- vllm_ascend/models/deepseek_dbo.py | 41 ------------------------------ vllm_ascend/ops/fused_moe.py | 3 +++ 2 files changed, 3 insertions(+), 41 deletions(-) diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 20dafdf7ac..a33a69b80a 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -154,47 +154,6 @@ def __init__( CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok self.config = config - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - forward_context = get_forward_context() - if attn_metadata is None: - attn_metadata = forward_context.attn_metadata - - # when profile runs, force experts to load balanced tokens - # to avoid high memory consumption on a single rank. - enable_force_load_balance = forward_context.in_profile_run - - is_prefill = forward_context.with_prefill - # If this node is kv_consumer, we force the moe always runs in decode path to make sure - # the behaviour aligned between dummy_run and normal model_execute. - if self.kv_consumer: - is_prefill = False - - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - experts_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=CustomDeepseekDBOMoE.top_k, - enable_force_load_balance=enable_force_load_balance, - shared_experts=self.shared_experts) - - shared_experts_hidden = experts_hidden_states[1] - if not (self.shared_experts.down_proj.reduce_results - and self.shared_experts.down_proj.tp_size > 1): - shared_experts_hidden = tensor_model_parallel_all_reduce( - shared_experts_hidden) - - hidden_states = ( - experts_hidden_states[0] * self.routed_scaling_factor + - shared_experts_hidden) - - return hidden_states - # ----------------------------------------- TBO-related -------------------------------------------- def _forward_ms_op_shared_expert( self, diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 9a688988e1..16b5e67b1f 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1291,6 +1291,9 @@ def forward( if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: shared_hidden_states = shared_experts(hidden_states) + if not shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1: + shared_hidden_states = tensor_model_parallel_all_reduce( + shared_hidden_states) mc2_mask = forward_context.mc2_mask tp_size = get_tensor_model_parallel_world_size() From a1ed8d1c418812815f3b0e17fb5d0d5ab707c76b Mon Sep 17 00:00:00 2001 From: weijinqian0 <1184188277@qq.com> Date: Fri, 18 Jul 2025 15:37:58 +0800 Subject: [PATCH 4/9] [BUGFIX][v0.9.1] ep_group is not equal to word_size in some cases. (#1862) [BUGFIX][v0.9.1] ep_group is not equal to word_size in some cases. for examples: external_dp. --------- Signed-off-by: weijinqian_v1 Co-authored-by: weijinqian_v1 --- vllm_ascend/ascend_forward_context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index e4a9b5adce..aec82d426b 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -5,7 +5,7 @@ import torch from vllm.config import VllmConfig -from vllm.distributed import get_dp_group, get_tp_group +from vllm.distributed import get_dp_group, get_ep_group, get_tp_group from vllm.forward_context import get_forward_context, set_forward_context from vllm.platforms import current_platform @@ -63,7 +63,7 @@ def set_ascend_forward_context( ): forward_context = get_forward_context() forward_context.with_prefill = with_prefill - ep_size = (torch.distributed.get_world_size() if + ep_size = (get_ep_group().world_size if vllm_config.parallel_config.enable_expert_parallel else 1) fused_moe_state = get_fused_moe_state(ep_size, with_prefill) From 685ac51ed7c583cf8690349a6547a2d99233d6b7 Mon Sep 17 00:00:00 2001 From: weijinqian0 <1184188277@qq.com> Date: Fri, 18 Jul 2025 19:54:26 +0800 Subject: [PATCH 5/9] [BUGFIX][v0.9.1] repair moe error when set multistream. (#1882) [BUGFIX][v0.9.1] repair moe error when set multistream. Signed-off-by: weijinqian_v1 Co-authored-by: weijinqian_v1 --- vllm_ascend/models/deepseek_dbo.py | 4 ++-- vllm_ascend/ops/fused_moe.py | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index a33a69b80a..e859915964 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -147,8 +147,7 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=not envs_ascend. - VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ, # shared experts tp comm is separated in alltoallv for better overlap. + reduce_results=True, prefix=f"{prefix}.shared_experts", ) CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok @@ -717,6 +716,7 @@ def _forward_ms_layer_alltoallv_finegrained( assert len(hidden_states) == num_micro_batchs assert residual is not None assert attn_metadata is not None + self.mlp.shared_experts.down_proj.reduce_results = False num_tokens = [None] * num_micro_batchs hidden_dims = [None] * num_micro_batchs topk_weights, topk_ids = [None] * num_micro_batchs, [ diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 16b5e67b1f..9a688988e1 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1291,9 +1291,6 @@ def forward( if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: shared_hidden_states = shared_experts(hidden_states) - if not shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1: - shared_hidden_states = tensor_model_parallel_all_reduce( - shared_hidden_states) mc2_mask = forward_context.mc2_mask tp_size = get_tensor_model_parallel_world_size() From ee37b63dcb6cba9ea5c831f52fd6f8739c79f4db Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Sat, 19 Jul 2025 11:37:31 +0800 Subject: [PATCH 6/9] [0.9.1] Fix wheel glibc version incompatibility (#1808) ### What this PR does / why we need it? - Fixes https://github.com/vllm-project/vllm-ascend/issues/1533 ### How was this patch tested? Backported: https://github.com/vllm-project/vllm-ascend/pull/1582 Signed-off-by: Icey <1790571317@qq.com> --- .github/Dockerfile.buildwheel | 11 +++------ .github/workflows/release_whl.yml | 41 ++++++++++++++++++++++++++----- 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/.github/Dockerfile.buildwheel b/.github/Dockerfile.buildwheel index dfe8a63f6d..4d9489ebc4 100644 --- a/.github/Dockerfile.buildwheel +++ b/.github/Dockerfile.buildwheel @@ -15,17 +15,16 @@ # This file is a part of the vllm-ascend project. # ARG PY_VERSION=3.10 -FROM quay.io/ascend/cann:8.0.0-910b-ubuntu22.04-py${PY_VERSION} +FROM quay.io/ascend/manylinux:8.0.0-910b-manylinux_2_28-py${PY_VERSION} ARG COMPILE_CUSTOM_KERNELS=1 # Define environments ENV DEBIAN_FRONTEND=noninteractive ENV COMPILE_CUSTOM_KERNELS=${COMPILE_CUSTOM_KERNELS} -RUN apt-get update -y && \ - apt-get install -y python3-pip git vim wget net-tools gcc g++ cmake libnuma-dev && \ - rm -rf /var/cache/apt/* && \ - rm -rf /var/lib/apt/lists/* +RUN yum update -y && \ + yum install -y python3-pip git vim wget net-tools gcc gcc-c++ make cmake numactl-devel && \ + rm -rf /var/cache/yum WORKDIR /workspace @@ -41,8 +40,6 @@ RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ cd vllm-ascend && \ python3 setup.py bdist_wheel && \ - ls -l dist && \ - for f in dist/*.whl; do mv "$f" "$(echo "$f" | sed -e 's/-linux_x86_64\.whl$/-manylinux1_x86_64.whl/' -e 's/-linux_aarch64\.whl$/-manylinux2014_aarch64.whl/')"; done && \ ls -l dist CMD ["/bin/bash"] diff --git a/.github/workflows/release_whl.yml b/.github/workflows/release_whl.yml index f66a01588e..c17498ff90 100644 --- a/.github/workflows/release_whl.yml +++ b/.github/workflows/release_whl.yml @@ -71,16 +71,11 @@ jobs: --build-arg PY_VERSION=${{ matrix.python-version }} \ -t wheel:v1 . docker run --rm \ + -u $(id -u):$(id -g) \ -v $(pwd):/outpwd \ wheel:v1 \ bash -c "cp -r /workspace/vllm-ascend/dist /outpwd" ls dist - - - name: Archive wheel - uses: actions/upload-artifact@v4 - with: - name: vllm-ascend-${{ matrix.os }}-py${{ matrix.python-version }}-wheel - path: dist/* - name: Set up Python ${{ matrix.python-version }} if: startsWith(github.ref, 'refs/tags/') @@ -88,6 +83,40 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Repair wheels with auditwheel + run: | + python3 -m pip install auditwheel + python3 -m pip install patchelf + mkdir -p dist/repaired + for whl in dist/*.whl; do + auditwheel repair "$whl" -w dist/repaired/ \ + --exclude libplatform.so \ + --exclude libregister.so \ + --exclude libge_common_base.so \ + --exclude libc10.so \ + --exclude libc_sec.so \ + --exclude "libascend*.so" \ + --exclude "libtorch*.so" + done + rm -f dist/*.whl + mv dist/repaired/*.whl dist/ + rmdir dist/repaired + ls dist + + - name: Verify automatic platform tags + run: | + cd dist + for wheel in *.whl; do + echo "verification file: $wheel" + auditwheel show "$wheel" + done + + - name: Archive wheel + uses: actions/upload-artifact@v4 + with: + name: vllm-ascend-${{ matrix.os }}-py${{ matrix.python-version }}-wheel + path: dist/* + - name: Release if: startsWith(github.ref, 'refs/tags/') run: | From 50c18fa96aa992a9484fa44fae80443846d5e665 Mon Sep 17 00:00:00 2001 From: linfeng-yuan <1102311262@qq.com> Date: Sat, 19 Jul 2025 14:21:50 +0800 Subject: [PATCH 7/9] [0.9.1][bugfix] V0.9.1 fix rope accruracy bug for deepseek model (#1887) ### What this PR does / why we need it? Fix the accuracy problem of deepseek model with eager mode introduced by an interface change of rope in https://github.com/vllm-project/vllm-ascend/pull/1719. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? End to end testing and CI passed. Signed-off-by: linfeng-yuan <1102311262@qq.com> --- vllm_ascend/ops/rotary_embedding.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 60802bebc5..8add25df38 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -102,8 +102,12 @@ def native_rope_deepseek_forward(self, 2).reshape(b, h_q, d) b, h_k, d = key.shape key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d) - q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets, - neox_style) + q_pe, k_pe = rope_forward_oot(self, + positions, + query, + key, + offsets=offsets, + is_neox_style_override=neox_style) return q_pe, k_pe From fd547bc90d4ad30e1ba5e10110ffdcbb0d1399c9 Mon Sep 17 00:00:00 2001 From: whx <56632993+whx-sjtu@users.noreply.github.com> Date: Sat, 19 Jul 2025 15:12:30 +0800 Subject: [PATCH 8/9] [0.9.1][Perf]Remove NZ of kv_b_proj in Deepseek MLA. (#1872) This PR removes NZ transformation of weights of kv_b_proj. This is because we find that this matmul weight is not quantized and will fall back to ND calculation in runtime (because currently float bmm nz is not supported in torchair graph), which causes two redundant transData operations (trans weight from NZ back to ND). Removing these two operations will provide an optimization of about 40us per layer. Signed-off-by: whx-sjtu <2952154980@qq.com> --- vllm_ascend/attention/mla_v1.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 226bea8a03..e72e698611 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -22,8 +22,7 @@ from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, npu_prefetch, - npu_stream_switch, npu_wait_tensor) +from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -711,12 +710,6 @@ def get_and_maybe_dequant_weights(layer: LinearBase): self.W_UV = W_UV.transpose(0, 1).contiguous() # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() - if get_ascend_config().enable_weight_nz_layout: - # cast quantized weight tensors in NZ layout for higher inference speed - self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, - ACL_FORMAT_FRACTAL_NZ) - self.W_UK_T.data = torch_npu.npu_format_cast( - self.W_UK_T.data, ACL_FORMAT_FRACTAL_NZ) def _compute_prefill_context( self, From 51e2abac951c0b74f90d9bf643bce04c1e98927e Mon Sep 17 00:00:00 2001 From: NNUCJ <616151263@qq.com> Date: Fri, 18 Jul 2025 14:20:03 +0800 Subject: [PATCH 9/9] Add super kernel in moe Signed-off-by: NNUCJ <616151263@qq.com> --- vllm_ascend/ascend_config.py | 1 - vllm_ascend/models/deepseek_v2.py | 2 +- vllm_ascend/quantization/w8a8_dynamic.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 12e6ee34f8..fe4c60d632 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -101,7 +101,6 @@ def __init__(self, torchair_graph_config): raise RuntimeError( "enable_super_kernel is valid only when Torchair graph mode and enable_multistream_moe is enabled" ) - if not self.enable_multistream_moe: if self.enable_super_kernel: raise RuntimeError( diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index b8c3534e89..1f6b670ae8 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -246,7 +246,7 @@ def __init__( prefix=f"{prefix}.gate") if config.topk_method == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts)) + torch.empty(config.n_routed_experts, dtype=torch.float)) else: self.gate.e_score_correction_bias = None diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 7c33032565..4fd7468528 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -866,7 +866,7 @@ def apply( # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern if global_num_experts == 256: topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits.float(), + router_logits, k=top_k, # topk当前写8 bias=e_score_correction_bias, k_group=topk_group, # fix: 4