@@ -338,7 +338,8 @@ def __init__(self, vllm_config: VllmConfig):
338
338
self .finished_reqs : set [str ] = set ()
339
339
self .soc_info = NPUSocInfo ()
340
340
# get decode tp size from extra config
341
- self .done_receiving_counts : defaultdict [str , set [int ]] = defaultdict (set )
341
+ self .done_receiving_counts : defaultdict [str ,
342
+ set [int ]] = defaultdict (set )
342
343
343
344
def listen_for_agent_metadata_req (self , event : threading .Event ):
344
345
assert self .local_agent_metadata is not None
@@ -372,9 +373,11 @@ def listen_for_agent_metadata_req(self, event: threading.Event):
372
373
f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data { decode_msg } "
373
374
)
374
375
elif event_msg == LLMDataDistCMgrEvent .ReqForFinished :
375
- finished_req_id , decode_tp_rank , decode_tp_size = decode_msg [:3 ]
376
+ finished_req_id = decode_msg [0 ]
377
+ decode_tp_rank = decode_msg [1 ]
378
+ decode_tp_size = decode_msg [2 ]
376
379
with self .thread_lock :
377
- if self ._increment_task_count (finished_req_id ,
380
+ if self ._increment_task_count (finished_req_id ,
378
381
decode_tp_rank ,
379
382
decode_tp_size ):
380
383
logger .debug (
@@ -387,7 +390,7 @@ def listen_for_agent_metadata_req(self, event: threading.Event):
387
390
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event { event_msg } from remote !"
388
391
)
389
392
390
- def _increment_task_count (self , request_id : str , tp_rank : int ,
393
+ def _increment_task_count (self , request_id : str , tp_rank : int ,
391
394
decode_tp_size : int ):
392
395
if tp_rank in self .done_receiving_counts [request_id ]:
393
396
logger .warning (
@@ -752,8 +755,10 @@ def send_finish_to_remote(self, host: str, port: int, request_id):
752
755
url = f"tcp://{ host } :{ port } "
753
756
logger .debug (f"Sending finished to remote: { url } " )
754
757
msg_encoder = msgspec .msgpack .Encoder ()
755
- msg_send = msg_encoder .encode (
756
- [LLMDataDistCMgrEvent .ReqForFinished , [request_id , self .tp_rank , self .tp_size ]])
758
+ msg_send = msg_encoder .encode ([
759
+ LLMDataDistCMgrEvent .ReqForFinished ,
760
+ [request_id , self .tp_rank , self .tp_size ]
761
+ ])
757
762
with zmq_ctx (zmq .REQ , url ) as sock : # type: ignore[attr-defined]
758
763
try :
759
764
sock .send (msg_send )
0 commit comments