@@ -170,38 +170,6 @@ def __init__(
170
170
ascend_config = get_ascend_config ()
171
171
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
172
172
173
- def forward (
174
- self ,
175
- hidden_states : torch .Tensor ,
176
- attn_metadata : Optional [AttentionMetadata ] = None ) -> torch .Tensor :
177
- forward_context = get_forward_context ()
178
- # when profile runs, force experts to load balanced tokens
179
- # to avoid high memory consumption on a single rank.
180
- enable_force_load_balance = forward_context .in_profile_run
181
-
182
- is_prefill = forward_context .with_prefill
183
-
184
- old_hidden_states = hidden_states .clone ()
185
-
186
- # router_logits: (num_tokens, n_experts)
187
- router_logits , _ = self .gate (hidden_states )
188
-
189
- hidden_states = self .experts (
190
- hidden_states = hidden_states ,
191
- router_logits = router_logits ,
192
- is_prefill = is_prefill ,
193
- top_k = CustomDeepseekDBOMoE .top_k ,
194
- enable_force_load_balance = enable_force_load_balance ,
195
- ) * self .routed_scaling_factor
196
-
197
- if self .n_shared_experts is not None :
198
- shared_output = self .shared_experts (old_hidden_states )
199
-
200
- if shared_output is not None :
201
- hidden_states = hidden_states + shared_output
202
-
203
- return hidden_states
204
-
205
173
# ----------------------------------------- TBO-related --------------------------------------------
206
174
def _forward_ms_op_shared_expert (
207
175
self ,
0 commit comments