From 31b891f1625c8b69845f26d836a96cae535e35cd Mon Sep 17 00:00:00 2001 From: rjg-lyh <1318825571@qq.com> Date: Thu, 10 Jul 2025 22:47:55 +0800 Subject: [PATCH 1/2] [V0.9.1] Add support for flashcomm_v1 in Qwen2.5 Signed-off-by: rjg-lyh <1318825571@qq.com> --- vllm_ascend/envs.py | 2 + vllm_ascend/models/__init__.py | 4 + vllm_ascend/models/qwen2.py | 253 ++++++++++++++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 4 +- 4 files changed, 261 insertions(+), 2 deletions(-) create mode 100644 vllm_ascend/models/qwen2.py diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 1f0b6ff4b1..a2b2358669 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -106,6 +106,8 @@ "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0')) ), + "VLLM_ASCEND_ENABLE_FLASHCOMM": + lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0')), # VLLM_ASCEND_MOE_ALL2ALL_BUFFER: # 0: default, normal init. # 1: enable moe_all2all_buffer. diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index b2da242106..e07dbcdddd 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -11,12 +11,16 @@ def register_model(): from .qwen2_5_vl import \ AscendQwen2_5_VLForConditionalGeneration # noqa: F401 from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401 + from .qwen2 import CustomQwen2ForCausalLM # noqa: F401 from .qwen3 import CustomQwen3ForCausalLM # noqa: F401 ModelRegistry.register_model( "DeepSeekMTPModel", "vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP") + ModelRegistry.register_model( + "Qwen2ForCausalLM", "vllm_ascend.models.qwen2:CustomQwen2ForCausalLM") + ModelRegistry.register_model( "Qwen2VLForConditionalGeneration", "vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration") diff --git a/vllm_ascend/models/qwen2.py b/vllm_ascend/models/qwen2.py new file mode 100644 index 0000000000..4f9f53c24a --- /dev/null +++ b/vllm_ascend/models/qwen2.py @@ -0,0 +1,253 @@ +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +import torch.nn.functional as F +from transformers import Qwen2Config + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix) + +from vllm.model_executor.models.qwen2 import Qwen2Model, Qwen2DecoderLayer +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm.forward_context import get_forward_context +import vllm_ascend.envs as ascend_envs + + +def all_gather_and_maybe_unpad( + hidden_states: torch.Tensor, + pad_size: int, +) -> torch.Tensor: + hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) + if pad_size > 0: + return hidden_states[:-pad_size, :] + return hidden_states + +def maybe_pad_and_reduce_scatter( + hidden_states: torch.Tensor, + pad_size: int, +) -> torch.Tensor: + if pad_size > 0: + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size)) + hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0) + return hidden_states + +class CustomQwen2DecoderLayer(Qwen2DecoderLayer): + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix) + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.self_attn.o_proj.reduce_results=False + self.mlp.down_proj.reduce_results=False + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + flashcomm_v1_enabled: bool, + pad_size: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + if flashcomm_v1_enabled: + if pad_size > 0: + residual = F.pad(residual, (0, 0, 0, pad_size)) + residual = torch.chunk(residual, self.tp_size, dim=0)[self.tp_rank] + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + if flashcomm_v1_enabled: + hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + if flashcomm_v1_enabled: + hidden_states = maybe_pad_and_reduce_scatter(hidden_states, pad_size) + else: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + if flashcomm_v1_enabled: + hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size) + hidden_states = self.mlp(hidden_states) + if flashcomm_v1_enabled: + hidden_states = maybe_pad_and_reduce_scatter(hidden_states, pad_size) + else: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class CustomQwen2Model(Qwen2Model): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + decoder_layer_type=decoder_layer_type) + self.tp_size = get_tensor_model_parallel_world_size() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + pad_size = 0 + flashcomm_v1_enabled = False + attn_metadata = get_forward_context().attn_metadata + if ascend_envs.VLLM_ASCEND_ENABLE_FLASHCOMM == 1 and \ + attn_metadata is not None and \ + attn_metadata.attn_state != AscendAttentionState.DecodeOnly: + flashcomm_v1_enabled = True + if flashcomm_v1_enabled: + num_tokens = hidden_states.size(0) + pad_size = (self.tp_size - (num_tokens % self.tp_size)) % self.tp_size + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + flashcomm_v1_enabled, + pad_size, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + if flashcomm_v1_enabled: + hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size) + return hidden_states + + +class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + # add `CustomQwen2Model` to init self.model + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = CustomQwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1207618706..3e4dfed306 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2033,8 +2033,8 @@ def capture_model(self) -> None: for num_tokens in reversed(self.aclgraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens, skip_attn=skip_attn) - self._dummy_run(num_tokens, skip_attn=skip_attn) + self._dummy_run(num_tokens, skip_attn=skip_attn, with_prefill=False) + self._dummy_run(num_tokens, skip_attn=skip_attn, with_prefill=False) else: logger.info("Skipping NPU graph capture for eager mode.") return From c15449ba0ed7caf1652a3b97dc8b1fbdbe27da0e Mon Sep 17 00:00:00 2001 From: rjg-lyh <1318825571@qq.com> Date: Tue, 15 Jul 2025 15:19:04 +0800 Subject: [PATCH 2/2] [V0.9.1] Patch compilation.decorator to support flashcomm_v1 in aclgraph Signed-off-by: rjg-lyh <1318825571@qq.com> --- .github/workflows/vllm_ascend_test.yaml | 2 + .../test_offline_inference_distributed.py | 21 ++- vllm_ascend/__init__.py | 2 + vllm_ascend/envs.py | 2 - vllm_ascend/models/__init__.py | 2 +- vllm_ascend/models/qwen2.py | 66 ++++---- .../patch/platform/patch_0_9_1/__init__.py | 5 + .../platform/patch_0_9_1/patch_decorator.py | 152 ++++++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 4 +- 9 files changed, 218 insertions(+), 38 deletions(-) create mode 100644 vllm_ascend/patch/platform/patch_0_9_1/patch_decorator.py diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index d9856bef3c..c3634559bd 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -205,6 +205,8 @@ jobs: VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_w8a8_ep_dbo VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ_with_flashcomm_v1 + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_with_flashcomm_v2 VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py --ignore=tests/multicard/test_w4a8_deepseek.py fi diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index 5fee7e4a17..7593a4afd7 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -195,8 +195,26 @@ def test_models_distributed_DeepSeek_W8A8(): vllm_model.generate_greedy(example_prompts, max_tokens) +@pytest.mark.parametrize("enforce_eager", [True, False]) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"}) +def test_models_distributed_QwQ_with_flashcomm_v1(enforce_eager: bool): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + + with VllmRunner( + snapshot_download("Qwen/QwQ-32B"), + max_model_len=8192, + enforce_eager=enforce_eager, + dtype="auto", + tensor_parallel_size=4, + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "2"}) -def test_models_distributed_Qwen3_with_flashcomm2(): +def test_models_distributed_Qwen3_with_flashcomm_v2(): example_prompts = [ "Hello, my name is", ] @@ -208,6 +226,5 @@ def test_models_distributed_Qwen3_with_flashcomm2(): enforce_eager=True, dtype="auto", tensor_parallel_size=2, - quantization="ascend", ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) \ No newline at end of file diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py index c8f33313ba..664f4fea2c 100644 --- a/vllm_ascend/__init__.py +++ b/vllm_ascend/__init__.py @@ -28,4 +28,6 @@ def register_model(): import vllm_ascend.patch.worker.patch_common.patch_utils # noqa: F401 from .models import register_model + + import vllm_ascend.patch.platform.patch_0_9_1.patch_decorator # isort: skip # noqa: F401 register_model() diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index a2b2358669..1f0b6ff4b1 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -106,8 +106,6 @@ "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0')) ), - "VLLM_ASCEND_ENABLE_FLASHCOMM": - lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0')), # VLLM_ASCEND_MOE_ALL2ALL_BUFFER: # 0: default, normal init. # 1: enable moe_all2all_buffer. diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index e07dbcdddd..d7ce77e1d9 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -8,10 +8,10 @@ def register_model(): from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401 from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401 from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401 + from .qwen2 import CustomQwen2ForCausalLM # noqa: F401 from .qwen2_5_vl import \ AscendQwen2_5_VLForConditionalGeneration # noqa: F401 from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401 - from .qwen2 import CustomQwen2ForCausalLM # noqa: F401 from .qwen3 import CustomQwen3ForCausalLM # noqa: F401 ModelRegistry.register_model( diff --git a/vllm_ascend/models/qwen2.py b/vllm_ascend/models/qwen2.py index 4f9f53c24a..adc03e2af2 100644 --- a/vllm_ascend/models/qwen2.py +++ b/vllm_ascend/models/qwen2.py @@ -1,34 +1,30 @@ from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Optional, Union import torch -from torch import nn import torch.nn.functional as F +from torch import nn from transformers import Qwen2Config - from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2Model +from vllm.model_executor.models.utils import (AutoWeightsLoader, + PPMissingLayer, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP -from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix) - -from vllm.model_executor.models.qwen2 import Qwen2Model, Qwen2DecoderLayer -from vllm.distributed import ( - get_pp_group, - get_tensor_model_parallel_world_size, - get_tensor_model_parallel_rank, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce, - tensor_model_parallel_reduce_scatter) -from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm.forward_context import get_forward_context import vllm_ascend.envs as ascend_envs +from vllm_ascend.attention.attention_v1 import AscendAttentionState def all_gather_and_maybe_unpad( @@ -40,6 +36,7 @@ def all_gather_and_maybe_unpad( return hidden_states[:-pad_size, :] return hidden_states + def maybe_pad_and_reduce_scatter( hidden_states: torch.Tensor, pad_size: int, @@ -49,6 +46,7 @@ def maybe_pad_and_reduce_scatter( hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0) return hidden_states + class CustomQwen2DecoderLayer(Qwen2DecoderLayer): def __init__( @@ -64,9 +62,9 @@ def __init__( prefix=prefix) self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() - self.self_attn.o_proj.reduce_results=False - self.mlp.down_proj.reduce_results=False - + self.self_attn.o_proj.reduce_results = False + self.mlp.down_proj.reduce_results = False + def forward( self, positions: torch.Tensor, @@ -81,19 +79,22 @@ def forward( if flashcomm_v1_enabled: if pad_size > 0: residual = F.pad(residual, (0, 0, 0, pad_size)) - residual = torch.chunk(residual, self.tp_size, dim=0)[self.tp_rank] + residual = torch.chunk(residual, self.tp_size, + dim=0)[self.tp_rank] hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( hidden_states, residual) if flashcomm_v1_enabled: - hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size) + hidden_states = all_gather_and_maybe_unpad( + hidden_states, pad_size) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) if flashcomm_v1_enabled: - hidden_states = maybe_pad_and_reduce_scatter(hidden_states, pad_size) + hidden_states = maybe_pad_and_reduce_scatter( + hidden_states, pad_size) else: hidden_states = tensor_model_parallel_all_reduce(hidden_states) # Fully Connected @@ -103,7 +104,8 @@ def forward( hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size) hidden_states = self.mlp(hidden_states) if flashcomm_v1_enabled: - hidden_states = maybe_pad_and_reduce_scatter(hidden_states, pad_size) + hidden_states = maybe_pad_and_reduce_scatter( + hidden_states, pad_size) else: hidden_states = tensor_model_parallel_all_reduce(hidden_states) return hidden_states, residual @@ -120,11 +122,12 @@ def forward( }) class CustomQwen2Model(Qwen2Model): - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer): super().__init__(vllm_config=vllm_config, prefix=prefix, decoder_layer_type=decoder_layer_type) @@ -156,7 +159,8 @@ def forward( flashcomm_v1_enabled = True if flashcomm_v1_enabled: num_tokens = hidden_states.size(0) - pad_size = (self.tp_size - (num_tokens % self.tp_size)) % self.tp_size + pad_size = (self.tp_size - + (num_tokens % self.tp_size)) % self.tp_size for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, @@ -201,7 +205,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config self.model = CustomQwen2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: if config.tie_word_embeddings: diff --git a/vllm_ascend/patch/platform/patch_0_9_1/__init__.py b/vllm_ascend/patch/platform/patch_0_9_1/__init__.py index 116c73c06c..c031b31558 100644 --- a/vllm_ascend/patch/platform/patch_0_9_1/__init__.py +++ b/vllm_ascend/patch/platform/patch_0_9_1/__init__.py @@ -14,3 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +# patch_utils should be the first import, because it will be used by other +# patch files. +import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip +import vllm_ascend.patch.platform.patch_0_9_1.patch_decorator # noqa diff --git a/vllm_ascend/patch/platform/patch_0_9_1/patch_decorator.py b/vllm_ascend/patch/platform/patch_0_9_1/patch_decorator.py new file mode 100644 index 0000000000..0a8cdd8d10 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_0_9_1/patch_decorator.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import inspect +from typing import TypeVar, Union +from unittest.mock import patch + +import torch +import torch.nn as nn +from torch._dynamo.symbolic_convert import InliningInstructionTranslator +from vllm.compilation import decorators +from vllm.compilation.counter import compilation_counter +from vllm.compilation.monitor import start_monitoring_torch_compile +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import CompilationLevel, VllmConfig +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors +from vllm.utils import supports_dynamo + +from vllm_ascend.attention.attention_v1 import AscendAttentionState + +logger = init_logger(__name__) + +_T = TypeVar("_T", bound=type[nn.Module]) + + +def _ascend_support_torch_compile( + cls: _T, + dynamic_arg_dims: dict[str, Union[int, list[int]]], +) -> _T: + """ + A decorator to add support for compiling the forward method of a class. + """ + if TorchCompileWrapperWithCustomDispatcher in cls.__bases__: + # support decorating multiple times + return cls + + # take care of method resolution order + # make sure super().__init__ is called on the base class + # other than TorchCompileWrapperWithCustomDispatcher + cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) + + old_init = cls.__init__ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): + old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) + self.vllm_config = vllm_config + # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner + # will handle the compilation, so we don't need to do anything here. + self.do_not_compile = \ + vllm_config.compilation_config.level in [ + CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS + ] or not supports_dynamo() + if self.do_not_compile: + return + compilation_counter.num_models_seen += 1 + TorchCompileWrapperWithCustomDispatcher.__init__( + self, compilation_level=vllm_config.compilation_config.level) + + cls.__init__ = __init__ + + def __call__(self, *args, **kwargs): + # torch.compiler.is_compiling() means we are inside the compilation + # e.g. TPU has the compilation logic in model runner, so we don't + # need to compile the model inside. + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None and attn_metadata.attn_state != AscendAttentionState.DecodeOnly: + return self.forward(*args, **kwargs) + + if self.do_not_compile or torch.compiler.is_compiling(): + return self.forward(*args, **kwargs) + + # the first compilation needs to have dynamic shapes marked + if len(self.compiled_codes) < 1: + sig = inspect.signature(self.__class__.forward) + bound_args = sig.bind(self, *args, **kwargs) + bound_args.apply_defaults() + for k, dims in dynamic_arg_dims.items(): + arg = bound_args.arguments.get(k) + if arg is not None: + dims = [dims] if isinstance(dims, int) else dims + if isinstance(arg, torch.Tensor): + # In case dims is specified with negative indexing + dims = [ + arg.ndim + dim if dim < 0 else dim for dim in dims + ] + torch._dynamo.mark_dynamic(arg, dims) + elif isinstance(arg, IntermediateTensors): + for tensor in arg.tensors.values(): + # In case dims is specified with negative indexing + dims = [ + tensor.ndim + dim if dim < 0 else dim + for dim in dims + ] + torch._dynamo.mark_dynamic(tensor, dims) + else: + raise ValueError( + "Unsupported dynamic dimensions" + f" {dims} for argument {k} with type {type(arg)}.") + # here, it is the starting point of the `torch.compile` process + start_monitoring_torch_compile(self.vllm_config) + logger.debug("Start compiling function %s", + self.original_code_object) + + # if we don't use custom dispatcher, we can directly call the + # compiled function and let torch.compile handle the dispatching, + # with the overhead of guard evaluation and recompilation. + if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher: + # it seems Dynamo reuse the compilation across instances, + # while we need to make sure the compiled code is not reused. + # we need to control all the compilation of the model. + torch._dynamo.eval_frame.remove_from_cache( + self.original_code_object) + + # collect all relevant files traced by Dynamo, + # so that the compilation cache can trigger re-compilation + # properly when any of these files change. + + # 1. the file containing the top-level forward function + self.vllm_config.compilation_config.traced_files.add( + self.original_code_object.co_filename) + + # 2. every time Dynamo sees a function call, it will inline + # the function by calling InliningInstructionTranslator.inline_call + # we hijack this function to know all the functions called + # during Dynamo tracing, and their corresponding files + inline_call = InliningInstructionTranslator.inline_call + + def patched_inline_call(parent, func, args, kwargs): + code = func.get_code() + self.vllm_config.compilation_config.traced_files.add( + code.co_filename) + return inline_call(parent, func, args, kwargs) + + with patch.object(InliningInstructionTranslator, 'inline_call', + patched_inline_call): + output = self.compiled_callable(*args, **kwargs) + return output + + # usually, capturing the model once is enough, and then we can + # dispatch to the compiled code directly, without going through + # the Dynamo guard mechanism. + with self.dispatch_to_code(0): + model_output = self.forward(*args, **kwargs) + return model_output + + cls.__call__ = __call__ + return cls + + +decorators._support_torch_compile = _ascend_support_torch_compile diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3e4dfed306..1207618706 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2033,8 +2033,8 @@ def capture_model(self) -> None: for num_tokens in reversed(self.aclgraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens, skip_attn=skip_attn, with_prefill=False) - self._dummy_run(num_tokens, skip_attn=skip_attn, with_prefill=False) + self._dummy_run(num_tokens, skip_attn=skip_attn) + self._dummy_run(num_tokens, skip_attn=skip_attn) else: logger.info("Skipping NPU graph capture for eager mode.") return