@@ -972,158 +972,21 @@ def extra_repr(self):
972972 return f"{ tuple (self .weight .shape )} , eps={ self .variance_epsilon } "
973973
974974
975- class GraniteMoeHybridParallelExperts (nn .Module ):
976- def __init__ (self , num_experts : int , input_size : int , output_size : int ) -> None :
977- """
978- Initialize the GraniteMoeHybridParallelExperts module.
979- The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with
980- many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and
981- [ScatterMoE](https://github.com/shawntan/scattermoe), as well as the
982- [MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py)
983- used in vllm.
984-
985- Args:
986- num_experts (int):
987- Number of experts.
988- input_size (int):
989- Size of the input.
990- output_size (int):
991- Size of the output.
992- """
993- super ().__init__ ()
994- self .weight = nn .Parameter (torch .empty (num_experts , output_size , input_size ))
995- self .num_experts = num_experts
996- self .input_size = input_size
997- self .output_size = output_size
998-
999- def forward (self , inputs , expert_size ):
1000- """
1001- Forward pass of the GraniteMoeHybridParallelExperts module.
1002-
1003- Args:
1004- inputs (Tensor):
1005- Input tensor.
1006- expert_size:
1007- Expert size information.
1008-
1009- Returns:
1010- Tensor: Output tensor.
1011- """
1012- input_list = inputs .split (expert_size , dim = 0 )
1013- output_list = []
1014- for i in range (self .num_experts ):
1015- output_list .append (F .linear (input_list [i ], self .weight [i ]))
1016- results = torch .cat (output_list , dim = 0 )
1017- return results
1018-
1019-
1020- class GraniteMoeHybridTopKGating (nn .Module ):
1021- def __init__ (self , input_size : int , num_experts : int , top_k : int ):
1022- """
1023- Initialize the top-k gating mechanism.
1024-
1025- Args:
1026- input_size (`int`):
1027- Size of the input.
1028- num_experts (`int`):
1029- Number of experts.
1030- top_k (`int`):
1031- Number of top experts to select.
1032- """
1033- super ().__init__ ()
1034-
1035- self .num_experts = num_experts
1036- self .input_size = input_size
1037- self .top_k = top_k
1038-
1039- self .layer = nn .Linear (input_size , num_experts , bias = False )
1040-
1041- def forward (self , hidden_states ):
1042- # compute the top_k routing decision
1043- logits = self .layer (hidden_states ).float () # [batch_size x seq_len, num_experts]
1044- top_k_logits , top_k_indices = logits .topk (self .top_k , dim = 1 ) # [num_tokens, top_k]
1045- top_k_gates = torch .softmax (top_k_logits , dim = 1 ).type_as (hidden_states ) # [num_tokens, top_k]
1046-
1047- # compute number of input given to each expert
1048- zeros = torch .zeros (
1049- [top_k_gates .size (0 ), self .num_experts ], dtype = top_k_gates .dtype , device = top_k_gates .device
1050- ) # [num_tokens, num_experts]
1051- gates = zeros .scatter (1 , top_k_indices , 1 ) # [num_tokens, num_experts]
1052- expert_size = gates .long ().sum (0 ) # [num_experts,]
1053- # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`)
1054- # (and `DataDependentOutputException`)
1055- expert_size = expert_size .tolist ()
1056-
1057- # sort and group input tokens according to expert assignment
1058- top_k_experts = top_k_indices .flatten () # [num_tokens * top_k]
1059- _ , index_sorted_experts = top_k_experts .sort (0 ) # [num_tokens * top_k]
1060- batch_index = index_sorted_experts .div (self .top_k , rounding_mode = "trunc" ) # [num_tokens * top_k]
1061-
1062- # gather the gate values for grouped input tokens
1063- top_k_gates = top_k_gates .flatten () # [num_tokens * top_k]
1064- batch_gates = top_k_gates [index_sorted_experts ] # [num_tokens * top_k]
1065-
1066- return index_sorted_experts , batch_index , batch_gates , expert_size , logits
1067-
1068-
1069- class GraniteMoeHybridMoE (nn .Module ):
1070- """
1071- A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
1072-
1073- Args:
1074- config:
1075- Configuration object with model hyperparameters.
1076- """
1077-
1078- def __init__ (self , config : GraniteMoeHybridConfig ):
1079- super ().__init__ ()
1080-
1081- self .input_size = config .hidden_size
1082- self .hidden_size = config .intermediate_size
1083- self .activation = ACT2FN [config .hidden_act ]
1084- self .input_linear = GraniteMoeHybridParallelExperts (
1085- config .num_local_experts , self .input_size , self .hidden_size * 2
1086- )
1087- self .output_linear = GraniteMoeHybridParallelExperts (
1088- config .num_local_experts , self .hidden_size , self .input_size
1089- )
1090-
1091- self .router = GraniteMoeHybridTopKGating (
1092- input_size = self .input_size ,
1093- num_experts = config .num_local_experts ,
1094- top_k = config .num_experts_per_tok ,
1095- )
1096-
1097- def forward (self , layer_input ):
1098- bsz , length , emb_size = layer_input .size ()
1099- layer_input = layer_input .reshape (- 1 , emb_size )
1100- _ , batch_index , batch_gates , expert_size , _ = self .router (layer_input )
1101-
1102- expert_inputs = layer_input [batch_index ]
1103- hidden_states = self .input_linear (expert_inputs , expert_size )
1104- chunked_hidden_states = hidden_states .chunk (2 , dim = - 1 )
1105- hidden_states = self .activation (chunked_hidden_states [0 ]) * chunked_hidden_states [1 ]
1106- expert_outputs = self .output_linear (hidden_states , expert_size )
1107-
1108- expert_outputs = expert_outputs * batch_gates [:, None ]
1109-
1110- zeros = torch .zeros ((bsz * length , self .input_size ), dtype = expert_outputs .dtype , device = expert_outputs .device )
1111- layer_output = zeros .index_add (0 , batch_index , expert_outputs )
1112- layer_output = layer_output .view (bsz , length , self .input_size )
1113- return layer_output
1114-
1115-
1116975class GraniteMoeHybridDecoderLayer (GradientCheckpointingLayer ):
1117976 def __init__ (self , config : GraniteMoeHybridConfig , layer_idx : int ):
1118977 super ().__init__ ()
1119978 self .hidden_size = config .hidden_size
1120979 # Either attention or mamba will be initialized, depending on the layer type.
1121980 self .self_attn = None
1122- self .block_sparse_moe = GraniteMoeHybridMoE (config )
981+ self .block_sparse_moe = (
982+ GraniteMoeHybridRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
983+ if config .num_local_experts > 0
984+ else None
985+ ) # Diff with mixtral!
1123986 self .input_layernorm = GraniteMoeHybridRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
1124987 self .post_attention_layernorm = GraniteMoeHybridRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
1125988
1126- self .residual_multiplier = config .residual_multiplier # Only diff with mixtral!
989+ self .residual_multiplier = config .residual_multiplier # Diff with mixtral!
1127990 self .shared_mlp = GraniteMoeHybridMLP (config )
1128991 self .mamba = None
1129992
@@ -1183,6 +1046,51 @@ def forward(
11831046 return hidden_states
11841047
11851048
1049+ class GraniteMoeHybridParallelExperts (nn .Module ):
1050+ def __init__ (self , num_experts : int , input_size : int , output_size : int ) -> None :
1051+ """
1052+ Initialize the GraniteMoeHybridParallelExperts module.
1053+ The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with
1054+ many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and
1055+ [ScatterMoE](https://github.com/shawntan/scattermoe), as well as the
1056+ [MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py)
1057+ used in vllm.
1058+
1059+ Args:
1060+ num_experts (int):
1061+ Number of experts.
1062+ input_size (int):
1063+ Size of the input.
1064+ output_size (int):
1065+ Size of the output.
1066+ """
1067+ super ().__init__ ()
1068+ self .weight = nn .Parameter (torch .empty (num_experts , output_size , input_size ))
1069+ self .num_experts = num_experts
1070+ self .input_size = input_size
1071+ self .output_size = output_size
1072+
1073+ def forward (self , inputs , expert_size ):
1074+ """
1075+ Forward pass of the GraniteMoeHybridParallelExperts module.
1076+
1077+ Args:
1078+ inputs (Tensor):
1079+ Input tensor.
1080+ expert_size:
1081+ Expert size information.
1082+
1083+ Returns:
1084+ Tensor: Output tensor.
1085+ """
1086+ input_list = inputs .split (expert_size , dim = 0 )
1087+ output_list = []
1088+ for i in range (self .num_experts ):
1089+ output_list .append (F .linear (input_list [i ], self .weight [i ]))
1090+ results = torch .cat (output_list , dim = 0 )
1091+ return results
1092+
1093+
11861094@auto_docstring
11871095class GraniteMoeHybridPreTrainedModel (PreTrainedModel ):
11881096 config : GraniteMoeHybridConfig
0 commit comments