|
| 1 | +from collections.abc import Iterable |
| 2 | +from typing import Optional, Union |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn.functional as F |
| 6 | +from torch import nn |
| 7 | +from transformers import Qwen2Config |
| 8 | +from vllm.compilation.decorators import support_torch_compile |
| 9 | +from vllm.config import CacheConfig, VllmConfig |
| 10 | +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, |
| 11 | + get_tensor_model_parallel_world_size, |
| 12 | + tensor_model_parallel_all_gather, |
| 13 | + tensor_model_parallel_all_reduce, |
| 14 | + tensor_model_parallel_reduce_scatter) |
| 15 | +from vllm.forward_context import get_forward_context |
| 16 | +from vllm.model_executor.layers.logits_processor import LogitsProcessor |
| 17 | +from vllm.model_executor.layers.quantization import QuantizationConfig |
| 18 | +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead |
| 19 | +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP |
| 20 | +from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2Model |
| 21 | +from vllm.model_executor.models.utils import (AutoWeightsLoader, |
| 22 | + PPMissingLayer, maybe_prefix) |
| 23 | +from vllm.model_executor.sampling_metadata import SamplingMetadata |
| 24 | +from vllm.sequence import IntermediateTensors |
| 25 | + |
| 26 | +import vllm_ascend.envs as ascend_envs |
| 27 | +from vllm_ascend.attention.attention_v1 import AscendAttentionState |
| 28 | + |
| 29 | + |
| 30 | +def all_gather_and_maybe_unpad( |
| 31 | + hidden_states: torch.Tensor, |
| 32 | + pad_size: int, |
| 33 | +) -> torch.Tensor: |
| 34 | + hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) |
| 35 | + if pad_size > 0: |
| 36 | + return hidden_states[:-pad_size, :] |
| 37 | + return hidden_states |
| 38 | + |
| 39 | + |
| 40 | +def maybe_pad_and_reduce_scatter( |
| 41 | + hidden_states: torch.Tensor, |
| 42 | + pad_size: int, |
| 43 | +) -> torch.Tensor: |
| 44 | + if pad_size > 0: |
| 45 | + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size)) |
| 46 | + hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0) |
| 47 | + return hidden_states |
| 48 | + |
| 49 | + |
| 50 | +class CustomQwen2DecoderLayer(Qwen2DecoderLayer): |
| 51 | + |
| 52 | + def __init__( |
| 53 | + self, |
| 54 | + config: Qwen2Config, |
| 55 | + cache_config: Optional[CacheConfig] = None, |
| 56 | + quant_config: Optional[QuantizationConfig] = None, |
| 57 | + prefix: str = "", |
| 58 | + ) -> None: |
| 59 | + super().__init__(config=config, |
| 60 | + cache_config=cache_config, |
| 61 | + quant_config=quant_config, |
| 62 | + prefix=prefix) |
| 63 | + self.tp_rank = get_tensor_model_parallel_rank() |
| 64 | + self.tp_size = get_tensor_model_parallel_world_size() |
| 65 | + self.self_attn.o_proj.reduce_results = False |
| 66 | + self.mlp.down_proj.reduce_results = False |
| 67 | + |
| 68 | + def forward( |
| 69 | + self, |
| 70 | + positions: torch.Tensor, |
| 71 | + hidden_states: torch.Tensor, |
| 72 | + residual: Optional[torch.Tensor], |
| 73 | + flashcomm_v1_enabled: bool, |
| 74 | + pad_size: int, |
| 75 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 76 | + # Self Attention |
| 77 | + if residual is None: |
| 78 | + residual = hidden_states |
| 79 | + if flashcomm_v1_enabled: |
| 80 | + if pad_size > 0: |
| 81 | + residual = F.pad(residual, (0, 0, 0, pad_size)) |
| 82 | + residual = torch.chunk(residual, self.tp_size, |
| 83 | + dim=0)[self.tp_rank] |
| 84 | + hidden_states = self.input_layernorm(hidden_states) |
| 85 | + else: |
| 86 | + hidden_states, residual = self.input_layernorm( |
| 87 | + hidden_states, residual) |
| 88 | + if flashcomm_v1_enabled: |
| 89 | + hidden_states = all_gather_and_maybe_unpad( |
| 90 | + hidden_states, pad_size) |
| 91 | + hidden_states = self.self_attn( |
| 92 | + positions=positions, |
| 93 | + hidden_states=hidden_states, |
| 94 | + ) |
| 95 | + if flashcomm_v1_enabled: |
| 96 | + hidden_states = maybe_pad_and_reduce_scatter( |
| 97 | + hidden_states, pad_size) |
| 98 | + else: |
| 99 | + hidden_states = tensor_model_parallel_all_reduce(hidden_states) |
| 100 | + # Fully Connected |
| 101 | + hidden_states, residual = self.post_attention_layernorm( |
| 102 | + hidden_states, residual) |
| 103 | + if flashcomm_v1_enabled: |
| 104 | + hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size) |
| 105 | + hidden_states = self.mlp(hidden_states) |
| 106 | + if flashcomm_v1_enabled: |
| 107 | + hidden_states = maybe_pad_and_reduce_scatter( |
| 108 | + hidden_states, pad_size) |
| 109 | + else: |
| 110 | + hidden_states = tensor_model_parallel_all_reduce(hidden_states) |
| 111 | + return hidden_states, residual |
| 112 | + |
| 113 | + |
| 114 | +@support_torch_compile( |
| 115 | + dynamic_arg_dims={ |
| 116 | + "input_ids": 0, |
| 117 | + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, |
| 118 | + # otherwise (seq_len, ). |
| 119 | + "positions": -1, |
| 120 | + "intermediate_tensors": 0, |
| 121 | + "inputs_embeds": 0, |
| 122 | + }) |
| 123 | +class CustomQwen2Model(Qwen2Model): |
| 124 | + |
| 125 | + def __init__( |
| 126 | + self, |
| 127 | + *, |
| 128 | + vllm_config: VllmConfig, |
| 129 | + prefix: str = "", |
| 130 | + decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer): |
| 131 | + super().__init__(vllm_config=vllm_config, |
| 132 | + prefix=prefix, |
| 133 | + decoder_layer_type=decoder_layer_type) |
| 134 | + self.tp_size = get_tensor_model_parallel_world_size() |
| 135 | + |
| 136 | + def forward( |
| 137 | + self, |
| 138 | + input_ids: torch.Tensor, |
| 139 | + positions: torch.Tensor, |
| 140 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 141 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 142 | + ) -> Union[torch.Tensor, IntermediateTensors]: |
| 143 | + if get_pp_group().is_first_rank: |
| 144 | + if inputs_embeds is not None: |
| 145 | + hidden_states = inputs_embeds |
| 146 | + else: |
| 147 | + hidden_states = self.get_input_embeddings(input_ids) |
| 148 | + residual = None |
| 149 | + else: |
| 150 | + assert intermediate_tensors is not None |
| 151 | + hidden_states = intermediate_tensors["hidden_states"] |
| 152 | + residual = intermediate_tensors["residual"] |
| 153 | + pad_size = 0 |
| 154 | + flashcomm_v1_enabled = False |
| 155 | + attn_metadata = get_forward_context().attn_metadata |
| 156 | + if ascend_envs.VLLM_ASCEND_ENABLE_FLASHCOMM == 1 and \ |
| 157 | + attn_metadata is not None and \ |
| 158 | + attn_metadata.attn_state != AscendAttentionState.DecodeOnly: |
| 159 | + flashcomm_v1_enabled = True |
| 160 | + if flashcomm_v1_enabled: |
| 161 | + num_tokens = hidden_states.size(0) |
| 162 | + pad_size = (self.tp_size - |
| 163 | + (num_tokens % self.tp_size)) % self.tp_size |
| 164 | + for layer in self.layers[self.start_layer:self.end_layer]: |
| 165 | + hidden_states, residual = layer( |
| 166 | + positions, |
| 167 | + hidden_states, |
| 168 | + residual, |
| 169 | + flashcomm_v1_enabled, |
| 170 | + pad_size, |
| 171 | + ) |
| 172 | + if not get_pp_group().is_last_rank: |
| 173 | + return IntermediateTensors({ |
| 174 | + "hidden_states": hidden_states, |
| 175 | + "residual": residual |
| 176 | + }) |
| 177 | + hidden_states, _ = self.norm(hidden_states, residual) |
| 178 | + if flashcomm_v1_enabled: |
| 179 | + hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size) |
| 180 | + return hidden_states |
| 181 | + |
| 182 | + |
| 183 | +class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): |
| 184 | + # add `CustomQwen2Model` to init self.model |
| 185 | + packed_modules_mapping = { |
| 186 | + "qkv_proj": [ |
| 187 | + "q_proj", |
| 188 | + "k_proj", |
| 189 | + "v_proj", |
| 190 | + ], |
| 191 | + "gate_up_proj": [ |
| 192 | + "gate_proj", |
| 193 | + "up_proj", |
| 194 | + ], |
| 195 | + } |
| 196 | + |
| 197 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 198 | + super().__init__() |
| 199 | + config = vllm_config.model_config.hf_config |
| 200 | + quant_config = vllm_config.quant_config |
| 201 | + lora_config = vllm_config.lora_config |
| 202 | + |
| 203 | + self.config = config |
| 204 | + self.lora_config = lora_config |
| 205 | + |
| 206 | + self.quant_config = quant_config |
| 207 | + self.model = CustomQwen2Model(vllm_config=vllm_config, |
| 208 | + prefix=maybe_prefix(prefix, "model")) |
| 209 | + |
| 210 | + if get_pp_group().is_last_rank: |
| 211 | + if config.tie_word_embeddings: |
| 212 | + self.lm_head = self.model.embed_tokens |
| 213 | + else: |
| 214 | + self.lm_head = ParallelLMHead(config.vocab_size, |
| 215 | + config.hidden_size, |
| 216 | + quant_config=quant_config, |
| 217 | + prefix=maybe_prefix( |
| 218 | + prefix, "lm_head")) |
| 219 | + else: |
| 220 | + self.lm_head = PPMissingLayer() |
| 221 | + |
| 222 | + self.logits_processor = LogitsProcessor(config.vocab_size) |
| 223 | + |
| 224 | + self.make_empty_intermediate_tensors = ( |
| 225 | + self.model.make_empty_intermediate_tensors) |
| 226 | + |
| 227 | + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| 228 | + return self.model.get_input_embeddings(input_ids) |
| 229 | + |
| 230 | + def forward( |
| 231 | + self, |
| 232 | + input_ids: torch.Tensor, |
| 233 | + positions: torch.Tensor, |
| 234 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 235 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 236 | + ) -> Union[torch.Tensor, IntermediateTensors]: |
| 237 | + hidden_states = self.model(input_ids, positions, intermediate_tensors, |
| 238 | + inputs_embeds) |
| 239 | + return hidden_states |
| 240 | + |
| 241 | + def compute_logits( |
| 242 | + self, |
| 243 | + hidden_states: torch.Tensor, |
| 244 | + sampling_metadata: SamplingMetadata, |
| 245 | + ) -> Optional[torch.Tensor]: |
| 246 | + logits = self.logits_processor(self.lm_head, hidden_states, |
| 247 | + sampling_metadata) |
| 248 | + return logits |
| 249 | + |
| 250 | + def load_weights(self, weights: Iterable[tuple[str, |
| 251 | + torch.Tensor]]) -> set[str]: |
| 252 | + loader = AutoWeightsLoader( |
| 253 | + self, |
| 254 | + skip_prefixes=(["lm_head."] |
| 255 | + if self.config.tie_word_embeddings else None), |
| 256 | + ) |
| 257 | + return loader.load_weights(weights) |
0 commit comments