diff --git a/tests/singlecard/test_offline_inference.py b/tests/singlecard/test_offline_inference.py index cd65a24969..ce7e58f365 100644 --- a/tests/singlecard/test_offline_inference.py +++ b/tests/singlecard/test_offline_inference.py @@ -45,6 +45,7 @@ QUANTIZATION_MODELS = [ "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8", + "vllm-ascend/Qwen2.5-0.5B-Instruct-fa3" ] @@ -71,7 +72,7 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None: @pytest.mark.parametrize("max_tokens", [5]) def test_quantization_models(model: str, max_tokens: int) -> None: prompt = "The following numbers of the sequence " + ", ".join( - str(i) for i in range(1024)) + " are:" + str(i) for i in range(256)) + " are:" example_prompts = [prompt] # NOTE: Using quantized model repo id from modelscope encounters an issue, @@ -80,7 +81,7 @@ def test_quantization_models(model: str, max_tokens: int) -> None: model_path = snapshot_download(model) with VllmRunner(model_path, - max_model_len=8192, + max_model_len=4096, enforce_eager=True, dtype="auto", gpu_memory_utilization=0.7, diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 629fe73d5d..49497ffb58 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -331,7 +331,7 @@ def forward( # TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata pass # V0-Style scheduler situation. - elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: assert attn_metadata is not None assert attn_metadata.attn_mask is not None mask = attn_metadata.attn_mask diff --git a/vllm_ascend/quantization/faquant.py b/vllm_ascend/quantization/faquant.py new file mode 100644 index 0000000000..d60e9b4f62 --- /dev/null +++ b/vllm_ascend/quantization/faquant.py @@ -0,0 +1,215 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List + +import torch +import torch_npu + +from .quant_utils import (SRC_DTYPE_TO_ACL_DTYPE, TYPE_QUANT_QKV_ONLINE, + quant_per_tensor) + + +class AscendFAQuantAttentionMethod: + """Linear method for Ascend FAQuant + """ + + def __init__(self) -> None: + super().__init__() + + @staticmethod + def get_quant_param() -> List[str]: + return [ + "fa_q.scale", "fa_q.offset", "fa_k.scale", "fa_k.offset", + "fa_v.scale", "fa_v.offset" + ] + + @staticmethod + def get_extra_module_names() -> List[str]: + + return ["fa_q", "fa_k", "fa_v"] + + @staticmethod + def process_weights_after_loading(layer): + fa_qscale = layer.fa_q.scale + fa_kscale = layer.fa_k.scale + fa_vscale = layer.fa_v.scale + repeated_query_scale = layer.fa_q.scale.repeat(1, layer.head_size) + layer.fa_qscale = torch.nn.Parameter(repeated_query_scale, + requires_grad=False) + repeated_query_offset = layer.fa_q.offset.repeat(1, layer.head_size) + layer.fa_qoffset = torch.nn.Parameter(repeated_query_offset, + requires_grad=False) + repeated_fa_kscale = layer.fa_k.scale.repeat(1, layer.head_size) + layer.fa_kscale = torch.nn.Parameter(repeated_fa_kscale, + requires_grad=False) + repeated_fa_koffset = layer.fa_k.offset.repeat(1, layer.head_size) + layer.fa_koffset = torch.nn.Parameter(repeated_fa_koffset, + requires_grad=False) + repeated_fa_vscale = layer.fa_v.scale.repeat(1, layer.head_size) + layer.fa_vscale = torch.nn.Parameter(repeated_fa_vscale, + requires_grad=False) + repeated_fa_voffset = layer.fa_v.offset.repeat(1, layer.head_size) + layer.fa_voffset = torch.nn.Parameter(repeated_fa_voffset, + requires_grad=False) + + if fa_kscale.shape[0] <= 0: + raise ValueError( + "Expected size of fa_kscale in dimension 0 should be greater than 0" + f"but got {fa_kscale.shape[0]}.") + gqa_size = fa_qscale.shape[0] // fa_kscale.shape[0] + fa3_k_scale, fa3_v_scale = fa_kscale.repeat(1, gqa_size).view( + -1, 1), fa_vscale.repeat(1, gqa_size).view(-1, 1) + qk_scale = torch.nn.Parameter(torch.squeeze( + fa_qscale * fa3_k_scale).to(torch.float), + requires_grad=False) + layer.register_parameter("qk_scale", qk_scale) + fa3_v_scale = torch.nn.Parameter( + torch.squeeze(fa3_v_scale).contiguous().to(torch.float), + requires_grad=False) + layer.register_parameter("fa3_v_scale", fa3_v_scale) + + @classmethod + def apply(cls, layer: torch.nn.Module, query: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, *extra_args, + **optional_args) -> torch.Tensor: + key_cache, value_cache, scale, block_tables, \ + is_prefill, mask, slots, output = extra_args + seq_lens_tensor_cpu = optional_args.get("seq_lens_tensor_cpu", None) + + query_shape = query.shape + key_shape = key.shape + value_shape = value.shape + + query = query.view(query.shape[0], -1) + key = key.view(key.shape[0], -1) + value = value.view(value.shape[0], -1) + + if is_prefill: + if key_cache is not None: + + key_int8 = quant_per_tensor(key, layer.fa_kscale, + layer.fa_koffset, True) + value_int8 = quant_per_tensor(value, layer.fa_vscale, + layer.fa_voffset, True) + key_int8 = key_int8.view(key_shape) + value_int8 = value_int8.view(value_shape) + torch_npu._npu_reshape_and_cache(key_int8, value_int8, + key_cache, value_cache, slots) + if mask is None: + raise ValueError( + "attn_metadata.attn_mask is Null. Please check.") + query = query.view(query_shape) + key = key.view(key_shape) + value = value.view(value_shape) + if output is not None: + output = output.view(query.shape) + torch_npu._npu_flash_attention(query, + key, + value, + mask, + torch.tensor( + seq_lens_tensor_cpu, + dtype=torch.int32), + scale, + layer.num_heads, + layer.num_kv_heads, + out=output) + else: + query = query.view(query_shape) + key = key.view(key_shape) + value = value.view(value_shape) + output = torch.empty_like(query, + dtype=query.dtype).to(query.device) + torch_npu._npu_flash_attention(query, + key, + value, + mask, + torch.tensor( + seq_lens_tensor_cpu, + dtype=torch.int32), + scale, + layer.num_heads, + layer.num_kv_heads, + out=output) + + else: + if key_cache is None: + raise ValueError( + "KV Cache can't be None in decoding phase. Got None. Please check." + ) + query_int8 = quant_per_tensor(query, layer.fa_qscale, + layer.fa_qoffset, True) + key_int8 = quant_per_tensor(key, layer.fa_kscale, layer.fa_koffset, + True) + value_int8 = quant_per_tensor(value, layer.fa_vscale, + layer.fa_voffset, True) + query_int8 = query_int8.view(query_shape) + key_int8 = key_int8.view(key_shape) + value_int8 = value_int8.view(value_shape) + query = query.view(query_shape) + torch_npu._npu_reshape_and_cache(key_int8, value_int8, key_cache, + value_cache, slots) + if output is not None: + output = output.view(query.shape) + torch_npu._npu_paged_attention_quant( + query_int8, key_cache, value_cache, layer.num_kv_heads, + layer.num_heads, scale, block_tables, + torch.tensor(seq_lens_tensor_cpu, dtype=torch.int32), + TYPE_QUANT_QKV_ONLINE, SRC_DTYPE_TO_ACL_DTYPE[query.dtype], + layer.qk_scale, layer.fa3_v_scale, output) + else: + output = torch.empty_like(query, + dtype=query.dtype).to(query.device) + torch_npu._npu_paged_attention_quant( + query_int8, key_cache, value_cache, layer.num_kv_heads, + layer.num_heads, scale, block_tables, + torch.tensor(seq_lens_tensor_cpu, dtype=torch.int32), + TYPE_QUANT_QKV_ONLINE, SRC_DTYPE_TO_ACL_DTYPE[query.dtype], + layer.qk_scale, layer.fa3_v_scale, output) + + output = torch.flatten(output, start_dim=-2) + return output + + @classmethod + def create_weights(cls, layer: torch.nn.Module) -> None: + extra_module_names = cls.get_extra_module_names() + for name in extra_module_names: + setattr(layer, name, torch.nn.Module()) + + params_dtype = torch.get_default_dtype() + + params_dict = {} + + params_dict["fa_q.scale"] = torch.empty((layer.num_heads, 1), + dtype=params_dtype) + params_dict["fa_q.offset"] = torch.empty((layer.num_heads, 1), + dtype=torch.int8) + params_dict["fa_k.scale"] = torch.empty((layer.num_kv_heads, 1), + dtype=params_dtype) + params_dict["fa_k.offset"] = torch.empty((layer.num_kv_heads, 1), + dtype=torch.int8) + params_dict["fa_v.scale"] = torch.empty((layer.num_kv_heads, 1), + dtype=params_dtype) + params_dict["fa_v.offset"] = torch.empty((layer.num_kv_heads, 1), + dtype=torch.int8) + + for name, weight in params_dict.items(): + module_name, weight_name = name.split('.') + module = getattr(layer, module_name) + module.register_parameter( + weight_name, torch.nn.Parameter(weight, requires_grad=False)) diff --git a/vllm_ascend/quantization/func_wrapper.py b/vllm_ascend/quantization/func_wrapper.py index 77ecca2b17..ab2e0ce8a0 100644 --- a/vllm_ascend/quantization/func_wrapper.py +++ b/vllm_ascend/quantization/func_wrapper.py @@ -45,23 +45,26 @@ def _rmsnorm_forward_oot( residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if not self.ignore_anti: + out = torch.empty_like(x, dtype=torch.int8).npu() if residual is not None: residual += x - out = torch_npu._npu_quant_rms_norm( + torch_npu._npu_quant_rms_norm( residual, self.weight, self.bias, self.input_scale, self.input_offset, + out, self.variance_epsilon, ) return out, residual - out = torch_npu._npu_quant_rms_norm( + torch_npu._npu_quant_rms_norm( x, self.weight, self.bias, self.input_scale, self.input_offset, + out, self.variance_epsilon, ) return out @@ -90,6 +93,20 @@ def _rmsnorm_forward_oot( "unquantized_type": UnquantizedLinearMethod, }, }, + "Qwen2Model": { + "attn": { + "layer_attr": "self_attn", + "proj_attr": "qkv_proj", + "norm_attr": "input_layernorm", + "unquantized_type": UnquantizedLinearMethod, + }, + "mlp": { + "layer_attr": "mlp", + "proj_attr": "gate_up_proj", + "norm_attr": "post_attention_layernorm", + "unquantized_type": UnquantizedLinearMethod, + }, + } } @@ -133,6 +150,24 @@ def process_module(module_cfg, layer_obj): process_module(mapping.get("attn"), layer) process_module(mapping.get("mlp"), layer) + def is_enable(quant_description) -> bool: + need_activate = False + for name in quant_description.keys(): + if "norm.bias" in name: + need_activate = True + return need_activate + return need_activate + + # check if patch activated + try: + if not is_enable(self.model.quant_config.quant_description): + return + except AttributeError: + logger.info( + "Warning: load model patch do not enable, because it is not quantified and llm weights" + ) + return + model_type = self.model.model.__class__.__name__ mapping = MODEL_LAYER_MAPPING.get(model_type) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 1b06a4294a..7ea53a2a4f 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -42,7 +42,7 @@ from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD -from .quantizer import AscendQuantizer +from .quantizer import VLLMAscendQuantizer @register_quantization_config(ASCEND_QUATIZATION_METHOD) @@ -107,6 +107,8 @@ def get_quant_method(self, layer: torch.nn.Module, return AscendFusedMoEMethod(self, prefix, self.packed_modules_mapping) elif isinstance(layer, VocabParallelEmbedding): + if len(prefix) == 0: + return UnquantizedEmbeddingMethod() if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): return UnquantizedEmbeddingMethod() @@ -151,7 +153,7 @@ def get_scaled_act_names(self) -> List[str]: class AscendLinearMethod(LinearMethodBase): """Linear method for Ascend quantization. - This class calls AscendQuantizer to search a specific quantization + This class calls VLLMAscendQuantizer to search a specific quantization implementations supported on ascend hardware for linear methods. Args: @@ -160,7 +162,7 @@ class AscendLinearMethod(LinearMethodBase): def __init__(self, quant_config: AscendQuantConfig, prefix: str, packed_modules_mapping: Dict[str, Any]) -> None: - self.quantizer = AscendQuantizer.get_quantizer( + self.quantizer = VLLMAscendQuantizer.get_quantizer( quant_config.quant_description, prefix, packed_modules_mapping) self.quant_method = self.quantizer.build_linear_method() @@ -232,7 +234,7 @@ def apply( class AscendKVCacheMethod(BaseKVCacheMethod): """KVCache method for Ascend quantization. - This class calls AscendQuantizer to search a specific quantization + This class calls VLLMAscendQuantizer to search a specific quantization implementations supported on ascend hardware for kvcache methods. Args: @@ -240,7 +242,7 @@ class AscendKVCacheMethod(BaseKVCacheMethod): """ def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None: - self.quantizer = AscendQuantizer.get_quantizer( + self.quantizer = VLLMAscendQuantizer.get_quantizer( quant_config.quant_description, prefix) self.quant_method = self.quantizer.build_attention_method() @@ -285,7 +287,7 @@ def apply(self, class AscendFusedMoEMethod(FusedMoEMethodBase): """FusedMoE method for Ascend quantization. - This class calls AscendQuantizer to search a specific quantization + This class calls VLLMAscendQuantizer to search a specific quantization implementations supported on ascend hardware for kvcache methods. Args: @@ -294,7 +296,7 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: AscendQuantConfig, prefix: str, packed_modules_mapping: Dict[str, Any]): - self.quantizer = AscendQuantizer.get_quantizer( + self.quantizer = VLLMAscendQuantizer.get_quantizer( quant_config.quant_description, prefix, packed_modules_mapping) self.quant_method = self.quantizer.build_moe_method() @@ -365,7 +367,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: class AscendEmbeddingMethod(AscendLinearMethod): """Embedding method for Ascend quantization. - This class calls AscendQuantizer to search a specific quantization + This class calls VLLMAscendQuantizer to search a specific quantization implementations supported on ascend hardware for Embedding methods. Args: @@ -374,6 +376,6 @@ class AscendEmbeddingMethod(AscendLinearMethod): def __init__(self, quant_config: AscendQuantConfig, prefix: str, packed_modules_mapping: Dict[str, Any]) -> None: - self.quantizer = AscendQuantizer.get_quantizer( + self.quantizer = VLLMAscendQuantizer.get_quantizer( quant_config.quant_description, prefix, packed_modules_mapping) self.quant_method = self.quantizer.build_linear_method() diff --git a/vllm_ascend/quantization/quant_utils.py b/vllm_ascend/quantization/quant_utils.py new file mode 100644 index 0000000000..f9ee47d8a8 --- /dev/null +++ b/vllm_ascend/quantization/quant_utils.py @@ -0,0 +1,122 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Generator, List, Tuple + +import torch +import torch_npu +from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.utils import LayerBlockType + +TYPE_QUANT_QKV_ONLINE = 3 + +SRC_DTYPE_TO_ACL_DTYPE = { + torch.float16: 1, + torch.bfloat16: 27, +} + + +def quant_per_tensor(in_tensor: torch.Tensor, + input_scale: torch.Tensor, + input_offset: torch.Tensor, + function=False): + input_scale = input_scale.view(-1) + input_offset = input_offset.view(-1) + return torch_npu.npu_quantize(in_tensor, input_scale, input_offset, + torch.qint8, -1, function) + + +def wrapper_weights_iterator(func): + + def _safetensors_weights_iterator( + hf_weights_files: List[str], + use_tqdm_on_load: bool, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + current_rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + for name, weight in func(hf_weights_files, use_tqdm_on_load): + # The name of attention weights generated by msmodelslim + # must be modified so that these weights can be loaded + # into Attention module rather than LlamaAttention module. + if "fa_" in name and ".attn." not in name: + name = name.split(".") + name.insert(name.index("self_attn") + 1, "attn") + name = ".".join(name) + # vLLM originally does not support splitting attention + # weights with respect to TP ranks. We need split + # weights manually here. + if world_size <= 0: + raise ValueError( + "Expected world_size should be greater than 0" + f"but got {world_size}.") + split_size = weight.size(0) // world_size + weight = weight[current_rank * split_size:(current_rank + 1) * + split_size] + + yield name, weight + + return _safetensors_weights_iterator + + +# Replace CacheEngine.__init__ +# vLLM does not include int8 cache dtype. +# We should set it here. +def cache_engine_init( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig, +) -> None: + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + self.device_config = device_config + + self.head_size = model_config.get_head_size() + # Models like Jamba, have mixed typed layers, E.g Mamba + self.num_attention_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + self.num_gpu_blocks = cache_config.num_gpu_blocks + if self.num_gpu_blocks: + self.num_gpu_blocks //= parallel_config.pipeline_parallel_size + self.num_cpu_blocks = cache_config.num_cpu_blocks + if self.num_cpu_blocks: + self.num_cpu_blocks //= parallel_config.pipeline_parallel_size + + # modified here. vLLM does not include int8 cache dtype. + # We should set it here. + self.dtype = torch.int8 + + # Get attention backend. + self.attn_backend = get_attn_backend(self.head_size, + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + model_config.is_attention_free, + use_mla=model_config.use_mla) + + # Initialize the cache. + self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, + self.device_config.device_type) + self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index 4d00e454ad..e5604349dd 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -21,9 +21,12 @@ from typing import Any, Dict, List, Optional from vllm.logger import logger +from vllm.model_executor.model_loader.weight_utils import \ + safetensors_weights_iterator -from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot, - wrapper_rmsnorm_init) +from .faquant import AscendFAQuantAttentionMethod +from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init +from .quant_utils import cache_engine_init, wrapper_weights_iterator from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod) from .w8a8 import AscendW8A8LinearMethod @@ -32,33 +35,8 @@ CUSTOMIZED_QUANTIZER_TYPE: List[str] = [] - -class AscendQuantizer: - """An interface to different quantization implementations for ascend hardwares.""" - - @classmethod - def get_quantizer(cls, - quant_config: Dict[str, Any], - prefix: str, - packed_modules_mapping: Optional[Dict[str, - Any]] = dict()): - # TODO: Need a param to choose quantization algorithms. - quantization_algorithm = '' - - if quantization_algorithm in CUSTOMIZED_QUANTIZER_TYPE: - return - - return VLLMAscendQuantizer.get_quantizer(quant_config, prefix, - packed_modules_mapping) - - def build_linear_method(self): - raise NotImplementedError - - def build_moe_method(self): - raise NotImplementedError - - def build_attention_method(self): - raise NotImplementedError +DECORATE = "decoreate" +REPLACE = "replace" class VLLMAscendQuantizer: @@ -72,37 +50,51 @@ def __init__(self, quant_description): if "norm.bias" in name: VLLMAscendQuantizer.apply_patch( "vllm.model_executor.layers.layernorm.RMSNorm", "__init__", - [wrapper_rmsnorm_init]) + wrapper_rmsnorm_init) VLLMAscendQuantizer.apply_patch( "vllm.model_executor.layers.layernorm.RMSNorm", - "forward_oot", [wrapper_rmsnorm_forward_oot]) - VLLMAscendQuantizer.apply_patch( - "vllm_ascend.worker.model_runner.NPUModelRunnerBase", - "load_model", [wrapper_load_model]) + "forward_oot", wrapper_rmsnorm_forward_oot) break + if quant_description.get("fa_quant_type") == "FAQuant": + VLLMAscendQuantizer.apply_patch( + "vllm.model_executor.model_loader.weight_utils", + "safetensors_weights_iterator", + wrapper_weights_iterator(safetensors_weights_iterator), + REPLACE) + VLLMAscendQuantizer.apply_patch( + "vllm.worker.cache_engine.CacheEngine", "__init__", + cache_engine_init, REPLACE) VLLMAscendQuantizer.patched = True logger.info("Using the vLLM Ascend Quantizer version now!") @staticmethod - def apply_patch(target_module, target_function, wrappers): + def apply_patch(target_module, target_function, wrapper, method=DECORATE): original_module, original_function = VLLMAscendQuantizer.parse_path( target_module, target_function, False) original_function_id = id(original_function) - candidate = original_function - for wrapper in wrappers: - candidate = wrapper(candidate) + candidate = None + if method == DECORATE: + candidate = candidate = wrapper(original_function) + else: + candidate = wrapper if target_function is not None: setattr(original_module, target_function, candidate) - for key, value in sys.modules.copy().items(): - if (target_function is not None - and hasattr(value, target_function) - and id(getattr(value, - target_function)) == original_function_id): - setattr(value, target_function, candidate) + for _, value in sys.modules.copy().items(): + if target_function is None: + continue + try: + attr = getattr(value, target_function, None) + if attr is not None and id(attr) == original_function_id: + setattr(value, target_function, candidate) + except ImportError: + logger.info( + "modelscope override getattr() method which cause import error, \ + see https://github.com/modelscope/modelscope/blob/7c9b89d24dddda5d6b9ef84e04e7fdbc0dbd8f8a/modelscope/__init__.py#L141" + ) @staticmethod def parse_path(module_path, function_name, create_dummy): @@ -246,6 +238,8 @@ def get_quantizer(cls, # Attention if '.attn' in prefix and 'fa_quant_type' in quant_description.keys(): quant_type = quant_description['fa_quant_type'] + elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys(): + quant_type = quant_description['kv_quant_type'] # Linear else: quant_type = cls.get_linear_quant_type(quant_description, prefix, @@ -288,8 +282,16 @@ def build_moe_method(): return AscendW8A8DynamicFusedMoEMethod() +class FAQuantizer(VLLMAscendQuantizer): + + @staticmethod + def build_attention_method(): + return AscendFAQuantAttentionMethod() + + SUPPORT_ASCEND_QUANTIZER_TYPE = { "W4A8_DYNAMIC": W4A8DYNAMICQuantizer, "W8A8": W8A8Quantizer, "W8A8_DYNAMIC": W8A8DYNAMICQuantizer, + "FAQuant": FAQuantizer } diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 9574f50c1f..973354db42 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -23,11 +23,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ - -def quant_per_tensor(in_tensor: torch.Tensor, input_scale: torch.Tensor, - input_offset: torch.Tensor): - return torch_npu.npu_quantize(in_tensor, input_scale, input_offset, - torch.qint8, -1, False) +from .quant_utils import quant_per_tensor class AscendW8A8LinearMethod: diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index 7846f655d1..90d1189ae1 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -67,6 +67,8 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from ..quantization.func_wrapper import wrapper_load_model + if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -991,6 +993,7 @@ def __init__( def get_model(self) -> nn.Module: return self.model + @wrapper_load_model def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6db94a99f1..2a4cc1663b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -86,6 +86,8 @@ from vllm_ascend.utils import ProfileExecuteDuration from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer +from ..quantization.func_wrapper import wrapper_load_model + if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import SchedulerOutput @@ -1776,6 +1778,7 @@ def profile_run(self) -> None: self.encoder_cache.clear() gc.collect() + @wrapper_load_model def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model)