Skip to content

Commit 1648906

Browse files
[long_seq_optim] all_gather optim
1 parent d7edd10 commit 1648906

File tree

1 file changed

+4
-14
lines changed

1 file changed

+4
-14
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -712,18 +712,12 @@ def forward(
712712
hidden_states_or_q_c, 0)
713713
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
714714

715-
# kv_c_k_pe = self.kv_a_proj_with_mqa(hidden_states)[0]
716715
if self.enable_sp and is_prefill:
717-
chunk_kv_no_split = [torch.empty_like(kv_no_split) for _ in range(self.sp_size)]
718-
dist.all_gather(list(chunk_kv_no_split), kv_no_split, self.sp_group)
719-
kv_no_split = torch.cat(chunk_kv_no_split, dim=0)
716+
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
720717
kv_no_split = kv_no_split[:original_len]
721718

722-
chunk_hidden_states_or_q_c = [torch.empty_like(hidden_states_or_q_c) for _ in range(self.sp_size)]
723-
dist.all_gather(list(chunk_hidden_states_or_q_c), hidden_states_or_q_c, self.sp_group)
724-
hidden_states_or_q_c = torch.cat(chunk_hidden_states_or_q_c, dim=0)
719+
hidden_states_or_q_c = get_tp_group().all_gather(hidden_states_or_q_c, 0)
725720
hidden_states_or_q_c = hidden_states_or_q_c[:original_len]
726-
# kv_c, k_pe = kv_c_k_pe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
727721

728722
kv_c, k_pe = kv_no_split.split(
729723
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
@@ -1038,14 +1032,10 @@ def forward(
10381032

10391033
hidden_states, _ = self.norm(hidden_states, residual)
10401034
if self.enable_sp and is_prefill:
1041-
chunk_hidden_states = [torch.empty_like(hidden_states) for _ in range(self.sp_size)]
1042-
dist.all_gather(list(chunk_hidden_states), hidden_states, self.sp_group)
1043-
hidden_states = torch.cat(chunk_hidden_states, dim=0)
1035+
hidden_states = get_tp_group().all_gather(hidden_states, 0)
10441036
hidden_states = hidden_states[:original_len]
10451037
if self.cp_size > 1 and is_prefill:
1046-
chunk_hidden_states = [torch.empty_like(hidden_states) for _ in range(self.cp_size)]
1047-
dist.all_gather(list(chunk_hidden_states), hidden_states, self.cp_group)
1048-
hidden_states = torch.cat(chunk_hidden_states, dim=0)
1038+
hidden_states = get_cp_group().all_gather(hidden_states, 0)
10491039
hidden_states = torch.index_select(hidden_states, 0, attn_metadata.prefill.cp_kv_recover_idx)
10501040
return hidden_states
10511041

0 commit comments

Comments
 (0)