Skip to content

Commit e1e87db

Browse files
committed
fix(granitemoe): Regenerate modeling_granitemoe*.py
Branch: GraniteMoeAsDenseFix Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 700ac97 commit e1e87db

File tree

3 files changed

+106
-288
lines changed

3 files changed

+106
-288
lines changed

src/transformers/models/granitemoe/modeling_granitemoe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,11 +411,13 @@ def __init__(self, config: GraniteMoeConfig, layer_idx: int):
411411
super().__init__()
412412
self.hidden_size = config.hidden_size
413413
self.self_attn = GraniteMoeAttention(config=config, layer_idx=layer_idx)
414-
self.block_sparse_moe = GraniteMoeMoE(config)
414+
self.block_sparse_moe = (
415+
GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if config.num_local_experts > 0 else None
416+
) # Diff with mixtral!
415417
self.input_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
416418
self.post_attention_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
417419

418-
self.residual_multiplier = config.residual_multiplier # Only diff with mixtral!
420+
self.residual_multiplier = config.residual_multiplier # Diff with mixtral!
419421

420422
def forward(
421423
self,

src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py

Lines changed: 51 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -972,158 +972,21 @@ def extra_repr(self):
972972
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
973973

974974

975-
class GraniteMoeHybridParallelExperts(nn.Module):
976-
def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
977-
"""
978-
Initialize the GraniteMoeHybridParallelExperts module.
979-
The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with
980-
many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and
981-
[ScatterMoE](https://github.com/shawntan/scattermoe), as well as the
982-
[MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py)
983-
used in vllm.
984-
985-
Args:
986-
num_experts (int):
987-
Number of experts.
988-
input_size (int):
989-
Size of the input.
990-
output_size (int):
991-
Size of the output.
992-
"""
993-
super().__init__()
994-
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
995-
self.num_experts = num_experts
996-
self.input_size = input_size
997-
self.output_size = output_size
998-
999-
def forward(self, inputs, expert_size):
1000-
"""
1001-
Forward pass of the GraniteMoeHybridParallelExperts module.
1002-
1003-
Args:
1004-
inputs (Tensor):
1005-
Input tensor.
1006-
expert_size:
1007-
Expert size information.
1008-
1009-
Returns:
1010-
Tensor: Output tensor.
1011-
"""
1012-
input_list = inputs.split(expert_size, dim=0)
1013-
output_list = []
1014-
for i in range(self.num_experts):
1015-
output_list.append(F.linear(input_list[i], self.weight[i]))
1016-
results = torch.cat(output_list, dim=0)
1017-
return results
1018-
1019-
1020-
class GraniteMoeHybridTopKGating(nn.Module):
1021-
def __init__(self, input_size: int, num_experts: int, top_k: int):
1022-
"""
1023-
Initialize the top-k gating mechanism.
1024-
1025-
Args:
1026-
input_size (`int`):
1027-
Size of the input.
1028-
num_experts (`int`):
1029-
Number of experts.
1030-
top_k (`int`):
1031-
Number of top experts to select.
1032-
"""
1033-
super().__init__()
1034-
1035-
self.num_experts = num_experts
1036-
self.input_size = input_size
1037-
self.top_k = top_k
1038-
1039-
self.layer = nn.Linear(input_size, num_experts, bias=False)
1040-
1041-
def forward(self, hidden_states):
1042-
# compute the top_k routing decision
1043-
logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
1044-
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
1045-
top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]
1046-
1047-
# compute number of input given to each expert
1048-
zeros = torch.zeros(
1049-
[top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
1050-
) # [num_tokens, num_experts]
1051-
gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
1052-
expert_size = gates.long().sum(0) # [num_experts,]
1053-
# (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`)
1054-
# (and `DataDependentOutputException`)
1055-
expert_size = expert_size.tolist()
1056-
1057-
# sort and group input tokens according to expert assignment
1058-
top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
1059-
_, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
1060-
batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]
1061-
1062-
# gather the gate values for grouped input tokens
1063-
top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
1064-
batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]
1065-
1066-
return index_sorted_experts, batch_index, batch_gates, expert_size, logits
1067-
1068-
1069-
class GraniteMoeHybridMoE(nn.Module):
1070-
"""
1071-
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
1072-
1073-
Args:
1074-
config:
1075-
Configuration object with model hyperparameters.
1076-
"""
1077-
1078-
def __init__(self, config: GraniteMoeHybridConfig):
1079-
super().__init__()
1080-
1081-
self.input_size = config.hidden_size
1082-
self.hidden_size = config.intermediate_size
1083-
self.activation = ACT2FN[config.hidden_act]
1084-
self.input_linear = GraniteMoeHybridParallelExperts(
1085-
config.num_local_experts, self.input_size, self.hidden_size * 2
1086-
)
1087-
self.output_linear = GraniteMoeHybridParallelExperts(
1088-
config.num_local_experts, self.hidden_size, self.input_size
1089-
)
1090-
1091-
self.router = GraniteMoeHybridTopKGating(
1092-
input_size=self.input_size,
1093-
num_experts=config.num_local_experts,
1094-
top_k=config.num_experts_per_tok,
1095-
)
1096-
1097-
def forward(self, layer_input):
1098-
bsz, length, emb_size = layer_input.size()
1099-
layer_input = layer_input.reshape(-1, emb_size)
1100-
_, batch_index, batch_gates, expert_size, _ = self.router(layer_input)
1101-
1102-
expert_inputs = layer_input[batch_index]
1103-
hidden_states = self.input_linear(expert_inputs, expert_size)
1104-
chunked_hidden_states = hidden_states.chunk(2, dim=-1)
1105-
hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
1106-
expert_outputs = self.output_linear(hidden_states, expert_size)
1107-
1108-
expert_outputs = expert_outputs * batch_gates[:, None]
1109-
1110-
zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
1111-
layer_output = zeros.index_add(0, batch_index, expert_outputs)
1112-
layer_output = layer_output.view(bsz, length, self.input_size)
1113-
return layer_output
1114-
1115-
1116975
class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
1117976
def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
1118977
super().__init__()
1119978
self.hidden_size = config.hidden_size
1120979
# Either attention or mamba will be initialized, depending on the layer type.
1121980
self.self_attn = None
1122-
self.block_sparse_moe = GraniteMoeHybridMoE(config)
981+
self.block_sparse_moe = (
982+
GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
983+
if config.num_local_experts > 0
984+
else None
985+
) # Diff with mixtral!
1123986
self.input_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1124987
self.post_attention_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1125988

1126-
self.residual_multiplier = config.residual_multiplier # Only diff with mixtral!
989+
self.residual_multiplier = config.residual_multiplier # Diff with mixtral!
1127990
self.shared_mlp = GraniteMoeHybridMLP(config)
1128991
self.mamba = None
1129992

@@ -1183,6 +1046,51 @@ def forward(
11831046
return hidden_states
11841047

11851048

1049+
class GraniteMoeHybridParallelExperts(nn.Module):
1050+
def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
1051+
"""
1052+
Initialize the GraniteMoeHybridParallelExperts module.
1053+
The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with
1054+
many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and
1055+
[ScatterMoE](https://github.com/shawntan/scattermoe), as well as the
1056+
[MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py)
1057+
used in vllm.
1058+
1059+
Args:
1060+
num_experts (int):
1061+
Number of experts.
1062+
input_size (int):
1063+
Size of the input.
1064+
output_size (int):
1065+
Size of the output.
1066+
"""
1067+
super().__init__()
1068+
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
1069+
self.num_experts = num_experts
1070+
self.input_size = input_size
1071+
self.output_size = output_size
1072+
1073+
def forward(self, inputs, expert_size):
1074+
"""
1075+
Forward pass of the GraniteMoeHybridParallelExperts module.
1076+
1077+
Args:
1078+
inputs (Tensor):
1079+
Input tensor.
1080+
expert_size:
1081+
Expert size information.
1082+
1083+
Returns:
1084+
Tensor: Output tensor.
1085+
"""
1086+
input_list = inputs.split(expert_size, dim=0)
1087+
output_list = []
1088+
for i in range(self.num_experts):
1089+
output_list.append(F.linear(input_list[i], self.weight[i]))
1090+
results = torch.cat(output_list, dim=0)
1091+
return results
1092+
1093+
11861094
@auto_docstring
11871095
class GraniteMoeHybridPreTrainedModel(PreTrainedModel):
11881096
config: GraniteMoeHybridConfig

0 commit comments

Comments
 (0)