29
29
30
30
import vllm_ascend .envs as envs_ascend
31
31
from vllm_ascend .utils import AscendSocVersion , get_ascend_soc_version
32
+ from vllm_ascend .utils import context_parallel_enable , sequence_parallel_enable
33
+ if context_parallel_enable ():
34
+ from vllm .distributed .parallel_state import get_context_model_parallel_rank
32
35
33
36
TORCH_DTYPE_TO_NPU_DTYPE = {
34
37
torch .half : llm_datadist .DataType .DT_FLOAT16 ,
@@ -64,6 +67,8 @@ class ReqMeta:
64
67
remote_port : str
65
68
engine_id : str
66
69
remote_tp_size : str
70
+ remote_cp_size : str
71
+ remote_sp_size : str
67
72
68
73
69
74
class LLMDataDistCMgrConnectorMetadata (KVConnectorMetadata ):
@@ -80,6 +85,8 @@ def add_new_req(self, request_id: str, local_block_ids: list[int],
80
85
remote_host = kv_transfer_params ["remote_host" ],
81
86
remote_port = kv_transfer_params ["remote_port" ],
82
87
remote_tp_size = kv_transfer_params ["remote_tp_size" ],
88
+ remote_cp_size = kv_transfer_params ["remote_cp_size" ],
89
+ remote_sp_size = kv_transfer_params ["remote_sp_size" ],
83
90
)
84
91
85
92
@@ -180,8 +187,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]):
180
187
self .tp_size = None
181
188
dp_rank_local = self .vllm_config .parallel_config .data_parallel_rank_local
182
189
tp_size = self .vllm_config .parallel_config .tensor_parallel_size
190
+ self .cp_size = self .vllm_config .parallel_config .context_parallel_size if context_parallel_enable () else 1
191
+ self .sp_size = tp_size if (sequence_parallel_enable () and self .vllm_config .parallel_config .enable_sequence_parallel ) else 1
183
192
184
- self .port = dp_rank_local * tp_size + envs_ascend .VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend .VLLM_ASCEND_LLMDD_RPC_PORT
193
+ self .port = dp_rank_local * self . cp_size * tp_size + envs_ascend .VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend .VLLM_ASCEND_LLMDD_RPC_PORT
185
194
186
195
self ._reqs_need_recv : dict [str , tuple [Request , list [int ]]] = {}
187
196
@@ -284,6 +293,8 @@ def request_finished(
284
293
remote_port = self .port ,
285
294
remote_tp_size = str (
286
295
self .vllm_config .parallel_config .tensor_parallel_size ),
296
+ remote_cp_size = str (self .cp_size ),
297
+ remote_sp_size = str (self .sp_size ),
287
298
)
288
299
289
300
@@ -305,6 +316,9 @@ def __init__(self, vllm_config: VllmConfig):
305
316
self .tp_size = vllm_config .parallel_config .tensor_parallel_size
306
317
self .tp_rank = get_tp_group ().rank_in_group
307
318
self .rank = get_world_group ().rank
319
+ self .cp_size = vllm_config .parallel_config .context_parallel_size if context_parallel_enable () else 1
320
+ self .cp_rank = get_context_model_parallel_rank () if context_parallel_enable () else 0
321
+ self .sp_size = self .tp_size if (sequence_parallel_enable () and vllm_config .parallel_config .enable_sequence_parallel ) else 1
308
322
self .local_ip = get_ip ()
309
323
self .kv_transfer_config : KVTransferConfig = vllm_config .kv_transfer_config
310
324
self .local_agent_metadata : Optional [
@@ -344,7 +358,8 @@ def __init__(self, vllm_config: VllmConfig):
344
358
345
359
def listen_for_agent_metadata_req (self , event : threading .Event ):
346
360
assert self .local_agent_metadata is not None
347
- port = envs_ascend .VLLM_ASCEND_LLMDD_RPC_PORT + self .local_dp_rank * self .tp_size + self .tp_rank if self .local_dp_rank is not None else envs_ascend .VLLM_ASCEND_LLMDD_RPC_PORT + self .tp_size + self .tp_rank
361
+ port = envs_ascend .VLLM_ASCEND_LLMDD_RPC_PORT + self .local_dp_rank * self .cp_size * self .tp_size + self .cp_rank * self .tp_size + self .tp_rank \
362
+ if self .local_dp_rank is not None else envs_ascend .VLLM_ASCEND_LLMDD_RPC_PORT + self .tp_size + self .tp_rank
348
363
url = f"tcp://{ envs_ascend .VLLM_ASCEND_LLMDD_RPC_IP } :{ port } "
349
364
msg_encoder = msgspec .msgpack .Encoder ()
350
365
msg_decoder = msgspec .msgpack .Decoder ()
@@ -452,9 +467,9 @@ def read_agent_metadata(self, global_rank_table):
452
467
d for d in device_list if d .get ("server_id" ) == self .local_ip
453
468
and device_filter (d .get ("device_id" , "" ))
454
469
]
455
- if len (device_list ) <= self .tp_rank :
470
+ if len (device_list ) <= self .cp_rank * self . tp_size + self . tp_rank :
456
471
continue
457
- device_info = device_list [self .tp_rank ]
472
+ device_info = device_list [self .cp_rank * self . tp_size + self . tp_rank ]
458
473
super_pod_id_ = device_info .get ("super_pod_id" , None )
459
474
server_id_ = device_info ["server_id" ]
460
475
device_id_ = device_info ["device_id" ]
@@ -573,6 +588,8 @@ def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata):
573
588
remote_engine_id = meta .engine_id ,
574
589
request_id = req_id ,
575
590
remote_tp_size = meta .remote_tp_size ,
591
+ remote_cp_size = meta .remote_cp_size ,
592
+ remote_sp_size = meta .remote_sp_size ,
576
593
)
577
594
futures .append (future )
578
595
@@ -772,65 +789,104 @@ def _read_blocks(
772
789
remote_engine_id : str ,
773
790
request_id : str ,
774
791
remote_tp_size : str ,
792
+ remote_cp_size : str ,
793
+ remote_sp_size : str ,
775
794
):
776
- # if remote_ip not in self.linked_cluster:
777
- tp_offset = self .tp_rank % int (remote_tp_size )
778
- remote_cluster_id = self .connect_to_remote_agent (
779
- remote_ip , remote_port + tp_offset )
780
- num_local_blocks = len (local_block_ids )
781
- if num_local_blocks == 0 :
782
- return
783
- num_remote_blocks = len (remote_block_ids )
784
- assert num_local_blocks <= num_remote_blocks
785
- if num_local_blocks < num_remote_blocks :
786
- remote_block_ids = remote_block_ids [- num_local_blocks :]
787
-
788
- logger .info (f"remote cluster id is: { remote_cluster_id } " )
789
- if self .use_mla :
790
- remote_cache_key_k_normed = BlocksCacheKey (
791
- cluster_id = remote_cluster_id , model_id = 0 )
792
- remote_cache_key_k_pe = BlocksCacheKey (
793
- cluster_id = remote_cluster_id , model_id = 1 )
794
- logger .info ("Try pull blocks from remote server" )
795
- try :
796
- self .cache_manager .pull_blocks (
797
- remote_cache_key_k_normed ,
798
- self .cache [0 ], # type: ignore[has-type]
799
- remote_block_ids ,
800
- local_block_ids )
801
- self .cache_manager .pull_blocks (
802
- remote_cache_key_k_pe ,
803
- self .cache [1 ], # type: ignore[has-type]
804
- remote_block_ids ,
805
- local_block_ids )
806
- except (TypeError , ValueError ):
807
- raise RuntimeError (
808
- f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: { remote_cache_key_k_normed } { remote_cache_key_k_pe } , cache: { self .cache } , local_block_ids: { local_block_ids } , remote_block_ids: { remote_block_ids } " # type: ignore[has-type]
809
- )
810
- except LLMException :
811
- raise RuntimeError (
812
- "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
813
- )
795
+ remote_cp_size = int (remote_cp_size )
796
+ remote_sp_size = int (remote_sp_size )
797
+ if self .cp_size == remote_cp_size and self .sp_size == remote_sp_size :
798
+ # same as original, P cpi_spj -> D cpi_spj
799
+ remote_kv_num = 1
800
+ # remote_ports = [remote_port + self.tp_rank % int(remote_tp_size)]
801
+ remote_ports = list (
802
+ range (remote_port + self .tp_rank ,
803
+ remote_port + int (remote_tp_size ), self .tp_size ))
804
+ num_remote_blocks = [len (remote_block_ids )]
805
+ elif self .cp_size == 1 and self .sp_size == 1 :
806
+ # only cp/sp in P, each D needs to pull from cp*sp P (to all-gather kv_cache)
807
+ remote_kv_num = remote_cp_size * remote_sp_size
808
+ remote_ports = [remote_port + offset for offset in range (remote_cp_size * remote_sp_size )]
809
+ # recompute cp/sp block assign here, maybe we can also pass it from P node meta
810
+ num_local_blocks = len (local_block_ids )
811
+ num_remote_blocks = [num_local_blocks // (remote_cp_size * remote_sp_size )] * remote_cp_size * remote_sp_size
812
+ num_remain_blocks = num_local_blocks % (remote_cp_size * remote_sp_size )
813
+ for i in range (num_remain_blocks ):
814
+ num_remote_blocks [i ] += 1
814
815
else :
815
- remote_cache_key = BlocksCacheKey (cluster_id = remote_cluster_id )
816
- logger .info ("Try pull blocks from remote server" )
817
- try :
818
- self .cache_manager .pull_blocks (
819
- remote_cache_key ,
820
- self .cache , # type: ignore[has-type]
821
- remote_block_ids ,
822
- local_block_ids )
823
- except (TypeError , ValueError ):
824
- raise RuntimeError (
825
- f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: { remote_cache_key } , cache: { self .cache } , local_block_ids: { local_block_ids } , remote_block_ids: { remote_block_ids } " # type: ignore[has-type]
826
- )
827
- except LLMException :
828
- raise RuntimeError (
829
- "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
830
- )
831
- remote_ports = list (
832
- range (remote_port + self .tp_rank ,
833
- remote_port + int (remote_tp_size ), self .tp_size ))
816
+ raise NotImplementedError ('cp/sp resharding not supported now, only support D cp/sp == 1 or D cp/sp == P cp/sp' )
817
+
818
+ local_block_offset = 0
819
+ remote_block_ids_full = remote_block_ids
820
+ local_block_ids_full = local_block_ids
821
+ for remote_kv_id in range (remote_kv_num ):
822
+ remote_port = remote_ports [remote_kv_id ]
823
+ num_blocks_to_pull = num_remote_blocks [remote_kv_id ]
824
+ if num_blocks_to_pull == 0 :
825
+ continue
826
+ remote_block_ids = remote_block_ids_full [:num_blocks_to_pull ]
827
+ local_block_ids = local_block_ids_full [local_block_offset :local_block_offset + num_blocks_to_pull ]
828
+ local_block_offset += num_blocks_to_pull
829
+ # if remote_ip not in self.linked_cluster:
830
+ # tp_offset = self.tp_rank % int(remote_tp_size)
831
+ # remote_cluster_id = self.connect_to_remote_agent(
832
+ # remote_ip, remote_port + tp_offset)
833
+ remote_cluster_id = self .connect_to_remote_agent (remote_ip , remote_port )
834
+ # TODO maybe this part is for prefix cache, not considered now, need to check
835
+ assert len (local_block_ids ) > 0
836
+ assert len (local_block_ids ) == len (remote_block_ids )
837
+ """
838
+ num_local_blocks = len(local_block_ids)
839
+ if num_local_blocks == 0:
840
+ return
841
+ num_remote_blocks = len(remote_block_ids)
842
+ assert num_local_blocks <= num_remote_blocks
843
+ if num_local_blocks < num_remote_blocks:
844
+ remote_block_ids = remote_block_ids[-num_local_blocks:]
845
+ """
846
+
847
+ logger .info (f"remote cluster id is: { remote_cluster_id } " )
848
+ if self .use_mla :
849
+ remote_cache_key_k_normed = BlocksCacheKey (
850
+ cluster_id = remote_cluster_id , model_id = 0 )
851
+ remote_cache_key_k_pe = BlocksCacheKey (
852
+ cluster_id = remote_cluster_id , model_id = 1 )
853
+ logger .info ("Try pull blocks from remote server" )
854
+ try :
855
+ self .cache_manager .pull_blocks (
856
+ remote_cache_key_k_normed ,
857
+ self .cache [0 ], # type: ignore[has-type]
858
+ remote_block_ids ,
859
+ local_block_ids )
860
+ self .cache_manager .pull_blocks (
861
+ remote_cache_key_k_pe ,
862
+ self .cache [1 ], # type: ignore[has-type]
863
+ remote_block_ids ,
864
+ local_block_ids )
865
+ except (TypeError , ValueError ):
866
+ raise RuntimeError (
867
+ f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: { remote_cache_key_k_normed } { remote_cache_key_k_pe } , cache: { self .cache } , local_block_ids: { local_block_ids } , remote_block_ids: { remote_block_ids } " # type: ignore[has-type]
868
+ )
869
+ except LLMException :
870
+ raise RuntimeError (
871
+ "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
872
+ )
873
+ else :
874
+ remote_cache_key = BlocksCacheKey (cluster_id = remote_cluster_id )
875
+ logger .info ("Try pull blocks from remote server" )
876
+ try :
877
+ self .cache_manager .pull_blocks (
878
+ remote_cache_key ,
879
+ self .cache , # type: ignore[has-type]
880
+ remote_block_ids ,
881
+ local_block_ids )
882
+ except (TypeError , ValueError ):
883
+ raise RuntimeError (
884
+ f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: { remote_cache_key } , cache: { self .cache } , local_block_ids: { local_block_ids } , remote_block_ids: { remote_block_ids } " # type: ignore[has-type]
885
+ )
886
+ except LLMException :
887
+ raise RuntimeError (
888
+ "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
889
+ )
834
890
self .send_finish_to_remote (remote_ip , remote_ports , request_id )
835
891
with self .thread_lock :
836
892
self .finished_reqs .add (request_id )
0 commit comments