| 
67 | 67 | 
 
  | 
68 | 68 | import vllm_ascend.envs as envs_ascend  | 
69 | 69 | from vllm_ascend.ascend_config import get_ascend_config  | 
 | 70 | +from vllm_ascend.distributed.parallel_state import get_ep_group  | 
70 | 71 | from vllm_ascend.ops.fused_moe import AscendFusedMoE  | 
71 | 72 | from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod  | 
72 | 73 | from vllm_ascend.utils import dispose_tensor  | 
@@ -211,13 +212,15 @@ def __init__(  | 
211 | 212 | 
 
  | 
212 | 213 |         self.tp_group = get_tp_group().device_group  | 
213 | 214 |         self.tp_rank = get_tp_group().rank_in_group  | 
 | 215 | +        self.ep_group = get_ep_group()  | 
214 | 216 | 
 
  | 
215 | 217 |         self.params_dtype = torch.get_default_dtype()  | 
216 | 218 | 
 
  | 
217 | 219 |         ascend_config = get_ascend_config()  | 
218 | 220 |         self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled  | 
 | 221 | +        # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on  | 
219 | 222 |         self.enable_multistream_shared_expert = \  | 
220 |  | -            ascend_config.torchair_graph_config.enable_multistream_shared_expert  | 
 | 223 | +            ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2  | 
221 | 224 | 
 
  | 
222 | 225 |     def forward(  | 
223 | 226 |             self,  | 
@@ -245,16 +248,12 @@ def forward(  | 
245 | 248 |         old_hidden_states = hidden_states.clone()  | 
246 | 249 | 
 
  | 
247 | 250 |         if self.tp_size > 1:  | 
248 |  | -            if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:  | 
249 |  | -                chunks = torch.chunk(hidden_states, self.tp_size, dim=0)  | 
250 |  | -                hidden_states = chunks[self.tp_rank]  | 
251 |  | -            elif not self.torchair_graph_enabled:  | 
252 |  | -                num_padding_tokens = (self.tp_size -  | 
253 |  | -                                      num_tokens % self.tp_size) % self.tp_size  | 
254 |  | -                # Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C  | 
255 |  | -                if num_padding_tokens > 0:  | 
 | 251 | +            if (VLLM_ENABLE_MC2  | 
 | 252 | +                    and not is_prefill) or not (self.torchair_graph_enabled or  | 
 | 253 | +                                                self.ep_group.world_size == 1):  | 
 | 254 | +                if num_tokens < self.tp_size:  | 
256 | 255 |                     hidden_states = nn.functional.pad(  | 
257 |  | -                        hidden_states, (0, 0, 0, num_padding_tokens))  | 
 | 256 | +                        hidden_states, (0, 0, 0, self.tp_size - num_tokens))  | 
258 | 257 |                 chunk_hidden_states = torch.tensor_split(hidden_states,  | 
259 | 258 |                                                          self.tp_size,  | 
260 | 259 |                                                          dim=0)  | 
@@ -284,24 +283,16 @@ def forward(  | 
284 | 283 |         hidden_states = hidden_states * self.routed_scaling_factor  | 
285 | 284 | 
 
  | 
286 | 285 |         if self.tp_size > 1:  | 
287 |  | -            if self.torchair_graph_enabled:  | 
288 |  | -                if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:  | 
289 |  | -                    final_hidden_states = torch.zeros(  | 
290 |  | -                        [num_tokens, hidden_size],  | 
291 |  | -                        dtype=self.params_dtype,  | 
292 |  | -                        device="npu")  | 
293 |  | -                    dist.all_gather_into_tensor(final_hidden_states,  | 
294 |  | -                                                hidden_states, self.tp_group)  | 
295 |  | -                    hidden_states = final_hidden_states  | 
296 |  | -                else:  | 
297 |  | -                    hidden_states = tensor_model_parallel_all_reduce(  | 
298 |  | -                        hidden_states)  | 
299 |  | -            else:  | 
 | 286 | +            if (VLLM_ENABLE_MC2  | 
 | 287 | +                    and not is_prefill) or not (self.torchair_graph_enabled or  | 
 | 288 | +                                                self.ep_group.world_size == 1):  | 
300 | 289 |                 dist.all_gather(list(chunk_hidden_states), hidden_states,  | 
301 | 290 |                                 self.tp_group)  | 
302 | 291 |                 hidden_states = torch.cat(chunk_hidden_states, dim=0)  | 
303 |  | -                if num_padding_tokens > 0:  | 
304 |  | -                    hidden_states = hidden_states[:-num_padding_tokens]  | 
 | 292 | +                if num_tokens < self.tp_size:  | 
 | 293 | +                    hidden_states = hidden_states[:num_tokens]  | 
 | 294 | +            else:  | 
 | 295 | +                hidden_states = tensor_model_parallel_all_reduce(hidden_states)  | 
305 | 296 | 
 
  | 
306 | 297 |         if self.n_shared_experts is not None:  | 
307 | 298 |             if not multistream:  | 
 | 
0 commit comments