|
66 | 66 | from vllm.sequence import IntermediateTensors
|
67 | 67 |
|
68 | 68 | import vllm_ascend.envs as envs_ascend
|
| 69 | +from vllm_ascend.distributed.parallel_state import get_ep_group |
69 | 70 | from vllm_ascend.ascend_config import get_ascend_config
|
70 | 71 | from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
71 | 72 | from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
@@ -211,13 +212,15 @@ def __init__(
|
211 | 212 |
|
212 | 213 | self.tp_group = get_tp_group().device_group
|
213 | 214 | self.tp_rank = get_tp_group().rank_in_group
|
| 215 | + self.ep_group = get_ep_group() |
214 | 216 |
|
215 | 217 | self.params_dtype = torch.get_default_dtype()
|
216 | 218 |
|
217 | 219 | ascend_config = get_ascend_config()
|
218 | 220 | self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 221 | + # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on |
219 | 222 | self.enable_multistream_shared_expert = \
|
220 |
| - ascend_config.torchair_graph_config.enable_multistream_shared_expert |
| 223 | + ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2 |
221 | 224 |
|
222 | 225 | def forward(
|
223 | 226 | self,
|
@@ -245,16 +248,10 @@ def forward(
|
245 | 248 | old_hidden_states = hidden_states.clone()
|
246 | 249 |
|
247 | 250 | if self.tp_size > 1:
|
248 |
| - if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: |
249 |
| - chunks = torch.chunk(hidden_states, self.tp_size, dim=0) |
250 |
| - hidden_states = chunks[self.tp_rank] |
251 |
| - elif not self.torchair_graph_enabled: |
252 |
| - num_padding_tokens = (self.tp_size - |
253 |
| - num_tokens % self.tp_size) % self.tp_size |
254 |
| - # Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C |
255 |
| - if num_padding_tokens > 0: |
| 251 | + if (VLLM_ENABLE_MC2 and not is_prefill) or not (self.torchair_graph_enabled or self.ep_group.world_size == 1): |
| 252 | + if num_tokens < self.tp_size: |
256 | 253 | hidden_states = nn.functional.pad(
|
257 |
| - hidden_states, (0, 0, 0, num_padding_tokens)) |
| 254 | + hidden_states, (0, 0, 0, self.tp_size - num_tokens)) |
258 | 255 | chunk_hidden_states = torch.tensor_split(hidden_states,
|
259 | 256 | self.tp_size,
|
260 | 257 | dim=0)
|
@@ -284,24 +281,14 @@ def forward(
|
284 | 281 | hidden_states = hidden_states * self.routed_scaling_factor
|
285 | 282 |
|
286 | 283 | if self.tp_size > 1:
|
287 |
| - if self.torchair_graph_enabled: |
288 |
| - if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: |
289 |
| - final_hidden_states = torch.zeros( |
290 |
| - [num_tokens, hidden_size], |
291 |
| - dtype=self.params_dtype, |
292 |
| - device="npu") |
293 |
| - dist.all_gather_into_tensor(final_hidden_states, |
294 |
| - hidden_states, self.tp_group) |
295 |
| - hidden_states = final_hidden_states |
296 |
| - else: |
297 |
| - hidden_states = tensor_model_parallel_all_reduce( |
298 |
| - hidden_states) |
299 |
| - else: |
| 284 | + if (VLLM_ENABLE_MC2 and not is_prefill) or not (self.torchair_graph_enabled or self.ep_group.world_size == 1): |
300 | 285 | dist.all_gather(list(chunk_hidden_states), hidden_states,
|
301 | 286 | self.tp_group)
|
302 | 287 | hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
303 |
| - if num_padding_tokens > 0: |
304 |
| - hidden_states = hidden_states[:-num_padding_tokens] |
| 288 | + if num_tokens < self.tp_size: |
| 289 | + hidden_states = hidden_states[:num_tokens] |
| 290 | + else: |
| 291 | + hidden_states = tensor_model_parallel_all_reduce(hidden_states) |
305 | 292 |
|
306 | 293 | if self.n_shared_experts is not None:
|
307 | 294 | if not multistream:
|
|
0 commit comments