Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion tests/ut/ops/test_common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch

from tests.ut.base import TestBase
from vllm_ascend.ops.common_fused_moe import fused_experts_moge
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE, fused_experts_moge


class TestFusedExpertsMoGE(TestBase):
Expand Down Expand Up @@ -67,3 +67,39 @@ def test_fused_experts_moge(self):
)

self.assertEqual(output.shape, (4, 128))


class TestLoadWeight(TestBase):

def test_load_w13_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
moe.hidden_size = 8
expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)

expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)

expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)

expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)

def test_load_w2_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
expert_data = torch.randn(128, 4)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)

expert_data = torch.randn(4, 128)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)
80 changes: 72 additions & 8 deletions vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def unquantized_fused_moe_init_func(self, *args, **kwargs):
self.use_aclgraph = (vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager)
self.transpose = True


def forward_oot_v01011(
Expand Down Expand Up @@ -260,13 +261,22 @@ def forward_oot(

def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
if self.transpose:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The self in process_weights_after_loading refers to an instance of UnquantizedFusedMoEMethod, which does not have a transpose attribute. The transpose attribute is defined on the AscendFusedMoE layer instance, which is passed as the layer argument to this method. Accessing self.transpose will raise an AttributeError at runtime. You should access this attribute via the layer object instead.

Suggested change
if self.transpose:
if layer.transpose:

w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)

w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)

self.transpose = False
else:
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)

w2_data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)

if not is_310p():
layer.w13_weight.data = torch_npu.npu_format_cast(
Expand Down Expand Up @@ -357,12 +367,11 @@ def __init__(
num_redundant_experts,
has_bias,
)

setup_token_dispatchers(self.moe_config.ep_size,
top_k=self.top_k,
num_experts=self.global_num_experts,
num_local_experts=self.local_num_experts)

self.hidden_size = hidden_size
self.moe_config.tp_group = get_tp_group()
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
Expand Down Expand Up @@ -414,6 +423,61 @@ def forward_impl(self, hidden_states: torch.Tensor,

return final_hidden_states

def transpose_weight(self, loaded_weight, expert_data, shard_dim):
# Ensure training and inference weight shapes match during RL weight updates
if (
loaded_weight.shape[1] != expert_data.shape[1] and \
loaded_weight.shape[0] != expert_data.shape[0]
):
shard_dim = int(not shard_dim)
loaded_weight = loaded_weight.transpose(0, 1).contiguous()
return loaded_weight, shard_dim

def _load_w13(self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
loaded_weight, shard_dim = self.transpose_weight(
loaded_weight, expert_data, shard_dim)
shard_size = expert_data.shape[shard_dim] // 2
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)

def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
loaded_weight, shard_dim = self.transpose_weight(
loaded_weight, expert_data, shard_dim)
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)


UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
Expand Down
Loading