|
39 | 39 | from vllm.distributed import (get_dp_group, get_pp_group,
|
40 | 40 | get_tensor_model_parallel_rank,
|
41 | 41 | get_tensor_model_parallel_world_size,
|
42 |
| - tensor_model_parallel_reduce_scatter, |
43 |
| - split_tensor_along_last_dim, get_tp_group) |
| 42 | + get_tp_group, split_tensor_along_last_dim, |
| 43 | + tensor_model_parallel_reduce_scatter) |
44 | 44 | from vllm.forward_context import get_forward_context
|
45 | 45 | from vllm.model_executor.layers.activation import SiluAndMul
|
46 | 46 | from vllm.model_executor.layers.layernorm import RMSNorm
|
|
68 | 68 | from vllm.sequence import IntermediateTensors
|
69 | 69 |
|
70 | 70 | from vllm_ascend.ascend_config import get_ascend_config
|
| 71 | +from vllm_ascend.attention.attention_v1 import AscendAttentionState |
71 | 72 | from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
72 | 73 | from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
73 | 74 | from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
@@ -539,9 +540,24 @@ def forward(
|
539 | 540 | else:
|
540 | 541 | hidden_states_or_q_c = hidden_states
|
541 | 542 | is_mtp_model = attn_metadata is not None and attn_metadata.is_mtp_model
|
542 |
| - if self.torchair_graph_enabled and not is_mtp_model: |
| 543 | + with_decode = attn_metadata is not None and attn_metadata.attn_state in [ |
| 544 | + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding |
| 545 | + ] |
| 546 | + with_optimization_prefill = self.enable_prefill_optimizations and not with_decode |
| 547 | + if self.torchair_graph_enabled and not is_mtp_model and not with_optimization_prefill: |
| 548 | + if self.enable_prefill_optimizations and self.debug_layer_idx > 3 and self.debug_layer_idx < 61: |
| 549 | + hidden_states_or_q_c = get_tp_group().all_gather( |
| 550 | + hidden_states_or_q_c, 0) |
| 551 | + hidden_states = get_tp_group().all_gather(hidden_states, 0) |
543 | 552 | if envs.VLLM_USE_V1:
|
544 |
| - output_shape = hidden_states.shape |
| 553 | + if not self.enable_prefill_optimizations or self.debug_layer_idx < 3: |
| 554 | + output_shape = hidden_states.shape |
| 555 | + else: |
| 556 | + num_tokens = hidden_states.shape[0] |
| 557 | + rows = num_tokens // self.tp_size |
| 558 | + if num_tokens % self.tp_size: |
| 559 | + rows += 1 |
| 560 | + output_shape = (rows, hidden_states.shape[1]) |
545 | 561 | output = torch.empty(output_shape,
|
546 | 562 | dtype=hidden_states_or_q_c.dtype,
|
547 | 563 | device=hidden_states_or_q_c.device)
|
|
0 commit comments