Skip to content

Commit f29401d

Browse files
committed
[bugfix] fix deeepseek accuracy
Signed-off-by: zzzzwwjj <1183291235@qq.com>
1 parent a2552e1 commit f29401d

File tree

2 files changed

+14
-26
lines changed

2 files changed

+14
-26
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from vllm.sequence import IntermediateTensors
6767

6868
import vllm_ascend.envs as envs_ascend
69+
from vllm_ascend.distributed.parallel_state import get_ep_group
6970
from vllm_ascend.ascend_config import get_ascend_config
7071
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7172
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
@@ -211,13 +212,15 @@ def __init__(
211212

212213
self.tp_group = get_tp_group().device_group
213214
self.tp_rank = get_tp_group().rank_in_group
215+
self.ep_group = get_ep_group()
214216

215217
self.params_dtype = torch.get_default_dtype()
216218

217219
ascend_config = get_ascend_config()
218220
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
221+
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
219222
self.enable_multistream_shared_expert = \
220-
ascend_config.torchair_graph_config.enable_multistream_shared_expert
223+
ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2
221224

222225
def forward(
223226
self,
@@ -245,16 +248,10 @@ def forward(
245248
old_hidden_states = hidden_states.clone()
246249

247250
if self.tp_size > 1:
248-
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
249-
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
250-
hidden_states = chunks[self.tp_rank]
251-
elif not self.torchair_graph_enabled:
252-
num_padding_tokens = (self.tp_size -
253-
num_tokens % self.tp_size) % self.tp_size
254-
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
255-
if num_padding_tokens > 0:
251+
if (VLLM_ENABLE_MC2 and not is_prefill) or not (self.torchair_graph_enabled or self.ep_group.world_size == 1):
252+
if num_tokens < self.tp_size:
256253
hidden_states = nn.functional.pad(
257-
hidden_states, (0, 0, 0, num_padding_tokens))
254+
hidden_states, (0, 0, 0, self.tp_size - num_tokens))
258255
chunk_hidden_states = torch.tensor_split(hidden_states,
259256
self.tp_size,
260257
dim=0)
@@ -284,24 +281,14 @@ def forward(
284281
hidden_states = hidden_states * self.routed_scaling_factor
285282

286283
if self.tp_size > 1:
287-
if self.torchair_graph_enabled:
288-
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
289-
final_hidden_states = torch.zeros(
290-
[num_tokens, hidden_size],
291-
dtype=self.params_dtype,
292-
device="npu")
293-
dist.all_gather_into_tensor(final_hidden_states,
294-
hidden_states, self.tp_group)
295-
hidden_states = final_hidden_states
296-
else:
297-
hidden_states = tensor_model_parallel_all_reduce(
298-
hidden_states)
299-
else:
284+
if (VLLM_ENABLE_MC2 and not is_prefill) or not (self.torchair_graph_enabled or self.ep_group.world_size == 1):
300285
dist.all_gather(list(chunk_hidden_states), hidden_states,
301286
self.tp_group)
302287
hidden_states = torch.cat(chunk_hidden_states, dim=0)
303-
if num_padding_tokens > 0:
304-
hidden_states = hidden_states[:-num_padding_tokens]
288+
if num_tokens < self.tp_size:
289+
hidden_states = hidden_states[:num_tokens]
290+
else:
291+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
305292

306293
if self.n_shared_experts is not None:
307294
if not multistream:

vllm_ascend/ops/fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,8 +771,9 @@ def __init__(
771771

772772
ascend_config = get_ascend_config()
773773
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
774+
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
774775
self.enable_multistream_shared_expert = \
775-
ascend_config.torchair_graph_config.enable_multistream_shared_expert
776+
ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2
776777

777778
if self.scoring_func != "softmax" and not self.use_grouped_topk:
778779
raise ValueError("Only softmax scoring function is supported for "

0 commit comments

Comments
 (0)