@@ -1786,16 +1786,18 @@ def _pool(
17861786 )
17871787
17881788 def _initialize_mc2 (self ):
1789- """Initialization of mc2 -related parameters."""
1789+ """Initialization of MC2 -related parameters and verify the validity ."""
17901790
1791- self .mc2_tokens_capacity = 0
17921791 self .reserved_mc2_mask = None
17931792
17941793 # For models contains no moe modules, we simply skip the
1795- # initialization of mc2 .
1794+ # initialization of MC2 .
17961795 if not is_moe_model (self .vllm_config ):
1796+ self .mc2_tokens_capacity = 0
17971797 return
17981798
1799+ # For moe models, we first assume that this model will use MC2, and compute
1800+ # self.mc2_tokens_capacity.
17991801 # NOTE: To be clear, we need to make sure that during graph capture, the number of
18001802 # tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
18011803 # the max number of tokens in graph is min(max_num_seqs * 2, 512).
@@ -1806,11 +1808,19 @@ def _initialize_mc2(self):
18061808 tp_size = self .parallel_config .tensor_parallel_size
18071809 # Use integer arithmetic for ceiling division.
18081810 num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1 ) // tp_size
1809- mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
1811+ # A larger number of input tokens for mc2 introduce much more HBM consumption.
1812+ # Therefore, self.mc2_tokens_capacity is set to maximum of possible input_tokens
1813+ # in graph or decode cases.
1814+ self .mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
18101815
1811- # Additional check for MC2 restrictions on specific hardwares.
1812- if self ._select_moe_comm_method (
1813- mc2_tokens_capacity ) == MoECommType .MC2 :
1816+ # We then check whether it is really necessary to run MC2
1817+ # and verify the validation of self.mc2_tokens_capacity on
1818+ # different hardwares.
1819+ if self ._select_moe_comm_method (self .mc2_tokens_capacity ,
1820+ with_prefill = False ) == MoECommType .MC2 :
1821+
1822+ # MC2 will be applied in runtime. Therefore we check whether
1823+ # the number of input tokens exceed the limit of MC2.
18141824
18151825 soc_version = get_ascend_soc_version ()
18161826 limit = None
@@ -1826,14 +1836,17 @@ def _initialize_mc2(self):
18261836 f"(current: { self .max_num_reqs } ) or increase `tp_size` (current: { tp_size } )."
18271837 )
18281838
1829- # Only set these parameters if mc2 is actually needed
1830- # and the above check is passed.
1831- self .mc2_tokens_capacity = mc2_tokens_capacity
1839+ # All verification is passed, we finally initialize self.reserved_mc2_mask.
18321840 self .reserved_mc2_mask = torch .zeros (
18331841 self .mc2_tokens_capacity ,
18341842 dtype = torch .bool ,
18351843 device = self .device ,
18361844 )
1845+ else :
1846+ # MC2 is still not needed for this moe model on certain hardware
1847+ # (such as a single node of A2). self.mc2_tokens_capacity falls
1848+ # back to 0.
1849+ self .mc2_tokens_capacity = 0
18371850
18381851 def _select_moe_comm_method (self , num_tokens : int ,
18391852 with_prefill : bool ) -> MoECommType :
0 commit comments