@@ -343,19 +343,23 @@ def pack_update_info(self, update_info_generator):
343
343
344
344
maps .append (new_expert_map [self .rank_id ])
345
345
346
- log2phy_map = ExpertMapUtils .generate_log2phy_map (new_expert_map ) if self .redundant_enable else None
347
- log2phy_all .append (log2phy_map )
346
+ if self .redundant_enable is not None :
347
+ log2phy_map = ExpertMapUtils .generate_log2phy_map (new_expert_map )
348
+ log2phy_all .append (log2phy_map )
348
349
349
350
layer_ids .append (layer_id )
350
351
351
352
# 把 list of Tensor 堆成一个大 Tensor
352
- stacked_maps = torch .stack (maps , dim = 0 ) # [N, ...]
353
- stacked_log2phy = torch .stack (log2phy_all , dim = 0 ) # [N, ...]
354
- layer_id_tensor = torch .as_tensor (layer_ids , dtype = torch .int64 ) # [N]
353
+ stacked_maps = torch .stack (maps , dim = 0 )
354
+ layer_id_tensor = torch .as_tensor (layer_ids , dtype = torch .int64 )
355
+ stacked_maps .share_memory_ ()
356
+ layer_id_tensor .share_memory_ ()
355
357
356
- # 跨进程零拷贝
357
- for t in (stacked_maps , stacked_log2phy , layer_id_tensor ):
358
- t .share_memory_ ()
358
+ if self .redundant_enable :
359
+ stacked_log2phy = torch .stack (log2phy_all , dim = 0 )
360
+ stacked_log2phy .share_memory_ ()
361
+ else :
362
+ stacked_log2phy = None
359
363
360
364
return send_all , recv_all , stacked_maps , stacked_log2phy , layer_id_tensor
361
365
@@ -375,7 +379,7 @@ def __init__(self, shared_dict, planner_q, block_update_q, redundant_enable, pol
375
379
self .redundant_enable = redundant_enable
376
380
377
381
# Create EplbWorker instance
378
- self .worker = EplbWorker (self .shared_dict , self .policy_type , self .enable_d2d )
382
+ self .worker = EplbWorker (self .shared_dict , self .policy_type , self .enable_d2d , self . redundant_enable )
379
383
380
384
381
385
def worker_process (self , planner_q , block_update_q ):
@@ -387,17 +391,12 @@ def worker_process(self, planner_q, block_update_q):
387
391
388
392
planner_q .get ()
389
393
390
- update_info_generator = self .worker .do_update ()
391
- update_info_list = []
392
-
393
- for (send_info , recv_info , new_expert_map , layer_id ) in update_info_generator :
394
- log2phy_map = ExpertMapUtils .generate_log2phy_map (new_expert_map ) if self .redundant_enable else None
395
- update_info_list .append ((send_info , recv_info , new_expert_map , log2phy_map , layer_id ))
394
+ packed_update_info = self .worker .do_update ()
396
395
397
396
while True :
398
397
if not block_update_q .empty ():
399
398
continue
400
- block_update_q .put (update_info_list )
399
+ block_update_q .put (packed_update_info )
401
400
break
402
401
403
402
except Exception as e :
0 commit comments