Skip to content

Commit 4b2faea

Browse files
committed
migrate faquant from mindie-turbo to vllm-ascend
Signed-off-by: 22dimensions <waitingwind@foxmail.com>
1 parent e0716c5 commit 4b2faea

File tree

10 files changed

+573
-63
lines changed

10 files changed

+573
-63
lines changed

tests/singlecard/test_offline_inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
QUANTIZATION_MODELS = [
4747
"vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8",
48+
"vllm-ascend/Qwen2.5-0.5B-Instruct-fa3"
4849
]
4950

5051

@@ -71,7 +72,7 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None:
7172
@pytest.mark.parametrize("max_tokens", [5])
7273
def test_quantization_models(model: str, max_tokens: int) -> None:
7374
prompt = "The following numbers of the sequence " + ", ".join(
74-
str(i) for i in range(1024)) + " are:"
75+
str(i) for i in range(256)) + " are:"
7576
example_prompts = [prompt]
7677

7778
# NOTE: Using quantized model repo id from modelscope encounters an issue,
@@ -80,7 +81,7 @@ def test_quantization_models(model: str, max_tokens: int) -> None:
8081
model_path = snapshot_download(model)
8182

8283
with VllmRunner(model_path,
83-
max_model_len=8192,
84+
max_model_len=4096,
8485
enforce_eager=True,
8586
dtype="auto",
8687
gpu_memory_utilization=0.7,

vllm_ascend/attention/attention_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def forward(
331331
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
332332
pass
333333
# V0-Style scheduler situation.
334-
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
334+
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
335335
assert attn_metadata is not None
336336
assert attn_metadata.attn_mask is not None
337337
mask = attn_metadata.attn_mask

vllm_ascend/quantization/faquant.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import List
19+
20+
import torch
21+
import torch_npu
22+
23+
from .quant_utils import (SRC_DTYPE_TO_ACL_DTYPE, TYPE_QUANT_QKV_ONLINE,
24+
quant_per_tensor)
25+
26+
27+
class AscendFAQuantAttentionMethod:
28+
"""Linear method for Ascend FAQuant
29+
"""
30+
31+
def __init__(self) -> None:
32+
super().__init__()
33+
34+
@staticmethod
35+
def get_quant_param() -> List[str]:
36+
return [
37+
"fa_q.scale", "fa_q.offset", "fa_k.scale", "fa_k.offset",
38+
"fa_v.scale", "fa_v.offset"
39+
]
40+
41+
@staticmethod
42+
def get_extra_module_names() -> List[str]:
43+
44+
return ["fa_q", "fa_k", "fa_v"]
45+
46+
@staticmethod
47+
def process_weights_after_loading(layer):
48+
fa_qscale = layer.fa_q.scale
49+
fa_kscale = layer.fa_k.scale
50+
fa_vscale = layer.fa_v.scale
51+
repeated_query_scale = layer.fa_q.scale.repeat(1, layer.head_size)
52+
layer.fa_qscale = torch.nn.Parameter(repeated_query_scale,
53+
requires_grad=False)
54+
repeated_query_offset = layer.fa_q.offset.repeat(1, layer.head_size)
55+
layer.fa_qoffset = torch.nn.Parameter(repeated_query_offset,
56+
requires_grad=False)
57+
repeated_fa_kscale = layer.fa_k.scale.repeat(1, layer.head_size)
58+
layer.fa_kscale = torch.nn.Parameter(repeated_fa_kscale,
59+
requires_grad=False)
60+
repeated_fa_koffset = layer.fa_k.offset.repeat(1, layer.head_size)
61+
layer.fa_koffset = torch.nn.Parameter(repeated_fa_koffset,
62+
requires_grad=False)
63+
repeated_fa_vscale = layer.fa_v.scale.repeat(1, layer.head_size)
64+
layer.fa_vscale = torch.nn.Parameter(repeated_fa_vscale,
65+
requires_grad=False)
66+
repeated_fa_voffset = layer.fa_v.offset.repeat(1, layer.head_size)
67+
layer.fa_voffset = torch.nn.Parameter(repeated_fa_voffset,
68+
requires_grad=False)
69+
70+
if fa_kscale.shape[0] <= 0:
71+
raise ValueError(
72+
"Expected size of fa_kscale in dimension 0 should be greater than 0"
73+
f"but got {fa_kscale.shape[0]}.")
74+
gqa_size = fa_qscale.shape[0] // fa_kscale.shape[0]
75+
fa3_k_scale, fa3_v_scale = fa_kscale.repeat(1, gqa_size).view(
76+
-1, 1), fa_vscale.repeat(1, gqa_size).view(-1, 1)
77+
qk_scale = torch.nn.Parameter(torch.squeeze(
78+
fa_qscale * fa3_k_scale).to(torch.float),
79+
requires_grad=False)
80+
layer.register_parameter("qk_scale", qk_scale)
81+
fa3_v_scale = torch.nn.Parameter(
82+
torch.squeeze(fa3_v_scale).contiguous().to(torch.float),
83+
requires_grad=False)
84+
layer.register_parameter("fa3_v_scale", fa3_v_scale)
85+
86+
@classmethod
87+
def apply(cls, layer: torch.nn.Module, query: torch.Tensor,
88+
key: torch.Tensor, value: torch.Tensor, *extra_args,
89+
**optional_args) -> torch.Tensor:
90+
key_cache, value_cache, scale, block_tables, \
91+
is_prefill, mask, slots, output = extra_args
92+
seq_lens_tensor_cpu = optional_args.get("seq_lens_tensor_cpu", None)
93+
94+
query_shape = query.shape
95+
key_shape = key.shape
96+
value_shape = value.shape
97+
98+
query = query.view(query.shape[0], -1)
99+
key = key.view(key.shape[0], -1)
100+
value = value.view(value.shape[0], -1)
101+
102+
if is_prefill:
103+
if key_cache is not None:
104+
105+
key_int8 = quant_per_tensor(key, layer.fa_kscale,
106+
layer.fa_koffset, True)
107+
value_int8 = quant_per_tensor(value, layer.fa_vscale,
108+
layer.fa_voffset, True)
109+
key_int8 = key_int8.view(key_shape)
110+
value_int8 = value_int8.view(value_shape)
111+
torch_npu._npu_reshape_and_cache(key_int8, value_int8,
112+
key_cache, value_cache, slots)
113+
if mask is None:
114+
raise ValueError(
115+
"attn_metadata.attn_mask is Null. Please check.")
116+
query = query.view(query_shape)
117+
key = key.view(key_shape)
118+
value = value.view(value_shape)
119+
if output is not None:
120+
output = output.view(query.shape)
121+
torch_npu._npu_flash_attention(query,
122+
key,
123+
value,
124+
mask,
125+
torch.tensor(
126+
seq_lens_tensor_cpu,
127+
dtype=torch.int32),
128+
scale,
129+
layer.num_heads,
130+
layer.num_kv_heads,
131+
out=output)
132+
else:
133+
query = query.view(query_shape)
134+
key = key.view(key_shape)
135+
value = value.view(value_shape)
136+
output = torch.empty_like(query,
137+
dtype=query.dtype).to(query.device)
138+
torch_npu._npu_flash_attention(query,
139+
key,
140+
value,
141+
mask,
142+
torch.tensor(
143+
seq_lens_tensor_cpu,
144+
dtype=torch.int32),
145+
scale,
146+
layer.num_heads,
147+
layer.num_kv_heads,
148+
out=output)
149+
150+
else:
151+
if key_cache is None:
152+
raise ValueError(
153+
"KV Cache can't be None in decoding phase. Got None. Please check."
154+
)
155+
query_int8 = quant_per_tensor(query, layer.fa_qscale,
156+
layer.fa_qoffset, True)
157+
key_int8 = quant_per_tensor(key, layer.fa_kscale, layer.fa_koffset,
158+
True)
159+
value_int8 = quant_per_tensor(value, layer.fa_vscale,
160+
layer.fa_voffset, True)
161+
query_int8 = query_int8.view(query_shape)
162+
key_int8 = key_int8.view(key_shape)
163+
value_int8 = value_int8.view(value_shape)
164+
query = query.view(query_shape)
165+
torch_npu._npu_reshape_and_cache(key_int8, value_int8, key_cache,
166+
value_cache, slots)
167+
if output is not None:
168+
output = output.view(query.shape)
169+
torch_npu._npu_paged_attention_quant(
170+
query_int8, key_cache, value_cache, layer.num_kv_heads,
171+
layer.num_heads, scale, block_tables,
172+
torch.tensor(seq_lens_tensor_cpu, dtype=torch.int32),
173+
TYPE_QUANT_QKV_ONLINE, SRC_DTYPE_TO_ACL_DTYPE[query.dtype],
174+
layer.qk_scale, layer.fa3_v_scale, output)
175+
else:
176+
output = torch.empty_like(query,
177+
dtype=query.dtype).to(query.device)
178+
torch_npu._npu_paged_attention_quant(
179+
query_int8, key_cache, value_cache, layer.num_kv_heads,
180+
layer.num_heads, scale, block_tables,
181+
torch.tensor(seq_lens_tensor_cpu, dtype=torch.int32),
182+
TYPE_QUANT_QKV_ONLINE, SRC_DTYPE_TO_ACL_DTYPE[query.dtype],
183+
layer.qk_scale, layer.fa3_v_scale, output)
184+
185+
output = torch.flatten(output, start_dim=-2)
186+
return output
187+
188+
@classmethod
189+
def create_weights(cls, layer: torch.nn.Module) -> None:
190+
extra_module_names = cls.get_extra_module_names()
191+
for name in extra_module_names:
192+
setattr(layer, name, torch.nn.Module())
193+
194+
params_dtype = torch.get_default_dtype()
195+
196+
params_dict = {}
197+
198+
params_dict["fa_q.scale"] = torch.empty((layer.num_heads, 1),
199+
dtype=params_dtype)
200+
params_dict["fa_q.offset"] = torch.empty((layer.num_heads, 1),
201+
dtype=torch.int8)
202+
params_dict["fa_k.scale"] = torch.empty((layer.num_kv_heads, 1),
203+
dtype=params_dtype)
204+
params_dict["fa_k.offset"] = torch.empty((layer.num_kv_heads, 1),
205+
dtype=torch.int8)
206+
params_dict["fa_v.scale"] = torch.empty((layer.num_kv_heads, 1),
207+
dtype=params_dtype)
208+
params_dict["fa_v.offset"] = torch.empty((layer.num_kv_heads, 1),
209+
dtype=torch.int8)
210+
211+
for name, weight in params_dict.items():
212+
module_name, weight_name = name.split('.')
213+
module = getattr(layer, module_name)
214+
module.register_parameter(
215+
weight_name, torch.nn.Parameter(weight, requires_grad=False))

vllm_ascend/quantization/func_wrapper.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,26 @@ def _rmsnorm_forward_oot(
4545
residual: Optional[torch.Tensor] = None,
4646
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
4747
if not self.ignore_anti:
48+
out = torch.empty_like(x, dtype=torch.int8).npu()
4849
if residual is not None:
4950
residual += x
50-
out = torch_npu._npu_quant_rms_norm(
51+
torch_npu._npu_quant_rms_norm(
5152
residual,
5253
self.weight,
5354
self.bias,
5455
self.input_scale,
5556
self.input_offset,
57+
out,
5658
self.variance_epsilon,
5759
)
5860
return out, residual
59-
out = torch_npu._npu_quant_rms_norm(
61+
torch_npu._npu_quant_rms_norm(
6062
x,
6163
self.weight,
6264
self.bias,
6365
self.input_scale,
6466
self.input_offset,
67+
out,
6568
self.variance_epsilon,
6669
)
6770
return out
@@ -90,6 +93,20 @@ def _rmsnorm_forward_oot(
9093
"unquantized_type": UnquantizedLinearMethod,
9194
},
9295
},
96+
"Qwen2Model": {
97+
"attn": {
98+
"layer_attr": "self_attn",
99+
"proj_attr": "qkv_proj",
100+
"norm_attr": "input_layernorm",
101+
"unquantized_type": UnquantizedLinearMethod,
102+
},
103+
"mlp": {
104+
"layer_attr": "mlp",
105+
"proj_attr": "gate_up_proj",
106+
"norm_attr": "post_attention_layernorm",
107+
"unquantized_type": UnquantizedLinearMethod,
108+
},
109+
}
93110
}
94111

95112

@@ -133,6 +150,24 @@ def process_module(module_cfg, layer_obj):
133150
process_module(mapping.get("attn"), layer)
134151
process_module(mapping.get("mlp"), layer)
135152

153+
def is_enable(quant_description) -> bool:
154+
need_activate = False
155+
for name in quant_description.keys():
156+
if "norm.bias" in name:
157+
need_activate = True
158+
return need_activate
159+
return need_activate
160+
161+
# check if patch activated
162+
try:
163+
if not is_enable(self.model.quant_config.quant_description):
164+
return
165+
except AttributeError:
166+
logger.info(
167+
"Warning: load model patch do not enable, because it is not quantified and llm weights"
168+
)
169+
return
170+
136171
model_type = self.model.model.__class__.__name__
137172
mapping = MODEL_LAYER_MAPPING.get(model_type)
138173

0 commit comments

Comments
 (0)