diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index 28be2b7e30..6bcc28bb9f 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -160,12 +160,13 @@ def forward( output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) + forward_context = get_forward_context() if self.reduce_results and self.tp_size > 1: num_tokens = output_parallel.shape[0] if is_force_scatter and num_tokens % self.tp_size: output_parallel = nn.functional.pad( output_parallel, (0, 0, 0, -num_tokens % self.tp_size)) - if is_force_scatter or (not is_prefill + if is_force_scatter or (not forward_context.with_prefill and output_parallel.shape[0] % self.tp_size == 0): output = tensor_model_parallel_reduce_scatter(output_parallel, @@ -726,7 +727,8 @@ def forward( replace_allreduce: bool = False, ) -> torch.Tensor: # Self Attention - if attn_metadata is not None and attn_metadata.num_decodes > 0: + forward_context = get_forward_context() + if attn_metadata is not None and not forward_context.with_prefill: mla_moe_communication = self.mla_moe_communication and replace_allreduce else: mla_moe_communication = False