Skip to content

Commit f74b5d4

Browse files
committed
refine logic
Signed-off-by: Angazenn <supperccell@163.com>
1 parent 7c32e60 commit f74b5d4

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1786,16 +1786,18 @@ def _pool(
17861786
)
17871787

17881788
def _initialize_mc2(self):
1789-
"""Initialization of mc2-related parameters."""
1789+
"""Initialization of MC2-related parameters and verify the validity."""
17901790

1791-
self.mc2_tokens_capacity = 0
17921791
self.reserved_mc2_mask = None
17931792

17941793
# For models contains no moe modules, we simply skip the
1795-
# initialization of mc2.
1794+
# initialization of MC2.
17961795
if not is_moe_model(self.vllm_config):
1796+
self.mc2_tokens_capacity = 0
17971797
return
17981798

1799+
# For moe models, we first assume that this model will use MC2, and compute
1800+
# self.mc2_tokens_capacity.
17991801
# NOTE: To be clear, we need to make sure that during graph capture, the number of
18001802
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
18011803
# the max number of tokens in graph is min(max_num_seqs * 2, 512).
@@ -1806,11 +1808,19 @@ def _initialize_mc2(self):
18061808
tp_size = self.parallel_config.tensor_parallel_size
18071809
# Use integer arithmetic for ceiling division.
18081810
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
1809-
mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
1811+
# A larger number of input tokens for mc2 introduce much more HBM consumption.
1812+
# Therefore, self.mc2_tokens_capacity is set to maximum of possible input_tokens
1813+
# in graph or decode cases.
1814+
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
18101815

1811-
# Additional check for MC2 restrictions on specific hardwares.
1812-
if self._select_moe_comm_method(
1813-
mc2_tokens_capacity) == MoECommType.MC2:
1816+
# We then check whether it is really necessary to run MC2
1817+
# and verify the validation of self.mc2_tokens_capacity on
1818+
# different hardwares.
1819+
if self._select_moe_comm_method(self.mc2_tokens_capacity,
1820+
with_prefill=False) == MoECommType.MC2:
1821+
1822+
# MC2 will be applied in runtime. Therefore we check whether
1823+
# the number of input tokens exceed the limit of MC2.
18141824

18151825
soc_version = get_ascend_soc_version()
18161826
limit = None
@@ -1826,14 +1836,17 @@ def _initialize_mc2(self):
18261836
f"(current: {self.max_num_reqs}) or increase `tp_size` (current: {tp_size})."
18271837
)
18281838

1829-
# Only set these parameters if mc2 is actually needed
1830-
# and the above check is passed.
1831-
self.mc2_tokens_capacity = mc2_tokens_capacity
1839+
# All verification is passed, we finally initialize self.reserved_mc2_mask.
18321840
self.reserved_mc2_mask = torch.zeros(
18331841
self.mc2_tokens_capacity,
18341842
dtype=torch.bool,
18351843
device=self.device,
18361844
)
1845+
else:
1846+
# MC2 is still not needed for this moe model on certain hardware
1847+
# (such as a single node of A2). self.mc2_tokens_capacity falls
1848+
# back to 0.
1849+
self.mc2_tokens_capacity = 0
18371850

18381851
def _select_moe_comm_method(self, num_tokens: int,
18391852
with_prefill: bool) -> MoECommType:

0 commit comments

Comments
 (0)