Skip to content

Commit 1257ced

Browse files
committed
fix(granitemoehybrid): Regenerate modeling_granitemoehybrid.py
Branch: GraniteMoeAsDenseFix Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent a2e6e95 commit 1257ced

File tree

1 file changed

+49
-47
lines changed

1 file changed

+49
-47
lines changed

src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py

Lines changed: 49 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -926,52 +926,6 @@ def forward(self, x, position_ids):
926926
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
927927

928928

929-
class GraniteFlashAttentionKwargs(TypedDict, total=False):
930-
"""
931-
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
932-
Use cases include padding-free training and fewer `torch.compile` graph breaks.
933-
934-
Attributes:
935-
cu_seq_lens_q (`torch.LongTensor`)
936-
Gets cumulative sequence length for query state.
937-
cu_seq_lens_k (`torch.LongTensor`)
938-
Gets cumulative sequence length for key state.
939-
max_length_q (`int`):
940-
Maximum sequence length for query state.
941-
max_length_k (`int`):
942-
Maximum sequence length for key state.
943-
seq_idx (`torch.IntTensor):
944-
Index of each packed sequence.
945-
"""
946-
947-
cu_seq_lens_q: torch.LongTensor
948-
cu_seq_lens_k: torch.LongTensor
949-
max_length_q: int
950-
max_length_k: int
951-
seq_idx: torch.IntTensor
952-
953-
954-
@use_kernel_forward_from_hub("RMSNorm")
955-
class GraniteMoeHybridRMSNorm(nn.Module):
956-
def __init__(self, hidden_size, eps=1e-6):
957-
"""
958-
GraniteMoeHybridRMSNorm is equivalent to T5LayerNorm
959-
"""
960-
super().__init__()
961-
self.weight = nn.Parameter(torch.ones(hidden_size))
962-
self.variance_epsilon = eps
963-
964-
def forward(self, hidden_states):
965-
input_dtype = hidden_states.dtype
966-
hidden_states = hidden_states.to(torch.float32)
967-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
968-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
969-
return self.weight * hidden_states.to(input_dtype)
970-
971-
def extra_repr(self):
972-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
973-
974-
975929
class GraniteMoeHybridParallelExperts(nn.Module):
976930
def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
977931
"""
@@ -1113,13 +1067,61 @@ def forward(self, layer_input):
11131067
return layer_output
11141068

11151069

1070+
class GraniteFlashAttentionKwargs(TypedDict, total=False):
1071+
"""
1072+
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
1073+
Use cases include padding-free training and fewer `torch.compile` graph breaks.
1074+
1075+
Attributes:
1076+
cu_seq_lens_q (`torch.LongTensor`)
1077+
Gets cumulative sequence length for query state.
1078+
cu_seq_lens_k (`torch.LongTensor`)
1079+
Gets cumulative sequence length for key state.
1080+
max_length_q (`int`):
1081+
Maximum sequence length for query state.
1082+
max_length_k (`int`):
1083+
Maximum sequence length for key state.
1084+
seq_idx (`torch.IntTensor):
1085+
Index of each packed sequence.
1086+
"""
1087+
1088+
cu_seq_lens_q: torch.LongTensor
1089+
cu_seq_lens_k: torch.LongTensor
1090+
max_length_q: int
1091+
max_length_k: int
1092+
seq_idx: torch.IntTensor
1093+
1094+
1095+
@use_kernel_forward_from_hub("RMSNorm")
1096+
class GraniteMoeHybridRMSNorm(nn.Module):
1097+
def __init__(self, hidden_size, eps=1e-6):
1098+
"""
1099+
GraniteMoeHybridRMSNorm is equivalent to T5LayerNorm
1100+
"""
1101+
super().__init__()
1102+
self.weight = nn.Parameter(torch.ones(hidden_size))
1103+
self.variance_epsilon = eps
1104+
1105+
def forward(self, hidden_states):
1106+
input_dtype = hidden_states.dtype
1107+
hidden_states = hidden_states.to(torch.float32)
1108+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
1109+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
1110+
return self.weight * hidden_states.to(input_dtype)
1111+
1112+
def extra_repr(self):
1113+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
1114+
1115+
11161116
class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer):
11171117
def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
11181118
super().__init__()
11191119
self.hidden_size = config.hidden_size
11201120
# Either attention or mamba will be initialized, depending on the layer type.
11211121
self.self_attn = None
1122-
self.block_sparse_moe = GraniteMoeHybridMoE(config)
1122+
1123+
# Allow non-MoE (dense)
1124+
self.block_sparse_moe = GraniteMoeHybridMoE(config) if config.num_local_experts > 0 else None
11231125
self.input_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
11241126
self.post_attention_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
11251127

0 commit comments

Comments
 (0)