Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 14 additions & 28 deletions tests/ut/ops/test_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,6 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
mock_get_forward_context):

mock_forward_context = MagicMock()
mock_forward_context.with_quant = True
mock_forward_context.fused_moe_state = FusedMoEState.MC2
mock_get_forward_context.return_value = mock_forward_context

Expand Down Expand Up @@ -587,10 +586,10 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=None)
topk_scales=None,
with_quant=True)

mock_get_forward_context.assert_called()
self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(mock_forward_context.fused_moe_state,
FusedMoEState.MC2)

Expand All @@ -602,19 +601,15 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,

self.assertEqual(result.dtype, torch.bfloat16)

@patch('vllm_ascend.ops.fused_moe.get_forward_context')
@patch('vllm_ascend.ops.fused_moe.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_without_quantization(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context):

mock_forward_context = MagicMock()
mock_forward_context.with_quant = False
mock_get_forward_context.return_value = mock_forward_context

def test_unified_apply_mlp_without_quantization(self,
mock_npu_dynamic_quant,
mock_npu_swiglu,
mock_npu_grouped_matmul,
mock_is_310p):
mock_is_310p.return_value = False

mock_npu_grouped_matmul.side_effect = [[
Expand All @@ -639,10 +634,8 @@ def test_unified_apply_mlp_without_quantization(
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales)

mock_get_forward_context.assert_called()
self.assertFalse(mock_forward_context.with_quant)
topk_scales=topk_scales,
with_quant=False)

self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
Expand Down Expand Up @@ -698,10 +691,10 @@ def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
group_list_type=1,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=None)
topk_scales=None,
with_quant=True)

mock_get_forward_context.assert_called()
self.assertTrue(mock_forward_context.with_quant)

self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
Expand All @@ -710,19 +703,13 @@ def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.bfloat16)

@patch('vllm_ascend.ops.fused_moe.get_forward_context')
@patch('vllm_ascend.ops.fused_moe.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_without_quantization_310p(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context):

mock_forward_context = MagicMock()
mock_forward_context.with_quant = False
mock_get_forward_context.return_value = mock_forward_context

mock_npu_grouped_matmul, mock_is_310p):
mock_is_310p.return_value = True

mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
Expand Down Expand Up @@ -750,10 +737,9 @@ def test_unified_apply_mlp_without_quantization_310p(
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales)
topk_scales=topk_scales,
with_quant=False)

mock_get_forward_context.assert_called()
self.assertFalse(mock_forward_context.with_quant)
mock_is_310p.assert_called_once()

self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
Expand Down
13 changes: 6 additions & 7 deletions tests/ut/ops/test_token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ def test_token_dispatch_with_quant(self):
"max_num_tokens": 100,
"ep_size": 2,
"num_experts": 128,
"with_quant": True,
}
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)

Expand Down Expand Up @@ -460,8 +459,7 @@ def test_token_combine(self):
def test_token_dispatch_with_quant(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2,
with_quant=True)
num_local_experts=2)

hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
Expand All @@ -476,7 +474,8 @@ def test_token_dispatch_with_quant(self):
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map)
expert_map=expert_map,
with_quant=True)

self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
Expand All @@ -486,8 +485,7 @@ def test_token_dispatch_with_quant(self):
def test_token_dispatch_with_quant_no_active_tokens(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2,
with_quant=True)
num_local_experts=2)

self.mock_repeat_interleave.return_value = torch.tensor(
[], dtype=torch.long)
Expand All @@ -505,7 +503,8 @@ def test_token_dispatch_with_quant_no_active_tokens(self):
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map)
expert_map=expert_map,
with_quant=True)

self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
Expand Down
2 changes: 0 additions & 2 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ def set_ascend_forward_context(
forward_context.fused_moe_state = fused_moe_state
forward_context.in_profile_run = in_profile_run

with_quant = vllm_config.quant_config is not None
forward_context.with_quant = with_quant
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
get_token_dispatcher
dispatcher_name = get_dispatcher_name(ep_size, with_prefill)
Expand Down
42 changes: 22 additions & 20 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,19 +408,19 @@ def unquant_apply_mlp(
return hidden_states


def unified_apply_mlp(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
if get_forward_context().with_quant:
def unified_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
topk_scales: Optional[torch.Tensor] = None,
with_quant: bool = False) -> torch.Tensor:
if with_quant:
return quant_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
Expand Down Expand Up @@ -457,7 +457,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
shared_gate_up: Optional[Any] = None,
shared_dequant_scale: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
token_dispatcher = get_forward_context().token_dispatcher

results = token_dispatcher.token_dispatch(
Expand All @@ -472,7 +473,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input)
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=with_quant)

expert_output = unified_apply_mlp(
hidden_states=results["hidden_states"],
Expand All @@ -485,7 +487,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
group_list_type=results.get("group_list_type"),
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=results.get("topk_scales"))
topk_scales=results.get("topk_scales"),
with_quant=with_quant)
final_hidden_states = token_dispatcher.token_combine(expert_output)
return final_hidden_states

Expand Down Expand Up @@ -577,7 +580,8 @@ def apply(
expert_map=expert_map,
shared_experts=shared_experts,
mc2_mask=kwargs.get(
"mc2_mask", None))
"mc2_mask", None),
with_quant=False)


class AscendFusedMoE(FusedMoE):
Expand Down Expand Up @@ -761,16 +765,14 @@ def __init__(

ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)
with_quant = quant_config is not None
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
setup_token_dispatchers
setup_token_dispatchers(
ep_size,
top_k=self.top_k,
num_experts=self.global_num_experts,
num_global_redundant_experts=self.global_redundant_expert_num,
num_local_experts=self.local_num_experts,
with_quant=with_quant)
num_local_experts=self.local_num_experts)

def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
Expand Down
22 changes: 16 additions & 6 deletions vllm_ascend/ops/moe_dispatcher/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,6 @@ def __init__(self, **kwargs) -> None:
"""
self.top_k = kwargs.get("top_k", 0)
self.num_experts = kwargs.get("num_experts", 0)
self.with_quant = kwargs.get("with_quant", False)

@property
def ep_group(self):
Expand Down Expand Up @@ -518,7 +517,8 @@ def token_dispatch(self,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
raise NotImplementedError("Dispatch function not implemented.")

@abstractmethod
Expand Down Expand Up @@ -555,6 +555,7 @@ def __init__(self, **kwargs):
self.topk_weights = None
self.shared_experts = None
self.mc2_mask = None
self.with_quant = False

def get_dispatch_mc2_kwargs(
self,
Expand Down Expand Up @@ -615,7 +616,9 @@ def token_dispatch(self,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.with_quant = with_quant
self.expert_map = expert_map
self.topk_ids = topk_ids
self.topk_weights = topk_weights
Expand Down Expand Up @@ -738,6 +741,7 @@ def __init__(self, **kwargs):
self.expert_map = None
self.topk_weights = None
self.topk_ids = None
self.with_quant = False

def token_dispatch(self,
hidden_states: torch.Tensor,
Expand All @@ -751,7 +755,9 @@ def token_dispatch(self,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.with_quant = with_quant
self.original_shape = hidden_states.shape

num_tokens = hidden_states.shape[:-1].numel()
Expand Down Expand Up @@ -922,7 +928,8 @@ def token_dispatch(self,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2
Expand Down Expand Up @@ -980,6 +987,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.with_quant = False
self.num_local_experts = kwargs.get("num_local_experts", 0)
self.num_global_redundant_experts = kwargs.get(
"num_global_redundant_experts", 0)
Expand Down Expand Up @@ -1032,7 +1040,9 @@ def token_dispatch(self,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.with_quant = with_quant
self.hidden_shape = hidden_states.shape
self.topk_weights = topk_weights
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"
Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/quantization/w4a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ def apply(
shared_experts=shared_experts,
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
mc2_mask=kwargs.get("mc2_mask", None))
mc2_mask=kwargs.get("mc2_mask", None),
with_quant=True)

def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
group_num, k, n = weight.shape
Expand Down
3 changes: 2 additions & 1 deletion vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,8 @@ def apply(
shared_experts=shared_experts,
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
mc2_mask=kwargs.get("mc2_mask", None))
mc2_mask=kwargs.get("mc2_mask", None),
with_quant=True)

def process_weights_after_loading(self, layer):
if self.transpose_weight:
Expand Down
Loading