Skip to content

Commit a2e6e95

Browse files
committed
fix(granitemoehybid): Only set self.block_sparse_moe if num_local_experts > 0
Branch: GraniteMoeAsDenseFix Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 561233c commit a2e6e95

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
GraniteMoeSharedDecoderLayer,
3636
GraniteMoeSharedForCausalLM,
3737
GraniteMoeSharedMLP,
38+
GraniteMoeSharedMoE,
3839
GraniteMoeSharedModel,
3940
GraniteMoeSharedPreTrainedModel,
4041
eager_attention_forward,
@@ -107,6 +108,10 @@ class GraniteMoeHybridRotaryEmbedding(Gemma2RotaryEmbedding):
107108
pass
108109

109110

111+
class GraniteMoeHybridMoE(GraniteMoeSharedMoE):
112+
pass
113+
114+
110115
class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
111116
def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
112117
super().__init__(config, layer_idx)
@@ -121,6 +126,9 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
121126
self.self_attn = GraniteMoeHybridAttention(config, layer_idx)
122127
self.layer_type = config.layers_block_type[layer_idx]
123128

129+
# Allow non-MoE (dense)
130+
self.block_sparse_moe = GraniteMoeHybridMoE(config) if config.num_local_experts > 0 else None
131+
124132
# Accept 0 experts: skip MoE if num_local_experts == 0
125133
self.has_experts = getattr(config, "num_local_experts", 0) > 0
126134

0 commit comments

Comments
 (0)