Skip to content

Commit 1d2e23e

Browse files
committed
tmp code
1 parent 9d6808e commit 1d2e23e

File tree

6 files changed

+517
-17
lines changed

6 files changed

+517
-17
lines changed

tests/singlecard/test_offline_inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
QUANTIZATION_MODELS = [
4747
"vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8",
48+
"vllm-ascend/Qwen2.5-0.5B-Instruct-fa3"
4849
]
4950

5051

@@ -71,7 +72,7 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None:
7172
@pytest.mark.parametrize("max_tokens", [5])
7273
def test_quantization_models(model: str, max_tokens: int) -> None:
7374
prompt = "The following numbers of the sequence " + ", ".join(
74-
str(i) for i in range(1024)) + " are:"
75+
str(i) for i in range(256)) + " are:"
7576
example_prompts = [prompt]
7677

7778
# NOTE: Using quantized model repo id from modelscope encounters an issue,
@@ -80,7 +81,7 @@ def test_quantization_models(model: str, max_tokens: int) -> None:
8081
model_path = snapshot_download(model)
8182

8283
with VllmRunner(model_path,
83-
max_model_len=8192,
84+
max_model_len=4096,
8485
enforce_eager=True,
8586
dtype="auto",
8687
gpu_memory_utilization=0.7,

vllm_ascend/quantization/faquant.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import List
19+
20+
import torch
21+
import torch_npu
22+
23+
from .quant_utils import (SRC_DTYPE_TO_ACL_DTYPE, TYPE_QUANT_QKV_ONLINE,
24+
quant_per_tensor)
25+
26+
27+
class AscendFAQuantAttentionMethod:
28+
"""Linear method for Ascend FAQuant
29+
"""
30+
31+
def __init__(self) -> None:
32+
super().__init__()
33+
34+
@staticmethod
35+
def get_quant_param() -> List[str]:
36+
return [
37+
"fa_q.scale", "fa_q.offset", "fa_k.scale", "fa_k.offset",
38+
"fa_v.scale", "fa_v.offset"
39+
]
40+
41+
@staticmethod
42+
def get_extra_module_names() -> List[str]:
43+
return ["fa_q", "fa_k", "fa_v"]
44+
45+
@staticmethod
46+
def process_weights_after_loading(layer):
47+
fa_qscale = layer.fa_q.scale
48+
fa_kscale = layer.fa_k.scale
49+
fa_vscale = layer.fa_v.scale
50+
repeated_query_scale = layer.fa_q.scale.repeat(1, 64)
51+
layer.fa_qscale = torch.nn.Parameter(repeated_query_scale,
52+
requires_grad=False)
53+
repeated_query_offset = layer.fa_q.offset.repeat(1, 64)
54+
layer.fa_qoffset = torch.nn.Parameter(repeated_query_offset,
55+
requires_grad=False)
56+
repeated_fa_kscale = layer.fa_k.scale.repeat(1, 64)
57+
layer.fa_kscale = torch.nn.Parameter(repeated_fa_kscale,
58+
requires_grad=False)
59+
repeated_fa_koffset = layer.fa_k.offset.repeat(1, 64)
60+
layer.fa_koffset = torch.nn.Parameter(repeated_fa_koffset,
61+
requires_grad=False)
62+
repeated_fa_vscale = layer.fa_v.scale.repeat(1, 64)
63+
layer.fa_vscale = torch.nn.Parameter(repeated_fa_vscale,
64+
requires_grad=False)
65+
repeated_fa_voffset = layer.fa_v.offset.repeat(1, 64)
66+
layer.fa_voffset = torch.nn.Parameter(repeated_fa_voffset,
67+
requires_grad=False)
68+
69+
if fa_kscale.shape[0] <= 0:
70+
raise ValueError(
71+
"Expected size of fa_kscale in dimension 0 should be greater than 0"
72+
f"but got {fa_kscale.shape[0]}.")
73+
gqa_size = fa_qscale.shape[0] // fa_kscale.shape[0]
74+
fa3_k_scale, fa3_v_scale = fa_kscale.repeat(1, gqa_size).view(
75+
-1, 1), fa_vscale.repeat(1, gqa_size).view(-1, 1)
76+
qk_scale = torch.nn.Parameter(torch.squeeze(
77+
fa_qscale * fa3_k_scale).to(torch.float),
78+
requires_grad=False)
79+
layer.register_parameter("qk_scale", qk_scale)
80+
fa3_v_scale = torch.nn.Parameter(
81+
torch.squeeze(fa3_v_scale).contiguous().to(torch.float),
82+
requires_grad=False)
83+
layer.register_parameter("fa3_v_scale", fa3_v_scale)
84+
85+
@classmethod
86+
def apply(cls, layer: torch.nn.Module, query: torch.Tensor,
87+
key: torch.Tensor, value: torch.Tensor, *extra_args,
88+
**optional_args) -> torch.Tensor:
89+
key_cache, value_cache, scale, block_tables, \
90+
is_prefill, mask, slots, output = extra_args
91+
seq_lens_tensor_cpu = optional_args.get("seq_lens_tensor_cpu", None)
92+
93+
query_shape = query.shape
94+
key_shape = key.shape
95+
value_shape = value.shape
96+
97+
query = query.view(query.shape[0], -1)
98+
key = key.view(key.shape[0], -1)
99+
value = value.view(value.shape[0], -1)
100+
101+
if is_prefill:
102+
if key_cache is not None:
103+
104+
key_int8 = quant_per_tensor(key, layer.fa_kscale,
105+
layer.fa_koffset, True)
106+
value_int8 = quant_per_tensor(value, layer.fa_vscale,
107+
layer.fa_voffset, True)
108+
key_int8 = key_int8.view(key_shape)
109+
value_int8 = key_int8.view(value_shape)
110+
query = query.view(query_shape)
111+
torch_npu._npu_reshape_and_cache(key_int8, value_int8,
112+
key_cache, value_cache, slots)
113+
if mask is None:
114+
raise ValueError(
115+
"attn_metadata.attn_mask is Null. Please check.")
116+
if output is not None:
117+
key = key.view(key_shape)
118+
value = key.view(value_shape)
119+
query = query.view(query_shape)
120+
output = output.view(query.shape)
121+
torch_npu._npu_flash_attention(query,
122+
key,
123+
value,
124+
mask,
125+
torch.tensor(
126+
seq_lens_tensor_cpu,
127+
dtype=torch.int32),
128+
scale,
129+
layer.num_heads,
130+
layer.num_kv_heads,
131+
out=output)
132+
else:
133+
key = key.view(key_shape)
134+
value = key.view(value_shape)
135+
query = query.view(query_shape)
136+
output = output.view(query.shape)
137+
output = torch.empty_like(query,
138+
dtype=query.dtype).to(query.device)
139+
torch_npu._npu_flash_attention(query,
140+
key,
141+
value,
142+
mask,
143+
torch.tensor(
144+
seq_lens_tensor_cpu,
145+
dtype=torch.int32),
146+
scale,
147+
layer.num_heads,
148+
layer.num_kv_heads,
149+
out=output)
150+
151+
else:
152+
if key_cache is None:
153+
raise ValueError(
154+
"KV Cache can't be None in decoding phase. Got None. Please check."
155+
)
156+
query_int8 = quant_per_tensor(query, layer.fa_qscale,
157+
layer.fa_qoffset, True)
158+
key_int8 = quant_per_tensor(key, layer.fa_kscale, layer.fa_koffset,
159+
True)
160+
value_int8 = quant_per_tensor(value, layer.fa_vscale,
161+
layer.fa_voffset, True)
162+
key_int8 = key_int8.view(key_shape)
163+
value_int8 = value_int8.view(value_shape)
164+
query = query.view(query_shape)
165+
query_int8 = query_int8.view(query_shape)
166+
output = output.view(query.shape)
167+
torch_npu._npu_reshape_and_cache(key_int8, value_int8, key_cache,
168+
value_cache, slots)
169+
if output is not None:
170+
output = output.view(query.shape)
171+
torch_npu._npu_paged_attention_quant(
172+
query_int8, key_cache, value_cache, layer.num_kv_heads,
173+
layer.num_heads, scale, block_tables,
174+
torch.tensor(seq_lens_tensor_cpu, dtype=torch.int32),
175+
TYPE_QUANT_QKV_ONLINE, SRC_DTYPE_TO_ACL_DTYPE[query.dtype],
176+
layer.qk_scale, layer.fa3_v_scale, output)
177+
else:
178+
output = torch.empty_like(query,
179+
dtype=query.dtype).to(query.device)
180+
torch_npu._npu_paged_attention_quant(
181+
query_int8, key_cache, value_cache, layer.num_kv_heads,
182+
layer.num_heads, scale, block_tables,
183+
torch.tensor(seq_lens_tensor_cpu, dtype=torch.int32),
184+
TYPE_QUANT_QKV_ONLINE, SRC_DTYPE_TO_ACL_DTYPE[query.dtype],
185+
layer.qk_scale, layer.fa3_v_scale, output)
186+
187+
output = torch.flatten(output, start_dim=-2)
188+
return output
189+
190+
@classmethod
191+
def create_weights(cls, layer: torch.nn.Module) -> None:
192+
extra_module_names = cls.get_extra_module_names()
193+
for name in extra_module_names:
194+
setattr(layer, name, torch.nn.Module())
195+
196+
params_dtype = torch.get_default_dtype()
197+
198+
params_dict = {}
199+
200+
params_dict["fa_q.scale"] = torch.empty((layer.num_heads, 1),
201+
dtype=params_dtype)
202+
params_dict["fa_q.offset"] = torch.empty((layer.num_heads, 1),
203+
dtype=torch.int8)
204+
params_dict["fa_k.scale"] = torch.empty((layer.num_kv_heads, 1),
205+
dtype=params_dtype)
206+
params_dict["fa_k.offset"] = torch.empty((layer.num_kv_heads, 1),
207+
dtype=torch.int8)
208+
params_dict["fa_v.scale"] = torch.empty((layer.num_kv_heads, 1),
209+
dtype=params_dtype)
210+
params_dict["fa_v.offset"] = torch.empty((layer.num_kv_heads, 1),
211+
dtype=torch.int8)
212+
213+
for name, weight in params_dict.items():
214+
module_name, weight_name = name.split('.')
215+
module = getattr(layer, module_name)
216+
module.register_parameter(
217+
weight_name, torch.nn.Parameter(weight, requires_grad=False))

vllm_ascend/quantization/quant_config.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod
4343
from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD
4444

45-
from .quantizer import AscendQuantizer
45+
from .quantizer import VLLMAscendQuantizer
4646

4747

4848
@register_quantization_config(ASCEND_QUATIZATION_METHOD)
@@ -151,7 +151,7 @@ def get_scaled_act_names(self) -> List[str]:
151151
class AscendLinearMethod(LinearMethodBase):
152152
"""Linear method for Ascend quantization.
153153
154-
This class calls AscendQuantizer to search a specific quantization
154+
This class calls VLLMAscendQuantizer to search a specific quantization
155155
implementations supported on ascend hardware for linear methods.
156156
157157
Args:
@@ -160,7 +160,7 @@ class AscendLinearMethod(LinearMethodBase):
160160

161161
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
162162
packed_modules_mapping: Dict[str, Any]) -> None:
163-
self.quantizer = AscendQuantizer.get_quantizer(
163+
self.quantizer = VLLMAscendQuantizer.get_quantizer(
164164
quant_config.quant_description, prefix, packed_modules_mapping)
165165
self.quant_method = self.quantizer.build_linear_method()
166166

@@ -232,15 +232,15 @@ def apply(
232232
class AscendKVCacheMethod(BaseKVCacheMethod):
233233
"""KVCache method for Ascend quantization.
234234
235-
This class calls AscendQuantizer to search a specific quantization
235+
This class calls VLLMAscendQuantizer to search a specific quantization
236236
implementations supported on ascend hardware for kvcache methods.
237237
238238
Args:
239239
quant_config: The Ascend quantization config.
240240
"""
241241

242242
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
243-
self.quantizer = AscendQuantizer.get_quantizer(
243+
self.quantizer = VLLMAscendQuantizer.get_quantizer(
244244
quant_config.quant_description, prefix)
245245
self.quant_method = self.quantizer.build_attention_method()
246246

@@ -285,7 +285,7 @@ def apply(self,
285285
class AscendFusedMoEMethod(FusedMoEMethodBase):
286286
"""FusedMoE method for Ascend quantization.
287287
288-
This class calls AscendQuantizer to search a specific quantization
288+
This class calls VLLMAscendQuantizer to search a specific quantization
289289
implementations supported on ascend hardware for kvcache methods.
290290
291291
Args:
@@ -294,7 +294,7 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
294294

295295
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
296296
packed_modules_mapping: Dict[str, Any]):
297-
self.quantizer = AscendQuantizer.get_quantizer(
297+
self.quantizer = VLLMAscendQuantizer.get_quantizer(
298298
quant_config.quant_description, prefix, packed_modules_mapping)
299299
self.quant_method = self.quantizer.build_moe_method()
300300

@@ -365,7 +365,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
365365
class AscendEmbeddingMethod(AscendLinearMethod):
366366
"""Embedding method for Ascend quantization.
367367
368-
This class calls AscendQuantizer to search a specific quantization
368+
This class calls VLLMAscendQuantizer to search a specific quantization
369369
implementations supported on ascend hardware for Embedding methods.
370370
371371
Args:
@@ -374,6 +374,6 @@ class AscendEmbeddingMethod(AscendLinearMethod):
374374

375375
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
376376
packed_modules_mapping: Dict[str, Any]) -> None:
377-
self.quantizer = AscendQuantizer.get_quantizer(
377+
self.quantizer = VLLMAscendQuantizer.get_quantizer(
378378
quant_config.quant_description, prefix, packed_modules_mapping)
379379
self.quant_method = self.quantizer.build_linear_method()

0 commit comments

Comments
 (0)