Skip to content

Commit f025bfa

Browse files
committed
[main] addrmsnorm + quant fusion optim
Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 1bbb20e commit f025bfa

File tree

4 files changed

+61
-5
lines changed

4 files changed

+61
-5
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def set_ascend_forward_context(
6666
moe_comm_method: str = "",
6767
num_actual_tokens: Optional[int] = None,
6868
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
69-
batch_descriptor: Optional[BatchDescriptor] = None):
69+
batch_descriptor: Optional[BatchDescriptor] = None,
70+
prefetch_model: torch.nn.Module = None):
7071
"""A context manager that stores the current forward context,
7172
can be attention metadata, etc.
7273
We add some additional param into forward_context.
@@ -81,6 +82,19 @@ def set_ascend_forward_context(
8182
batch_descriptor=batch_descriptor,
8283
):
8384
forward_context = get_forward_context()
85+
if envs_ascend.VLLM_ASCEND_ENABLE_ADDRMSNORM_QUANT_FUSION:
86+
model_type = vllm_config.model_config.hf_config.model_type
87+
forward_context.prefetch_model = prefetch_model
88+
forward_context.layer_idx = 0
89+
# dense model
90+
if model_type in ["llama", "qwen2", "qwen3"]:
91+
forward_context.fusion_linear = "gate_up_dense"
92+
# moe model
93+
elif model_type == "qwen3_moe":
94+
forward_context.fusion_linear = "qkv_moe"
95+
else:
96+
raise ValueError(f"AddRmsNorm+Quant Fusion unsupport model type: {model_type}")
97+
8498
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
8599
forward_context.with_prefill = with_prefill
86100
tp_world_size = get_tensor_model_parallel_world_size()

vllm_ascend/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@
144144
# this feature in eager mode will get better performance.
145145
"VLLM_ASCEND_ENABLE_MLP_OPTIMIZE":
146146
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLP_OPTIMIZE", '0'))),
147+
# Whether to enable addrmsnorm + quant fusion
148+
"VLLM_ASCEND_ENABLE_ADDRMSNORM_QUANT_FUSION":
149+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_ADDRMSNORM_QUANT_FUSION", '0'))),
147150
# Determine the number of physical devices in a non-full-use scenario
148151
# caused by the initialization of the Mooncake connector.
149152
"PHYSICAL_DEVICES":

vllm_ascend/ops/layernorm.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from typing import Optional, Tuple, Union
1919

2020
import torch
21+
from vllm.forward_context import get_forward_context
2122
from vllm.model_executor.layers.layernorm import RMSNorm
23+
import vllm_ascend.envs as envs_ascend
2224

2325

2426
class AddRMSNormW8A8Quant(RMSNorm):
@@ -64,6 +66,42 @@ def forward(
6466
self.variance_epsilon)
6567
return x
6668

69+
def _addrmsnorm_w8a8_quant_forward_oot(
70+
self,
71+
x: torch.Tensor,
72+
residual: torch.Tensor,
73+
layer: torch.nn.Module,
74+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
75+
import torch_npu
76+
77+
x, _, residual = torch_npu.npu_add_rms_norm_quant(
78+
x,
79+
residual,
80+
self.weight,
81+
layer.aclnn_input_scale,
82+
layer.aclnn_input_offset,
83+
epsilon=self.variance_epsilon)
84+
return x, residual
85+
86+
87+
def _addrmsnorm_forward_oot(
88+
self,
89+
x: torch.Tensor,
90+
residual: Optional[torch.Tensor] = None,
91+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
92+
import torch_npu
93+
94+
from vllm_ascend.utils import is_310p
95+
if is_310p():
96+
orig_dtype = residual.dtype
97+
x = x + residual.to(x.dtype)
98+
residual = x.to(orig_dtype)
99+
x, _ = torch_npu.npu_rms_norm(x, self.weight,
100+
self.variance_epsilon)
101+
else:
102+
x, _, residual = torch_npu.npu_add_rms_norm(
103+
x, residual, self.weight, self.variance_epsilon)
104+
return x, residual
67105

68106
class AscendRMSNorm(RMSNorm):
69107

@@ -90,8 +128,7 @@ def forward_oot(
90128
x, _ = torch_npu.npu_rms_norm(x, self.weight,
91129
self.variance_epsilon)
92130
else:
93-
x, _, residual = torch_npu.npu_add_rms_norm(
94-
x, residual, self.weight, self.variance_epsilon)
131+
x, residual = _addrmsnorm_forward_oot(x, residual)
95132
return x, residual
96133

97134
x, residual = torch_npu.npu_rms_norm(x, self.weight,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,7 +1563,8 @@ def execute_model(
15631563
aclgraph_runtime_mode=aclgraph_runtime_mode,
15641564
batch_descriptor=batch_descriptor,
15651565
num_actual_tokens=scheduler_output.
1566-
total_num_scheduled_tokens):
1566+
total_num_scheduled_tokens,
1567+
prefetch_model=self.model):
15671568
self.maybe_setup_kv_connector(scheduler_output)
15681569

15691570
hidden_states = self._generate_process_reqs_hidden_states(
@@ -2009,7 +2010,8 @@ def dummy_compute_logits(hidden_states):
20092010
moe_comm_method=moe_comm_method,
20102011
num_actual_tokens=0,
20112012
aclgraph_runtime_mode=aclgraph_runtime_mode,
2012-
batch_descriptor=batch_descriptor):
2013+
batch_descriptor=batch_descriptor,
2014+
prefetch_model=self.model):
20132015
hidden_states = self._generate_dummy_run_hidden_states(
20142016
with_prefill, is_torchair_compile, input_ids, positions,
20152017
attn_metadata, num_tokens, intermediate_tensors,

0 commit comments

Comments
 (0)