Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 137 additions & 69 deletions tests/ut/ops/test_layernorm.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from unittest.mock import patch
import unittest

import pytest
import torch
from pytest_mock import MockerFixture
from vllm.model_executor.layers.layernorm import RMSNorm


@pytest.fixture
def dummy_tensor():
return torch.randn(4, 8, dtype=torch.float16)
from tests.ut.base import PytestBase
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod


def mock_maybe_chunk_residual(x, residual):
if x.size(0) != residual.size(0):
return residual[:4]

return residual


Expand All @@ -25,69 +23,139 @@ def mock_add_rms_norm(x, residual, weight, eps):
return 2 * x, None, 2 * residual


@pytest.mark.parametrize("is_310p_return", [True, False])
@pytest.mark.parametrize("residual",
[None, torch.randn(4, 8, dtype=torch.float32)])
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=mock_maybe_chunk_residual)
def test_RMSNorm_forward(mock_maybe_chunk_residual,
mock_maybe_wait_prefetch_done, mock_add_rmsnorm,
mock_rmsnorm, is_310p_return, residual, dummy_tensor):

with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset,
epsilon):
x_out = 2 * x
residual_out = 2 * residual
x_out_quant = x_out.to(torch.int8)
residual_out_quant = residual_out.to(torch.int8)
return x_out_quant, None, residual_out_quant


class TestAscendRMSNorm(PytestBase):

@pytest.fixture(autouse=True)
def context(self, mocker: MockerFixture):
mocker.patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=mock_maybe_chunk_residual)
mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
mocker.patch("torch_npu.npu_add_rms_norm",
side_effect=mock_add_rms_norm)
mocker.patch("torch_npu.npu_add_rms_norm_quant",
side_effect=mock_add_rms_norm_quant)
mocker.patch("torch.ops.vllm.maybe_wait_prefetch_done",
side_effect=lambda x: None)

# Test case for the most common and basic scenario
@pytest.mark.parametrize(
"residual", [None, torch.randn(4, 8, dtype=torch.float16)])
def test_forward_oot_basic(self, residual):
layer = RMSNorm(hidden_size=8, eps=1e-05)
x = torch.randn(4, 8, dtype=torch.float16)
if residual is not None:
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)

if is_310p_return:
expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype)
expected_out_x = expected_arg_x + 1
expected_out_residual = expected_arg_x.to(residual.dtype)

mock_maybe_chunk_residual.assert_called_once()
mock_rmsnorm.assert_called_once()
mock_maybe_wait_prefetch_done.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
else:
expected_out_x = 2 * dummy_tensor
expected_out_residual = 2 * residual
mock_maybe_chunk_residual.assert_called_once()
mock_add_rmsnorm.assert_called_once()
mock_maybe_wait_prefetch_done.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
x_out, residual_out = layer.forward_oot(x, residual)

x_out_expected = 2 * x
residual_out_expected = 2 * residual

assert torch.allclose(x_out, x_out_expected)
assert torch.allclose(residual_out, residual_out_expected)
else:
out_x = layer.forward(dummy_tensor, residual)
expected_out_x = dummy_tensor + 1

mock_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)


@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=mock_maybe_chunk_residual)
def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
mock_maybe_wait_prefetch_done,
mock_add_rms_norm, mock_is310p):
x = torch.randn(4, 512, dtype=torch.bfloat16)
residual = torch.randn(16, 512, dtype=torch.bfloat16)
layer = RMSNorm(hidden_size=512, eps=1e-05)

out_x, out_residual = layer.forward_oot(x, residual)

expected_out_x = 2 * x
expected_out_residual = 2 * residual[:4]

mock_maybe_chunk_residual.assert_called_once()
mock_add_rms_norm.assert_called_once()
mock_maybe_wait_prefetch_done.assert_called_once()
assert out_residual.size(0) == 4
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
x_out = layer.forward(x, residual)
x_out_expected = x + 1

assert torch.allclose(x_out, x_out_expected)

# Test case for flashcomm_v1 scenario
def test_forward_oot_with_flashcomm_v1(self):
layer = RMSNorm(hidden_size=512, eps=1e-05)
x = torch.randn(4, 512, dtype=torch.bfloat16)
residual = torch.randn(16, 512, dtype=torch.bfloat16)

x_out, residual_out = layer.forward_oot(x, residual)

x_out_expected = 2 * x
residual_out_expected = 2 * residual[:4]

assert residual_out.size(0) == 4
assert torch.allclose(x_out, x_out_expected)
assert torch.allclose(residual_out, residual_out_expected)

# Test case for addrmsnorm + w8a8 quant fusion
def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture):
mock_is_310p = mocker.patch("vllm_ascend.utils.is_310p")
mock_is_310p.return_value = False
mock_get_forward_context = mocker.patch(
"vllm_ascend.ops.layernorm.get_forward_context")

# Simulating a scenario with quant_fusion enabled
mock_forward_context = mocker.MagicMock()

mock_model_instance = mocker.MagicMock()
mock_forward_context.model_instance = mock_model_instance
mock_model_instance.model.layers = [
mocker.MagicMock() for _ in range(2)
]

mock_layer_0 = mock_model_instance.model.layers[0]
mock_layer_0.self_attn.qkv_proj = mocker.MagicMock()
mock_layer_0.mlp.gate_up_proj = mocker.MagicMock()

mock_layer_1 = mock_model_instance.model.layers[1]
mock_layer_1.self_attn.qkv_proj = mocker.MagicMock()
mock_layer_1.mlp.gate_up_proj = mocker.MagicMock()

mock_quant_method_0_qkv = mocker.MagicMock()
mock_quant_method_0_qkv.quant_method = AscendW8A8LinearMethod()
mock_quant_method_0_gate_up = mocker.MagicMock()
mock_quant_method_0_gate_up.quant_method = AscendW8A8LinearMethod()
mock_layer_0.self_attn.qkv_proj.quant_method = mock_quant_method_0_qkv
mock_layer_0.mlp.gate_up_proj.quant_method = mock_quant_method_0_gate_up

mock_quant_method_1_qkv = mocker.MagicMock()
mock_quant_method_1_qkv.quant_method = AscendW8A8LinearMethod()
mock_quant_method_1_gate_up = mocker.MagicMock()
mock_quant_method_1_gate_up.quant_method = AscendW8A8LinearMethod()
mock_layer_1.self_attn.qkv_proj.quant_method = mock_quant_method_1_qkv
mock_layer_1.mlp.gate_up_proj.quant_method = mock_quant_method_1_gate_up

mock_get_forward_context.return_value = mock_forward_context

mock_forward_context.addrmsnorm_quant_fusion_enabled = True
mock_forward_context.prefetch_mlp_enabled = False
mock_forward_context.layer_idx = 0
mock_forward_context.num_hidden_layers = 2
mock_forward_context.fusion_linear = "gate_up_dense"

# Ensure fusion and layer_idx increment are handled correctly
x = torch.randn(4, 8, dtype=torch.float16)
residual = torch.randn(4, 8, dtype=torch.float16)
layer = RMSNorm(hidden_size=8, eps=1e-05)

x_out, residual_out = layer.forward_oot(x, residual)

assert mock_get_forward_context.call_count == 1
assert mock_forward_context.fusion_linear == "qkv_dense"
assert mock_forward_context.layer_idx == 1

x_out, residual_out = layer.forward_oot(x, residual)

assert mock_get_forward_context.call_count == 2
assert mock_forward_context.fusion_linear == "gate_up_dense"
assert mock_forward_context.layer_idx == 1

x_out, residual_out = layer.forward_oot(x, residual)

assert mock_get_forward_context.call_count == 3
assert mock_forward_context.fusion_linear == "qkv_dense"
assert mock_forward_context.layer_idx == 2

x_out, residual_out = layer.forward_oot(x, residual)

assert mock_get_forward_context.call_count == 4
assert mock_forward_context.fusion_linear == "qkv_dense"
assert mock_forward_context.layer_idx == 2


if __name__ == '__main__':
unittest.main()
16 changes: 16 additions & 0 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,22 @@ def set_ascend_forward_context(
forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled

# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
# It will be improved later by implementing operator fusion through the FX graph.
#
# set for addrmsnorm+quant fusion.
# this optim now just support dense models due to the specific operators used.
# Once the necessary conditions are met, support for MOE models will also be added.
from vllm_ascend.quantization.quant_config import AscendQuantConfig
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \
forward_context.layer_idx is not None
if addrmsnorm_quant_fusion_enabled:
forward_context.model_instance = model_instance
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled

if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens

Expand Down
3 changes: 0 additions & 3 deletions vllm_ascend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ def register_model():
"Qwen3MoeForCausalLM",
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")

ModelRegistry.register_model(
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")

# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
ModelRegistry.register_model(
Expand Down
Loading
Loading