-
Notifications
You must be signed in to change notification settings - Fork 270
Add graph mode for Qwen2.5 and Qwen3 #1787
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
9afb7cf
to
c985f62
Compare
@NeverRaR plz review |
871c038
to
5a93bc3
Compare
I got the following error when running Qwen2.5-32B and Qwen3-30B-A3B model:
And the following error when running Qwen3-32B model:
All models are run with tp 2. The first error is occurred during inference, and the second error is occurred during startup. |
Does this pr support all_gather's dp? |
16a59c4
to
4c85d7c
Compare
@@ -188,6 +217,41 @@ def build(self, | |||
slot_mapping = self.runner.slot_mapping[:num_actual_tokens] | |||
attn_mask = self.runner.attn_mask | |||
attn_state = self.runner.attn_state | |||
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
误修改的代码,请删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
|
||
vllm_config = get_current_vllm_config() | ||
self.full_graph = vllm_config.compilation_config.full_cuda_graph | ||
self.block_size = vllm_config.cache_config.block_size | ||
|
||
def update_kv_cache( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
补充UT,和_npu_reshape_and_cachej结果要求一致
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已补充UT
key_cache=self.key_cache, | ||
value_cache=self.value_cache, | ||
slot_indices=slots) | ||
if not attn_metadata.with_prefill_across_dp and self.torchair_graph_enabled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with_prefill_across_dp建议修改为attn_metadata.attn_state == AscendAttentionState.DecodeOnly 保持decode内部用一套算子
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当前两者语义略有不同, 将在详细分析后修改
key_cache = self.key_cache.view(*self.key_cache.shape[:-2], -1) | ||
value_cache = self.value_cache.view(*self.value_cache.shape[:-2], -1) | ||
|
||
output = torch_npu.npu_incre_flash_attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议修改为npu_fused_infer_attention_score算子,并增加UT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
经讨论继续使用npu_incre_flash_attention
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) | ||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) | ||
ascend_config = get_ascend_config() | ||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否移到forward中直接判断ascend_config,减少代码
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
todo, 移到forward中的代码, 待验证后合入
vllm_ascend/ops/rotary_embedding.py
Outdated
|
||
|
||
def rope_forward( | ||
self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
补充UT,和ATB算子作精度校验
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已补充UT
@@ -992,7 +992,8 @@ def _process_reqs( | |||
# Use host tensor, other wise error: tensor.hostData is null | |||
common_attn_metadata = CommonAttentionMetadata( | |||
query_start_loc=query_start_loc, | |||
seq_lens=self.seq_lens_cpu[:num_reqs]) | |||
seq_lens=self.seq_lens_cpu[:num_reqs], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
误修改?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
无误, 最新PR逻辑正常
@@ -1112,6 +1114,20 @@ def _process_reqs( | |||
if envs_ascend.VLLM_ASCEND_ENABLE_DBO and with_prefill: | |||
model_kwargs["graph_enable"] = False # type: ignore | |||
if self.torchair_graph_enabled and not with_prefill: | |||
torch._dynamo.mark_static(input_ids) | |||
torch._dynamo.mark_static(positions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
定位重复编译问题,解决后删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
todo
a0e44e3
to
076e767
Compare
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
b6c1124
to
e7c0013
Compare
Signed-off-by: taoyuxiang <t30002884@china.huawei.com>
What this PR does / why we need it?
Add graph mode for Qwen2.5 and Qwen3
Does this PR introduce any user-facing change?
No
How was this patch tested?
Tested the single-operator mode and graph mode of Qwen2.5, Qwen3, DeepSeek.