|
61 | 61 | fused_moe_pallas = None # type: ignore
|
62 | 62 | logger = init_logger(__name__)
|
63 | 63 |
|
64 |
| -# Note: this limit is somewhat arbitrary and might be changed later. |
65 |
| -# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim. |
66 |
| -MOE_DP_CHUNK_SIZE = 256 |
67 |
| - |
68 | 64 |
|
69 | 65 | @dataclass
|
70 | 66 | class FusedMoEParallelConfig:
|
@@ -218,7 +214,12 @@ class MoEConfig:
|
218 | 214 | # TODO: add more quantization params, blocked, per-token, etc.
|
219 | 215 | block_size: int = 128
|
220 | 216 |
|
221 |
| - max_num_tokens: int = MOE_DP_CHUNK_SIZE |
| 217 | + max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE |
| 218 | + |
| 219 | + def __post_init__(self): |
| 220 | + if self.dp_size > 1: |
| 221 | + logger.debug("Using MOEConfig::max_num_tokens=%d", |
| 222 | + self.max_num_tokens) |
222 | 223 |
|
223 | 224 | @property
|
224 | 225 | def tp_size(self):
|
@@ -913,7 +914,7 @@ def __init__(
|
913 | 914 | moe_parallel_config=self.moe_parallel_config,
|
914 | 915 | in_dtype=params_dtype,
|
915 | 916 | quant_dtype=quant_dtype,
|
916 |
| - max_num_tokens=MOE_DP_CHUNK_SIZE, |
| 917 | + max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, |
917 | 918 | )
|
918 | 919 | self.moe_config = moe
|
919 | 920 | self.quant_config = quant_config
|
@@ -952,12 +953,12 @@ def __init__(
|
952 | 953 | or self.moe_parallel_config.use_deepep_ll_kernels):
|
953 | 954 | act_dtype = vllm_config.model_config.dtype
|
954 | 955 | self.batched_hidden_states = torch.zeros(
|
955 |
| - (MOE_DP_CHUNK_SIZE, self.hidden_size), |
| 956 | + (envs.VLLM_MOE_DP_CHUNK_SIZE, self.hidden_size), |
956 | 957 | dtype=act_dtype,
|
957 | 958 | device=torch.cuda.current_device())
|
958 | 959 |
|
959 | 960 | self.batched_router_logits = torch.zeros(
|
960 |
| - (MOE_DP_CHUNK_SIZE, self.global_num_experts), |
| 961 | + (envs.VLLM_MOE_DP_CHUNK_SIZE, self.global_num_experts), |
961 | 962 | dtype=act_dtype,
|
962 | 963 | device=torch.cuda.current_device())
|
963 | 964 |
|
|
0 commit comments