18
18
from typing import Optional , Tuple , Union , cast
19
19
20
20
import torch
21
+ from vllm .forward_context import get_forward_context
21
22
from vllm .model_executor .layers .layernorm import RMSNorm
22
23
23
24
24
- class AddRMSNormW8A8Quant (RMSNorm ):
25
- # Fuse AddRmsNorm and W8A8 quantization ops together
26
-
27
- def __init__ (
28
- self ,
29
- hidden_size : int ,
30
- layer : torch .nn .Module ,
31
- eps : float = 1e-6 ,
32
- var_hidden_size : Optional [int ] = None ,
33
- has_weight : bool = True ,
34
- dtype : Optional [torch .dtype ] = None ,
35
- ) -> None :
36
- super ().__init__ (hidden_size , eps , var_hidden_size , has_weight , dtype )
37
- self .layer = layer
38
-
39
- def forward (
40
- self ,
41
- x : torch .Tensor ,
42
- residual : Optional [torch .Tensor ] = None ,
43
- ) -> Union [torch .Tensor , tuple [torch .Tensor , torch .Tensor ]]:
44
- import torch_npu
45
-
46
- if residual is not None :
47
- residual = torch .ops .vllm .maybe_chunk_residual (x , residual )
48
- assert x .size (0 ) == residual .size (0 )
49
- x , _ , residual = torch_npu .npu_add_rms_norm_quant (
50
- x ,
51
- residual ,
52
- self .weight ,
53
- self .layer .aclnn_input_scale ,
54
- self .layer .aclnn_input_offset ,
55
- epsilon = self .variance_epsilon )
56
- torch .ops .vllm .maybe_wait_prefetch_done (x )
57
- return x , residual
58
-
59
- x , residual = torch_npu .npu_rms_norm (x , self .weight ,
60
- self .variance_epsilon )
61
- return x
25
+ def _addrmsnorm_forward_oot (
26
+ self ,
27
+ x : torch .Tensor ,
28
+ residual : Optional [torch .Tensor ] = None ,
29
+ layer : Optional [torch .nn .Module ] = None ,
30
+ ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
31
+ import torch_npu
32
+
33
+ if layer is not None :
34
+ x , _ , residual = torch_npu .npu_add_rms_norm_quant (
35
+ x ,
36
+ residual ,
37
+ self .weight ,
38
+ layer .aclnn_input_scale ,
39
+ layer .aclnn_input_offset ,
40
+ epsilon = self .variance_epsilon )
41
+ else :
42
+ from vllm_ascend .utils import is_310p
43
+ if is_310p ():
44
+ orig_dtype = residual .dtype
45
+ x = x + residual .to (x .dtype )
46
+ residual = x .to (orig_dtype )
47
+ x , _ = torch_npu .npu_rms_norm (x , self .weight ,
48
+ self .variance_epsilon )
49
+ else :
50
+ x , _ , residual = torch_npu .npu_add_rms_norm (
51
+ x , residual , self .weight , self .variance_epsilon )
52
+ torch .ops .vllm .maybe_wait_prefetch_done (x )
53
+ return x , residual
62
54
63
55
64
56
class AscendRMSNorm (RMSNorm ):
@@ -70,26 +62,47 @@ def forward_oot(
70
62
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
71
63
import torch_npu
72
64
73
- from vllm_ascend .utils import is_310p
74
65
if residual is not None :
75
66
residual = torch .ops .vllm .maybe_chunk_residual (x , residual )
76
67
assert x .size (0 ) == residual .size (0 )
77
- if is_310p ():
78
- orig_dtype = residual .dtype
79
- x = x + residual .to (x .dtype )
80
- residual = x .to (orig_dtype )
81
- x , _ = torch_npu .npu_rms_norm (x , self .weight ,
82
- self .variance_epsilon )
83
- else :
84
- x , _ , residual = torch_npu .npu_add_rms_norm (
85
- x , residual , self .weight , self .variance_epsilon )
86
- torch .ops .vllm .maybe_wait_prefetch_done (x )
68
+ x , residual = _addrmsnorm_forward_oot (self , x , residual ,
69
+ self .next_need_quant_fusion_linear )
87
70
return x , residual
88
-
89
71
x , residual = torch_npu .npu_rms_norm (x , self .weight ,
90
72
self .variance_epsilon )
91
73
return x
92
74
75
+ @property
76
+ def next_need_quant_fusion_linear ():
77
+ try :
78
+ forward_context = get_forward_context ()
79
+ if not forward_context .addrmsnorm_quant_fusion_enabled or \
80
+ forward_context .layer_idx == forward_context .num_hidden_layers :
81
+ return None
82
+ except AssertionError :
83
+ return None
84
+
85
+ next_linear = None
86
+ model_instance = forward_context .model_instance
87
+ layer_idx = forward_context .layer_idx
88
+ fusion_linear = forward_context .fusion_linear
89
+ next_linear = None
90
+ if fusion_linear == "qkv_dense" :
91
+ next_linear = model_instance .model .layers [layer_idx ].self_attn .qkv_proj
92
+ forward_context .fusion_linear = "gate_up_dense"
93
+ elif fusion_linear == "gate_up_dense" :
94
+ next_linear = model_instance .model .layers [layer_idx ].mlp .gate_up_proj
95
+ forward_context .fusion_linear = "qkv_dense"
96
+ # if prefetch_mlp_weight enabled, following accumulation operation
97
+ # does not need to be repeated
98
+ if not forward_context .prefetch_mlp_enabled :
99
+ forward_context .layer_idx += 1
100
+ from vllm_ascend .quantization .w8a8 import AscendW8A8LinearMethod
101
+ if next_linear is not None and \
102
+ not isinstance (next_linear .quant_method .quant_method , AscendW8A8LinearMethod ):
103
+ next_linear = None
104
+ return next_linear
105
+
93
106
94
107
class AscendQuantRMSNorm (AscendRMSNorm ):
95
108
0 commit comments