Skip to content

Commit 9d880f5

Browse files
[Misc] Turn MOE_DP_CHUNK_SIZE into an env var (#19506)
1 parent 017ef64 commit 9d880f5

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

vllm/envs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
VLLM_DP_SIZE: int = 1
113113
VLLM_DP_MASTER_IP: str = ""
114114
VLLM_DP_MASTER_PORT: int = 0
115+
VLLM_MOE_DP_CHUNK_SIZE: int = 256
115116
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
116117
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
117118
VLLM_V0_USE_OUTLINES_CACHE: bool = False
@@ -773,6 +774,14 @@ def get_vllm_port() -> Optional[int]:
773774
"VLLM_DP_MASTER_PORT":
774775
lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")),
775776

777+
# In the context of executing MoE models with Data-Parallel, Expert-Parallel
778+
# and Batched All-to-All dispatch/combine kernels, VLLM_MOE_DP_CHUNK_SIZE
779+
# dictates the quantum of tokens that can be dispatched from a DP
780+
# rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE
781+
# units.
782+
"VLLM_MOE_DP_CHUNK_SIZE":
783+
lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")),
784+
776785
# Randomize inputs during dummy runs when using Data Parallel
777786
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS":
778787
lambda: os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1",

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,6 @@
6161
fused_moe_pallas = None # type: ignore
6262
logger = init_logger(__name__)
6363

64-
# Note: this limit is somewhat arbitrary and might be changed later.
65-
# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim.
66-
MOE_DP_CHUNK_SIZE = 256
67-
6864

6965
@dataclass
7066
class FusedMoEParallelConfig:
@@ -218,7 +214,12 @@ class MoEConfig:
218214
# TODO: add more quantization params, blocked, per-token, etc.
219215
block_size: int = 128
220216

221-
max_num_tokens: int = MOE_DP_CHUNK_SIZE
217+
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
218+
219+
def __post_init__(self):
220+
if self.dp_size > 1:
221+
logger.debug("Using MOEConfig::max_num_tokens=%d",
222+
self.max_num_tokens)
222223

223224
@property
224225
def tp_size(self):
@@ -913,7 +914,7 @@ def __init__(
913914
moe_parallel_config=self.moe_parallel_config,
914915
in_dtype=params_dtype,
915916
quant_dtype=quant_dtype,
916-
max_num_tokens=MOE_DP_CHUNK_SIZE,
917+
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
917918
)
918919
self.moe_config = moe
919920
self.quant_config = quant_config
@@ -952,12 +953,12 @@ def __init__(
952953
or self.moe_parallel_config.use_deepep_ll_kernels):
953954
act_dtype = vllm_config.model_config.dtype
954955
self.batched_hidden_states = torch.zeros(
955-
(MOE_DP_CHUNK_SIZE, self.hidden_size),
956+
(envs.VLLM_MOE_DP_CHUNK_SIZE, self.hidden_size),
956957
dtype=act_dtype,
957958
device=torch.cuda.current_device())
958959

959960
self.batched_router_logits = torch.zeros(
960-
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
961+
(envs.VLLM_MOE_DP_CHUNK_SIZE, self.global_num_experts),
961962
dtype=act_dtype,
962963
device=torch.cuda.current_device())
963964

0 commit comments

Comments
 (0)