diff --git a/tests/ut/torchair/ops/test_torchair_fused_moe.py b/tests/ut/torchair/ops/test_torchair_fused_moe.py index ec2d9e7ee2..155ee78eb5 100644 --- a/tests/ut/torchair/ops/test_torchair_fused_moe.py +++ b/tests/ut/torchair/ops/test_torchair_fused_moe.py @@ -72,6 +72,7 @@ def mock_dist_env(mocker: MockerFixture): return_value=MagicMock( torchair_graph_config=MagicMock(enabled=False), enable_multistream_moe=False, + enable_shared_expert_dp=False, expert_map_path=None )), \ patch('vllm_ascend.torchair.ops.torchair_fused_moe.determine_expert_map', diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index edd80d8bde..10cfcdd6db 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -60,6 +60,8 @@ def __init__( self.torchair_compiled_models = {} # type: ignore self.torchair_graph_enabled = get_ascend_config( ).torchair_graph_config.enabled + self.enable_shared_expert_dp = get_ascend_config( + ).enable_shared_expert_dp # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + @@ -79,7 +81,9 @@ def load_model(self, model) -> None: with set_default_torch_dtype( draft_model_config.dtype), set_current_vllm_config( self.vllm_config): - if self.torchair_graph_enabled: + if self.torchair_graph_enabled or ( + self.enable_shared_expert_dp + and self.vllm_config.model_config.use_mla): self.model = TorchairDeepSeekMTP( vllm_config=self.vllm_config).to(target_device) else: diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index 70877cb919..65c0288464 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -814,7 +814,7 @@ def forward( attn_metadata = get_forward_context().attn_metadata if attn_metadata is not None and isinstance(attn_metadata, dict): - attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] + attn_metadata = next(iter(attn_metadata.values()), None) if attn_metadata is not None: num_tokens = attn_metadata.num_actual_tokens else: diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 2377b5082f..0c85c85086 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -803,6 +803,7 @@ def __init__(self, moe: FusedMoEConfig = None): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp try: device_group = get_mc2_group().device_group @@ -884,6 +885,8 @@ def apply( topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) fused_moe_state = get_forward_context().fused_moe_state + if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: + fused_moe_state = FusedMoEState.All2All if fused_moe_state == FusedMoEState.MC2: return torchair_fused_experts_with_mc2( @@ -1155,6 +1158,8 @@ def forward(self, forward_context = get_forward_context() fused_moe_state = forward_context.fused_moe_state mc2_mask = forward_context.mc2_mask + if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: + fused_moe_state = FusedMoEState.All2All # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. quantized_x_for_share, dynamic_scale_for_share = None, None from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index 832cbc56a0..34624fb2e7 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -823,6 +823,7 @@ def __init__(self): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp try: device_group = get_mc2_group().device_group @@ -936,6 +937,8 @@ def apply( ) fused_moe_state = get_forward_context().fused_moe_state + if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2: + fused_moe_state = FusedMoEState.All2All shared_gate_up, shared_dequant_scale = None, None if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: with npu_stream_switch("moe_secondary", 0): diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 0c715fd245..b39495ed7d 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -87,8 +87,8 @@ def _sync_metadata_across_dp( ) -> tuple[int, Optional[torch.Tensor], bool, bool]: """Override from NPUModelRunner to pad num_tokens""" if self.enable_shared_expert_dp: - return super()._sync_metadata_across_dp(num_tokens, with_prefill, - enable_dbo) + # Padding is not required for shared_expert_dp cases in eager mode. + return num_tokens, None, with_prefill, enable_dbo if self.dp_size == 1: if not with_prefill: maybe_padded_num_tokens = self.select_torchair_padded_batch_size( diff --git a/vllm_ascend/torchair/torchair_worker.py b/vllm_ascend/torchair/torchair_worker.py index 2c8c4584f8..dbee800354 100644 --- a/vllm_ascend/torchair/torchair_worker.py +++ b/vllm_ascend/torchair/torchair_worker.py @@ -35,26 +35,25 @@ def determine_available_memory(self) -> int: ascend_config = get_ascend_config() if ascend_config.enable_shared_expert_dp: return available_kv_cache_memory - if ascend_config.torchair_graph_config.use_cached_kv_cache_bytes and check_kv_cache_bytes_cache_exist( - ): - old_kv_cache_bytes = read_kv_cache_bytes_from_file( - torch.distributed.get_rank()) - if 0 < old_kv_cache_bytes <= available_kv_cache_memory: - logger.info( - f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}" - ) - self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes - return old_kv_cache_bytes - else: - logger.info( - "Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache" - ) - delete_torchair_cache_file() - bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE - available_kv_cache_memory -= bytes_floating_tolerance - logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}") - self.model_runner.new_kv_cache_bytes = available_kv_cache_memory - + if ascend_config.torchair_graph_config.use_cached_kv_cache_bytes: + if check_kv_cache_bytes_cache_exist(): + old_kv_cache_bytes = read_kv_cache_bytes_from_file( + torch.distributed.get_rank()) + if 0 < old_kv_cache_bytes <= available_kv_cache_memory: + logger.info( + f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}" + ) + self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes + return old_kv_cache_bytes + else: + logger.info( + "Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache" + ) + delete_torchair_cache_file() + bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE + available_kv_cache_memory -= bytes_floating_tolerance + logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}") + self.model_runner.new_kv_cache_bytes = available_kv_cache_memory return available_kv_cache_memory def init_device(self):