Skip to content

Commit dc697af

Browse files
committed
[main] addrmsnorm + quant fusion optim
Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 382c29f commit dc697af

File tree

4 files changed

+77
-209
lines changed

4 files changed

+77
-209
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,20 @@ def set_ascend_forward_context(
145145
forward_context.prefetch_mlp_down_proj = False
146146
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
147147

148+
# set for addrmsnorm+quant fusion.
149+
# this optim now just support dense models due to the specific operators used.
150+
# Once the necessary conditions are met, support for MOE models will also be added.
151+
addrmsnorm_quant_fusion_enabled = vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \
152+
forward_context.layer_idx is not None
153+
if addrmsnorm_quant_fusion_enabled:
154+
from vllm_ascend.quantization.quant_config import AscendQuantConfig
155+
assert isinstance(vllm_config.quant_config, AscendQuantConfig), \
156+
"Expected quant_config to be an instance of AscendQuantConfig"
157+
forward_context.model_instance = model_instance
158+
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
159+
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
160+
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled
161+
148162
if num_tokens is None and attn_metadata is not None:
149163
num_tokens = attn_metadata.num_actual_tokens
150164

vllm_ascend/models/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,6 @@ def register_model():
4444
"Qwen3MoeForCausalLM",
4545
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
4646

47-
ModelRegistry.register_model(
48-
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
49-
5047
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
5148
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
5249
ModelRegistry.register_model(

vllm_ascend/models/qwen3.py

Lines changed: 0 additions & 156 deletions
This file was deleted.

vllm_ascend/ops/layernorm.py

Lines changed: 63 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,47 +18,39 @@
1818
from typing import Optional, Tuple, Union, cast
1919

2020
import torch
21+
from vllm.forward_context import get_forward_context
2122
from vllm.model_executor.layers.layernorm import RMSNorm
2223

2324

24-
class AddRMSNormW8A8Quant(RMSNorm):
25-
# Fuse AddRmsNorm and W8A8 quantization ops together
26-
27-
def __init__(
28-
self,
29-
hidden_size: int,
30-
layer: torch.nn.Module,
31-
eps: float = 1e-6,
32-
var_hidden_size: Optional[int] = None,
33-
has_weight: bool = True,
34-
dtype: Optional[torch.dtype] = None,
35-
) -> None:
36-
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
37-
self.layer = layer
38-
39-
def forward(
40-
self,
41-
x: torch.Tensor,
42-
residual: Optional[torch.Tensor] = None,
43-
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
44-
import torch_npu
45-
46-
if residual is not None:
47-
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
48-
assert x.size(0) == residual.size(0)
49-
x, _, residual = torch_npu.npu_add_rms_norm_quant(
50-
x,
51-
residual,
52-
self.weight,
53-
self.layer.aclnn_input_scale,
54-
self.layer.aclnn_input_offset,
55-
epsilon=self.variance_epsilon)
56-
torch.ops.vllm.maybe_wait_prefetch_done(x)
57-
return x, residual
58-
59-
x, residual = torch_npu.npu_rms_norm(x, self.weight,
60-
self.variance_epsilon)
61-
return x
25+
def _addrmsnorm_forward_oot(
26+
self,
27+
x: torch.Tensor,
28+
residual: Optional[torch.Tensor] = None,
29+
layer: Optional[torch.nn.Module] = None,
30+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
31+
import torch_npu
32+
33+
if layer is not None:
34+
x, _, residual = torch_npu.npu_add_rms_norm_quant(
35+
x,
36+
residual,
37+
self.weight,
38+
layer.aclnn_input_scale,
39+
layer.aclnn_input_offset,
40+
epsilon=self.variance_epsilon)
41+
else:
42+
from vllm_ascend.utils import is_310p
43+
if is_310p():
44+
orig_dtype = residual.dtype
45+
x = x + residual.to(x.dtype)
46+
residual = x.to(orig_dtype)
47+
x, _ = torch_npu.npu_rms_norm(x, self.weight,
48+
self.variance_epsilon)
49+
else:
50+
x, _, residual = torch_npu.npu_add_rms_norm(
51+
x, residual, self.weight, self.variance_epsilon)
52+
torch.ops.vllm.maybe_wait_prefetch_done(x)
53+
return x, residual
6254

6355

6456
class AscendRMSNorm(RMSNorm):
@@ -70,26 +62,47 @@ def forward_oot(
7062
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
7163
import torch_npu
7264

73-
from vllm_ascend.utils import is_310p
7465
if residual is not None:
7566
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
7667
assert x.size(0) == residual.size(0)
77-
if is_310p():
78-
orig_dtype = residual.dtype
79-
x = x + residual.to(x.dtype)
80-
residual = x.to(orig_dtype)
81-
x, _ = torch_npu.npu_rms_norm(x, self.weight,
82-
self.variance_epsilon)
83-
else:
84-
x, _, residual = torch_npu.npu_add_rms_norm(
85-
x, residual, self.weight, self.variance_epsilon)
86-
torch.ops.vllm.maybe_wait_prefetch_done(x)
68+
x, residual = _addrmsnorm_forward_oot(self, x, residual,
69+
self.next_need_quant_fusion_linear)
8770
return x, residual
88-
8971
x, residual = torch_npu.npu_rms_norm(x, self.weight,
9072
self.variance_epsilon)
9173
return x
9274

75+
@property
76+
def next_need_quant_fusion_linear():
77+
try:
78+
forward_context = get_forward_context()
79+
if not forward_context.addrmsnorm_quant_fusion_enabled or \
80+
forward_context.layer_idx == forward_context.num_hidden_layers:
81+
return None
82+
except AssertionError:
83+
return None
84+
85+
next_linear = None
86+
model_instance = forward_context.model_instance
87+
layer_idx = forward_context.layer_idx
88+
fusion_linear = forward_context.fusion_linear
89+
next_linear = None
90+
if fusion_linear == "qkv_dense":
91+
next_linear = model_instance.model.layers[layer_idx].self_attn.qkv_proj
92+
forward_context.fusion_linear = "gate_up_dense"
93+
elif fusion_linear == "gate_up_dense":
94+
next_linear = model_instance.model.layers[layer_idx].mlp.gate_up_proj
95+
forward_context.fusion_linear = "qkv_dense"
96+
# if prefetch_mlp_weight enabled, following accumulation operation
97+
# does not need to be repeated
98+
if not forward_context.prefetch_mlp_enabled:
99+
forward_context.layer_idx += 1
100+
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
101+
if next_linear is not None and \
102+
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):
103+
next_linear = None
104+
return next_linear
105+
93106

94107
class AscendQuantRMSNorm(AscendRMSNorm):
95108

0 commit comments

Comments
 (0)