Skip to content

Commit 68408bd

Browse files
Merge pull request #26 from zhangsicheng5/long_seq_tmp
support cp sp pd disaggregate
2 parents 8ecf93a + 48e4456 commit 68408bd

File tree

1 file changed

+117
-61
lines changed

1 file changed

+117
-61
lines changed

vllm_ascend/distributed/llmdatadist_c_mgr_connector.py

Lines changed: 117 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929

3030
import vllm_ascend.envs as envs_ascend
3131
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
32+
from vllm_ascend.utils import context_parallel_enable, sequence_parallel_enable
33+
if context_parallel_enable():
34+
from vllm.distributed.parallel_state import get_context_model_parallel_rank
3235

3336
TORCH_DTYPE_TO_NPU_DTYPE = {
3437
torch.half: llm_datadist.DataType.DT_FLOAT16,
@@ -64,6 +67,8 @@ class ReqMeta:
6467
remote_port: str
6568
engine_id: str
6669
remote_tp_size: str
70+
remote_cp_size: str
71+
remote_sp_size: str
6772

6873

6974
class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
@@ -80,6 +85,8 @@ def add_new_req(self, request_id: str, local_block_ids: list[int],
8085
remote_host=kv_transfer_params["remote_host"],
8186
remote_port=kv_transfer_params["remote_port"],
8287
remote_tp_size=kv_transfer_params["remote_tp_size"],
88+
remote_cp_size=kv_transfer_params["remote_cp_size"],
89+
remote_sp_size=kv_transfer_params["remote_sp_size"],
8390
)
8491

8592

@@ -180,8 +187,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]):
180187
self.tp_size = None
181188
dp_rank_local = self.vllm_config.parallel_config.data_parallel_rank_local
182189
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
190+
self.cp_size = self.vllm_config.parallel_config.context_parallel_size if context_parallel_enable() else 1
191+
self.sp_size = tp_size if (sequence_parallel_enable() and self.vllm_config.parallel_config.enable_sequence_parallel) else 1
183192

184-
self.port = dp_rank_local * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT
193+
self.port = dp_rank_local * self.cp_size * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT
185194

186195
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
187196

@@ -284,6 +293,8 @@ def request_finished(
284293
remote_port=self.port,
285294
remote_tp_size=str(
286295
self.vllm_config.parallel_config.tensor_parallel_size),
296+
remote_cp_size=str(self.cp_size),
297+
remote_sp_size=str(self.sp_size),
287298
)
288299

289300

@@ -305,6 +316,9 @@ def __init__(self, vllm_config: VllmConfig):
305316
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
306317
self.tp_rank = get_tp_group().rank_in_group
307318
self.rank = get_world_group().rank
319+
self.cp_size = vllm_config.parallel_config.context_parallel_size if context_parallel_enable() else 1
320+
self.cp_rank = get_context_model_parallel_rank() if context_parallel_enable() else 0
321+
self.sp_size = self.tp_size if (sequence_parallel_enable() and vllm_config.parallel_config.enable_sequence_parallel) else 1
308322
self.local_ip = get_ip()
309323
self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config
310324
self.local_agent_metadata: Optional[
@@ -344,7 +358,8 @@ def __init__(self, vllm_config: VllmConfig):
344358

345359
def listen_for_agent_metadata_req(self, event: threading.Event):
346360
assert self.local_agent_metadata is not None
347-
port = envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.local_dp_rank * self.tp_size + self.tp_rank if self.local_dp_rank is not None else envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
361+
port = envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.local_dp_rank * self.cp_size * self.tp_size + self.cp_rank * self.tp_size + self.tp_rank \
362+
if self.local_dp_rank is not None else envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT + self.tp_size + self.tp_rank
348363
url = f"tcp://{envs_ascend.VLLM_ASCEND_LLMDD_RPC_IP}:{port}"
349364
msg_encoder = msgspec.msgpack.Encoder()
350365
msg_decoder = msgspec.msgpack.Decoder()
@@ -452,9 +467,9 @@ def read_agent_metadata(self, global_rank_table):
452467
d for d in device_list if d.get("server_id") == self.local_ip
453468
and device_filter(d.get("device_id", ""))
454469
]
455-
if len(device_list) <= self.tp_rank:
470+
if len(device_list) <= self.cp_rank * self.tp_size + self.tp_rank:
456471
continue
457-
device_info = device_list[self.tp_rank]
472+
device_info = device_list[self.cp_rank * self.tp_size + self.tp_rank]
458473
super_pod_id_ = device_info.get("super_pod_id", None)
459474
server_id_ = device_info["server_id"]
460475
device_id_ = device_info["device_id"]
@@ -573,6 +588,8 @@ def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata):
573588
remote_engine_id=meta.engine_id,
574589
request_id=req_id,
575590
remote_tp_size=meta.remote_tp_size,
591+
remote_cp_size=meta.remote_cp_size,
592+
remote_sp_size=meta.remote_sp_size,
576593
)
577594
futures.append(future)
578595

@@ -772,65 +789,104 @@ def _read_blocks(
772789
remote_engine_id: str,
773790
request_id: str,
774791
remote_tp_size: str,
792+
remote_cp_size: str,
793+
remote_sp_size: str,
775794
):
776-
# if remote_ip not in self.linked_cluster:
777-
tp_offset = self.tp_rank % int(remote_tp_size)
778-
remote_cluster_id = self.connect_to_remote_agent(
779-
remote_ip, remote_port + tp_offset)
780-
num_local_blocks = len(local_block_ids)
781-
if num_local_blocks == 0:
782-
return
783-
num_remote_blocks = len(remote_block_ids)
784-
assert num_local_blocks <= num_remote_blocks
785-
if num_local_blocks < num_remote_blocks:
786-
remote_block_ids = remote_block_ids[-num_local_blocks:]
787-
788-
logger.info(f"remote cluster id is: {remote_cluster_id}")
789-
if self.use_mla:
790-
remote_cache_key_k_normed = BlocksCacheKey(
791-
cluster_id=remote_cluster_id, model_id=0)
792-
remote_cache_key_k_pe = BlocksCacheKey(
793-
cluster_id=remote_cluster_id, model_id=1)
794-
logger.info("Try pull blocks from remote server")
795-
try:
796-
self.cache_manager.pull_blocks(
797-
remote_cache_key_k_normed,
798-
self.cache[0], # type: ignore[has-type]
799-
remote_block_ids,
800-
local_block_ids)
801-
self.cache_manager.pull_blocks(
802-
remote_cache_key_k_pe,
803-
self.cache[1], # type: ignore[has-type]
804-
remote_block_ids,
805-
local_block_ids)
806-
except (TypeError, ValueError):
807-
raise RuntimeError(
808-
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
809-
)
810-
except LLMException:
811-
raise RuntimeError(
812-
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
813-
)
795+
remote_cp_size = int(remote_cp_size)
796+
remote_sp_size = int(remote_sp_size)
797+
if self.cp_size == remote_cp_size and self.sp_size == remote_sp_size:
798+
# same as original, P cpi_spj -> D cpi_spj
799+
remote_kv_num = 1
800+
# remote_ports = [remote_port + self.tp_rank % int(remote_tp_size)]
801+
remote_ports = list(
802+
range(remote_port + self.tp_rank,
803+
remote_port + int(remote_tp_size), self.tp_size))
804+
num_remote_blocks = [len(remote_block_ids)]
805+
elif self.cp_size == 1 and self.sp_size == 1:
806+
# only cp/sp in P, each D needs to pull from cp*sp P (to all-gather kv_cache)
807+
remote_kv_num = remote_cp_size * remote_sp_size
808+
remote_ports = [remote_port + offset for offset in range(remote_cp_size * remote_sp_size)]
809+
# recompute cp/sp block assign here, maybe we can also pass it from P node meta
810+
num_local_blocks = len(local_block_ids)
811+
num_remote_blocks = [num_local_blocks // (remote_cp_size * remote_sp_size)] * remote_cp_size * remote_sp_size
812+
num_remain_blocks = num_local_blocks % (remote_cp_size * remote_sp_size)
813+
for i in range(num_remain_blocks):
814+
num_remote_blocks[i] += 1
814815
else:
815-
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
816-
logger.info("Try pull blocks from remote server")
817-
try:
818-
self.cache_manager.pull_blocks(
819-
remote_cache_key,
820-
self.cache, # type: ignore[has-type]
821-
remote_block_ids,
822-
local_block_ids)
823-
except (TypeError, ValueError):
824-
raise RuntimeError(
825-
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
826-
)
827-
except LLMException:
828-
raise RuntimeError(
829-
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
830-
)
831-
remote_ports = list(
832-
range(remote_port + self.tp_rank,
833-
remote_port + int(remote_tp_size), self.tp_size))
816+
raise NotImplementedError('cp/sp resharding not supported now, only support D cp/sp == 1 or D cp/sp == P cp/sp')
817+
818+
local_block_offset = 0
819+
remote_block_ids_full = remote_block_ids
820+
local_block_ids_full = local_block_ids
821+
for remote_kv_id in range(remote_kv_num):
822+
remote_port = remote_ports[remote_kv_id]
823+
num_blocks_to_pull = num_remote_blocks[remote_kv_id]
824+
if num_blocks_to_pull == 0:
825+
continue
826+
remote_block_ids = remote_block_ids_full[:num_blocks_to_pull]
827+
local_block_ids = local_block_ids_full[local_block_offset:local_block_offset+num_blocks_to_pull]
828+
local_block_offset += num_blocks_to_pull
829+
# if remote_ip not in self.linked_cluster:
830+
# tp_offset = self.tp_rank % int(remote_tp_size)
831+
# remote_cluster_id = self.connect_to_remote_agent(
832+
# remote_ip, remote_port + tp_offset)
833+
remote_cluster_id = self.connect_to_remote_agent(remote_ip, remote_port)
834+
# TODO maybe this part is for prefix cache, not considered now, need to check
835+
assert len(local_block_ids) > 0
836+
assert len(local_block_ids) == len(remote_block_ids)
837+
"""
838+
num_local_blocks = len(local_block_ids)
839+
if num_local_blocks == 0:
840+
return
841+
num_remote_blocks = len(remote_block_ids)
842+
assert num_local_blocks <= num_remote_blocks
843+
if num_local_blocks < num_remote_blocks:
844+
remote_block_ids = remote_block_ids[-num_local_blocks:]
845+
"""
846+
847+
logger.info(f"remote cluster id is: {remote_cluster_id}")
848+
if self.use_mla:
849+
remote_cache_key_k_normed = BlocksCacheKey(
850+
cluster_id=remote_cluster_id, model_id=0)
851+
remote_cache_key_k_pe = BlocksCacheKey(
852+
cluster_id=remote_cluster_id, model_id=1)
853+
logger.info("Try pull blocks from remote server")
854+
try:
855+
self.cache_manager.pull_blocks(
856+
remote_cache_key_k_normed,
857+
self.cache[0], # type: ignore[has-type]
858+
remote_block_ids,
859+
local_block_ids)
860+
self.cache_manager.pull_blocks(
861+
remote_cache_key_k_pe,
862+
self.cache[1], # type: ignore[has-type]
863+
remote_block_ids,
864+
local_block_ids)
865+
except (TypeError, ValueError):
866+
raise RuntimeError(
867+
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
868+
)
869+
except LLMException:
870+
raise RuntimeError(
871+
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
872+
)
873+
else:
874+
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
875+
logger.info("Try pull blocks from remote server")
876+
try:
877+
self.cache_manager.pull_blocks(
878+
remote_cache_key,
879+
self.cache, # type: ignore[has-type]
880+
remote_block_ids,
881+
local_block_ids)
882+
except (TypeError, ValueError):
883+
raise RuntimeError(
884+
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
885+
)
886+
except LLMException:
887+
raise RuntimeError(
888+
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
889+
)
834890
self.send_finish_to_remote(remote_ip, remote_ports, request_id)
835891
with self.thread_lock:
836892
self.finished_reqs.add(request_id)

0 commit comments

Comments
 (0)