From 8071af74fb422dcd6d1b399e704eb74e644d893d Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Fri, 29 Aug 2025 12:00:45 +0800 Subject: [PATCH] fix with_quant logic Signed-off-by: Pr0Wh1teGivee --- tests/ut/ops/test_fused_ops.py | 42 +++++++------------ tests/ut/ops/test_token_dispatcher.py | 13 +++--- vllm_ascend/ascend_forward_context.py | 2 - vllm_ascend/ops/fused_moe.py | 42 ++++++++++--------- .../ops/moe_dispatcher/token_dispatcher.py | 22 +++++++--- vllm_ascend/quantization/w4a8_dynamic.py | 3 +- vllm_ascend/quantization/w8a8_dynamic.py | 3 +- 7 files changed, 62 insertions(+), 65 deletions(-) diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 04a16591ee..8c4c7f416b 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -543,7 +543,6 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant, mock_get_forward_context): mock_forward_context = MagicMock() - mock_forward_context.with_quant = True mock_forward_context.fused_moe_state = FusedMoEState.MC2 mock_get_forward_context.return_value = mock_forward_context @@ -587,10 +586,10 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant, group_list_type=1, w1_scale_bias=None, w2_scale_bias=None, - topk_scales=None) + topk_scales=None, + with_quant=True) mock_get_forward_context.assert_called() - self.assertTrue(mock_forward_context.with_quant) self.assertEqual(mock_forward_context.fused_moe_state, FusedMoEState.MC2) @@ -602,19 +601,15 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant, self.assertEqual(result.dtype, torch.bfloat16) - @patch('vllm_ascend.ops.fused_moe.get_forward_context') @patch('vllm_ascend.ops.fused_moe.is_310p') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') - def test_unified_apply_mlp_without_quantization( - self, mock_npu_dynamic_quant, mock_npu_swiglu, - mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context): - - mock_forward_context = MagicMock() - mock_forward_context.with_quant = False - mock_get_forward_context.return_value = mock_forward_context - + def test_unified_apply_mlp_without_quantization(self, + mock_npu_dynamic_quant, + mock_npu_swiglu, + mock_npu_grouped_matmul, + mock_is_310p): mock_is_310p.return_value = False mock_npu_grouped_matmul.side_effect = [[ @@ -639,10 +634,8 @@ def test_unified_apply_mlp_without_quantization( group_list_type=1, w1_scale_bias=None, w2_scale_bias=None, - topk_scales=topk_scales) - - mock_get_forward_context.assert_called() - self.assertFalse(mock_forward_context.with_quant) + topk_scales=topk_scales, + with_quant=False) self.assertEqual(mock_npu_grouped_matmul.call_count, 2) mock_npu_swiglu.assert_called_once() @@ -698,10 +691,10 @@ def test_unified_apply_mlp_with_quantization_and_dynamic_scale( group_list_type=1, w1_scale_bias=w1_scale_bias, w2_scale_bias=w2_scale_bias, - topk_scales=None) + topk_scales=None, + with_quant=True) mock_get_forward_context.assert_called() - self.assertTrue(mock_forward_context.with_quant) self.assertEqual(mock_npu_grouped_matmul.call_count, 2) mock_npu_swiglu.assert_called_once() @@ -710,19 +703,13 @@ def test_unified_apply_mlp_with_quantization_and_dynamic_scale( self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.dtype, torch.bfloat16) - @patch('vllm_ascend.ops.fused_moe.get_forward_context') @patch('vllm_ascend.ops.fused_moe.is_310p') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') def test_unified_apply_mlp_without_quantization_310p( self, mock_npu_dynamic_quant, mock_npu_swiglu, - mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context): - - mock_forward_context = MagicMock() - mock_forward_context.with_quant = False - mock_get_forward_context.return_value = mock_forward_context - + mock_npu_grouped_matmul, mock_is_310p): mock_is_310p.return_value = True mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16) @@ -750,10 +737,9 @@ def test_unified_apply_mlp_without_quantization_310p( group_list_type=1, w1_scale_bias=None, w2_scale_bias=None, - topk_scales=topk_scales) + topk_scales=topk_scales, + with_quant=False) - mock_get_forward_context.assert_called() - self.assertFalse(mock_forward_context.with_quant) mock_is_310p.assert_called_once() self.assertEqual(mock_npu_grouped_matmul.call_count, 2) diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 77f40fae30..be0a4f97fc 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -263,7 +263,6 @@ def test_token_dispatch_with_quant(self): "max_num_tokens": 100, "ep_size": 2, "num_experts": 128, - "with_quant": True, } self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs) @@ -460,8 +459,7 @@ def test_token_combine(self): def test_token_dispatch_with_quant(self): self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2, num_experts=4, - num_local_experts=2, - with_quant=True) + num_local_experts=2) hidden_states = torch.randn(8, 16) topk_weights = torch.rand(8, 4) @@ -476,7 +474,8 @@ def test_token_dispatch_with_quant(self): topk_weights=topk_weights, topk_ids=topk_ids, row_idx=self.row_idx, - expert_map=expert_map) + expert_map=expert_map, + with_quant=True) self.assertIsNotNone(result["hidden_states"]) self.assertIsNotNone(result["group_list"]) @@ -486,8 +485,7 @@ def test_token_dispatch_with_quant(self): def test_token_dispatch_with_quant_no_active_tokens(self): self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2, num_experts=4, - num_local_experts=2, - with_quant=True) + num_local_experts=2) self.mock_repeat_interleave.return_value = torch.tensor( [], dtype=torch.long) @@ -505,7 +503,8 @@ def test_token_dispatch_with_quant_no_active_tokens(self): topk_weights=topk_weights, topk_ids=topk_ids, row_idx=self.row_idx, - expert_map=expert_map) + expert_map=expert_map, + with_quant=True) self.assertIsNotNone(result["hidden_states"]) self.assertIsNotNone(result["group_list"]) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 3e48cf7308..7ddbc82058 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -99,8 +99,6 @@ def set_ascend_forward_context( forward_context.fused_moe_state = fused_moe_state forward_context.in_profile_run = in_profile_run - with_quant = vllm_config.quant_config is not None - forward_context.with_quant = with_quant from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ get_token_dispatcher dispatcher_name = get_dispatcher_name(ep_size, with_prefill) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 24a1667bb5..5f85b36503 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -408,19 +408,19 @@ def unquant_apply_mlp( return hidden_states -def unified_apply_mlp( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - group_list: torch.Tensor, - dynamic_scale: torch.Tensor = None, - group_list_type: int = 1, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, - topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor: - if get_forward_context().with_quant: +def unified_apply_mlp(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + topk_scales: Optional[torch.Tensor] = None, + with_quant: bool = False) -> torch.Tensor: + if with_quant: return quant_apply_mlp(hidden_states=hidden_states, w1=w1, w1_scale=w1_scale, @@ -457,7 +457,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor, shared_gate_up: Optional[Any] = None, shared_dequant_scale: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False): + apply_router_weight_on_input: bool = False, + with_quant: bool = False): token_dispatcher = get_forward_context().token_dispatcher results = token_dispatcher.token_dispatch( @@ -472,7 +473,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor, shared_gate_up=shared_gate_up, shared_dequant_scale=shared_dequant_scale, mc2_mask=mc2_mask, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + with_quant=with_quant) expert_output = unified_apply_mlp( hidden_states=results["hidden_states"], @@ -485,7 +487,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor, group_list_type=results.get("group_list_type"), w1_scale_bias=w1_scale_bias, w2_scale_bias=w2_scale_bias, - topk_scales=results.get("topk_scales")) + topk_scales=results.get("topk_scales"), + with_quant=with_quant) final_hidden_states = token_dispatcher.token_combine(expert_output) return final_hidden_states @@ -577,7 +580,8 @@ def apply( expert_map=expert_map, shared_experts=shared_experts, mc2_mask=kwargs.get( - "mc2_mask", None)) + "mc2_mask", None), + with_quant=False) class AscendFusedMoE(FusedMoE): @@ -761,7 +765,6 @@ def __init__( ep_size = (get_ep_group().world_size if vllm_config.parallel_config.enable_expert_parallel else 1) - with_quant = quant_config is not None from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ setup_token_dispatchers setup_token_dispatchers( @@ -769,8 +772,7 @@ def __init__( top_k=self.top_k, num_experts=self.global_num_experts, num_global_redundant_experts=self.global_redundant_expert_num, - num_local_experts=self.local_num_experts, - with_quant=with_quant) + num_local_experts=self.local_num_experts) def naive_multicast(self, x: torch.Tensor, cu_tokens_across_dp_cpu: torch.Tensor): diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index c0d85bb6bd..a5ca03afd0 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -490,7 +490,6 @@ def __init__(self, **kwargs) -> None: """ self.top_k = kwargs.get("top_k", 0) self.num_experts = kwargs.get("num_experts", 0) - self.with_quant = kwargs.get("with_quant", False) @property def ep_group(self): @@ -518,7 +517,8 @@ def token_dispatch(self, shared_gate_up: Optional[torch.Tensor] = None, shared_dequant_scale: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False): + apply_router_weight_on_input: bool = False, + with_quant: bool = False): raise NotImplementedError("Dispatch function not implemented.") @abstractmethod @@ -555,6 +555,7 @@ def __init__(self, **kwargs): self.topk_weights = None self.shared_experts = None self.mc2_mask = None + self.with_quant = False def get_dispatch_mc2_kwargs( self, @@ -615,7 +616,9 @@ def token_dispatch(self, shared_gate_up: Optional[torch.Tensor] = None, shared_dequant_scale: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False): + apply_router_weight_on_input: bool = False, + with_quant: bool = False): + self.with_quant = with_quant self.expert_map = expert_map self.topk_ids = topk_ids self.topk_weights = topk_weights @@ -738,6 +741,7 @@ def __init__(self, **kwargs): self.expert_map = None self.topk_weights = None self.topk_ids = None + self.with_quant = False def token_dispatch(self, hidden_states: torch.Tensor, @@ -751,7 +755,9 @@ def token_dispatch(self, shared_gate_up: Optional[torch.Tensor] = None, shared_dequant_scale: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False): + apply_router_weight_on_input: bool = False, + with_quant: bool = False): + self.with_quant = with_quant self.original_shape = hidden_states.shape num_tokens = hidden_states.shape[:-1].numel() @@ -922,7 +928,8 @@ def token_dispatch(self, shared_gate_up: Optional[torch.Tensor] = None, shared_dequant_scale: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False): + apply_router_weight_on_input: bool = False, + with_quant: bool = False): self.apply_router_weight_on_input = apply_router_weight_on_input if self.apply_router_weight_on_input: assert (topk_weights.dim() == 2 @@ -980,6 +987,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): def __init__(self, **kwargs): super().__init__(**kwargs) + self.with_quant = False self.num_local_experts = kwargs.get("num_local_experts", 0) self.num_global_redundant_experts = kwargs.get( "num_global_redundant_experts", 0) @@ -1032,7 +1040,9 @@ def token_dispatch(self, shared_gate_up: Optional[torch.Tensor] = None, shared_dequant_scale: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False): + apply_router_weight_on_input: bool = False, + with_quant: bool = False): + self.with_quant = with_quant self.hidden_shape = hidden_states.shape self.topk_weights = topk_weights assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights" diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 207c5e1d18..72f956d1d2 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -308,7 +308,8 @@ def apply( shared_experts=shared_experts, shared_gate_up=shared_gate_up, shared_dequant_scale=shared_dequant_scale, - mc2_mask=kwargs.get("mc2_mask", None)) + mc2_mask=kwargs.get("mc2_mask", None), + with_quant=True) def process_scale(self, weight: torch.Tensor, scale, per_group_scale): group_num, k, n = weight.shape diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 1177af8507..8438b33690 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -406,7 +406,8 @@ def apply( shared_experts=shared_experts, shared_gate_up=shared_gate_up, shared_dequant_scale=shared_dequant_scale, - mc2_mask=kwargs.get("mc2_mask", None)) + mc2_mask=kwargs.get("mc2_mask", None), + with_quant=True) def process_weights_after_loading(self, layer): if self.transpose_weight: