diff --git a/.github/Dockerfile.buildwheel b/.github/Dockerfile.buildwheel index dfe8a63f6d..4d9489ebc4 100644 --- a/.github/Dockerfile.buildwheel +++ b/.github/Dockerfile.buildwheel @@ -15,17 +15,16 @@ # This file is a part of the vllm-ascend project. # ARG PY_VERSION=3.10 -FROM quay.io/ascend/cann:8.0.0-910b-ubuntu22.04-py${PY_VERSION} +FROM quay.io/ascend/manylinux:8.0.0-910b-manylinux_2_28-py${PY_VERSION} ARG COMPILE_CUSTOM_KERNELS=1 # Define environments ENV DEBIAN_FRONTEND=noninteractive ENV COMPILE_CUSTOM_KERNELS=${COMPILE_CUSTOM_KERNELS} -RUN apt-get update -y && \ - apt-get install -y python3-pip git vim wget net-tools gcc g++ cmake libnuma-dev && \ - rm -rf /var/cache/apt/* && \ - rm -rf /var/lib/apt/lists/* +RUN yum update -y && \ + yum install -y python3-pip git vim wget net-tools gcc gcc-c++ make cmake numactl-devel && \ + rm -rf /var/cache/yum WORKDIR /workspace @@ -41,8 +40,6 @@ RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ cd vllm-ascend && \ python3 setup.py bdist_wheel && \ - ls -l dist && \ - for f in dist/*.whl; do mv "$f" "$(echo "$f" | sed -e 's/-linux_x86_64\.whl$/-manylinux1_x86_64.whl/' -e 's/-linux_aarch64\.whl$/-manylinux2014_aarch64.whl/')"; done && \ ls -l dist CMD ["/bin/bash"] diff --git a/.github/workflows/release_whl.yml b/.github/workflows/release_whl.yml index f66a01588e..c17498ff90 100644 --- a/.github/workflows/release_whl.yml +++ b/.github/workflows/release_whl.yml @@ -71,16 +71,11 @@ jobs: --build-arg PY_VERSION=${{ matrix.python-version }} \ -t wheel:v1 . docker run --rm \ + -u $(id -u):$(id -g) \ -v $(pwd):/outpwd \ wheel:v1 \ bash -c "cp -r /workspace/vllm-ascend/dist /outpwd" ls dist - - - name: Archive wheel - uses: actions/upload-artifact@v4 - with: - name: vllm-ascend-${{ matrix.os }}-py${{ matrix.python-version }}-wheel - path: dist/* - name: Set up Python ${{ matrix.python-version }} if: startsWith(github.ref, 'refs/tags/') @@ -88,6 +83,40 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Repair wheels with auditwheel + run: | + python3 -m pip install auditwheel + python3 -m pip install patchelf + mkdir -p dist/repaired + for whl in dist/*.whl; do + auditwheel repair "$whl" -w dist/repaired/ \ + --exclude libplatform.so \ + --exclude libregister.so \ + --exclude libge_common_base.so \ + --exclude libc10.so \ + --exclude libc_sec.so \ + --exclude "libascend*.so" \ + --exclude "libtorch*.so" + done + rm -f dist/*.whl + mv dist/repaired/*.whl dist/ + rmdir dist/repaired + ls dist + + - name: Verify automatic platform tags + run: | + cd dist + for wheel in *.whl; do + echo "verification file: $wheel" + auditwheel show "$wheel" + done + + - name: Archive wheel + uses: actions/upload-artifact@v4 + with: + name: vllm-ascend-${{ matrix.os }}-py${{ matrix.python-version }}-wheel + path: dist/* + - name: Release if: startsWith(github.ref, 'refs/tags/') run: | diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 8ea67994ea..fe4c60d632 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -63,6 +63,8 @@ def __init__(self, torchair_graph_config): self.enable_view_optimize = torchair_graph_config.get( "enable_view_optimize", True) self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False) + self.enable_super_kernel = torchair_graph_config.get( + "enable_super_kernel", False) if not isinstance(self.graph_batch_sizes, list): raise TypeError("graph_batch_sizes must be list[int]") @@ -95,6 +97,15 @@ def __init__(self, torchair_graph_config): raise RuntimeError( "enable_kv_nz is valid only when Torchair graph mode is enabled" ) + if self.enable_super_kernel: + raise RuntimeError( + "enable_super_kernel is valid only when Torchair graph mode and enable_multistream_moe is enabled" + ) + if not self.enable_multistream_moe: + if self.enable_super_kernel: + raise RuntimeError( + "enable_super_kernel is valid only when Torchair graph mode and enable_multistream_moe is enabled" + ) class AscendSchedulerConfig: diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index e4a9b5adce..aec82d426b 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -5,7 +5,7 @@ import torch from vllm.config import VllmConfig -from vllm.distributed import get_dp_group, get_tp_group +from vllm.distributed import get_dp_group, get_ep_group, get_tp_group from vllm.forward_context import get_forward_context, set_forward_context from vllm.platforms import current_platform @@ -63,7 +63,7 @@ def set_ascend_forward_context( ): forward_context = get_forward_context() forward_context.with_prefill = with_prefill - ep_size = (torch.distributed.get_world_size() if + ep_size = (get_ep_group().world_size if vllm_config.parallel_config.enable_expert_parallel else 1) fused_moe_state = get_fused_moe_state(ep_size, with_prefill) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 226bea8a03..e72e698611 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -22,8 +22,7 @@ from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, npu_prefetch, - npu_stream_switch, npu_wait_tensor) +from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -711,12 +710,6 @@ def get_and_maybe_dequant_weights(layer: LinearBase): self.W_UV = W_UV.transpose(0, 1).contiguous() # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() - if get_ascend_config().enable_weight_nz_layout: - # cast quantized weight tensors in NZ layout for higher inference speed - self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, - ACL_FORMAT_FRACTAL_NZ) - self.W_UK_T.data = torch_npu.npu_format_cast( - self.W_UK_T.data, ACL_FORMAT_FRACTAL_NZ) def _compute_prefill_context( self, diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 20dafdf7ac..e859915964 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -147,54 +147,12 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=not envs_ascend. - VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ, # shared experts tp comm is separated in alltoallv for better overlap. + reduce_results=True, prefix=f"{prefix}.shared_experts", ) CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok self.config = config - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - forward_context = get_forward_context() - if attn_metadata is None: - attn_metadata = forward_context.attn_metadata - - # when profile runs, force experts to load balanced tokens - # to avoid high memory consumption on a single rank. - enable_force_load_balance = forward_context.in_profile_run - - is_prefill = forward_context.with_prefill - # If this node is kv_consumer, we force the moe always runs in decode path to make sure - # the behaviour aligned between dummy_run and normal model_execute. - if self.kv_consumer: - is_prefill = False - - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - experts_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=CustomDeepseekDBOMoE.top_k, - enable_force_load_balance=enable_force_load_balance, - shared_experts=self.shared_experts) - - shared_experts_hidden = experts_hidden_states[1] - if not (self.shared_experts.down_proj.reduce_results - and self.shared_experts.down_proj.tp_size > 1): - shared_experts_hidden = tensor_model_parallel_all_reduce( - shared_experts_hidden) - - hidden_states = ( - experts_hidden_states[0] * self.routed_scaling_factor + - shared_experts_hidden) - - return hidden_states - # ----------------------------------------- TBO-related -------------------------------------------- def _forward_ms_op_shared_expert( self, @@ -758,6 +716,7 @@ def _forward_ms_layer_alltoallv_finegrained( assert len(hidden_states) == num_micro_batchs assert residual is not None assert attn_metadata is not None + self.mlp.shared_experts.down_proj.reduce_results = False num_tokens = [None] * num_micro_batchs hidden_dims = [None] * num_micro_batchs topk_weights, topk_ids = [None] * num_micro_batchs, [ diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index fb1ed6f11b..1f6b670ae8 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -242,10 +242,11 @@ def __init__( config.n_routed_experts, bias=False, quant_config=None, + params_dtype=torch.float32, prefix=f"{prefix}.gate") if config.topk_method == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts)) + torch.empty(config.n_routed_experts, dtype=torch.float)) else: self.gate.e_score_correction_bias = None diff --git a/vllm_ascend/models/qwen2.py b/vllm_ascend/models/qwen2.py index b024898cdd..8efee8b37f 100644 --- a/vllm_ascend/models/qwen2.py +++ b/vllm_ascend/models/qwen2.py @@ -1,11 +1,12 @@ from collections.abc import Iterable -from typing import Optional, Union +from typing import Any, Optional, Union import torch import torch.nn.functional as F import vllm.envs as envs from torch import nn from transformers import Qwen2Config +from vllm.attention import AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -14,11 +15,14 @@ tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter) from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP -from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2Model +from vllm.model_executor.models.qwen2 import (Qwen2Attention, Qwen2MLP, + Qwen2Model) from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -48,7 +52,59 @@ def maybe_pad_and_reduce_scatter( return hidden_states -class CustomQwen2DecoderLayer(Qwen2DecoderLayer): +class CustomQwen2Attention(Qwen2Attention): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_position=max_position, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=prefix, + attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config) + + def forward(self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + cos: Optional[torch.Tensor] = None, + sin: Optional[torch.Tensor] = None) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if type(self.rotary_emb) is RotaryEmbedding: + # We optimized RotaryEmbedding by moving index_select of cos & sin outside. + # if cos & sin are provided, set is_cos_sin_cached to True to skip index_select. + q, k = self.rotary_emb(positions, + q, + k, + cos=cos, + sin=sin, + is_cos_sin_cached=True) + else: + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class CustomQwen2DecoderLayer(nn.Module): def __init__( self, @@ -57,10 +113,49 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - super().__init__(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix) + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) + + # By default, Qwen2 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = CustomQwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() self.self_attn.o_proj.reduce_results = False @@ -73,6 +168,8 @@ def forward( residual: Optional[torch.Tensor], flashcomm_v1_enabled: bool, pad_size: int, + cos: Optional[torch.Tensor] = None, + sin: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -89,10 +186,10 @@ def forward( if flashcomm_v1_enabled: hidden_states = all_gather_and_maybe_unpad( hidden_states, pad_size) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states, + cos=cos, + sin=sin) if flashcomm_v1_enabled: hidden_states = maybe_pad_and_reduce_scatter( hidden_states, pad_size) @@ -133,6 +230,7 @@ def __init__( prefix=prefix, decoder_layer_type=decoder_layer_type) self.tp_size = get_tensor_model_parallel_world_size() + self.cos_sin_cache = self.layers[0].self_attn.rotary_emb.cos_sin_cache def forward( self, @@ -163,6 +261,19 @@ def forward( num_tokens = hidden_states.size(0) pad_size = (self.tp_size - (num_tokens % self.tp_size)) % self.tp_size + + # Generate cos and sin outside layers to avoid repeated calculation. + cos, sin = None, None + if type(self.layers[0].self_attn.rotary_emb) is RotaryEmbedding: + cos_sin = self.cos_sin_cache.index_select(0, positions) + last_dim = cos_sin.size()[-1] + cos, sin = cos_sin.reshape(-1, 2, + last_dim // 2).repeat(1, 1, + 2).chunk(2, + dim=-2) + cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view( + 1, -1, 1, last_dim).contiguous() + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, @@ -170,6 +281,8 @@ def forward( residual, flashcomm_v1_enabled, pad_size, + cos=cos, + sin=sin, ) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm_ascend/models/qwen3.py b/vllm_ascend/models/qwen3.py index a45e040bcb..82eaa57676 100644 --- a/vllm_ascend/models/qwen3.py +++ b/vllm_ascend/models/qwen3.py @@ -170,7 +170,7 @@ def forward( k, cos=cos, sin=sin, - skip_index_select=True) + is_cos_sin_cached=True) attn_output = self.attn(q, k, v) pad_size = 0 if self.enable_fc == 2: diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 37edb9767a..9a688988e1 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -48,7 +48,7 @@ MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_ascend_soc_version, npu_stream_switch, - npu_wait_tensor) + npu_wait_tensor, super_kernel) VLLM_ASCEND_MOE_ALL2ALL_BUFFER: bool = envs_ascend.VLLM_ASCEND_MOE_ALL2ALL_BUFFER @@ -1123,6 +1123,7 @@ def __init__( AscendFusedMoE.moe_counter += 1 self.moe_instance_id = AscendFusedMoE.moe_counter + self.prefix = prefix if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -1179,6 +1180,9 @@ def __init__( self.enable_multistream_moe = ( ascend_config.torchair_graph_config.enable_multistream_moe and self.torchair_graph_enabled) + self.enable_super_kernel = ( + ascend_config.torchair_graph_config.super_kernel + and self.enable_multistream_moe) if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -1264,6 +1268,7 @@ def forward( forward_context = get_forward_context() fused_moe_state = get_forward_context().fused_moe_state + is_prefill = get_forward_context().with_prefill # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. quantized_x_for_share, dynamic_scale_for_share = None, None from vllm_ascend.quantization.w8a8_dynamic import \ @@ -1271,13 +1276,17 @@ def forward( if self.enable_multistream_moe: assert gate is not None - router_logits, _ = gate(hidden_states) - if (isinstance(self.quant_method.quant_method, - AscendW8A8DynamicFusedMoEMethod) - and fused_moe_state == FusedMoEState.MC2): - with npu_stream_switch("moe_secondary", 0): - quantized_x_for_share, dynamic_scale_for_share = ( - torch_npu.npu_dynamic_quant(hidden_states)) + with super_kernel(self.prefix, + "stream-fusion=1", + enabled=not is_prefill + and self.enable_super_kernel): + router_logits, _ = gate(hidden_states.float()) + if (isinstance(self.quant_method.quant_method, + AscendW8A8DynamicFusedMoEMethod) + and fused_moe_state == FusedMoEState.MC2): + with npu_stream_switch("moe_secondary", 0): + quantized_x_for_share, dynamic_scale_for_share = ( + torch_npu.npu_dynamic_quant(hidden_states)) if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: @@ -1354,6 +1363,8 @@ def forward( dynamic_scale_for_share=dynamic_scale_for_share, mc2_mask=mc2_mask, token_dispatcher=self.token_dispatcher, + prefix=self.prefix, + enable_super_kernel=self.enable_super_kernel, ) if shared_experts: diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index c4c568d4d0..8add25df38 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -39,7 +39,7 @@ def rope_forward_oot( cos: torch.Tensor = None, sin: torch.Tensor = None, is_neox_style_override: Optional[bool] = None, - skip_index_select: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + is_cos_sin_cached: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: import torch_npu query_shape, key_shape = query.shape, key.shape if self.cos_sin_cache.device != query.device: @@ -64,16 +64,14 @@ def rope_forward_oot( raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") else: - if skip_index_select and neox_style and self.head_size == self.rotary_dim: - # TODO: Remove the contiguous in the future. - # BSNH + if is_cos_sin_cached and neox_style and self.head_size == self.rotary_dim and self.head_size == 128: + # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. + # This method requires head_size and rotary_dim equal 128 and neox_style is True query = query.contiguous().view(1, query.shape[0], -1, self.head_size) key = key.contiguous().view(1, key.shape[0], -1, self.head_size) - # requires head_size=128 and neox_style=True torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin) else: - # TODO: Remove the contiguous in the future. query = query.contiguous().view(query.shape[0], -1) key = key.contiguous().view(key.shape[0], -1) torch_npu._npu_rotary_embedding( @@ -104,8 +102,12 @@ def native_rope_deepseek_forward(self, 2).reshape(b, h_q, d) b, h_k, d = key.shape key = key.view(b, h_k, d // 2, 2).transpose(3, 2).reshape(b, h_k, d) - q_pe, k_pe = rope_forward_oot(self, positions, query, key, offsets, - neox_style) + q_pe, k_pe = rope_forward_oot(self, + positions, + query, + key, + offsets=offsets, + is_neox_style_override=neox_style) return q_pe, k_pe diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 6e672728e2..4fd7468528 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -30,7 +30,8 @@ from vllm_ascend.ops.fused_moe import select_experts from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, dispose_tensor, get_ascend_soc_version, - npu_stream_switch, npu_wait_tensor) + npu_stream_switch, npu_wait_tensor, + super_kernel) CHUNK_SIZE: int = ascend_envs.VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE @@ -853,77 +854,84 @@ def apply( shared_experts: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None, + prefix: str = "", + enable_super_kernel: bool = False, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ 1] == global_num_experts, "Number of global experts mismatch" - - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if global_num_experts == 256: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk当前写8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; 第三个输出是否输出 - # y2_flag=False, # old api; 第三个输出是否输出 - routed_scaling_factor=1, - eps=float(1e-20)) - else: - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) - - fused_moe_state = get_forward_context().fused_moe_state - shared_gate_up, shared_dequant_scale = None, None - if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(quantized_x_for_share, router_logits) - share_up_out, _ = shared_experts.gate_up_proj( - (quantized_x_for_share, dynamic_scale_for_share)) - shared_gate_up, shared_dequant_scale = share_up_out[ - 0], share_up_out[1] - - # this is a naive implementation for experts load balance so as - # to avoid accumulating too much tokens on a single rank. - # currently it is only activated when doing profile runs. - if enable_force_load_balance: - topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) - - topk_weights = topk_weights.to(x.dtype) + with super_kernel(prefix, + "stream-fusion=1", + enabled=enable_super_kernel): + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if global_num_experts == 256: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, # topk当前写8 + bias=e_score_correction_bias, + k_group=topk_group, # fix: 4 + group_count=num_expert_group, # fix 8 + group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; 第三个输出是否输出 + # y2_flag=False, # old api; 第三个输出是否输出 + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + fused_moe_state = get_forward_context().fused_moe_state + shared_gate_up, shared_dequant_scale = None, None + if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(quantized_x_for_share, router_logits) + share_up_out, _ = shared_experts.gate_up_proj( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_gate_up, shared_dequant_scale = share_up_out[ + 0], share_up_out[1] + + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + + topk_weights = topk_weights.to(x.dtype) if fused_moe_state == FusedMoEState.MC2: - return fused_experts_with_mc2( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale_fp32, - w2_scale=layer.w2_weight_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - moe_all_to_all_group_name=self.moe_all_to_all_group_name, - log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, - shared_experts=shared_experts, - is_torchair=self.torchair_graph_enabled, - quantized_x_for_share=shared_gate_up, - dynamic_scale_for_share=shared_dequant_scale, - mc2_mask=kwargs.get("mc2_mask", None)) + with super_kernel(prefix, + "stream-fusion=1", + enabled=enable_super_kernel): + return fused_experts_with_mc2( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale_fp32, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + moe_all_to_all_group_name=self.moe_all_to_all_group_name, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled, + quantized_x_for_share=shared_gate_up, + dynamic_scale_for_share=shared_dequant_scale, + mc2_mask=kwargs.get("mc2_mask", None)) elif fused_moe_state == FusedMoEState.MC2_PREFILL: return fused_prefill_experts_with_mc2( hidden_states=x, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index c63594edc1..a02c7428be 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -30,6 +30,7 @@ import torchair # type: ignore[import] # noqa: F401 from packaging.version import InvalidVersion, Version from torch_npu.npu.streams import Event +from torchair.scope import super_kernel as _super_kernel from vllm.logger import logger import vllm_ascend.envs as envs @@ -296,6 +297,10 @@ def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True): return _npu_stream_switch(tag, priority) if enabled else nullcontext() +def super_kernel(prefix: str, stream: str, enabled: bool = True): + return _super_kernel(prefix, stream) if enabled else nullcontext() + + def npu_wait_tensor(self: torch.Tensor, dependency: torch.Tensor, *,