Skip to content

Commit 31ba211

Browse files
committed
[V0.9.1] Patch compilation.decorator to support flashcomm_v1 in aclgraph
Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 1542a75 commit 31ba211

File tree

6 files changed

+194
-34
lines changed

6 files changed

+194
-34
lines changed

vllm_ascend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def register():
2323

2424

2525
def register_model():
26+
import vllm # noqa: F401
27+
import vllm_ascend.patch.platform.patch_0_9_1.patch_decorator # noqa: F401
2628
# fix pytorch schema check error, remove this line after pytorch
2729
# is upgraded to 2.7.0
2830
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa: F401

vllm_ascend/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ def register_model():
88
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
99
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
1010
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
11+
from .qwen2 import CustomQwen2ForCausalLM # noqa: F401
1112
from .qwen2_5_vl import \
1213
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1314
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
14-
from .qwen2 import CustomQwen2ForCausalLM # noqa: F401
1515
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
1616

1717
ModelRegistry.register_model(

vllm_ascend/models/qwen2.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,30 @@
11
from collections.abc import Iterable
2-
from typing import Any, Optional, Union
2+
from typing import Optional, Union
33

44
import torch
5-
from torch import nn
65
import torch.nn.functional as F
6+
from torch import nn
77
from transformers import Qwen2Config
8-
98
from vllm.compilation.decorators import support_torch_compile
109
from vllm.config import CacheConfig, VllmConfig
11-
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
10+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
11+
get_tensor_model_parallel_world_size,
12+
tensor_model_parallel_all_gather,
13+
tensor_model_parallel_all_reduce,
14+
tensor_model_parallel_reduce_scatter)
15+
from vllm.forward_context import get_forward_context
1216
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1317
from vllm.model_executor.layers.quantization import QuantizationConfig
1418
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
19+
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
20+
from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2Model
21+
from vllm.model_executor.models.utils import (AutoWeightsLoader,
22+
PPMissingLayer, maybe_prefix)
1523
from vllm.model_executor.sampling_metadata import SamplingMetadata
1624
from vllm.sequence import IntermediateTensors
1725

18-
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
19-
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix)
20-
21-
from vllm.model_executor.models.qwen2 import Qwen2Model, Qwen2DecoderLayer
22-
from vllm.distributed import (
23-
get_pp_group,
24-
get_tensor_model_parallel_world_size,
25-
get_tensor_model_parallel_rank,
26-
tensor_model_parallel_all_gather,
27-
tensor_model_parallel_all_reduce,
28-
tensor_model_parallel_reduce_scatter)
29-
from vllm_ascend.attention.attention_v1 import AscendAttentionState
30-
from vllm.forward_context import get_forward_context
3126
import vllm_ascend.envs as ascend_envs
27+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
3228

3329

3430
def all_gather_and_maybe_unpad(
@@ -40,6 +36,7 @@ def all_gather_and_maybe_unpad(
4036
return hidden_states[:-pad_size, :]
4137
return hidden_states
4238

39+
4340
def maybe_pad_and_reduce_scatter(
4441
hidden_states: torch.Tensor,
4542
pad_size: int,
@@ -49,6 +46,7 @@ def maybe_pad_and_reduce_scatter(
4946
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0)
5047
return hidden_states
5148

49+
5250
class CustomQwen2DecoderLayer(Qwen2DecoderLayer):
5351

5452
def __init__(
@@ -64,9 +62,9 @@ def __init__(
6462
prefix=prefix)
6563
self.tp_rank = get_tensor_model_parallel_rank()
6664
self.tp_size = get_tensor_model_parallel_world_size()
67-
self.self_attn.o_proj.reduce_results=False
68-
self.mlp.down_proj.reduce_results=False
69-
65+
self.self_attn.o_proj.reduce_results = False
66+
self.mlp.down_proj.reduce_results = False
67+
7068
def forward(
7169
self,
7270
positions: torch.Tensor,
@@ -81,19 +79,22 @@ def forward(
8179
if flashcomm_v1_enabled:
8280
if pad_size > 0:
8381
residual = F.pad(residual, (0, 0, 0, pad_size))
84-
residual = torch.chunk(residual, self.tp_size, dim=0)[self.tp_rank]
82+
residual = torch.chunk(residual, self.tp_size,
83+
dim=0)[self.tp_rank]
8584
hidden_states = self.input_layernorm(hidden_states)
8685
else:
8786
hidden_states, residual = self.input_layernorm(
8887
hidden_states, residual)
8988
if flashcomm_v1_enabled:
90-
hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size)
89+
hidden_states = all_gather_and_maybe_unpad(
90+
hidden_states, pad_size)
9191
hidden_states = self.self_attn(
9292
positions=positions,
9393
hidden_states=hidden_states,
9494
)
9595
if flashcomm_v1_enabled:
96-
hidden_states = maybe_pad_and_reduce_scatter(hidden_states, pad_size)
96+
hidden_states = maybe_pad_and_reduce_scatter(
97+
hidden_states, pad_size)
9798
else:
9899
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
99100
# Fully Connected
@@ -103,7 +104,8 @@ def forward(
103104
hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size)
104105
hidden_states = self.mlp(hidden_states)
105106
if flashcomm_v1_enabled:
106-
hidden_states = maybe_pad_and_reduce_scatter(hidden_states, pad_size)
107+
hidden_states = maybe_pad_and_reduce_scatter(
108+
hidden_states, pad_size)
107109
else:
108110
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
109111
return hidden_states, residual
@@ -120,11 +122,12 @@ def forward(
120122
})
121123
class CustomQwen2Model(Qwen2Model):
122124

123-
def __init__(self,
124-
*,
125-
vllm_config: VllmConfig,
126-
prefix: str = "",
127-
decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer):
125+
def __init__(
126+
self,
127+
*,
128+
vllm_config: VllmConfig,
129+
prefix: str = "",
130+
decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer):
128131
super().__init__(vllm_config=vllm_config,
129132
prefix=prefix,
130133
decoder_layer_type=decoder_layer_type)
@@ -156,7 +159,8 @@ def forward(
156159
flashcomm_v1_enabled = True
157160
if flashcomm_v1_enabled:
158161
num_tokens = hidden_states.size(0)
159-
pad_size = (self.tp_size - (num_tokens % self.tp_size)) % self.tp_size
162+
pad_size = (self.tp_size -
163+
(num_tokens % self.tp_size)) % self.tp_size
160164
for layer in self.layers[self.start_layer:self.end_layer]:
161165
hidden_states, residual = layer(
162166
positions,
@@ -201,7 +205,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
201205

202206
self.quant_config = quant_config
203207
self.model = CustomQwen2Model(vllm_config=vllm_config,
204-
prefix=maybe_prefix(prefix, "model"))
208+
prefix=maybe_prefix(prefix, "model"))
205209

206210
if get_pp_group().is_last_rank:
207211
if config.tie_word_embeddings:

vllm_ascend/patch/platform/patch_0_9_1/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
18+
import vllm_ascend.patch.platform.patch_0_9_1.patch_decorator # noqa
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import inspect
5+
from typing import TypeVar, Union
6+
from unittest.mock import patch
7+
8+
import torch
9+
import torch.nn as nn
10+
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
11+
from vllm.compilation import decorators
12+
from vllm.compilation.counter import compilation_counter
13+
from vllm.compilation.monitor import start_monitoring_torch_compile
14+
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
15+
from vllm.config import CompilationLevel, VllmConfig
16+
from vllm.forward_context import get_forward_context
17+
from vllm.logger import init_logger
18+
from vllm.sequence import IntermediateTensors
19+
from vllm.utils import supports_dynamo
20+
21+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
22+
23+
logger = init_logger(__name__)
24+
25+
_T = TypeVar("_T", bound=type[nn.Module])
26+
27+
28+
def _ascend_support_torch_compile(
29+
cls: _T,
30+
dynamic_arg_dims: dict[str, Union[int, list[int]]],
31+
) -> _T:
32+
"""
33+
A decorator to add support for compiling the forward method of a class.
34+
"""
35+
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
36+
# support decorating multiple times
37+
return cls
38+
39+
# take care of method resolution order
40+
# make sure super().__init__ is called on the base class
41+
# other than TorchCompileWrapperWithCustomDispatcher
42+
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
43+
44+
old_init = cls.__init__
45+
46+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
47+
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
48+
self.vllm_config = vllm_config
49+
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
50+
# will handle the compilation, so we don't need to do anything here.
51+
self.do_not_compile = \
52+
vllm_config.compilation_config.level in [
53+
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
54+
] or not supports_dynamo()
55+
if self.do_not_compile:
56+
return
57+
compilation_counter.num_models_seen += 1
58+
TorchCompileWrapperWithCustomDispatcher.__init__(
59+
self, compilation_level=vllm_config.compilation_config.level)
60+
61+
cls.__init__ = __init__
62+
63+
def __call__(self, *args, **kwargs):
64+
# torch.compiler.is_compiling() means we are inside the compilation
65+
# e.g. TPU has the compilation logic in model runner, so we don't
66+
# need to compile the model inside.
67+
attn_metadata = get_forward_context().attn_metadata
68+
if attn_metadata is not None and attn_metadata.attn_state != AscendAttentionState.DecodeOnly:
69+
return self.forward(*args, **kwargs)
70+
71+
if self.do_not_compile or torch.compiler.is_compiling():
72+
return self.forward(*args, **kwargs)
73+
74+
# the first compilation needs to have dynamic shapes marked
75+
if len(self.compiled_codes) < 1:
76+
sig = inspect.signature(self.__class__.forward)
77+
bound_args = sig.bind(self, *args, **kwargs)
78+
bound_args.apply_defaults()
79+
for k, dims in dynamic_arg_dims.items():
80+
arg = bound_args.arguments.get(k)
81+
if arg is not None:
82+
dims = [dims] if isinstance(dims, int) else dims
83+
if isinstance(arg, torch.Tensor):
84+
# In case dims is specified with negative indexing
85+
dims = [
86+
arg.ndim + dim if dim < 0 else dim for dim in dims
87+
]
88+
torch._dynamo.mark_dynamic(arg, dims)
89+
elif isinstance(arg, IntermediateTensors):
90+
for tensor in arg.tensors.values():
91+
# In case dims is specified with negative indexing
92+
dims = [
93+
tensor.ndim + dim if dim < 0 else dim
94+
for dim in dims
95+
]
96+
torch._dynamo.mark_dynamic(tensor, dims)
97+
else:
98+
raise ValueError(
99+
"Unsupported dynamic dimensions"
100+
f" {dims} for argument {k} with type {type(arg)}.")
101+
# here, it is the starting point of the `torch.compile` process
102+
start_monitoring_torch_compile(self.vllm_config)
103+
logger.debug("Start compiling function %s",
104+
self.original_code_object)
105+
106+
# if we don't use custom dispatcher, we can directly call the
107+
# compiled function and let torch.compile handle the dispatching,
108+
# with the overhead of guard evaluation and recompilation.
109+
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
110+
# it seems Dynamo reuse the compilation across instances,
111+
# while we need to make sure the compiled code is not reused.
112+
# we need to control all the compilation of the model.
113+
torch._dynamo.eval_frame.remove_from_cache(
114+
self.original_code_object)
115+
116+
# collect all relevant files traced by Dynamo,
117+
# so that the compilation cache can trigger re-compilation
118+
# properly when any of these files change.
119+
120+
# 1. the file containing the top-level forward function
121+
self.vllm_config.compilation_config.traced_files.add(
122+
self.original_code_object.co_filename)
123+
124+
# 2. every time Dynamo sees a function call, it will inline
125+
# the function by calling InliningInstructionTranslator.inline_call
126+
# we hijack this function to know all the functions called
127+
# during Dynamo tracing, and their corresponding files
128+
inline_call = InliningInstructionTranslator.inline_call
129+
130+
def patched_inline_call(parent, func, args, kwargs):
131+
code = func.get_code()
132+
self.vllm_config.compilation_config.traced_files.add(
133+
code.co_filename)
134+
return inline_call(parent, func, args, kwargs)
135+
136+
with patch.object(InliningInstructionTranslator, 'inline_call',
137+
patched_inline_call):
138+
output = self.compiled_callable(*args, **kwargs)
139+
return output
140+
141+
# usually, capturing the model once is enough, and then we can
142+
# dispatch to the compiled code directly, without going through
143+
# the Dynamo guard mechanism.
144+
with self.dispatch_to_code(0):
145+
model_output = self.forward(*args, **kwargs)
146+
return model_output
147+
148+
cls.__call__ = __call__
149+
return cls
150+
151+
152+
decorators._support_torch_compile = _ascend_support_torch_compile

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,8 +2033,8 @@ def capture_model(self) -> None:
20332033
for num_tokens in reversed(self.aclgraph_batch_sizes):
20342034
for _ in range(self.vllm_config.compilation_config.
20352035
cudagraph_num_of_warmups):
2036-
self._dummy_run(num_tokens, skip_attn=skip_attn, with_prefill=False)
2037-
self._dummy_run(num_tokens, skip_attn=skip_attn, with_prefill=False)
2036+
self._dummy_run(num_tokens, skip_attn=skip_attn)
2037+
self._dummy_run(num_tokens, skip_attn=skip_attn)
20382038
else:
20392039
logger.info("Skipping NPU graph capture for eager mode.")
20402040
return

0 commit comments

Comments
 (0)