@@ -212,6 +212,14 @@ def __init__(
212
212
self .tp_group = get_tp_group ().device_group
213
213
self .tp_rank = get_tp_group ().rank_in_group
214
214
215
+ self .params_dtype = torch .get_default_dtype ()
216
+
217
+ self .enable_graph_mode = False
218
+ additional_config = get_current_vllm_config ().additional_config
219
+ if additional_config :
220
+ self .enable_graph_mode = additional_config .get (
221
+ "enable_graph_mode" , False )
222
+
215
223
def forward (
216
224
self ,
217
225
hidden_states : torch .Tensor ,
@@ -228,52 +236,65 @@ def forward(
228
236
else :
229
237
is_prefill = attn_metadata .num_prefills > 0
230
238
enable_force_load_balance = False
231
- num_tokens , hidden_dim = hidden_states .shape
239
+ if hasattr (attn_metadata , 'with_prefill_across_dp' ):
240
+ is_prefill = is_prefill or attn_metadata .with_prefill_across_dp
241
+
242
+ num_tokens , hidden_size = hidden_states .shape
232
243
233
244
if self .n_shared_experts is not None :
234
245
shared_output = self .shared_experts (hidden_states )
235
246
236
247
if self .tp_size > 1 :
237
- # pass
238
- num_tokens , hidden_size = hidden_states .shape
239
- if num_tokens < self .tp_size :
240
- target_size = self .tp_size
241
- new_hidden_states = torch .empty ([target_size , hidden_size ],
242
- dtype = hidden_states .dtype ,
243
- device = hidden_states .device )
244
- new_hidden_states [:num_tokens ] = hidden_states
245
- hidden_states = new_hidden_states
246
- chunk_hidden_states = torch .tensor_split (hidden_states ,
247
- self .tp_size ,
248
- dim = 0 )
249
- local_hidden_states = chunk_hidden_states [self .tp_rank ]
250
- else :
251
- local_hidden_states = hidden_states
248
+ if envs_ascend .VLLM_ENABLE_MC2 and not is_prefill :
249
+ chunks = torch .chunk (hidden_states , self .tp_size , dim = 0 )
250
+ hidden_states = chunks [self .tp_rank ]
251
+ elif not self .enable_graph_mode :
252
+ num_padding_tokens = (self .tp_size -
253
+ num_tokens % self .tp_size ) % self .tp_size
254
+ # Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
255
+ if num_padding_tokens > 0 :
256
+ hidden_states = nn .functional .pad (
257
+ hidden_states , (0 , 0 , 0 , num_padding_tokens ))
258
+ chunk_hidden_states = torch .tensor_split (hidden_states ,
259
+ self .tp_size ,
260
+ dim = 0 )
261
+ hidden_states = chunk_hidden_states [self .tp_rank ]
252
262
253
263
# router_logits: (num_tokens, n_experts)
254
- router_logits , _ = self .gate (local_hidden_states )
264
+ router_logits , _ = self .gate (hidden_states )
255
265
256
- router_hidden_states = self .experts (
257
- hidden_states = local_hidden_states ,
266
+ hidden_states = self .experts (
267
+ hidden_states = hidden_states ,
258
268
router_logits = router_logits ,
259
269
is_prefill = is_prefill ,
260
270
top_k = CustomDeepseekV2MoE .top_k ,
261
271
enable_force_load_balance = enable_force_load_balance ,
262
272
) * self .routed_scaling_factor
263
273
264
274
if self .tp_size > 1 :
265
- dist .all_gather (list (chunk_hidden_states ), router_hidden_states ,
266
- self .tp_group )
267
- final_hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
268
- if num_tokens < self .tp_size :
269
- final_hidden_states = final_hidden_states [:num_tokens ]
270
- else :
271
- final_hidden_states = router_hidden_states
275
+ if self .enable_graph_mode :
276
+ if envs_ascend .VLLM_ENABLE_MC2 and not is_prefill :
277
+ final_hidden_states = torch .zeros (
278
+ [num_tokens , hidden_size ],
279
+ dtype = self .params_dtype ,
280
+ device = "npu" )
281
+ dist .all_gather_into_tensor (final_hidden_states ,
282
+ hidden_states , self .tp_group )
283
+ hidden_states = final_hidden_states
284
+ else :
285
+ hidden_states = tensor_model_parallel_all_reduce (
286
+ hidden_states )
287
+ else :
288
+ dist .all_gather (list (chunk_hidden_states ), hidden_states ,
289
+ self .tp_group )
290
+ hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
291
+ if num_padding_tokens > 0 :
292
+ hidden_states = hidden_states [:- num_padding_tokens ]
272
293
273
294
if shared_output is not None :
274
- final_hidden_states = final_hidden_states + shared_output
295
+ hidden_states = hidden_states + shared_output
275
296
276
- return final_hidden_states .view (num_tokens , hidden_dim )
297
+ return hidden_states .view (num_tokens , hidden_size )
277
298
278
299
279
300
class CustomDeepseekV2MLAAttention (DeepseekV2MLAAttention ):
0 commit comments