diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index d9856bef3c..c3634559bd 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -205,6 +205,8 @@ jobs: VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_w8a8_ep_dbo VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ_with_flashcomm_v1 + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_with_flashcomm_v2 VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py --ignore=tests/multicard/test_w4a8_deepseek.py fi diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index 5fee7e4a17..7593a4afd7 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -195,8 +195,26 @@ def test_models_distributed_DeepSeek_W8A8(): vllm_model.generate_greedy(example_prompts, max_tokens) +@pytest.mark.parametrize("enforce_eager", [True, False]) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"}) +def test_models_distributed_QwQ_with_flashcomm_v1(enforce_eager: bool): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + + with VllmRunner( + snapshot_download("Qwen/QwQ-32B"), + max_model_len=8192, + enforce_eager=enforce_eager, + dtype="auto", + tensor_parallel_size=4, + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "2"}) -def test_models_distributed_Qwen3_with_flashcomm2(): +def test_models_distributed_Qwen3_with_flashcomm_v2(): example_prompts = [ "Hello, my name is", ] @@ -208,6 +226,5 @@ def test_models_distributed_Qwen3_with_flashcomm2(): enforce_eager=True, dtype="auto", tensor_parallel_size=2, - quantization="ascend", ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) \ No newline at end of file diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py index c8f33313ba..664f4fea2c 100644 --- a/vllm_ascend/__init__.py +++ b/vllm_ascend/__init__.py @@ -28,4 +28,6 @@ def register_model(): import vllm_ascend.patch.worker.patch_common.patch_utils # noqa: F401 from .models import register_model + + import vllm_ascend.patch.platform.patch_0_9_1.patch_decorator # isort: skip # noqa: F401 register_model() diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index b2da242106..d7ce77e1d9 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -8,6 +8,7 @@ def register_model(): from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401 from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401 from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401 + from .qwen2 import CustomQwen2ForCausalLM # noqa: F401 from .qwen2_5_vl import \ AscendQwen2_5_VLForConditionalGeneration # noqa: F401 from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401 @@ -17,6 +18,9 @@ def register_model(): "DeepSeekMTPModel", "vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP") + ModelRegistry.register_model( + "Qwen2ForCausalLM", "vllm_ascend.models.qwen2:CustomQwen2ForCausalLM") + ModelRegistry.register_model( "Qwen2VLForConditionalGeneration", "vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration") diff --git a/vllm_ascend/models/qwen2.py b/vllm_ascend/models/qwen2.py new file mode 100644 index 0000000000..adc03e2af2 --- /dev/null +++ b/vllm_ascend/models/qwen2.py @@ -0,0 +1,257 @@ +from collections.abc import Iterable +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import Qwen2Config +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2Model +from vllm.model_executor.models.utils import (AutoWeightsLoader, + PPMissingLayer, maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +import vllm_ascend.envs as ascend_envs +from vllm_ascend.attention.attention_v1 import AscendAttentionState + + +def all_gather_and_maybe_unpad( + hidden_states: torch.Tensor, + pad_size: int, +) -> torch.Tensor: + hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) + if pad_size > 0: + return hidden_states[:-pad_size, :] + return hidden_states + + +def maybe_pad_and_reduce_scatter( + hidden_states: torch.Tensor, + pad_size: int, +) -> torch.Tensor: + if pad_size > 0: + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size)) + hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0) + return hidden_states + + +class CustomQwen2DecoderLayer(Qwen2DecoderLayer): + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix) + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.self_attn.o_proj.reduce_results = False + self.mlp.down_proj.reduce_results = False + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + flashcomm_v1_enabled: bool, + pad_size: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + if flashcomm_v1_enabled: + if pad_size > 0: + residual = F.pad(residual, (0, 0, 0, pad_size)) + residual = torch.chunk(residual, self.tp_size, + dim=0)[self.tp_rank] + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + if flashcomm_v1_enabled: + hidden_states = all_gather_and_maybe_unpad( + hidden_states, pad_size) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + if flashcomm_v1_enabled: + hidden_states = maybe_pad_and_reduce_scatter( + hidden_states, pad_size) + else: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + if flashcomm_v1_enabled: + hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size) + hidden_states = self.mlp(hidden_states) + if flashcomm_v1_enabled: + hidden_states = maybe_pad_and_reduce_scatter( + hidden_states, pad_size) + else: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class CustomQwen2Model(Qwen2Model): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + decoder_layer_type=decoder_layer_type) + self.tp_size = get_tensor_model_parallel_world_size() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + pad_size = 0 + flashcomm_v1_enabled = False + attn_metadata = get_forward_context().attn_metadata + if ascend_envs.VLLM_ASCEND_ENABLE_FLASHCOMM == 1 and \ + attn_metadata is not None and \ + attn_metadata.attn_state != AscendAttentionState.DecodeOnly: + flashcomm_v1_enabled = True + if flashcomm_v1_enabled: + num_tokens = hidden_states.size(0) + pad_size = (self.tp_size - + (num_tokens % self.tp_size)) % self.tp_size + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + flashcomm_v1_enabled, + pad_size, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + if flashcomm_v1_enabled: + hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size) + return hidden_states + + +class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + # add `CustomQwen2Model` to init self.model + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = CustomQwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm_ascend/patch/platform/patch_0_9_1/__init__.py b/vllm_ascend/patch/platform/patch_0_9_1/__init__.py index 116c73c06c..c031b31558 100644 --- a/vllm_ascend/patch/platform/patch_0_9_1/__init__.py +++ b/vllm_ascend/patch/platform/patch_0_9_1/__init__.py @@ -14,3 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +# patch_utils should be the first import, because it will be used by other +# patch files. +import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip +import vllm_ascend.patch.platform.patch_0_9_1.patch_decorator # noqa diff --git a/vllm_ascend/patch/platform/patch_0_9_1/patch_decorator.py b/vllm_ascend/patch/platform/patch_0_9_1/patch_decorator.py new file mode 100644 index 0000000000..0a8cdd8d10 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_0_9_1/patch_decorator.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import inspect +from typing import TypeVar, Union +from unittest.mock import patch + +import torch +import torch.nn as nn +from torch._dynamo.symbolic_convert import InliningInstructionTranslator +from vllm.compilation import decorators +from vllm.compilation.counter import compilation_counter +from vllm.compilation.monitor import start_monitoring_torch_compile +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import CompilationLevel, VllmConfig +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors +from vllm.utils import supports_dynamo + +from vllm_ascend.attention.attention_v1 import AscendAttentionState + +logger = init_logger(__name__) + +_T = TypeVar("_T", bound=type[nn.Module]) + + +def _ascend_support_torch_compile( + cls: _T, + dynamic_arg_dims: dict[str, Union[int, list[int]]], +) -> _T: + """ + A decorator to add support for compiling the forward method of a class. + """ + if TorchCompileWrapperWithCustomDispatcher in cls.__bases__: + # support decorating multiple times + return cls + + # take care of method resolution order + # make sure super().__init__ is called on the base class + # other than TorchCompileWrapperWithCustomDispatcher + cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) + + old_init = cls.__init__ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): + old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) + self.vllm_config = vllm_config + # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner + # will handle the compilation, so we don't need to do anything here. + self.do_not_compile = \ + vllm_config.compilation_config.level in [ + CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS + ] or not supports_dynamo() + if self.do_not_compile: + return + compilation_counter.num_models_seen += 1 + TorchCompileWrapperWithCustomDispatcher.__init__( + self, compilation_level=vllm_config.compilation_config.level) + + cls.__init__ = __init__ + + def __call__(self, *args, **kwargs): + # torch.compiler.is_compiling() means we are inside the compilation + # e.g. TPU has the compilation logic in model runner, so we don't + # need to compile the model inside. + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None and attn_metadata.attn_state != AscendAttentionState.DecodeOnly: + return self.forward(*args, **kwargs) + + if self.do_not_compile or torch.compiler.is_compiling(): + return self.forward(*args, **kwargs) + + # the first compilation needs to have dynamic shapes marked + if len(self.compiled_codes) < 1: + sig = inspect.signature(self.__class__.forward) + bound_args = sig.bind(self, *args, **kwargs) + bound_args.apply_defaults() + for k, dims in dynamic_arg_dims.items(): + arg = bound_args.arguments.get(k) + if arg is not None: + dims = [dims] if isinstance(dims, int) else dims + if isinstance(arg, torch.Tensor): + # In case dims is specified with negative indexing + dims = [ + arg.ndim + dim if dim < 0 else dim for dim in dims + ] + torch._dynamo.mark_dynamic(arg, dims) + elif isinstance(arg, IntermediateTensors): + for tensor in arg.tensors.values(): + # In case dims is specified with negative indexing + dims = [ + tensor.ndim + dim if dim < 0 else dim + for dim in dims + ] + torch._dynamo.mark_dynamic(tensor, dims) + else: + raise ValueError( + "Unsupported dynamic dimensions" + f" {dims} for argument {k} with type {type(arg)}.") + # here, it is the starting point of the `torch.compile` process + start_monitoring_torch_compile(self.vllm_config) + logger.debug("Start compiling function %s", + self.original_code_object) + + # if we don't use custom dispatcher, we can directly call the + # compiled function and let torch.compile handle the dispatching, + # with the overhead of guard evaluation and recompilation. + if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher: + # it seems Dynamo reuse the compilation across instances, + # while we need to make sure the compiled code is not reused. + # we need to control all the compilation of the model. + torch._dynamo.eval_frame.remove_from_cache( + self.original_code_object) + + # collect all relevant files traced by Dynamo, + # so that the compilation cache can trigger re-compilation + # properly when any of these files change. + + # 1. the file containing the top-level forward function + self.vllm_config.compilation_config.traced_files.add( + self.original_code_object.co_filename) + + # 2. every time Dynamo sees a function call, it will inline + # the function by calling InliningInstructionTranslator.inline_call + # we hijack this function to know all the functions called + # during Dynamo tracing, and their corresponding files + inline_call = InliningInstructionTranslator.inline_call + + def patched_inline_call(parent, func, args, kwargs): + code = func.get_code() + self.vllm_config.compilation_config.traced_files.add( + code.co_filename) + return inline_call(parent, func, args, kwargs) + + with patch.object(InliningInstructionTranslator, 'inline_call', + patched_inline_call): + output = self.compiled_callable(*args, **kwargs) + return output + + # usually, capturing the model once is enough, and then we can + # dispatch to the compiled code directly, without going through + # the Dynamo guard mechanism. + with self.dispatch_to_code(0): + model_output = self.forward(*args, **kwargs) + return model_output + + cls.__call__ = __call__ + return cls + + +decorators._support_torch_compile = _ascend_support_torch_compile