74
74
from vllm_ascend .multistream .layers import (MultiStreamPostTransformerLayer ,
75
75
MultiStreamPreTransformerLayer )
76
76
from vllm_ascend .multistream .metadata import (MultiStreamConfig ,
77
+ MultiStreamMetadata ,
77
78
MultiStreamStepMetadata ,
78
79
make_multistream_metadata_ds )
79
80
from vllm_ascend .multistream .ms_split import compute_split_seq_index
@@ -698,13 +699,12 @@ def _forward_ms_layer(
698
699
shared_outputs = []
699
700
router_logits = []
700
701
chunk_hidden_states = []
701
- ''' block 1 : attention
702
- block 2 : attn tp communication, currently we switch to the comm stream
703
- in tensor_model_parallel_all_reduce;
704
- the attn computation of microbatch 1 can be overlapped with the moe
705
- communication in the previous layer, and the attn computation of microbatch
706
- 2 can be overlapped with the attn communication of microbatch 1
707
- '''
702
+
703
+ # block 1 : attention
704
+ # block 2 : attn tp communication
705
+ # the attn computation of microbatch 1 can be overlapped with the moe
706
+ # communication in the previous layer, and the attn computation of microbatch 2
707
+ # can be overlapped with the attn communication of microbatch 1
708
708
for i in range (num_micro_batchs ):
709
709
# wait last layer moe finishing communication
710
710
ms_metadata .try_wait_event (layer_index - 1 , i ,
@@ -731,10 +731,10 @@ def _forward_ms_layer(
731
731
hidden_states [i ], residual [i ] = self ._forward_ms_op_attn (
732
732
positions [i ], hidden_states [i ], residual [i ], kv_cache ,
733
733
attn_metadata [i ])
734
- ''' block 3 : shared experts
735
- if there is an allreduce ops in shared expert, we can overlap it with the computation of the
736
- shared expert for next microbatch or moe gating
737
- '''
734
+
735
+ # block 3 : shared experts
736
+ # if there is an allreduce ops in shared expert, we can overlap it with the computation of the
737
+ # shared expert for next microbatch or moe gating
738
738
for i in range (num_micro_batchs ):
739
739
ms_metadata .try_wait_event (layer_index , i ,
740
740
MSEventKey .ATTN_AR_FINISH )
@@ -763,7 +763,6 @@ def _forward_ms_layer(
763
763
764
764
# block 4 : moe
765
765
for i in range (num_micro_batchs ):
766
- #ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH)
767
766
# when profile runs, force experts to load balanced tokens
768
767
# to avoid high memory consumption on a single rank.
769
768
# TODO: need a better flag to indicate whether in profile run or not.
@@ -776,13 +775,6 @@ def _forward_ms_layer(
776
775
enable_force_load_balance = False
777
776
778
777
if self .mlp .tp_size > 1 :
779
- #if num_tokens[i] < self.mlp.tp_size:
780
- # target_size = self.mlp.tp_size
781
- # new_hidden_states = torch.empty([target_size, hidden_dims[i]],
782
- # dtype=hidden_states[i].dtype,
783
- # device=hidden_states[i].device)
784
- # new_hidden_states[:num_tokens[i]] = hidden_states[i]
785
- # hidden_states[i] = new_hidden_states
786
778
num_token , _ = hidden_states [i ].shape
787
779
padded_num_tokens = (self .mlp .tp_size - num_token %
788
780
self .mlp .tp_size ) % self .mlp .tp_size
@@ -805,18 +797,12 @@ def _forward_ms_layer(
805
797
else :
806
798
real_top_k = self .mlp .experts .top_k
807
799
808
- if VLLM_ENABLE_MC2 and not is_prefill :
809
- ...
810
-
811
800
hidden_states [i ] = self .mlp .experts ._forward_ms_fused_moe_comp (
812
801
local_hidden_states , router_logits [i ], is_prefill , real_top_k ,
813
802
enable_force_load_balance )
814
803
815
- if VLLM_ENABLE_MC2 and not is_prefill :
816
- ...
817
- ''' the following kernels will be submitted to the comm stream to overlap the computation of the
818
- moe computation of next microbatch and the attn computation of next layer
819
- '''
804
+ # the following kernels will be submitted to the comm stream to overlap the computation of the
805
+ # moe computation of next microbatch and the attn computation of next layer
820
806
context = MultiStreamStepMetadata (
821
807
comm_stream = ms_metadata .communicate_stream ,
822
808
before_comm_event = ms_metadata .ms_events [layer_index ][i ][
@@ -826,15 +812,14 @@ def _forward_ms_layer(
826
812
)
827
813
context .before_comm_event .record ()
828
814
with torch .npu .stream (ms_metadata .communicate_stream ):
829
- #with set_multistream_context(context, i):
830
815
context .before_comm_event .wait ()
831
816
if self .mlp .experts .reduce_results and (
832
817
self .mlp .experts .tp_size > 1
833
818
or self .mlp .experts .ep_size > 1 ):
834
819
hidden_states [i ] = tensor_model_parallel_all_reduce (
835
820
hidden_states [i ])
836
821
context .after_comm_event .record ()
837
- # check here
822
+
838
823
hidden_states [
839
824
i ] = hidden_states [i ] * self .mlp .routed_scaling_factor
840
825
context = MultiStreamStepMetadata (
@@ -959,21 +944,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
959
944
["hidden_states" , "residual" ], config .hidden_size ))
960
945
961
946
# tbo related members
962
- self .multistream_config : Optional [MultiStreamConfig ] = None
963
- if VLLM_ENABLE_DBO :
964
- self .multistream_config = MultiStreamConfig ()
965
-
966
947
self .use_mla = model_config .use_mla
967
- self .multistream_metadata = make_multistream_metadata_ds (
968
- start_layer = self .start_layer + self .first_k_dense_replace ,
969
- end_layer = self .end_layer ,
970
- causal_lm = getattr (config , "causal_lm" , True ),
971
- multistream_config = self .multistream_config ,
972
- )
973
- self .ms_pre_layer = MultiStreamPreTransformerLayer (
974
- self .multistream_metadata )
975
- self .ms_post_layer = MultiStreamPostTransformerLayer (
976
- self .multistream_metadata )
948
+ multistream_config : Optional [MultiStreamConfig ] = None
949
+ self .multistream_metadata : Optional [MultiStreamMetadata ] = None
950
+ if VLLM_ENABLE_DBO :
951
+ multistream_config = MultiStreamConfig ()
952
+ self .multistream_metadata = make_multistream_metadata_ds (
953
+ start_layer = self .start_layer + self .first_k_dense_replace ,
954
+ end_layer = self .end_layer ,
955
+ causal_lm = getattr (config , "causal_lm" , True ),
956
+ multistream_config = multistream_config ,
957
+ )
958
+ self .ms_pre_layer = MultiStreamPreTransformerLayer (
959
+ self .multistream_metadata )
960
+ self .ms_post_layer = MultiStreamPostTransformerLayer (
961
+ self .multistream_metadata )
977
962
978
963
def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
979
964
return self .embed_tokens (input_ids )
@@ -999,10 +984,10 @@ def forward(
999
984
residual = intermediate_tensors ["residual" ]
1000
985
1001
986
num_normal_layers = (self .first_k_dense_replace
1002
- if self .multistream_config is not None
987
+ if self .multistream_metadata is not None
1003
988
and self .can_run_ms () else self .end_layer -
1004
989
self .start_layer )
1005
- # if we enable multistream/dbo, only process dense layers here
990
+
1006
991
for i in range (self .start_layer , self .start_layer + num_normal_layers ):
1007
992
layer = self .layers [i ]
1008
993
hidden_states , residual = layer (
@@ -1012,13 +997,15 @@ def forward(
1012
997
attn_metadata )
1013
998
1014
999
moe_start_layer = self .start_layer + num_normal_layers
1015
- hidden_states , residual = self ._forward_ms_layers (
1016
- positions = positions ,
1017
- hidden_states = hidden_states ,
1018
- residual = residual ,
1019
- moe_start_layer = moe_start_layer ,
1020
- kv_caches = kv_caches ,
1021
- )
1000
+ if moe_start_layer != self .end_layer :
1001
+ # if we enable multistream/dbo, process sparse layers here
1002
+ hidden_states , residual = self ._forward_ms_layers (
1003
+ positions = positions ,
1004
+ hidden_states = hidden_states ,
1005
+ residual = residual ,
1006
+ moe_start_layer = moe_start_layer ,
1007
+ kv_caches = kv_caches ,
1008
+ )
1022
1009
1023
1010
if not get_pp_group ().is_last_rank :
1024
1011
return IntermediateTensors ({
@@ -1046,10 +1033,11 @@ def can_run_ms(self):
1046
1033
attn_metadata .query_lens ):
1047
1034
return False
1048
1035
1049
- if self .multistream_config is None :
1036
+ if self .multistream_metadata is None :
1050
1037
return False
1051
1038
# check whether the total tokens exceed the threshold
1052
- if attn_metadata .num_actual_tokens < self .multistream_config .min_total_tokens_to_split :
1039
+ ms_config = self .multistream_metadata .ms_config
1040
+ if ms_config is None or attn_metadata .num_actual_tokens < ms_config .min_total_tokens_to_split :
1053
1041
return False
1054
1042
return True
1055
1043
0 commit comments