|
| 1 | +# |
| 2 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
| 3 | +# This file is a part of the vllm-ascend project. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | + |
| 18 | +from typing import List |
| 19 | + |
| 20 | +import torch |
| 21 | +import torch_npu |
| 22 | + |
| 23 | +from .quant_utils import (SRC_DTYPE_TO_ACL_DTYPE, TYPE_QUANT_QKV_ONLINE, |
| 24 | + quant_per_tensor) |
| 25 | + |
| 26 | + |
| 27 | +class AscendFAQuantAttentionMethod: |
| 28 | + """Linear method for Ascend FAQuant |
| 29 | + """ |
| 30 | + |
| 31 | + def __init__(self) -> None: |
| 32 | + super().__init__() |
| 33 | + |
| 34 | + @staticmethod |
| 35 | + def get_quant_param() -> List[str]: |
| 36 | + return [ |
| 37 | + "fa_q.scale", "fa_q.offset", "fa_k.scale", "fa_k.offset", |
| 38 | + "fa_v.scale", "fa_v.offset" |
| 39 | + ] |
| 40 | + |
| 41 | + @staticmethod |
| 42 | + def get_extra_module_names() -> List[str]: |
| 43 | + |
| 44 | + return ["fa_q", "fa_k", "fa_v"] |
| 45 | + |
| 46 | + @staticmethod |
| 47 | + def process_weights_after_loading(layer): |
| 48 | + fa_qscale = layer.fa_q.scale |
| 49 | + fa_kscale = layer.fa_k.scale |
| 50 | + fa_vscale = layer.fa_v.scale |
| 51 | + repeated_query_scale = layer.fa_q.scale.repeat(1, layer.head_size) |
| 52 | + layer.fa_qscale = torch.nn.Parameter(repeated_query_scale, |
| 53 | + requires_grad=False) |
| 54 | + repeated_query_offset = layer.fa_q.offset.repeat(1, layer.head_size) |
| 55 | + layer.fa_qoffset = torch.nn.Parameter(repeated_query_offset, |
| 56 | + requires_grad=False) |
| 57 | + repeated_fa_kscale = layer.fa_k.scale.repeat(1, layer.head_size) |
| 58 | + layer.fa_kscale = torch.nn.Parameter(repeated_fa_kscale, |
| 59 | + requires_grad=False) |
| 60 | + repeated_fa_koffset = layer.fa_k.offset.repeat(1, layer.head_size) |
| 61 | + layer.fa_koffset = torch.nn.Parameter(repeated_fa_koffset, |
| 62 | + requires_grad=False) |
| 63 | + repeated_fa_vscale = layer.fa_v.scale.repeat(1, layer.head_size) |
| 64 | + layer.fa_vscale = torch.nn.Parameter(repeated_fa_vscale, |
| 65 | + requires_grad=False) |
| 66 | + repeated_fa_voffset = layer.fa_v.offset.repeat(1, layer.head_size) |
| 67 | + layer.fa_voffset = torch.nn.Parameter(repeated_fa_voffset, |
| 68 | + requires_grad=False) |
| 69 | + |
| 70 | + if fa_kscale.shape[0] <= 0: |
| 71 | + raise ValueError( |
| 72 | + "Expected size of fa_kscale in dimension 0 should be greater than 0" |
| 73 | + f"but got {fa_kscale.shape[0]}.") |
| 74 | + gqa_size = fa_qscale.shape[0] // fa_kscale.shape[0] |
| 75 | + fa3_k_scale, fa3_v_scale = fa_kscale.repeat(1, gqa_size).view( |
| 76 | + -1, 1), fa_vscale.repeat(1, gqa_size).view(-1, 1) |
| 77 | + qk_scale = torch.nn.Parameter(torch.squeeze( |
| 78 | + fa_qscale * fa3_k_scale).to(torch.float), |
| 79 | + requires_grad=False) |
| 80 | + layer.register_parameter("qk_scale", qk_scale) |
| 81 | + fa3_v_scale = torch.nn.Parameter( |
| 82 | + torch.squeeze(fa3_v_scale).contiguous().to(torch.float), |
| 83 | + requires_grad=False) |
| 84 | + layer.register_parameter("fa3_v_scale", fa3_v_scale) |
| 85 | + |
| 86 | + @classmethod |
| 87 | + def apply(cls, layer: torch.nn.Module, query: torch.Tensor, |
| 88 | + key: torch.Tensor, value: torch.Tensor, *extra_args, |
| 89 | + **optional_args) -> torch.Tensor: |
| 90 | + key_cache, value_cache, scale, block_tables, \ |
| 91 | + is_prefill, mask, slots, output = extra_args |
| 92 | + seq_lens_tensor_cpu = optional_args.get("seq_lens_tensor_cpu", None) |
| 93 | + |
| 94 | + query_shape = query.shape |
| 95 | + key_shape = key.shape |
| 96 | + value_shape = value.shape |
| 97 | + |
| 98 | + query = query.view(query.shape[0], -1) |
| 99 | + key = key.view(key.shape[0], -1) |
| 100 | + value = value.view(value.shape[0], -1) |
| 101 | + |
| 102 | + if is_prefill: |
| 103 | + if key_cache is not None: |
| 104 | + |
| 105 | + key_int8 = quant_per_tensor(key, layer.fa_kscale, |
| 106 | + layer.fa_koffset, True) |
| 107 | + value_int8 = quant_per_tensor(value, layer.fa_vscale, |
| 108 | + layer.fa_voffset, True) |
| 109 | + key_int8 = key_int8.view(key_shape) |
| 110 | + value_int8 = value_int8.view(value_shape) |
| 111 | + torch_npu._npu_reshape_and_cache(key_int8, value_int8, |
| 112 | + key_cache, value_cache, slots) |
| 113 | + if mask is None: |
| 114 | + raise ValueError( |
| 115 | + "attn_metadata.attn_mask is Null. Please check.") |
| 116 | + query = query.view(query_shape) |
| 117 | + key = key.view(key_shape) |
| 118 | + value = value.view(value_shape) |
| 119 | + if output is not None: |
| 120 | + output = output.view(query.shape) |
| 121 | + torch_npu._npu_flash_attention(query, |
| 122 | + key, |
| 123 | + value, |
| 124 | + mask, |
| 125 | + torch.tensor( |
| 126 | + seq_lens_tensor_cpu, |
| 127 | + dtype=torch.int32), |
| 128 | + scale, |
| 129 | + layer.num_heads, |
| 130 | + layer.num_kv_heads, |
| 131 | + out=output) |
| 132 | + else: |
| 133 | + query = query.view(query_shape) |
| 134 | + key = key.view(key_shape) |
| 135 | + value = value.view(value_shape) |
| 136 | + output = torch.empty_like(query, |
| 137 | + dtype=query.dtype).to(query.device) |
| 138 | + torch_npu._npu_flash_attention(query, |
| 139 | + key, |
| 140 | + value, |
| 141 | + mask, |
| 142 | + torch.tensor( |
| 143 | + seq_lens_tensor_cpu, |
| 144 | + dtype=torch.int32), |
| 145 | + scale, |
| 146 | + layer.num_heads, |
| 147 | + layer.num_kv_heads, |
| 148 | + out=output) |
| 149 | + |
| 150 | + else: |
| 151 | + if key_cache is None: |
| 152 | + raise ValueError( |
| 153 | + "KV Cache can't be None in decoding phase. Got None. Please check." |
| 154 | + ) |
| 155 | + query_int8 = quant_per_tensor(query, layer.fa_qscale, |
| 156 | + layer.fa_qoffset, True) |
| 157 | + key_int8 = quant_per_tensor(key, layer.fa_kscale, layer.fa_koffset, |
| 158 | + True) |
| 159 | + value_int8 = quant_per_tensor(value, layer.fa_vscale, |
| 160 | + layer.fa_voffset, True) |
| 161 | + query_int8 = query_int8.view(query_shape) |
| 162 | + key_int8 = key_int8.view(key_shape) |
| 163 | + value_int8 = value_int8.view(value_shape) |
| 164 | + query = query.view(query_shape) |
| 165 | + torch_npu._npu_reshape_and_cache(key_int8, value_int8, key_cache, |
| 166 | + value_cache, slots) |
| 167 | + if output is not None: |
| 168 | + output = output.view(query.shape) |
| 169 | + torch_npu._npu_paged_attention_quant( |
| 170 | + query_int8, key_cache, value_cache, layer.num_kv_heads, |
| 171 | + layer.num_heads, scale, block_tables, |
| 172 | + torch.tensor(seq_lens_tensor_cpu, dtype=torch.int32), |
| 173 | + TYPE_QUANT_QKV_ONLINE, SRC_DTYPE_TO_ACL_DTYPE[query.dtype], |
| 174 | + layer.qk_scale, layer.fa3_v_scale, output) |
| 175 | + else: |
| 176 | + output = torch.empty_like(query, |
| 177 | + dtype=query.dtype).to(query.device) |
| 178 | + torch_npu._npu_paged_attention_quant( |
| 179 | + query_int8, key_cache, value_cache, layer.num_kv_heads, |
| 180 | + layer.num_heads, scale, block_tables, |
| 181 | + torch.tensor(seq_lens_tensor_cpu, dtype=torch.int32), |
| 182 | + TYPE_QUANT_QKV_ONLINE, SRC_DTYPE_TO_ACL_DTYPE[query.dtype], |
| 183 | + layer.qk_scale, layer.fa3_v_scale, output) |
| 184 | + |
| 185 | + output = torch.flatten(output, start_dim=-2) |
| 186 | + return output |
| 187 | + |
| 188 | + @classmethod |
| 189 | + def create_weights(cls, layer: torch.nn.Module) -> None: |
| 190 | + extra_module_names = cls.get_extra_module_names() |
| 191 | + for name in extra_module_names: |
| 192 | + setattr(layer, name, torch.nn.Module()) |
| 193 | + |
| 194 | + params_dtype = torch.get_default_dtype() |
| 195 | + |
| 196 | + params_dict = {} |
| 197 | + |
| 198 | + params_dict["fa_q.scale"] = torch.empty((layer.num_heads, 1), |
| 199 | + dtype=params_dtype) |
| 200 | + params_dict["fa_q.offset"] = torch.empty((layer.num_heads, 1), |
| 201 | + dtype=torch.int8) |
| 202 | + params_dict["fa_k.scale"] = torch.empty((layer.num_kv_heads, 1), |
| 203 | + dtype=params_dtype) |
| 204 | + params_dict["fa_k.offset"] = torch.empty((layer.num_kv_heads, 1), |
| 205 | + dtype=torch.int8) |
| 206 | + params_dict["fa_v.scale"] = torch.empty((layer.num_kv_heads, 1), |
| 207 | + dtype=params_dtype) |
| 208 | + params_dict["fa_v.offset"] = torch.empty((layer.num_kv_heads, 1), |
| 209 | + dtype=torch.int8) |
| 210 | + |
| 211 | + for name, weight in params_dict.items(): |
| 212 | + module_name, weight_name = name.split('.') |
| 213 | + module = getattr(layer, module_name) |
| 214 | + module.register_parameter( |
| 215 | + weight_name, torch.nn.Parameter(weight, requires_grad=False)) |
0 commit comments