15
15
# This file is a part of the vllm-ascend project.
16
16
#
17
17
18
- from typing import Optional , Tuple , Union , cast
18
+ from typing import Optional , Tuple , Union
19
19
20
20
import torch
21
+ from vllm .config import get_current_vllm_config
21
22
from vllm .forward_context import get_forward_context
22
23
from vllm .model_executor .layers .layernorm import RMSNorm
23
24
@@ -27,6 +28,7 @@ def _addrmsnorm_forward_oot(
27
28
x : torch .Tensor ,
28
29
residual : torch .Tensor ,
29
30
layer : Optional [torch .nn .Module ] = None ,
31
+ bias : Optional [torch .nn .Parameter ] = None ,
30
32
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
31
33
import torch_npu
32
34
@@ -39,6 +41,7 @@ def _addrmsnorm_forward_oot(
39
41
self .weight ,
40
42
layer .aclnn_input_scale ,
41
43
layer .aclnn_input_offset ,
44
+ beta = bias ,
42
45
epsilon = self .variance_epsilon )
43
46
else :
44
47
if is_310p ():
@@ -50,12 +53,31 @@ def _addrmsnorm_forward_oot(
50
53
else :
51
54
x , _ , residual = torch_npu .npu_add_rms_norm (
52
55
x , residual , self .weight , self .variance_epsilon )
56
+ if bias is not None :
57
+ x .add_ (bias )
53
58
torch .ops .vllm .maybe_wait_prefetch_done (x )
54
59
return x , residual
55
60
56
61
57
62
class AscendRMSNorm (RMSNorm ):
58
63
64
+ def __init__ (
65
+ self ,
66
+ hidden_size : int ,
67
+ eps : float = 1e-6 ,
68
+ var_hidden_size : Optional [int ] = None ,
69
+ has_weight : bool = True ,
70
+ dtype : Optional [torch .dtype ] = None ,
71
+ ) -> None :
72
+ super ().__init__ (hidden_size , eps , var_hidden_size , has_weight , dtype )
73
+ vllm_config = get_current_vllm_config ()
74
+ self .bias = None
75
+ # m4
76
+ if vllm_config is not None and vllm_config .quant_config is not None and \
77
+ any ("norm.bias" in name for name in vllm_config .quant_config .quant_description .keys ()):
78
+ self .bias = torch .nn .Parameter (torch .zeros (hidden_size ),
79
+ requires_grad = False )
80
+
59
81
def forward_oot (
60
82
self ,
61
83
x : torch .Tensor ,
@@ -67,10 +89,13 @@ def forward_oot(
67
89
residual = torch .ops .vllm .maybe_chunk_residual (x , residual )
68
90
assert x .size (0 ) == residual .size (0 )
69
91
x , residual = _addrmsnorm_forward_oot (
70
- self , x , residual , self .next_need_quant_fusion_linear )
92
+ self , x , residual , self .next_need_quant_fusion_linear ,
93
+ self .bias )
71
94
return x , residual
72
95
x , residual = torch_npu .npu_rms_norm (x , self .weight ,
73
96
self .variance_epsilon )
97
+ if self .bias is not None :
98
+ x .add_ (self .bias )
74
99
return x
75
100
76
101
@property
@@ -100,33 +125,15 @@ def next_need_quant_fusion_linear(self):
100
125
# does not need to be repeated
101
126
if not forward_context .prefetch_mlp_enabled :
102
127
forward_context .layer_idx += 1
128
+ elif fusion_linear == "qkv_moe" :
129
+ next_linear = model_instance .model .layers [
130
+ layer_idx ].self_attn .qkv_proj
131
+ forward_context .fusion_linear = "gate_moe"
132
+ elif fusion_linear == "gate_moe" :
133
+ forward_context .fusion_linear = "qkv_moe"
134
+ forward_context .layer_idx += 1
103
135
from vllm_ascend .quantization .w8a8 import AscendW8A8LinearMethod
104
136
if next_linear is not None and \
105
137
not isinstance (next_linear .quant_method .quant_method , AscendW8A8LinearMethod ):
106
138
next_linear = None
107
139
return next_linear
108
-
109
-
110
- class AscendQuantRMSNorm (AscendRMSNorm ):
111
-
112
- def __init__ (
113
- self ,
114
- hidden_size : int ,
115
- eps : float = 1e-6 ,
116
- var_hidden_size : Optional [int ] = None ,
117
- has_weight : bool = True ,
118
- dtype : Optional [torch .dtype ] = None ,
119
- ) -> None :
120
- super ().__init__ (hidden_size , eps , var_hidden_size , has_weight , dtype )
121
- self .bias = torch .nn .Parameter (torch .zeros (hidden_size ),
122
- requires_grad = False )
123
-
124
- def forward_oot (
125
- self ,
126
- x : torch .Tensor ,
127
- residual : Optional [torch .Tensor ] = None ,
128
- ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
129
- if residual is not None :
130
- x , residual = super ().forward_oot (x , residual )
131
- return x .add_ (self .bias ), residual
132
- return cast (torch .Tensor , super ().forward_oot (x )).add_ (self .bias )
0 commit comments