diff --git a/tests/ut/ops/test_common_fused_moe.py b/tests/ut/ops/test_common_fused_moe.py index 409a301e73..11058a051e 100644 --- a/tests/ut/ops/test_common_fused_moe.py +++ b/tests/ut/ops/test_common_fused_moe.py @@ -17,7 +17,7 @@ import torch from tests.ut.base import TestBase -from vllm_ascend.ops.common_fused_moe import fused_experts_moge +from vllm_ascend.ops.common_fused_moe import AscendFusedMoE, fused_experts_moge class TestFusedExpertsMoGE(TestBase): @@ -67,3 +67,39 @@ def test_fused_experts_moge(self): ) self.assertEqual(output.shape, (4, 128)) + + +class TestLoadWeight(TestBase): + + def test_load_w13_transpose(self): + with patch.object(AscendFusedMoE, "__init__", + lambda self, *args, **kwargs: None): + moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8) + moe.hidden_size = 8 + expert_data = torch.randn(128, 8) + loaded_weight = torch.randn(128, 4) + moe._load_w13(expert_data, 1, "w1", loaded_weight, 0) + + expert_data = torch.randn(8, 128) + loaded_weight = torch.randn(128, 4) + moe._load_w13(expert_data, 1, "w1", loaded_weight, 0) + + expert_data = torch.randn(128, 8) + loaded_weight = torch.randn(128, 4) + moe._load_w13(expert_data, 1, "w3", loaded_weight, 0) + + expert_data = torch.randn(8, 128) + loaded_weight = torch.randn(128, 4) + moe._load_w13(expert_data, 1, "w3", loaded_weight, 0) + + def test_load_w2_transpose(self): + with patch.object(AscendFusedMoE, "__init__", + lambda self, *args, **kwargs: None): + moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8) + expert_data = torch.randn(128, 4) + loaded_weight = torch.randn(128, 8) + moe._load_w2(expert_data, 1, loaded_weight, 0) + + expert_data = torch.randn(4, 128) + loaded_weight = torch.randn(128, 8) + moe._load_w2(expert_data, 1, loaded_weight, 0) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 3142bc8323..be11a39b32 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -133,6 +133,7 @@ def unquantized_fused_moe_init_func(self, *args, **kwargs): self.use_aclgraph = (vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager) + self.transpose = True def forward_oot_v01011( @@ -260,13 +261,22 @@ def forward_oot( def process_weights_after_loading(self, layer): super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer) - w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose( - 1, 2).contiguous() - layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False) + if self.transpose: + w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose( + 1, 2).contiguous() + layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False) - w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose( - 1, 2).contiguous() - layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) + w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose( + 1, 2).contiguous() + layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) + + self.transpose = False + else: + w13_data = self._maybe_pad_weight(layer.w13_weight.data) + layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False) + + w2_data = self._maybe_pad_weight(layer.w2_weight.data) + layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) if not is_310p(): layer.w13_weight.data = torch_npu.npu_format_cast( @@ -357,12 +367,11 @@ def __init__( num_redundant_experts, has_bias, ) - setup_token_dispatchers(self.moe_config.ep_size, top_k=self.top_k, num_experts=self.global_num_experts, num_local_experts=self.local_num_experts) - + self.hidden_size = hidden_size self.moe_config.tp_group = get_tp_group() self.moe_config.dp_group = get_dp_group() self.moe_config.ep_group = get_ep_group() @@ -414,6 +423,61 @@ def forward_impl(self, hidden_states: torch.Tensor, return final_hidden_states + def transpose_weight(self, loaded_weight, expert_data, shard_dim): + # Ensure training and inference weight shapes match during RL weight updates + if ( + loaded_weight.shape[1] != expert_data.shape[1] and \ + loaded_weight.shape[0] != expert_data.shape[0] + ): + shard_dim = int(not shard_dim) + loaded_weight = loaded_weight.transpose(0, 1).contiguous() + return loaded_weight, shard_dim + + def _load_w13(self, + expert_data: torch.Tensor, + shard_dim: int, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False): + # Index the loaded weight for tp sharding. + # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + loaded_weight, shard_dim = self.transpose_weight( + loaded_weight, expert_data, shard_dim) + shard_size = expert_data.shape[shard_dim] // 2 + if not load_full: + loaded_weight = loaded_weight.narrow(shard_dim, + shard_size * tp_rank, + shard_size) + # Narrow parameter and load. + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + # w3, up_proj: Load into second logical weight of w13. + else: + assert shard_id == "w3" + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) + + def _load_w2(self, + expert_data: torch.Tensor, + shard_dim: int, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False): + # Index the loaded weight for tp sharding. + # down_proj: "RowParallel" so tp sharding on input_dim + # Narrow parameter and load. + loaded_weight, shard_dim = self.transpose_weight( + loaded_weight, expert_data, shard_dim) + shard_size = expert_data.shape[shard_dim] + if not load_full: + loaded_weight = loaded_weight.narrow(shard_dim, + shard_size * tp_rank, + shard_size) + # w2, down_proj: Load into only logical weight of w2. + expert_data.copy_(loaded_weight) + UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading