Skip to content

Commit 6592d72

Browse files
committed
fix bug in pack update info
1 parent e4b1ba0 commit 6592d72

File tree

2 files changed

+31
-24
lines changed

2 files changed

+31
-24
lines changed

vllm_ascend/eplb/core/worker/eplb_worker.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -343,19 +343,23 @@ def pack_update_info(self, update_info_generator):
343343

344344
maps.append(new_expert_map[self.rank_id])
345345

346-
log2phy_map = ExpertMapUtils.generate_log2phy_map(new_expert_map) if self.redundant_enable else None
347-
log2phy_all.append(log2phy_map)
346+
if self.redundant_enable is not None:
347+
log2phy_map = ExpertMapUtils.generate_log2phy_map(new_expert_map)
348+
log2phy_all.append(log2phy_map)
348349

349350
layer_ids.append(layer_id)
350351

351352
# 把 list of Tensor 堆成一个大 Tensor
352-
stacked_maps = torch.stack(maps, dim=0) # [N, ...]
353-
stacked_log2phy = torch.stack(log2phy_all, dim=0) # [N, ...]
354-
layer_id_tensor = torch.as_tensor(layer_ids, dtype=torch.int64) # [N]
353+
stacked_maps = torch.stack(maps, dim=0)
354+
layer_id_tensor = torch.as_tensor(layer_ids, dtype=torch.int64)
355+
stacked_maps.share_memory_()
356+
layer_id_tensor.share_memory_()
355357

356-
# 跨进程零拷贝
357-
for t in (stacked_maps, stacked_log2phy, layer_id_tensor):
358-
t.share_memory_()
358+
if self.redundant_enable:
359+
stacked_log2phy = torch.stack(log2phy_all, dim=0)
360+
stacked_log2phy.share_memory_()
361+
else:
362+
stacked_log2phy = None
359363

360364
return send_all, recv_all, stacked_maps, stacked_log2phy, layer_id_tensor
361365

@@ -375,7 +379,7 @@ def __init__(self, shared_dict, planner_q, block_update_q, redundant_enable, pol
375379
self.redundant_enable = redundant_enable
376380

377381
# Create EplbWorker instance
378-
self.worker = EplbWorker(self.shared_dict, self.policy_type, self.enable_d2d)
382+
self.worker = EplbWorker(self.shared_dict, self.policy_type, self.enable_d2d, self.redundant_enable)
379383

380384

381385
def worker_process(self, planner_q, block_update_q):
@@ -387,17 +391,12 @@ def worker_process(self, planner_q, block_update_q):
387391

388392
planner_q.get()
389393

390-
update_info_generator = self.worker.do_update()
391-
update_info_list = []
392-
393-
for (send_info , recv_info , new_expert_map, layer_id) in update_info_generator:
394-
log2phy_map = ExpertMapUtils.generate_log2phy_map(new_expert_map) if self.redundant_enable else None
395-
update_info_list.append((send_info , recv_info , new_expert_map, log2phy_map, layer_id))
394+
packed_update_info = self.worker.do_update()
396395

397396
while True:
398397
if not block_update_q.empty():
399398
continue
400-
block_update_q.put(update_info_list)
399+
block_update_q.put(packed_update_info)
401400
break
402401

403402
except Exception as e:

vllm_ascend/eplb/eplb_updator.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def init_eplb(self, redundant_enable):
7575

7676
logger.info(f"[ModelRunner] Launched EPLB process (pid={self.eplb_process.pid})")
7777

78-
7978
def get_update_iteration(self):
8079
self.cur_iterations = self.cur_iterations + 1
8180
return self.cur_iterations % self.num_iterations == 0
@@ -191,16 +190,25 @@ def unpack_update_batch(self, packed_update_info):
191190
"""
192191
send_all, recv_all, stacked_maps, stacked_log2phy, layer_id_tensor = packed_update_info
193192

194-
# 拆分 Tensor,得到 N 个张量的 tuple
195-
maps = stacked_maps.unbind(0)
196-
log2phy = stacked_log2phy.unbind(0)
197-
198-
# 把 layer_id_tensor 转成 Python int 列表
193+
maps = stacked_maps.unbind(0)
199194
layer_ids = layer_id_tensor.tolist()
200195

196+
if self.redundant_enable:
197+
log2phy_list = stacked_log2phy.unbind(0)
198+
else:
199+
log2phy_list = [None] * len(maps)
200+
201+
_zip = zip
202+
_send = send_all
203+
_recv = recv_all
204+
_maps = maps
205+
_l2p = log2phy_list
206+
_lids = layer_ids
207+
201208
recovered = [
202-
(s, r, m, l, lid)
203-
for s, r, m, l, lid in zip(send_all, recv_all, maps, log2phy, layer_ids)
209+
(_s, _r, _m, _lp, _lid)
210+
for _s, _r, _m, _lp, _lid
211+
in _zip(_send, _recv, _maps, _l2p, _lids)
204212
]
205213
return recovered
206214

0 commit comments

Comments
 (0)