Skip to content

Commit d36f590

Browse files
committed
[fix]: optimize the dbo execution and fix minor issues
Signed-off-by: zhuohuan <zxdu1997@gmail.com>
1 parent 29bd484 commit d36f590

File tree

6 files changed

+51
-68
lines changed

6 files changed

+51
-68
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 42 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer,
7575
MultiStreamPreTransformerLayer)
7676
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
77+
MultiStreamMetadata,
7778
MultiStreamStepMetadata,
7879
make_multistream_metadata_ds)
7980
from vllm_ascend.multistream.ms_split import compute_split_seq_index
@@ -698,13 +699,12 @@ def _forward_ms_layer(
698699
shared_outputs = []
699700
router_logits = []
700701
chunk_hidden_states = []
701-
''' block 1 : attention
702-
block 2 : attn tp communication, currently we switch to the comm stream
703-
in tensor_model_parallel_all_reduce;
704-
the attn computation of microbatch 1 can be overlapped with the moe
705-
communication in the previous layer, and the attn computation of microbatch
706-
2 can be overlapped with the attn communication of microbatch 1
707-
'''
702+
703+
# block 1 : attention
704+
# block 2 : attn tp communication
705+
# the attn computation of microbatch 1 can be overlapped with the moe
706+
# communication in the previous layer, and the attn computation of microbatch 2
707+
# can be overlapped with the attn communication of microbatch 1
708708
for i in range(num_micro_batchs):
709709
# wait last layer moe finishing communication
710710
ms_metadata.try_wait_event(layer_index - 1, i,
@@ -731,10 +731,10 @@ def _forward_ms_layer(
731731
hidden_states[i], residual[i] = self._forward_ms_op_attn(
732732
positions[i], hidden_states[i], residual[i], kv_cache,
733733
attn_metadata[i])
734-
''' block 3 : shared experts
735-
if there is an allreduce ops in shared expert, we can overlap it with the computation of the
736-
shared expert for next microbatch or moe gating
737-
'''
734+
735+
# block 3 : shared experts
736+
# if there is an allreduce ops in shared expert, we can overlap it with the computation of the
737+
# shared expert for next microbatch or moe gating
738738
for i in range(num_micro_batchs):
739739
ms_metadata.try_wait_event(layer_index, i,
740740
MSEventKey.ATTN_AR_FINISH)
@@ -763,7 +763,6 @@ def _forward_ms_layer(
763763

764764
# block 4 : moe
765765
for i in range(num_micro_batchs):
766-
#ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH)
767766
# when profile runs, force experts to load balanced tokens
768767
# to avoid high memory consumption on a single rank.
769768
# TODO: need a better flag to indicate whether in profile run or not.
@@ -776,13 +775,6 @@ def _forward_ms_layer(
776775
enable_force_load_balance = False
777776

778777
if self.mlp.tp_size > 1:
779-
#if num_tokens[i] < self.mlp.tp_size:
780-
# target_size = self.mlp.tp_size
781-
# new_hidden_states = torch.empty([target_size, hidden_dims[i]],
782-
# dtype=hidden_states[i].dtype,
783-
# device=hidden_states[i].device)
784-
# new_hidden_states[:num_tokens[i]] = hidden_states[i]
785-
# hidden_states[i] = new_hidden_states
786778
num_token, _ = hidden_states[i].shape
787779
padded_num_tokens = (self.mlp.tp_size - num_token %
788780
self.mlp.tp_size) % self.mlp.tp_size
@@ -805,18 +797,12 @@ def _forward_ms_layer(
805797
else:
806798
real_top_k = self.mlp.experts.top_k
807799

808-
if VLLM_ENABLE_MC2 and not is_prefill:
809-
...
810-
811800
hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp(
812801
local_hidden_states, router_logits[i], is_prefill, real_top_k,
813802
enable_force_load_balance)
814803

815-
if VLLM_ENABLE_MC2 and not is_prefill:
816-
...
817-
''' the following kernels will be submitted to the comm stream to overlap the computation of the
818-
moe computation of next microbatch and the attn computation of next layer
819-
'''
804+
# the following kernels will be submitted to the comm stream to overlap the computation of the
805+
# moe computation of next microbatch and the attn computation of next layer
820806
context = MultiStreamStepMetadata(
821807
comm_stream=ms_metadata.communicate_stream,
822808
before_comm_event=ms_metadata.ms_events[layer_index][i][
@@ -826,15 +812,14 @@ def _forward_ms_layer(
826812
)
827813
context.before_comm_event.record()
828814
with torch.npu.stream(ms_metadata.communicate_stream):
829-
#with set_multistream_context(context, i):
830815
context.before_comm_event.wait()
831816
if self.mlp.experts.reduce_results and (
832817
self.mlp.experts.tp_size > 1
833818
or self.mlp.experts.ep_size > 1):
834819
hidden_states[i] = tensor_model_parallel_all_reduce(
835820
hidden_states[i])
836821
context.after_comm_event.record()
837-
# check here
822+
838823
hidden_states[
839824
i] = hidden_states[i] * self.mlp.routed_scaling_factor
840825
context = MultiStreamStepMetadata(
@@ -959,21 +944,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
959944
["hidden_states", "residual"], config.hidden_size))
960945

961946
# tbo related members
962-
self.multistream_config: Optional[MultiStreamConfig] = None
963-
if VLLM_ENABLE_DBO:
964-
self.multistream_config = MultiStreamConfig()
965-
966947
self.use_mla = model_config.use_mla
967-
self.multistream_metadata = make_multistream_metadata_ds(
968-
start_layer=self.start_layer + self.first_k_dense_replace,
969-
end_layer=self.end_layer,
970-
causal_lm=getattr(config, "causal_lm", True),
971-
multistream_config=self.multistream_config,
972-
)
973-
self.ms_pre_layer = MultiStreamPreTransformerLayer(
974-
self.multistream_metadata)
975-
self.ms_post_layer = MultiStreamPostTransformerLayer(
976-
self.multistream_metadata)
948+
multistream_config: Optional[MultiStreamConfig] = None
949+
self.multistream_metadata: Optional[MultiStreamMetadata] = None
950+
if VLLM_ENABLE_DBO:
951+
multistream_config = MultiStreamConfig()
952+
self.multistream_metadata = make_multistream_metadata_ds(
953+
start_layer=self.start_layer + self.first_k_dense_replace,
954+
end_layer=self.end_layer,
955+
causal_lm=getattr(config, "causal_lm", True),
956+
multistream_config=multistream_config,
957+
)
958+
self.ms_pre_layer = MultiStreamPreTransformerLayer(
959+
self.multistream_metadata)
960+
self.ms_post_layer = MultiStreamPostTransformerLayer(
961+
self.multistream_metadata)
977962

978963
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
979964
return self.embed_tokens(input_ids)
@@ -999,10 +984,10 @@ def forward(
999984
residual = intermediate_tensors["residual"]
1000985

1001986
num_normal_layers = (self.first_k_dense_replace
1002-
if self.multistream_config is not None
987+
if self.multistream_metadata is not None
1003988
and self.can_run_ms() else self.end_layer -
1004989
self.start_layer)
1005-
# if we enable multistream/dbo, only process dense layers here
990+
1006991
for i in range(self.start_layer, self.start_layer + num_normal_layers):
1007992
layer = self.layers[i]
1008993
hidden_states, residual = layer(
@@ -1012,13 +997,15 @@ def forward(
1012997
attn_metadata)
1013998

1014999
moe_start_layer = self.start_layer + num_normal_layers
1015-
hidden_states, residual = self._forward_ms_layers(
1016-
positions=positions,
1017-
hidden_states=hidden_states,
1018-
residual=residual,
1019-
moe_start_layer=moe_start_layer,
1020-
kv_caches=kv_caches,
1021-
)
1000+
if moe_start_layer != self.end_layer:
1001+
# if we enable multistream/dbo, process sparse layers here
1002+
hidden_states, residual = self._forward_ms_layers(
1003+
positions=positions,
1004+
hidden_states=hidden_states,
1005+
residual=residual,
1006+
moe_start_layer=moe_start_layer,
1007+
kv_caches=kv_caches,
1008+
)
10221009

10231010
if not get_pp_group().is_last_rank:
10241011
return IntermediateTensors({
@@ -1046,10 +1033,11 @@ def can_run_ms(self):
10461033
attn_metadata.query_lens):
10471034
return False
10481035

1049-
if self.multistream_config is None:
1036+
if self.multistream_metadata is None:
10501037
return False
10511038
# check whether the total tokens exceed the threshold
1052-
if attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split:
1039+
ms_config = self.multistream_metadata.ms_config
1040+
if ms_config is None or attn_metadata.num_actual_tokens < ms_config.min_total_tokens_to_split:
10531041
return False
10541042
return True
10551043

vllm_ascend/multistream/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from enum import Enum
33

44

5-
# TODO: move this part to vllm
65
class MSEventKey(Enum):
76
ATTN_COM_FINISH = 0
87
ATTN_AR_FINISH = 1

vllm_ascend/multistream/context.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from contextlib import contextmanager
22
from typing import Any
33

4-
# TODO: move this part to vllm
5-
64
_ms_comm_context: Any = None
75
_cur_micro_batch_num: int = -1
86
_ms_layer_index_context: int = -1

vllm_ascend/multistream/decorator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from .context import (get_multistream_layer_context,
44
get_multistream_microbatch_context)
55

6-
# TODO: move this part to vllm
7-
86
logger = init_logger(__name__)
97

108

vllm_ascend/multistream/layers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from .metadata import MultiStreamMetadata
1111

1212

13-
# TODO: move this part to vllm
1413
class MultiStreamPreTransformerLayer(torch.nn.Module):
1514

1615
def __init__(self, multistream_metadata: MultiStreamMetadata):

vllm_ascend/multistream/metadata.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
from typing import Dict, List, Optional, Tuple, Union
33

44
import torch
5-
from vllm.attention.backends.abstract import AttentionMetadata
65
from vllm.sequence import IntermediateTensors
76

7+
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
8+
89
from .base import MSAttentionMetadataSplitConfig, MSEventKey
910

1011

@@ -111,19 +112,19 @@ def try_record_event(self, layer_index: int, micro_batch_index: int,
111112

112113
def split_micro_batch(
113114
self,
114-
attn_metadata: "AttentionMetadata",
115+
attn_metadata: "AscendMLAMetadata",
115116
intput_tensors: List[torch.Tensor],
116117
intermediate_tensors: Optional[IntermediateTensors] = None,
117118
intermediate_tensors_keys: Optional[List[str]] = None,
118-
) -> Tuple[bool, Union[AttentionMetadata, List[AttentionMetadata]], Union[
119+
) -> Tuple[bool, List[AscendMLAMetadata], Union[
119120
List[torch.Tensor], List[List[torch.Tensor]]], Union[
120121
IntermediateTensors, List[IntermediateTensors]]]:
121-
attn_metadata = attn_metadata.split_metadata_for_multistream(
122+
attn_metadata_list = attn_metadata.split_metadata_for_multistream(
122123
self.ms_split_config)
123-
if len(attn_metadata) == 1:
124-
return False, attn_metadata[
124+
if len(attn_metadata_list) == 1:
125+
return False, attn_metadata_list[
125126
0], intput_tensors, intermediate_tensors
126-
split_index = attn_metadata[0].slot_mapping.shape[0]
127+
split_index = attn_metadata_list[0].slot_mapping.shape[0]
127128
input_tensors = split_micro_batches_tensors(intput_tensors,
128129
split_index)
129130
if intermediate_tensors is not None:
@@ -134,7 +135,7 @@ def split_micro_batch(
134135
IntermediateTensors(inter_tensors)
135136
for inter_tensors in inter_tensors_list
136137
]
137-
return True, attn_metadata, input_tensors, intermediate_tensors
138+
return True, attn_metadata_list, input_tensors, intermediate_tensors
138139

139140
def merge_micro_batches(
140141
self, input_tensors: Union[List[torch.Tensor],

0 commit comments

Comments
 (0)