From 224efd40840b1092750593e2d5e04ab73ef14de0 Mon Sep 17 00:00:00 2001 From: rjg-lyh <1318825571@qq.com> Date: Fri, 5 Sep 2025 13:33:52 +0800 Subject: [PATCH] [main] addrmsnorm + quant fusion optim Signed-off-by: rjg-lyh <1318825571@qq.com> --- tests/ut/ops/test_layernorm.py | 206 +++++++++++++++++--------- vllm_ascend/ascend_forward_context.py | 16 ++ vllm_ascend/models/__init__.py | 3 - vllm_ascend/models/qwen3.py | 156 ------------------- vllm_ascend/ops/layernorm.py | 116 ++++++++------- 5 files changed, 219 insertions(+), 278 deletions(-) delete mode 100644 vllm_ascend/models/qwen3.py diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index 3bed0781f6..b0c05a2033 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -1,19 +1,17 @@ -from unittest.mock import patch +import unittest import pytest import torch +from pytest_mock import MockerFixture from vllm.model_executor.layers.layernorm import RMSNorm - -@pytest.fixture -def dummy_tensor(): - return torch.randn(4, 8, dtype=torch.float16) +from tests.ut.base import PytestBase +from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod def mock_maybe_chunk_residual(x, residual): if x.size(0) != residual.size(0): return residual[:4] - return residual @@ -25,69 +23,139 @@ def mock_add_rms_norm(x, residual, weight, eps): return 2 * x, None, 2 * residual -@pytest.mark.parametrize("is_310p_return", [True, False]) -@pytest.mark.parametrize("residual", - [None, torch.randn(4, 8, dtype=torch.float32)]) -@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) -@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) -@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None) -@patch("torch.ops.vllm.maybe_chunk_residual", - side_effect=mock_maybe_chunk_residual) -def test_RMSNorm_forward(mock_maybe_chunk_residual, - mock_maybe_wait_prefetch_done, mock_add_rmsnorm, - mock_rmsnorm, is_310p_return, residual, dummy_tensor): - - with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return): +def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset, + epsilon): + x_out = 2 * x + residual_out = 2 * residual + x_out_quant = x_out.to(torch.int8) + residual_out_quant = residual_out.to(torch.int8) + return x_out_quant, None, residual_out_quant + + +class TestAscendRMSNorm(PytestBase): + + @pytest.fixture(autouse=True) + def context(self, mocker: MockerFixture): + mocker.patch("torch.ops.vllm.maybe_chunk_residual", + side_effect=mock_maybe_chunk_residual) + mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm) + mocker.patch("torch_npu.npu_add_rms_norm", + side_effect=mock_add_rms_norm) + mocker.patch("torch_npu.npu_add_rms_norm_quant", + side_effect=mock_add_rms_norm_quant) + mocker.patch("torch.ops.vllm.maybe_wait_prefetch_done", + side_effect=lambda x: None) + + # Test case for the most common and basic scenario + @pytest.mark.parametrize( + "residual", [None, torch.randn(4, 8, dtype=torch.float16)]) + def test_forward_oot_basic(self, residual): layer = RMSNorm(hidden_size=8, eps=1e-05) + x = torch.randn(4, 8, dtype=torch.float16) if residual is not None: - out_x, out_residual = layer.forward_oot(dummy_tensor, residual) - - if is_310p_return: - expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype) - expected_out_x = expected_arg_x + 1 - expected_out_residual = expected_arg_x.to(residual.dtype) - - mock_maybe_chunk_residual.assert_called_once() - mock_rmsnorm.assert_called_once() - mock_maybe_wait_prefetch_done.assert_called_once() - assert torch.allclose(out_x, expected_out_x) - assert torch.allclose(out_residual, expected_out_residual) - else: - expected_out_x = 2 * dummy_tensor - expected_out_residual = 2 * residual - mock_maybe_chunk_residual.assert_called_once() - mock_add_rmsnorm.assert_called_once() - mock_maybe_wait_prefetch_done.assert_called_once() - assert torch.allclose(out_x, expected_out_x) - assert torch.allclose(out_residual, expected_out_residual) + x_out, residual_out = layer.forward_oot(x, residual) + + x_out_expected = 2 * x + residual_out_expected = 2 * residual + + assert torch.allclose(x_out, x_out_expected) + assert torch.allclose(residual_out, residual_out_expected) else: - out_x = layer.forward(dummy_tensor, residual) - expected_out_x = dummy_tensor + 1 - - mock_rmsnorm.assert_called_once() - assert torch.allclose(out_x, expected_out_x) - - -@patch("vllm_ascend.utils.is_310p", return_value=False) -@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm) -@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None) -@patch("torch.ops.vllm.maybe_chunk_residual", - side_effect=mock_maybe_chunk_residual) -def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual, - mock_maybe_wait_prefetch_done, - mock_add_rms_norm, mock_is310p): - x = torch.randn(4, 512, dtype=torch.bfloat16) - residual = torch.randn(16, 512, dtype=torch.bfloat16) - layer = RMSNorm(hidden_size=512, eps=1e-05) - - out_x, out_residual = layer.forward_oot(x, residual) - - expected_out_x = 2 * x - expected_out_residual = 2 * residual[:4] - - mock_maybe_chunk_residual.assert_called_once() - mock_add_rms_norm.assert_called_once() - mock_maybe_wait_prefetch_done.assert_called_once() - assert out_residual.size(0) == 4 - assert torch.allclose(out_x, expected_out_x) - assert torch.allclose(out_residual, expected_out_residual) + x_out = layer.forward(x, residual) + x_out_expected = x + 1 + + assert torch.allclose(x_out, x_out_expected) + + # Test case for flashcomm_v1 scenario + def test_forward_oot_with_flashcomm_v1(self): + layer = RMSNorm(hidden_size=512, eps=1e-05) + x = torch.randn(4, 512, dtype=torch.bfloat16) + residual = torch.randn(16, 512, dtype=torch.bfloat16) + + x_out, residual_out = layer.forward_oot(x, residual) + + x_out_expected = 2 * x + residual_out_expected = 2 * residual[:4] + + assert residual_out.size(0) == 4 + assert torch.allclose(x_out, x_out_expected) + assert torch.allclose(residual_out, residual_out_expected) + + # Test case for addrmsnorm + w8a8 quant fusion + def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture): + mock_is_310p = mocker.patch("vllm_ascend.utils.is_310p") + mock_is_310p.return_value = False + mock_get_forward_context = mocker.patch( + "vllm_ascend.ops.layernorm.get_forward_context") + + # Simulating a scenario with quant_fusion enabled + mock_forward_context = mocker.MagicMock() + + mock_model_instance = mocker.MagicMock() + mock_forward_context.model_instance = mock_model_instance + mock_model_instance.model.layers = [ + mocker.MagicMock() for _ in range(2) + ] + + mock_layer_0 = mock_model_instance.model.layers[0] + mock_layer_0.self_attn.qkv_proj = mocker.MagicMock() + mock_layer_0.mlp.gate_up_proj = mocker.MagicMock() + + mock_layer_1 = mock_model_instance.model.layers[1] + mock_layer_1.self_attn.qkv_proj = mocker.MagicMock() + mock_layer_1.mlp.gate_up_proj = mocker.MagicMock() + + mock_quant_method_0_qkv = mocker.MagicMock() + mock_quant_method_0_qkv.quant_method = AscendW8A8LinearMethod() + mock_quant_method_0_gate_up = mocker.MagicMock() + mock_quant_method_0_gate_up.quant_method = AscendW8A8LinearMethod() + mock_layer_0.self_attn.qkv_proj.quant_method = mock_quant_method_0_qkv + mock_layer_0.mlp.gate_up_proj.quant_method = mock_quant_method_0_gate_up + + mock_quant_method_1_qkv = mocker.MagicMock() + mock_quant_method_1_qkv.quant_method = AscendW8A8LinearMethod() + mock_quant_method_1_gate_up = mocker.MagicMock() + mock_quant_method_1_gate_up.quant_method = AscendW8A8LinearMethod() + mock_layer_1.self_attn.qkv_proj.quant_method = mock_quant_method_1_qkv + mock_layer_1.mlp.gate_up_proj.quant_method = mock_quant_method_1_gate_up + + mock_get_forward_context.return_value = mock_forward_context + + mock_forward_context.addrmsnorm_quant_fusion_enabled = True + mock_forward_context.prefetch_mlp_enabled = False + mock_forward_context.layer_idx = 0 + mock_forward_context.num_hidden_layers = 2 + mock_forward_context.fusion_linear = "gate_up_dense" + + # Ensure fusion and layer_idx increment are handled correctly + x = torch.randn(4, 8, dtype=torch.float16) + residual = torch.randn(4, 8, dtype=torch.float16) + layer = RMSNorm(hidden_size=8, eps=1e-05) + + x_out, residual_out = layer.forward_oot(x, residual) + + assert mock_get_forward_context.call_count == 1 + assert mock_forward_context.fusion_linear == "qkv_dense" + assert mock_forward_context.layer_idx == 1 + + x_out, residual_out = layer.forward_oot(x, residual) + + assert mock_get_forward_context.call_count == 2 + assert mock_forward_context.fusion_linear == "gate_up_dense" + assert mock_forward_context.layer_idx == 1 + + x_out, residual_out = layer.forward_oot(x, residual) + + assert mock_get_forward_context.call_count == 3 + assert mock_forward_context.fusion_linear == "qkv_dense" + assert mock_forward_context.layer_idx == 2 + + x_out, residual_out = layer.forward_oot(x, residual) + + assert mock_get_forward_context.call_count == 4 + assert mock_forward_context.fusion_linear == "qkv_dense" + assert mock_forward_context.layer_idx == 2 + + +if __name__ == '__main__': + unittest.main() diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index b368feb73f..a8cbf832c9 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -129,6 +129,22 @@ def set_ascend_forward_context( forward_context.prefetch_mlp_down_proj = False forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled + # TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant. + # It will be improved later by implementing operator fusion through the FX graph. + # + # set for addrmsnorm+quant fusion. + # this optim now just support dense models due to the specific operators used. + # Once the necessary conditions are met, support for MOE models will also be added. + from vllm_ascend.quantization.quant_config import AscendQuantConfig + addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \ + vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \ + forward_context.layer_idx is not None + if addrmsnorm_quant_fusion_enabled: + forward_context.model_instance = model_instance + forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers + forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense" + forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled + if num_tokens is None and attn_metadata is not None: num_tokens = attn_metadata.num_actual_tokens diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 996ebfab25..34529d52f7 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -35,9 +35,6 @@ def register_model(): "Qwen3MoeForCausalLM", "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM") - ModelRegistry.register_model( - "Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM") - # There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization # to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM. ModelRegistry.register_model( diff --git a/vllm_ascend/models/qwen3.py b/vllm_ascend/models/qwen3.py deleted file mode 100644 index a05106f228..0000000000 --- a/vllm_ascend/models/qwen3.py +++ /dev/null @@ -1,156 +0,0 @@ -from collections.abc import Iterable -from typing import Optional, Union - -import torch -from torch import nn -from transformers import Qwen3Config -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP -from vllm.model_executor.models.qwen2 import Qwen2Model -from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer -from vllm.model_executor.models.utils import (AutoWeightsLoader, - PPMissingLayer, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors - -from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant - - -class CustomQwen3DecoderLayer(Qwen3DecoderLayer): - - def __init__( - self, - config: Qwen3Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix) - if quant_config is None: - return - - from vllm_ascend.quantization.quant_config import AscendQuantConfig - from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod - - assert isinstance(quant_config, AscendQuantConfig), \ - "Expected quant_config to be an instance of AscendQuantConfig" - - if isinstance(self.self_attn.qkv_proj.quant_method.quant_method, - AscendW8A8LinearMethod): - self.input_layernorm = AddRMSNormW8A8Quant( - config.hidden_size, - layer=self.self_attn.qkv_proj, - eps=config.rms_norm_eps) - if isinstance(self.mlp.gate_up_proj.quant_method.quant_method, - AscendW8A8LinearMethod): - self.post_attention_layernorm = AddRMSNormW8A8Quant( - config.hidden_size, - layer=self.mlp.gate_up_proj, - eps=config.rms_norm_eps) - - -ALL_DECODER_LAYER_TYPES = { - "attention": CustomQwen3DecoderLayer, -} - - -@support_torch_compile( - dynamic_arg_dims={ - "input_ids": 0, - # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, - # otherwise (seq_len, ). - "positions": -1, - "intermediate_tensors": 0, - "inputs_embeds": 0, - }) -class CustomQwen3Model(Qwen2Model): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - decoder_layer_type=CustomQwen3DecoderLayer) - - -class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): - # add `CustomQwen3Model` to init self.model - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - - self.config = config - self.lora_config = lora_config - - self.quant_config = quant_config - self.model = CustomQwen3Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - if get_pp_group().is_last_rank: - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) - else: - self.lm_head = PPMissingLayer() - - self.logits_processor = LogitsProcessor(config.vocab_size) - - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), - ) - return loader.load_weights(weights) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index ccd031ccdf..da48362f46 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -18,47 +18,40 @@ from typing import Optional, Tuple, Union, cast import torch +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm -class AddRMSNormW8A8Quant(RMSNorm): - # Fuse AddRmsNorm and W8A8 quantization ops together - - def __init__( - self, - hidden_size: int, - layer: torch.nn.Module, - eps: float = 1e-6, - var_hidden_size: Optional[int] = None, - has_weight: bool = True, - dtype: Optional[torch.dtype] = None, - ) -> None: - super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) - self.layer = layer - - def forward( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - import torch_npu - - if residual is not None: - residual = torch.ops.vllm.maybe_chunk_residual(x, residual) - assert x.size(0) == residual.size(0) - x, _, residual = torch_npu.npu_add_rms_norm_quant( - x, - residual, - self.weight, - self.layer.aclnn_input_scale, - self.layer.aclnn_input_offset, - epsilon=self.variance_epsilon) - torch.ops.vllm.maybe_wait_prefetch_done(x) - return x, residual - - x, residual = torch_npu.npu_rms_norm(x, self.weight, - self.variance_epsilon) - return x +def _addrmsnorm_forward_oot( + self, + x: torch.Tensor, + residual: torch.Tensor, + layer: Optional[torch.nn.Module] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + import torch_npu + + from vllm_ascend.utils import is_310p + + if layer is not None and not is_310p(): + x, _, residual = torch_npu.npu_add_rms_norm_quant( + x, + residual, + self.weight, + layer.aclnn_input_scale, + layer.aclnn_input_offset, + epsilon=self.variance_epsilon) + else: + if is_310p(): + orig_dtype = residual.dtype + x = x + residual.to(x.dtype) + residual = x.to(orig_dtype) + x, _ = torch_npu.npu_rms_norm(x, self.weight, + self.variance_epsilon) + else: + x, _, residual = torch_npu.npu_add_rms_norm( + x, residual, self.weight, self.variance_epsilon) + torch.ops.vllm.maybe_wait_prefetch_done(x) + return x, residual class AscendRMSNorm(RMSNorm): @@ -70,26 +63,49 @@ def forward_oot( ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: import torch_npu - from vllm_ascend.utils import is_310p if residual is not None: residual = torch.ops.vllm.maybe_chunk_residual(x, residual) assert x.size(0) == residual.size(0) - if is_310p(): - orig_dtype = residual.dtype - x = x + residual.to(x.dtype) - residual = x.to(orig_dtype) - x, _ = torch_npu.npu_rms_norm(x, self.weight, - self.variance_epsilon) - else: - x, _, residual = torch_npu.npu_add_rms_norm( - x, residual, self.weight, self.variance_epsilon) - torch.ops.vllm.maybe_wait_prefetch_done(x) + x, residual = _addrmsnorm_forward_oot( + self, x, residual, self.next_need_quant_fusion_linear) return x, residual - x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) return x + @property + def next_need_quant_fusion_linear(self): + try: + forward_context = get_forward_context() + if not forward_context.addrmsnorm_quant_fusion_enabled or \ + forward_context.layer_idx == forward_context.num_hidden_layers: + return None + except AssertionError: + return None + + next_linear = None + model_instance = forward_context.model_instance + layer_idx = forward_context.layer_idx + fusion_linear = forward_context.fusion_linear + next_linear = None + if fusion_linear == "qkv_dense": + next_linear = model_instance.model.layers[ + layer_idx].self_attn.qkv_proj + forward_context.fusion_linear = "gate_up_dense" + elif fusion_linear == "gate_up_dense": + next_linear = model_instance.model.layers[ + layer_idx].mlp.gate_up_proj + forward_context.fusion_linear = "qkv_dense" + # if prefetch_mlp_weight enabled, following accumulation operation + # does not need to be repeated + if not forward_context.prefetch_mlp_enabled: + forward_context.layer_idx += 1 + from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod + if next_linear is not None and \ + not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod): + next_linear = None + return next_linear + class AscendQuantRMSNorm(AscendRMSNorm):