18
18
19
19
import torch .distributed as dist
20
20
from vllm .logger import logger
21
+ from vllm_ascend .ascend_config import get_ascend_config
21
22
22
23
23
24
class ExpertWeightUpdateState (Enum ):
@@ -38,6 +39,7 @@ def __init__(self, eplb_adaptor):
38
39
self .state = ExpertWeightUpdateState .WAITING
39
40
self .recv_expert_list = []
40
41
self .mock_flag = True
42
+ self .enable_weight_nz_layout = get_ascend_config ().enable_weight_nz_layout
41
43
42
44
def generate_expert_d2d_transfer_task (self , expert_send_info ,
43
45
expert_recv_info , updated_expert_map ,
@@ -61,10 +63,14 @@ def generate_expert_d2d_transfer_task(self, expert_send_info,
61
63
dst_rank , global_expert_id_to_send = send_info
62
64
local_expert_id = self .eplb_adaptor .expert_map_per_layer_cpu [
63
65
layer_id ][global_expert_id_to_send ].item ()
66
+ idx = 0
64
67
for src_tensor in self .eplb_adaptor .expert_param_per_layer [
65
- layer_id ][local_expert_id ]:
68
+ layer_id ][local_expert_id ]:
69
+ if self .enable_weight_nz_layout and idx < 2 :
70
+ src_tensor = src_tensor .clone ()
66
71
self .comm_op_list .append (
67
72
dist .P2POp (dist .isend , src_tensor , dst_rank ))
73
+ idx += 1
68
74
69
75
buffer_tensor_id = 0
70
76
for recv_info in expert_recv_info :
0 commit comments