@@ -926,52 +926,6 @@ def forward(self, x, position_ids):
926926 return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
927927
928928
929- class GraniteFlashAttentionKwargs (TypedDict , total = False ):
930- """
931- Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
932- Use cases include padding-free training and fewer `torch.compile` graph breaks.
933-
934- Attributes:
935- cu_seq_lens_q (`torch.LongTensor`)
936- Gets cumulative sequence length for query state.
937- cu_seq_lens_k (`torch.LongTensor`)
938- Gets cumulative sequence length for key state.
939- max_length_q (`int`):
940- Maximum sequence length for query state.
941- max_length_k (`int`):
942- Maximum sequence length for key state.
943- seq_idx (`torch.IntTensor):
944- Index of each packed sequence.
945- """
946-
947- cu_seq_lens_q : torch .LongTensor
948- cu_seq_lens_k : torch .LongTensor
949- max_length_q : int
950- max_length_k : int
951- seq_idx : torch .IntTensor
952-
953-
954- @use_kernel_forward_from_hub ("RMSNorm" )
955- class GraniteMoeHybridRMSNorm (nn .Module ):
956- def __init__ (self , hidden_size , eps = 1e-6 ):
957- """
958- GraniteMoeHybridRMSNorm is equivalent to T5LayerNorm
959- """
960- super ().__init__ ()
961- self .weight = nn .Parameter (torch .ones (hidden_size ))
962- self .variance_epsilon = eps
963-
964- def forward (self , hidden_states ):
965- input_dtype = hidden_states .dtype
966- hidden_states = hidden_states .to (torch .float32 )
967- variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
968- hidden_states = hidden_states * torch .rsqrt (variance + self .variance_epsilon )
969- return self .weight * hidden_states .to (input_dtype )
970-
971- def extra_repr (self ):
972- return f"{ tuple (self .weight .shape )} , eps={ self .variance_epsilon } "
973-
974-
975929class GraniteMoeHybridParallelExperts (nn .Module ):
976930 def __init__ (self , num_experts : int , input_size : int , output_size : int ) -> None :
977931 """
@@ -1113,13 +1067,61 @@ def forward(self, layer_input):
11131067 return layer_output
11141068
11151069
1070+ class GraniteFlashAttentionKwargs (TypedDict , total = False ):
1071+ """
1072+ Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
1073+ Use cases include padding-free training and fewer `torch.compile` graph breaks.
1074+
1075+ Attributes:
1076+ cu_seq_lens_q (`torch.LongTensor`)
1077+ Gets cumulative sequence length for query state.
1078+ cu_seq_lens_k (`torch.LongTensor`)
1079+ Gets cumulative sequence length for key state.
1080+ max_length_q (`int`):
1081+ Maximum sequence length for query state.
1082+ max_length_k (`int`):
1083+ Maximum sequence length for key state.
1084+ seq_idx (`torch.IntTensor):
1085+ Index of each packed sequence.
1086+ """
1087+
1088+ cu_seq_lens_q : torch .LongTensor
1089+ cu_seq_lens_k : torch .LongTensor
1090+ max_length_q : int
1091+ max_length_k : int
1092+ seq_idx : torch .IntTensor
1093+
1094+
1095+ @use_kernel_forward_from_hub ("RMSNorm" )
1096+ class GraniteMoeHybridRMSNorm (nn .Module ):
1097+ def __init__ (self , hidden_size , eps = 1e-6 ):
1098+ """
1099+ GraniteMoeHybridRMSNorm is equivalent to T5LayerNorm
1100+ """
1101+ super ().__init__ ()
1102+ self .weight = nn .Parameter (torch .ones (hidden_size ))
1103+ self .variance_epsilon = eps
1104+
1105+ def forward (self , hidden_states ):
1106+ input_dtype = hidden_states .dtype
1107+ hidden_states = hidden_states .to (torch .float32 )
1108+ variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
1109+ hidden_states = hidden_states * torch .rsqrt (variance + self .variance_epsilon )
1110+ return self .weight * hidden_states .to (input_dtype )
1111+
1112+ def extra_repr (self ):
1113+ return f"{ tuple (self .weight .shape )} , eps={ self .variance_epsilon } "
1114+
1115+
11161116class GraniteMoeHybridDecoderLayer (GradientCheckpointingLayer ):
11171117 def __init__ (self , config : GraniteMoeHybridConfig , layer_idx : int ):
11181118 super ().__init__ ()
11191119 self .hidden_size = config .hidden_size
11201120 # Either attention or mamba will be initialized, depending on the layer type.
11211121 self .self_attn = None
1122- self .block_sparse_moe = GraniteMoeHybridMoE (config )
1122+
1123+ # Allow non-MoE (dense)
1124+ self .block_sparse_moe = GraniteMoeHybridMoE (config ) if config .num_local_experts > 0 else None
11231125 self .input_layernorm = GraniteMoeHybridRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
11241126 self .post_attention_layernorm = GraniteMoeHybridRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
11251127
0 commit comments