@@ -712,18 +712,12 @@ def forward(
712
712
hidden_states_or_q_c , 0 )
713
713
kv_no_split = get_tp_group ().all_gather (kv_no_split , 0 )
714
714
715
- # kv_c_k_pe = self.kv_a_proj_with_mqa(hidden_states)[0]
716
715
if self .enable_sp and is_prefill :
717
- chunk_kv_no_split = [torch .empty_like (kv_no_split ) for _ in range (self .sp_size )]
718
- dist .all_gather (list (chunk_kv_no_split ), kv_no_split , self .sp_group )
719
- kv_no_split = torch .cat (chunk_kv_no_split , dim = 0 )
716
+ kv_no_split = get_tp_group ().all_gather (kv_no_split , 0 )
720
717
kv_no_split = kv_no_split [:original_len ]
721
718
722
- chunk_hidden_states_or_q_c = [torch .empty_like (hidden_states_or_q_c ) for _ in range (self .sp_size )]
723
- dist .all_gather (list (chunk_hidden_states_or_q_c ), hidden_states_or_q_c , self .sp_group )
724
- hidden_states_or_q_c = torch .cat (chunk_hidden_states_or_q_c , dim = 0 )
719
+ hidden_states_or_q_c = get_tp_group ().all_gather (hidden_states_or_q_c , 0 )
725
720
hidden_states_or_q_c = hidden_states_or_q_c [:original_len ]
726
- # kv_c, k_pe = kv_c_k_pe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
727
721
728
722
kv_c , k_pe = kv_no_split .split (
729
723
[self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1 )
@@ -1038,14 +1032,10 @@ def forward(
1038
1032
1039
1033
hidden_states , _ = self .norm (hidden_states , residual )
1040
1034
if self .enable_sp and is_prefill :
1041
- chunk_hidden_states = [torch .empty_like (hidden_states ) for _ in range (self .sp_size )]
1042
- dist .all_gather (list (chunk_hidden_states ), hidden_states , self .sp_group )
1043
- hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
1035
+ hidden_states = get_tp_group ().all_gather (hidden_states , 0 )
1044
1036
hidden_states = hidden_states [:original_len ]
1045
1037
if self .cp_size > 1 and is_prefill :
1046
- chunk_hidden_states = [torch .empty_like (hidden_states ) for _ in range (self .cp_size )]
1047
- dist .all_gather (list (chunk_hidden_states ), hidden_states , self .cp_group )
1048
- hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
1038
+ hidden_states = get_cp_group ().all_gather (hidden_states , 0 )
1049
1039
hidden_states = torch .index_select (hidden_states , 0 , attn_metadata .prefill .cp_kv_recover_idx )
1050
1040
return hidden_states
1051
1041
0 commit comments