From 7655041061b58688aec316053a5dfa73fa9a862f Mon Sep 17 00:00:00 2001 From: LCAIZJ Date: Sat, 13 Sep 2025 23:25:04 +0800 Subject: [PATCH 01/10] mooncake store connector Co-authored-by: fems14 <1804143737@qq.com> Co-authored-by: Dreamerleader <2270923832@qq.com> Co-authored-by: Pz1116 Co-authored-by: lizy124 <1950471827@qq.com> Co-authored-by: zouyida2052 Signed-off-by: LCAIZJ --- vllm_ascend/attention/mla_v1.py | 39 +- vllm_ascend/distributed/__init__.py | 5 + .../distributed/mooncake/config_data.py | 477 +++++++++++++++++ .../distributed/mooncake/kv_transfer.py | 203 +++++++ .../distributed/mooncake/mooncake_engine.py | 494 ++++++++++++++++++ .../distributed/mooncake/mooncake_store.py | 106 ++++ .../mooncake/mooncake_store_connector_v1.py | 466 +++++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 12 +- 8 files changed, 1792 insertions(+), 10 deletions(-) create mode 100644 vllm_ascend/distributed/mooncake/config_data.py create mode 100644 vllm_ascend/distributed/mooncake/kv_transfer.py create mode 100644 vllm_ascend/distributed/mooncake/mooncake_engine.py create mode 100644 vllm_ascend/distributed/mooncake/mooncake_store.py create mode 100644 vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 0031513742..5e47814609 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar +from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar, List import torch import torch_npu @@ -12,6 +12,10 @@ from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.utils import cdiv, round_down +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) +from vllm.forward_context import ForwardContext, get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState @@ -976,6 +980,7 @@ def forward( assert attn_metadata.num_decodes is not None and \ attn_metadata.num_prefills is not None and \ attn_metadata.num_decode_tokens is not None + self.wait_for_kv_layer_from_connector(layer.layer_name) num_decode_tokens = attn_metadata.num_decode_tokens # Inputs and outputs may be padded for CUDA graphs output_padded = output @@ -1046,4 +1051,36 @@ def forward( is_force_scatter=self.enable_shared_expert_dp)[0] current_ms_metadata.after_comm_event.record() del o_proj_input + self.maybe_save_kv_layer_to_connector(layer_name=layer.layer_name, kv_cache_layer=kv_cache) return output_padded + + def wait_for_kv_layer_from_connector(self, layer_name: str): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + assert isinstance(attn_metadata, AscendMLAMetadata) + connector.wait_for_layer_load(layer_name) + + def maybe_save_kv_layer_to_connector( + self, + layer_name: str, + kv_cache_layer: List[torch.Tensor], + ): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + assert isinstance(attn_metadata, AscendMLAMetadata) + connector.save_kv_layer(layer_name, kv_cache_layer, + attn_metadata) diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py index 458b814d7e..26ddd8f9dc 100644 --- a/vllm_ascend/distributed/__init__.py +++ b/vllm_ascend/distributed/__init__.py @@ -26,3 +26,8 @@ KVConnectorFactory.register_connector( "MooncakeConnectorV1", "vllm_ascend.distributed.mooncake_connector", "MooncakeConnector") + +KVConnectorFactory.register_connector( + "MooncakeConnectorStoreV1", + "vllm_ascend.distributed.mooncake.mooncake_store_connector_v1", + "MooncakeConnectorV1") diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py new file mode 100644 index 0000000000..47eda15afd --- /dev/null +++ b/vllm_ascend/distributed/mooncake/config_data.py @@ -0,0 +1,477 @@ +# Standard +from dataclasses import dataclass +import hashlib +from typing import Any, Iterable, List, Optional, Tuple, Union +import json +import os +# Third Party +from numpy import array +import torch, torch_npu +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata +from vllm.utils import logger +from vllm.utils import cdiv + +# First Party + +@dataclass +class MooncakeEngineMetadata: + """name of the LLM model""" + + model_name: str + """ world size when running under a distributed setting """ + world_size: int + """ worker id when running under a distributed setting """ + worker_id: int + """ the format of kv tensors """ + kv_dtype: torch.dtype + """ the shape of kv tensors """ + """ (num_layer, 2, metadata.block_size, num_kv_head, head_size) """ + kv_shape: tuple[int, int, int, int, int] + block_size: int = 128 + """ whether use MLA""" + use_mla: bool = False + +@dataclass(order=True) +class MooncakeEngineKey: + model_name: str + world_size: int + worker_id: int + chunk_hash: str + + def __hash__(self): + return hash( + ( + self.model_name, + self.world_size, + self.worker_id, + self.chunk_hash, + ) + ) + + def to_string(self): + return ( + f"{self.model_name}@{self.world_size}" + f"@{self.worker_id}@{self.chunk_hash}" + ) + + def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]: + """Split the key into multiple keys for each layer""" + keys = [] + for layer_id in range(num_layers): + keys.append( + LayerMooncakeEngineKey( + self.model_name, + self.world_size, + self.worker_id, + self.chunk_hash, + layer_id, + ) + ) + return keys + + @staticmethod + def from_string(s): + parts = s.split("@") + if len(parts) != 5: + raise ValueError(f"Invalid key string: {s}") + return MooncakeEngineKey( + parts[0], int(parts[1]), int(parts[2]), parts[3] + ) + + def to_dict(self): + # Note(Kuntai): this is used for serializing CacheEngineKey via msgpack. + return { + "__type__": "CacheEngineKey", + "model_name": self.model_name, + "world_size": self.world_size, + "worker_id": self.worker_id, + "chunk_hash": self.chunk_hash, + } + + @staticmethod + def from_dict(d): + return MooncakeEngineKey( + model_name=d["model_name"], + world_size=d["world_size"], + worker_id=d["worker_id"], + chunk_hash=d["chunk_hash"], + ) + + +@dataclass(order=True) +class LayerMooncakeEngineKey(MooncakeEngineKey): + """A key for the layer cache engine""" + + layer_id: int + + def __hash__(self): + return hash( + ( + self.model_name, + self.world_size, + self.worker_id, + self.chunk_hash, + self.layer_id, + ) + ) + + def to_string(self): + return ( + f"{self.model_name}@{self.world_size}" + f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}" + ) + + @staticmethod + def from_string(s): + parts = s.split("@") + return LayerMooncakeEngineKey( + parts[0], + int(parts[1]), + int(parts[2]), + parts[3], + int(parts[4]), + ) + + +class ChunkedTokenDatabase(): + def __init__( + self, + metadata: Optional[MooncakeEngineMetadata] = None, + ): + self.metadata = metadata + + def _make_key_by_hash(self, chunk_hash: str, layer_id: Optional[int] = None): + assert self.metadata is not None + return MooncakeEngineKey( + self.metadata.model_name, + self.metadata.world_size, + self.metadata.worker_id, + chunk_hash, + ) + + def _hash( + self, + tokens: Union[torch.Tensor, List[int]], + prefix_hash: str, + ) -> str: + # TODO: change it to a more efficient hash function + if isinstance(tokens, torch.Tensor): + tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes() + elif isinstance(tokens, list): + tokens_bytes = array.array("I", tokens).tobytes() + return hashlib.sha256(prefix_hash.encode("ascii") + tokens_bytes).hexdigest() + + def _chunk_tokens( + self, + tokens: Union[torch.Tensor, List[int]], + ) -> Iterable[Union[torch.Tensor, List[int]]]: + """ + Chunk the tokens into chunks of size self.metadata.block_size. + + :param tokens: the input tokens, with shape [seq_len] + device: the target device after chunking + + :return: a generator of chunks of tokens, each with + shape [metadata.block_size] + """ + for i in range(0, len(tokens), self.metadata.block_size): + yield tokens[i : i + self.metadata.block_size] + + def _prefix_hash( + self, + token_chunks: Iterable[Union[torch.Tensor, List[int]]], + ) -> Iterable[str]: + prefix_hash = '' + for token_chunk in token_chunks: + prefix_hash = self._hash(token_chunk, prefix_hash) + yield prefix_hash + + def process_tokens( + self, + tokens: Union[torch.Tensor, List[int]], + mask: Optional[torch.Tensor] = None, + make_key: bool = True, + ) -> Iterable[Tuple[int, int, Union[MooncakeEngineKey, str]]]: + """Process the tokens and return the corresponding cache engine keys. + + :param Union[torch.Tensor, List[int]] tokens: The tokens to process. + + :param Optional[torch.Tensor] mask: The mask for the tokens. Should + have the same length as tokens. And the mask should ALWAYS be like + FFFFFTTTTTTT, where True means the tokens needs to be matched, + and the Falses will ALWAYS be at the PREFIX of the tensor. + + :param bool make_key: Whether to make the cache engine key or not. + If False, the hash value will be returned instead. + + :returns: A iterable of tuples with three elements. The first element + is the start index of the tokens for the key. The second element + is the end index of the tokens for the key. The third element is + the cache engine key (or hash) for the tokens. + + :raises: ValueError if the number of Falses in the mask is not a + multiple of the chunk size. + """ + if mask is not None: + num_falses = mask.numel() - mask.long().sum().item() + else: + num_falses = 0 + + if num_falses % self.metadata.block_size != 0: + raise ValueError( + "The number of Falses in the mask is not a multiple of the chunk size." + ) + total_len = len(tokens) + + token_chunks = self._chunk_tokens(tokens) + prefix_hashes = self._prefix_hash(token_chunks) + + start_idx = 0 + for chunk_id, hash_val in enumerate(prefix_hashes): + start_idx = chunk_id * self.metadata.block_size + end_idx = min(start_idx + self.metadata.block_size, total_len) + if start_idx < num_falses: + continue + else: + if make_key: + yield start_idx, end_idx, self._make_key_by_hash(hash_val) + else: + yield start_idx, end_idx, hash_val + + +@dataclass +class LoadSpec: + # Number of tokens cached in vLLM + vllm_cached_tokens: int + # Number of tokens that are cached in mooncake + mooncake_cached_tokens: int + # Whether the scheduler allow us to load the tokens + can_load: bool + +@dataclass +class SaveSpec: + # Skip already saved tokens + skip_leading_tokens: int + # Whether the scheduler allow us to save the tokens + can_save: bool + +@dataclass +class RequestTracker: + # Request id + req_id: str + + # The token ids that has been scheduled so far + token_ids: list[int] + + # The block ids that has been allocated so far + # NOTE: allocated blocks could be more than the number of tokens + # FIXME: need to check whether the block ids will be changed after + # preemption + allocated_block_ids: list[int] + + # The number of tokens that has been savd + num_saved_tokens: int = 0 + + @staticmethod + def from_new_request( + new_request: "NewRequestData", + num_tokens_to_compute: int, + ) -> "RequestTracker": + """Create the request tracker from a new request. + + Args: + new_request (NewRequestData): the new request data. + num_tokens_to_compute (int): the number of tokens that will + be 'computed', including the `num_computed_tokens` (vLLM's + local cache hit) and new tokens that will be scheduled. + + """ + # vLLM 0.9.0 update: request.block_ids changed from list[int] to + # list[list[int]] + # Need to check the type of request.block_ids + + unfolded_block_ids = [] + + if not isinstance(new_request.block_ids[0], list): + unfolded_block_ids = new_request.block_ids.copy() + else: + unfolded_block_ids = new_request.block_ids[0].copy() + + return RequestTracker( + req_id=new_request.req_id, + token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].copy(), + allocated_block_ids=unfolded_block_ids, + num_saved_tokens=0, + ) + + def update( + self, + new_token_ids: list[int], + new_block_ids: Union[tuple[list[int], ...], list[int]], + ) -> None: + """Update the request tracker when a running request is + scheduled again + """ + + self.token_ids.extend(new_token_ids) + + if len(new_block_ids) == 0: + new_block_ids = [] + elif isinstance(new_block_ids, tuple): + new_block_ids = new_block_ids[0] + elif isinstance(new_block_ids, list): + pass + else: + raise ValueError(f"Unsupported new_block_ids type {type(new_block_ids)}") + self.allocated_block_ids.extend(new_block_ids) + + +@dataclass +class ReqMeta: + # Request id + req_id: str + # Request tokens + token_ids: torch.Tensor + + block_ids: list[int] + # # Slot mapping if exchange for block_id + # slot_mapping: torch.Tensor + # Skip save or not + save_spec: Optional[SaveSpec] = None + # load_spec + load_spec: Optional[LoadSpec] = None + + is_last_chunk: Optional[bool] = None + @staticmethod + def from_request_tracker( + tracker: RequestTracker, + block_size: int, + load_spec: Optional[LoadSpec] = None, + skip_save: bool = False, + is_last_chunk: Optional[bool] = None, + discard_partial_chunks: bool = True, + ) -> Optional["ReqMeta"]: + """Create the request metadata from a request tracker. + + Args: + tracker (RequestTracker): the request tracker. + block_size (int): the block size in vLLM. + load_spec (Optional[LoadSpec]): the load spec for KV cache loading. + skip_save (bool): whether to skip the save operation. + discard_partial_chunks (bool): whether to discard partial chunks. + + Returns: + the request metadata if we need to perform load/save + operations, None otherwise. + """ + input_token_ids = tracker.token_ids + input_token_len = len(input_token_ids) + + # For save operation: do not save if the following condition is met + # 1. has already been saved before (num_saved_tokens > 0) + # 2. number of unsaved tokens is not reached the chunk boundary + skip_leading_tokens = tracker.num_saved_tokens + chunk_boundary = ( + cdiv(tracker.num_saved_tokens + 1, block_size) * block_size + ) + skip_save = skip_save or ( + tracker.num_saved_tokens > 0 and input_token_len < chunk_boundary + ) + + if skip_save and load_spec is None: + return None + + # Calculate number of tokens to save based on discard_partial_chunks + # setting + num_tokens_to_save = ( + (input_token_len // block_size * block_size) + if discard_partial_chunks + else input_token_len + ) + + # If we need to save, update the number of saved tokens + if not skip_save: + tracker.num_saved_tokens = num_tokens_to_save + save_spec = SaveSpec(skip_leading_tokens, not skip_save) + + # Calculate the token ids and slot mappings for load and save + # OPTIMIZATION: pre-allocate the buffer for token ids and block ids + token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save] + + # # For load operation: check whether the request is scheduled to load + if load_spec is not None and load_spec.can_load: + logger.debug( + "Scheduled to load %d tokens for request %s", + load_spec.mooncake_cached_tokens, + tracker.req_id, + ) + else: + # Do not load if not in `can_load` state + load_spec = None + + return ReqMeta( + req_id=tracker.req_id, + token_ids=token_ids, + block_ids=tracker.allocated_block_ids, + save_spec=save_spec, + load_spec=load_spec, + is_last_chunk=is_last_chunk, + ) + + +@dataclass +class MooncakeConnectorMetadata(KVConnectorMetadata): + requests: list[ReqMeta] + + def __init__(self): + self.requests = [] + + def add_request(self, req_meta: ReqMeta) -> None: + """Add a request to the metadata. + + Args: + req_meta (ReqMeta): the request metadata. + """ + self.requests.append(req_meta) + + +@dataclass +class LasyerMultiBlockReqMeta: + req_id: str + keys: List[LayerMooncakeEngineKey] + starts: List[int] + ends: list[int] + block_ids: list[int] + layer_id: int + + +@dataclass +class MooncakeStoreConfig: + local_hostname: str + metadata_server: str + global_segment_size: int + local_buffer_size: int + protocol: str + device_name: str + master_server_address: str + + @staticmethod + def from_file(file_path: str) -> "MooncakeStoreConfig": + with open(file_path) as file: + config = json.load(file) + return MooncakeStoreConfig( + local_hostname=config.get("local_hostname"), + metadata_server=config.get("metadata_server"), + global_segment_size=config.get("global_segment_size", 3355443200), + local_buffer_size=config.get("local_buffer_size", 1073741824), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", ""), + master_server_address=config.get("master_server_address") + ) + + @staticmethod + def load_from_env() -> "MooncakeStoreConfig": + config_path = os.getenv("MOONCAKE_CONFIG_PATH") + if not config_path: + raise ValueError("The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeStoreConfig.from_file(config_path) diff --git a/vllm_ascend/distributed/mooncake/kv_transfer.py b/vllm_ascend/distributed/mooncake/kv_transfer.py new file mode 100644 index 0000000000..4da6b06463 --- /dev/null +++ b/vllm_ascend/distributed/mooncake/kv_transfer.py @@ -0,0 +1,203 @@ +import threading +import queue +import torch, torch_npu +import zmq +from typing import Any, Iterable, List, Optional, Tuple, Union +from collections import defaultdict, deque +from dataclasses import dataclass +from concurrent.futures import ThreadPoolExecutor +from vllm.utils import logger, get_ip, logger, make_zmq_path, make_zmq_socket +from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey, MooncakeEngineMetadata, ChunkedTokenDatabase, LayerMooncakeEngineKey, MooncakeConnectorMetadata, LasyerMultiBlockReqMeta +from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore +import os + + +class KVTransferThread(threading.Thread): + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], token_database: ChunkedTokenDatabase, + block_len: list[int], block_size:int, ready_event: threading.Event, name:str): + super().__init__(daemon=True, name=name) + self.tp_rank = tp_rank + self.tp_size = tp_size + self.m_store = m_store + self.ready_event = ready_event + self.kv_caches_base_addr = local_kv_caches_base_addr + self.block_len = block_len + self.token_database = token_database + self.block_size = block_size + self.done_task_lock = threading.Lock() + # TODO(jianzs): find a better way to detect MLA. + self.use_mla = len(block_len) == 2 + + self.request_queue: queue.Queue[Any] = queue.Queue() + # TODO(jianzs): make this configurable + self.executor = ThreadPoolExecutor(max_workers=32) + self.finished_requests: set[str] = set() + + def prepare_value(self, start: int, end: int, block_ids: list[int]): + addr_list=[] + size_list=[] + block_id=block_ids[start//self.block_size] + for index, base_addr in enumerate(self.kv_caches_base_addr): + block_len = (self.block_len[index % 2] + if self.use_mla else self.block_len[0]) + + addr=base_addr+block_id*block_len + length=int(block_len/self.block_size*(end-start)) + addr_list.append(addr) + size_list.append(length) + return addr_list, size_list, block_id + + def prepare_value_layer(self, start: int, end: int, block_ids: list[int], layer_id: int): + block_id=block_ids[start//self.block_size] + if self.use_mla: + addr_k=self.kv_caches_base_addr[layer_id*2]+block_id*self.block_len[0] + addr_v=self.kv_caches_base_addr[layer_id*2+1]+block_id*self.block_len[1] + length_k=int(self.block_len[0]/self.block_size*(end-start)) + length_v=int(self.block_len[1]/self.block_size*(end-start)) + size_list=[length_k, length_v] + else: + addr_k=self.kv_caches_base_addr[layer_id*2]+block_id*self.block_len[0] + addr_v=self.kv_caches_base_addr[layer_id*2+1]+block_id*self.block_len[0] + length=int(self.block_len[0]/self.block_size*(end-start)) + size_list=[length, length] + addr_list=[addr_k,addr_v] + return addr_list, size_list + + def add_request( + self, + req_id: str, + tokens: torch.Tensor, + block_ids: list[int], + mask: Optional[torch.Tensor] = None, + is_last_chunk: Optional[bool] = None, + ) -> torch.Tensor: + req=({ + "req_id": req_id, + "tokens": tokens, + "block_ids": block_ids, + "mask": mask, + "is_last_chunk":is_last_chunk, + }) + self.request_queue.put(req) + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + with self.done_task_lock: + finished_requests = self.finished_requests.copy() + self.finished_requests.clear() + return finished_requests + + def set_finished_request(self, req_id): + with self.done_task_lock: + self.finished_requests.add(req_id) + + def run(self): + """Run the thread to handle KV cache transfer requests.""" + self.ready_event.set() + while True: + try: + request_data = self.request_queue.get() + if request_data is None: + logger.warning("Received a None request!") + self.request_queue.task_done() + continue + self._handle_request(request_data) + except Exception as e: + logger.error(f"Error in KVCacheTransferThread: {e}") + + def _handle_request(self, req_meta: dict[str, Any]): + pass + + +class KVCacheStoreSendingThread(KVTransferThread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], token_database: ChunkedTokenDatabase, + block_len: list[int], block_size:int, ready_event: threading.Event): + super().__init__(tp_rank, tp_size, m_store, local_kv_caches_base_addr, + token_database, block_len, block_size, ready_event, name="KVCacheSendingThread") + + def _handle_request(self, req_meta: dict[str, Any]): + tokens=req_meta["tokens"] + mask=req_meta["mask"] + block_ids=req_meta["block_ids"] + req_id=req_meta["req_id"] + torch.npu.current_stream().synchronize() + for start, end, key in self.token_database.process_tokens(tokens, mask): + addr, size, _ =self.prepare_value(start, end, block_ids) + self.m_store.put(key, addr, size) + if is_last_chunk: + self.set_finished_request(req_id) + self.request_queue.task_done() + + +class KVCacheStoreRecvingThread(KVTransferThread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], token_database: ChunkedTokenDatabase, + block_len: list[int], block_size:int, ready_event: threading.Event): + super().__init__(tp_rank, tp_size, m_store, local_kv_caches_base_addr, + token_database, block_len, block_size, ready_event, name="KVCacheStoreRecvingThread") + + def _handle_request(self, req_meta: dict[str, Any]): + tokens = req_meta["tokens"] + mask = req_meta["mask"] + block_ids = req_meta["block_ids"] + req_id = req_meta["req_id"] + for start, end, key in self.token_database.process_tokens(tokens, mask): + addr, size, _ = self.prepare_value(start, end, block_ids) + self.m_store.get(key, addr, size) + self.set_finished_request(req_id) + self.request_queue.task_done() + + +class KVCacheStoreLayerSendingThread(KVTransferThread): + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], token_database: ChunkedTokenDatabase, + block_len: list[int], block_size:int, ready_event: threading.Event, num_layers:int): + super().__init__(tp_rank, tp_size, m_store, local_kv_caches_base_addr, + token_database, block_len, block_size, ready_event, name="KVCacheStoreLayerSendingThread") + self.final_layer_id = num_layers - 1 + + def add_request( + self, + req_meta: LasyerMultiBlockReqMeta + ) -> torch.Tensor: + self.request_queue.put(req_meta) + + def _handle_request(self, req_meta: dict[str, Any]): #chunk + torch.npu.current_stream().synchronize() + for index, key in enumerate(req_meta.keys): + addr, size = self.prepare_value_layer(req_meta.starts[index], req_meta.ends[index], req_meta.block_ids, req_meta.layer_id) + self.m_store.put(key, addr, size) + if req_meta.layer_id==self.final_layer_id: + self.set_finished_request(req_meta.req_id) + self.request_queue.task_done() + + +class KVCacheStoreLayerRecvingThread(KVTransferThread): + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], token_database: ChunkedTokenDatabase, + block_len: list[int], block_size:int, ready_event: threading.Event, get_event: threading.Event): + super().__init__(tp_rank, tp_size, m_store, local_kv_caches_base_addr, + token_database, block_len, block_size, ready_event, name="KVCacheStoreLayerRecvingThread") + self.get_event=get_event + + def add_request( + self, + req_meta: LasyerMultiBlockReqMeta + ) -> torch.Tensor: + self.request_queue.put(req_meta) + + def _handle_request(self, req_meta: dict[str, Any]): #chunk + for index, key in enumerate(req_meta.keys): + addr, size=self.prepare_value_layer(req_meta.starts[index], req_meta.ends[index], req_meta.block_ids, req_meta.layer_id) + self.m_store.get(key, addr, size) + self.request_queue.task_done() + self.get_event.set() + diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py new file mode 100644 index 0000000000..362a80608f --- /dev/null +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -0,0 +1,494 @@ +# Standard +from typing import Dict, Generator, List, Optional, Union +import math +import asyncio +import multiprocessing +import time +import threading +import queue +from dataclasses import dataclass + +# Third Party +import torch, torch_npu +from vllm.utils import cdiv, get_kv_cache_torch_dtype, round_down +from vllm.utils import logger +from vllm.config import ( + VllmConfig, +) +from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey, MooncakeEngineMetadata, ChunkedTokenDatabase, LayerMooncakeEngineKey, MooncakeConnectorMetadata, LasyerMultiBlockReqMeta +from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore +from vllm_ascend.distributed.mooncake.kv_transfer import KVTransferThread, KVCacheStoreSendingThread, KVCacheStoreRecvingThread, KVCacheStoreLayerSendingThread, KVCacheStoreLayerRecvingThread +# First Party + + +class MooncakeEngine: + #The main class for the cache engine. + + def __init__( + self, + vllm_config: VllmConfig, + use_layerwize: bool, + skip_last_n_tokens: int, + ): + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self.use_mla = False + if ( + hasattr(model_config, "use_mla") + and isinstance(model_config.use_mla, bool) + and model_config.use_mla + ): + self.use_mla = True + self.use_layerwise=use_layerwize + self.skip_last_n_tokens = skip_last_n_tokens + self.tp_rank = parallel_config.rank + self.tp_size = parallel_config.tensor_parallel_size + self.kv_role = vllm_config.kv_transfer_config.kv_role + self.block_size = vllm_config.cache_config.block_size + self.current_layer = 0 + # self.use_mla = first_kv_cache_tuple[0].size( + # -1) != first_kv_cache_tuple[1].size(-1) + self.num_layers = model_config.get_num_layers(parallel_config) + self.block_size = vllm_config.cache_config.block_size + num_kv_head = model_config.get_num_kv_heads(parallel_config) + head_size = model_config.get_head_size() + kv_dtype = get_kv_cache_torch_dtype(vllm_config.cache_config.cache_dtype, model_config.dtype) + self.hidden_dim_size = num_kv_head * head_size + if self.use_mla: + kv_shape = (self.num_layers, 1, self.block_size, 1, head_size) + else: + kv_shape = (self.num_layers, 2, self.block_size, num_kv_head, head_size) + self.metadata = MooncakeEngineMetadata( + model_config.model, + parallel_config.world_size, + parallel_config.rank, + kv_dtype, + kv_shape, + self.block_size, + self.use_mla, + ) + + self.token_database = ChunkedTokenDatabase(self.metadata) + + self.m_store = Mooncakestore(parallel_config) + + self.kv_send_thread: Optional[KVTransferThread] = None + self.kv_recv_thread: Optional[KVTransferThread] = None + + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + _, first_kv_cache_tuple = next(iter(kv_caches.items())) + first_kv_cache = first_kv_cache_tuple[0] + + # TODO(tms): Find a more robust way to detect and handle MLA + if self.use_mla: + # MLA case.[num_block, block_size, 1, hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] + block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] + self.block_len = [ + first_kv_cache[0].element_size() * math.prod(block_shape_norm), + first_kv_cache[1].element_size() * math.prod(block_shape_pe) + ] + logger.info( + "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", + self.num_blocks, block_shape_norm, block_shape_pe) + else: + # [num_block, block_size, num_head, hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + kv_elem_size = first_kv_cache.element_size() + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + self.block_len = [kv_elem_size * math.prod(block_shape)] + logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, + block_shape) + + logger.info("Registering KV_Caches. use_mla: %s, shape %s", + self.use_mla, first_kv_cache.shape) + + self.kv_caches = kv_caches + self.m_store.set_kv_caches(kv_caches.values()) + self.kv_caches_base_addr = [] + for cache_or_caches in kv_caches.values(): + # Normalize to always be a list of caches + if self.use_mla: + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + self.kv_caches_base_addr.append(base_addr) + else: + cache_list = [cache_or_caches + ] if self.use_mla else cache_or_caches + for cache in cache_list: + base_addr = cache.data_ptr() + self.kv_caches_base_addr.append(base_addr) + + if self.use_layerwise: + self.get_event = threading.Event() + if self.kv_role == 'kv_producer': + ready_event_sending = threading.Event() + self.kv_send_thread = KVCacheStoreLayerSendingThread(self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, self.block_len, self.block_size, ready_event_sending, self.num_layers) + self.kv_send_thread.start() + ready_event = threading.Event() + self.kv_recv_thread = KVCacheStoreLayerRecvingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, self.block_len, self.block_size, ready_event, self.get_event) + self.kv_recv_thread.start() + ready_event.wait() + else: + if self.kv_role == 'kv_producer': + ready_event_sending = threading.Event() + self.kv_send_thread = KVCacheStoreSendingThread(self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, self.block_len, self.block_size, ready_event_sending) + self.kv_send_thread.start() + ready_event = threading.Event() + self.kv_recv_thread = KVCacheStoreRecvingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, self.block_len, self.block_size, ready_event) + self.kv_recv_thread.start() + ready_event.wait() + + def start_load_kv(self, metadata: MooncakeConnectorMetadata): + self.current_layer = 0 + self.layerwise_retrievers = [] + for request in metadata.requests: + load_spec = request.load_spec + if load_spec is None or not load_spec.can_load: #load =0 + continue + tokens = request.token_ids + req_id = request.req_id + if (load_spec.mooncake_cached_tokens % self.block_size != 0) and (load_spec.mooncake_cached_tokens == tokens.shape[0] - 1): + tokens = tokens[: request.load_spec.mooncake_cached_tokens + 1] + else: + tokens = tokens[: request.load_spec.mooncake_cached_tokens] + masked_token_count = ( + request.load_spec.vllm_cached_tokens + // self.block_size + * self.block_size + ) + token_mask = torch.ones_like(tokens, dtype=torch.bool) + token_mask[:masked_token_count] = False + if self.use_layerwise: + layerwise_retriever = self.retrieve_layer( + req_id, + tokens, + request.block_ids, + token_mask, + ) + next(layerwise_retriever) # first layer load + self.layerwise_retrievers.append(layerwise_retriever) + else: + self.kv_recv_thread.add_request( + req_id, + tokens, + request.block_ids, + token_mask, + ) + + def wait_for_layer_load(self) -> None: + """MooncakeConnector does not do layerwise saving.""" + for layerwise_retriever in self.layerwise_retrievers: + ret_token_mask = next(layerwise_retriever) + if self.current_layer == self.num_layers - 1: + assert ret_token_mask is not None + num_retrieved_tokens = ret_token_mask.sum().item() + logger.info(f"Retrieved {num_retrieved_tokens} tokens") + + def save_kv_layer(self, connector_metadata: MooncakeConnectorMetadata) -> None: + """MooncakeConnector does not save explicitly.""" + if self.current_layer == 0: + self.layerwise_storers = [] + for request in connector_metadata.requests: + save_spec = request.save_spec + if save_spec is None or not save_spec.can_save: + continue + + token_ids = request.token_ids + req_id = request.req_id + assert isinstance(token_ids, torch.Tensor) + assert token_ids.is_cpu + + # TODO: whther need to remov saveThread + # no lookup, skipmask + skip_leading_tokens = max( + self.lookup(token_ids, self.use_layerwise), + save_spec.skip_leading_tokens, + ) + if skip_leading_tokens == len(token_ids): + continue # skip this request + + skip_leading_tokens = ( + skip_leading_tokens + // self.block_size + * self.block_size + ) + + store_mask = torch.ones_like(token_ids, dtype=torch.bool) + store_mask[:skip_leading_tokens] = False + logger.info( + "Storing KV cache for %d out of %d tokens " + "(skip_leading_tokens=%d) for request %s", + len(token_ids) - skip_leading_tokens, + len(token_ids), + skip_leading_tokens, + request.req_id, + ) + + layerwise_storer = self.store_layer( + req_id, + token_ids, + mask=store_mask, + block_ids=request.block_ids, + ) + self.layerwise_storers.append(layerwise_storer) + for layerwise_storer in self.layerwise_storers: + try: + next(layerwise_storer) + except Exception as e: + raise + self.current_layer = self.current_layer + 1 + + def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata): + """MooncakeConnector does not save explicitly.""" + for request in connector_metadata.requests: + save_spec = request.save_spec + if save_spec is None or not save_spec.can_save: + continue + + token_ids = request.token_ids + # token_ids = token_ids[: -self.skip_last_n_tokens] + req_id = request.req_id + assert isinstance(token_ids, torch.Tensor) + assert token_ids.is_cpu + + skip_leading_tokens = max( + self.lookup(token_ids, self.use_layerwise), + save_spec.skip_leading_tokens, + ) + if skip_leading_tokens == len(token_ids): + continue # skip this request + + skip_leading_tokens = ( + skip_leading_tokens + // self.block_size + * self.block_size + ) + + store_mask = torch.ones_like(token_ids, dtype=torch.bool) + store_mask[:skip_leading_tokens] = False + + logger.info( + "Storing KV cache for %d out of %d tokens " + "(skip_leading_tokens=%d) for request %s", + len(token_ids) - skip_leading_tokens, + len(token_ids), + skip_leading_tokens, + request.req_id, + ) + + self.kv_send_thread.add_request( + req_id, + token_ids, + request.block_ids, + store_mask, + request.is_last_chunk, + ) + + def retrieve_layer( + self, + req_id: str, + tokens: torch.Tensor, + block_ids: list[int], + mask: Optional[torch.Tensor] = None, + ) -> Generator[Optional[torch.Tensor], None, None]: + """ + Retrieve the KV cache in a layerwise manner. + + :param torch.Tensor tokens: The tokens of the corresponding KV caches. + + :param Optional[torch.Tensor] mask: The mask for the tokens. Should + have the same length as tokens. And the mask should ALWAYS be like + FFFFFTTTTTTT, where True means the tokens needs to be matched. + + :param **kwargs: The additional arguments for the KV transfer which + will be passed into the npu_transfer. + + return: A generator that yields Optional[torch.Tensor]. The tensor will + be the boolean mask indicating which tokens are retrieved and will + only be returned in the last iteration. + """ + + if mask is not None: + num_required_tokens = torch.sum(mask).item() + else: + num_required_tokens = len(tokens) + + ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu") + + starts = [] + ends = [] + keys = [] + first_flag= True + for start, end, key in self.token_database.process_tokens(tokens, mask): + keys_multi_layer = key.split_layers(self.num_layers) + starts.append(start) + ends.append(end) + keys.append(keys_multi_layer) + ret_mask[start:end] = True + + if keys: + # Transpose the keys into layer major format + keys = [list(row) for row in zip(*keys, strict=False)] # [num_layer,block_num] + for layer_id, keys_multi_chunk in enumerate(keys): + if not first_flag: + is_finish=self.get_event.wait(timeout=3) #try---cache + if not is_finish: + raise SystemError("Layerwise get failed") + self.get_event.clear() + req_meta=LasyerMultiBlockReqMeta( + req_id, + keys_multi_chunk, + starts, + ends, + block_ids, + layer_id + ) + self.kv_recv_thread.add_request(req_meta) + first_flag=False + yield None + else: + # If no cache are found, we still need to yield to avoid + # `StopIteration` + for layer_id in range(self.num_layers): + yield None + + retrieved_tokens = torch.sum(ret_mask) + logger.debug( + f"Retrieved {retrieved_tokens} " + f"out of {num_required_tokens} " + f"out of total {len(tokens)} tokens" + ) + + yield ret_mask + + def store_layer( + self, + req_id: str, + tokens: torch.Tensor, + block_ids: list[int], + mask: Optional[torch.Tensor] = None, + ) -> Generator[None, None, None]: + """ + Store the KV cache in a layerwise manner. + + :param torch.Tensor tokens: The tokens of the corresponding KV caches. + + :param Optional[torch.Tensor] mask: The mask for the tokens. Should + have the same length as tokens. And the mask should ALWAYS be like + FFFFFTTTTTTT, where True means the tokens needs to be matched. + + :param **kwargs: The additional arguments for the storage backend which + will be passed into the gpu_connector. + + return: A generator that yields None. In the first iteration, the + generator allocates the memory objects for all layers and moves + the KV cache of the first layer from GPU to CPU. In the next + iterations, it moves the KV cache of layer i from GPU to the memory + objects (on CPU) and puts the memory objects of layer i-1 to the + storage backends. In the last iteration, it puts the memory objects + of the last layer to the storage backends. + """ + + if mask is not None: + num_stored_tokens = torch.sum(mask).item() + else: + num_stored_tokens = len(tokens) + + starts = [] + ends = [] + keys = [] + for start, end, key in self.token_database.process_tokens(tokens, mask): + keys_multi_layer = key.split_layers(self.num_layers) + starts.append(start) + ends.append(end) + keys.append(keys_multi_layer) #[block_num,layer_num] + + if keys: + keys = [list(row) for row in zip(*keys, strict=False)] #[layer_num,block_num] + for layer_id, keys_multi_chunk in enumerate(keys): + req_meta=LasyerMultiBlockReqMeta( + req_id, + keys_multi_chunk, + starts, + ends, + block_ids, + layer_id + ) + self.kv_send_thread.add_request(req_meta) + yield + else: + for layer_id in range(self.num_layers): + yield + logger.debug(f"Stored {num_stored_tokens} out of total {len(tokens)} tokens") + + def get_finished(self) -> tuple[set[str], set[str]]: + done_sending = ( + self.kv_send_thread. + get_and_clear_finished_requests( # type: ignore[union-attr] + ) if self.kv_role == 'kv_producer' else set()) + done_recving = self.kv_recv_thread.get_and_clear_finished_requests() # type: ignore[union-attr] + + logger.debug( + "Number of completed KV cache send requests: %d, receive " + "requests: %d, tp_rank:%d", len(done_sending), len(done_recving), self.tp_rank) + return done_sending, done_recving + + def wait_layer_transfer_finish(self): + time.sleep(10) + pass + + def lookup( + self, + tokens: Union[torch.Tensor, List[int]], + use_layerwise: bool, + ) -> int: + """ + Checks the existence of KV cache of the tokens from the cache engine. + + :param tokens: the input tokens, with shape [seq_len] + + :return: An int indicating how many prefix tokens are cached. + """ + end = 0 + + for start, end, key in self.token_database.process_tokens(tokens): + try: + if use_layerwise: + keys=[] + keys_multi_layer = key.split_layers(self.num_layers) + for key in keys_multi_layer: + keys.append(key.to_string()) + # batch is_exists + ress=self.m_store.batch_exists(keys) + res=1 + for value in ress: + if value != 1: + res=0 + break + else: + res=self.m_store.exists(key) + if res == 1: + continue + else: + return start + except Exception as e: + logger.warning(f"Remote connection failed in contains: {e}") + return start + + # all tokens where found, return the maximal end + return end + + def close(self) -> None: + """Close the cache engine and free all the resources""" + self.m_store.close() diff --git a/vllm_ascend/distributed/mooncake/mooncake_store.py b/vllm_ascend/distributed/mooncake/mooncake_store.py new file mode 100644 index 0000000000..6eaf77f6fc --- /dev/null +++ b/vllm_ascend/distributed/mooncake/mooncake_store.py @@ -0,0 +1,106 @@ +# Standard +from contextlib import contextmanager +from dataclasses import dataclass +from functools import reduce +from typing import List, Optional, no_type_check +from enum import Enum +import asyncio +import json +import operator +import os +import struct +import ctypes +import time +import csv +from contextlib import contextmanager +# Third Party +import torch, torch_npu +from vllm.config import ParallelConfig + +# First Party +from vllm.utils import logger +from vllm.distributed.parallel_state import (get_dp_group, + get_tensor_model_parallel_rank, + get_tp_group) +from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey +from .config_data import MooncakeStoreConfig + +METADATA_BYTES_LEN = 24 +BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790")) + + +class Mooncakestore(): + def __init__( + self, parallel_config: ParallelConfig + ): + try: + from mooncake.store import MooncakeDistributedStore + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run vLLM with MooncakeConnector.") from e + tp_rank = get_tensor_model_parallel_rank() + tp_size = parallel_config.tensor_parallel_size + dp_rank = parallel_config.data_parallel_rank_local + all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None) + if not all_device_ids: + device_ids_list = list(range(dp_rank * tp_size, (dp_rank + 1) * tp_size)) + else: + device_ids_list = list(map(int, all_device_ids.split(','))) + assert len(device_ids_list) > tp_rank + device_id = device_ids_list[tp_rank] + self.config = MooncakeStoreConfig.load_from_env() + if self.config.protocol == "ascend": + local_hostname = self.config.local_hostname + ":" + str(BASE_PORT + int(device_id)) + \ + ":npu_" + str(device_id) + else: + local_hostname = self.config.local_hostname + self.store = MooncakeDistributedStore() + ret = self.store.setup( + local_hostname, + self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, self.config.device_name, + self.config.master_server_address + ) + if ret != 0: + msg = "Initialize mooncake failed." + logger.error(msg) + raise RuntimeError(msg) + + def set_kv_caches(self, kvcache): + self.kvcache = list(kvcache) + + def exists(self, key: MooncakeEngineKey) -> bool: + return self.store.is_exist(key.to_string()) == 1 + + def batch_exists(self, keys:list[str]) -> list[bool]: + return self.store.batch_is_exist(keys) + + def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): + key_str = key.to_string() + try: + res = self.store.batch_get_into_ascend(key_str, addr, size) + if res[0] != expect_res: + logger.error(f"Failed to get key: [{key_str}] .") + except Exception as e: + logger.error(f"Failed to get key: [{key_str}] .") + return res + + + def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): + key_str = key.to_string() + try: + ret = self.store.batch_put_from_ascend(key_str, addr, size) + if ret[0] != 0: + logger.error(f"Failed to put key {key_str}.") + except Exception as e: + logger.error(f"Failed to put key {key_str}.") + + return ret + + def close(self): + self.store.close() + logger.info("Closed the mooncake store connection") diff --git a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py new file mode 100644 index 0000000000..3bde09a8cf --- /dev/null +++ b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py @@ -0,0 +1,466 @@ + +import threading +from enum import Enum +from collections import defaultdict +from collections.abc import Iterator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Dict, List, Tuple, Union +import msgspec +import torch +import zmq +import threading + +from concurrent.futures import Future + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group) +from vllm.utils import logger +from vllm.utils import make_zmq_path, make_zmq_socket, round_down, get_ip,cdiv +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder +from vllm.v1.core.sched.output import SchedulerOutput +from vllm_ascend.distributed.mooncake.mooncake_engine import MooncakeEngine +from vllm.v1.request import Request +from vllm.forward_context import ForwardContext +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm_ascend.distributed.mooncake.config_data import MooncakeConnectorMetadata, RequestTracker, LoadSpec, ReqMeta + + + +class MooncakeConnectorV1(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self.kv_role = vllm_config.kv_transfer_config.kv_role + + self.use_layerwise=vllm_config.kv_transfer_config.kv_connector_extra_config.get("use_layerwise", False) + + self.kv_caches: dict[str, torch.Tensor] = {} + + self._block_size = vllm_config.cache_config.block_size + + self.skip_last_n_tokens = vllm_config.kv_transfer_config.get_from_extra_config( + "skip_last_n_tokens", 1 + ) + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = MooncakeStoreConnectorV1Scheduler(vllm_config, self.skip_last_n_tokens, self.use_layerwise) + else: + self.connector_worker = MooncakeEngine( + vllm_config, + self.use_layerwise, + self.skip_last_n_tokens, + ) + + assert self.connector_worker is not None + if vllm_config.parallel_config.rank == 0: + self.lookup_server = MooncakeLookupServer( + self.connector_worker, vllm_config, self.use_layerwise + ) + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + attn_metadata = forward_context.attn_metadata + # if attn_metadata is None: + # logger.warning("In connector.start_load_kv, but the attn_metadata is None") + # return + assert self.connector_worker is not None + assert isinstance(self._get_connector_metadata(), MooncakeConnectorMetadata) + self.connector_worker.start_load_kv(self._get_connector_metadata()) + + def wait_for_layer_load(self, layer_name: str) -> None: + """MooncakeStoreConnector does not do layerwise saving.""" + if not self.use_layerwise: + return + self.connector_worker.wait_for_layer_load() + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """MooncakeStoreConnector does not save explicitly.""" + if not self.use_layerwise: + return + + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + self.connector_worker.save_kv_layer(self._get_connector_metadata()) + + def wait_for_save(self): + """MooncakeStoreConnector does not save explicitly.""" + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + + if self.use_layerwise: + self.connector_worker.wait_layer_transfer_finish() + return + + self.connector_worker.wait_for_save(self._get_connector_metadata()) + #time.sleep(1) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + +def get_zmq_rpc_path_mooncake( + vllm_config: Optional["VllmConfig"] = None, +) -> str: + base_url = envs.VLLM_RPC_BASE_PATH + # Default to 0 if not configured + rpc_port = 0 + if vllm_config is not None: + rpc_port = vllm_config.kv_transfer_config.get_from_extra_config( + "mooncake_rpc_port", 0 + ) + logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port) + return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}" + + +class MooncakeStoreConnectorV1Scheduler: + def __init__(self, vllm_config: "VllmConfig", skip_last_n_tokens, use_layerwise): + self.client=MooncakeLookupClient(vllm_config) + self.use_layerwise=use_layerwise + self.kv_role = vllm_config.kv_transfer_config.kv_role + # request_id -> (vllm cached tokes, mooncake cached tokens) + self.load_specs: dict[str, LoadSpec] = {} + self.skip_last_n_tokens = skip_last_n_tokens + self._block_size = vllm_config.cache_config.block_size + # request_id -> full_token_ids + self._request_trackers: dict[str, RequestTracker] = {} + # Whether to discard partial chunks + self._discard_partial_chunks = ( + vllm_config.kv_transfer_config.get_from_extra_config( + "discard_partial_chunks", False + ) + ) + self._unfinished_requests: dict[str, Request] = {} + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Check for external KV cache hit. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + + if self._discard_partial_chunks: + token_block_end = len(request.prompt_token_ids) // self._block_size * self._block_size + token_ids = torch.tensor(request.prompt_token_ids[:token_block_end]) + else: + token_ids = torch.tensor(request.prompt_token_ids) + + num_external_hit_tokens = self.client.lookup(token_ids) + + if num_external_hit_tokens == request.num_tokens: + num_external_hit_tokens -= 1 + + need_to_allocate = num_external_hit_tokens - num_computed_tokens + + logger.info( + "Reqid: %s, Total tokens %d, mooncake hit tokens: %d, need to load: %d", + request.request_id, + request.num_tokens, + num_external_hit_tokens, + need_to_allocate, + ) + + if need_to_allocate <= 0: + return 0, False + + self.load_specs[request.request_id] = LoadSpec( + vllm_cached_tokens=num_computed_tokens, + mooncake_cached_tokens=num_external_hit_tokens, + can_load=False, + ) + + return need_to_allocate, not self.use_layerwise + + def update_state_after_alloc(self, request: "Request", blocks:"KVCacheBlocks", num_external_tokens: int): + """ + Update KVConnector state after temporary buffer alloc. + + For SharedStorageConnector, update _request_needs_load + if the CacheManager this allocated blocks for us. + """ + local_block_ids=[] + if num_external_tokens > 0: + local_block_ids = blocks.get_block_ids()[0] + + self._unfinished_requests[request.request_id] = ( + request, local_block_ids) + if request.request_id not in self.load_specs: + # No KV tokens from external KV cache, return + return + + if num_external_tokens == 0: + # No need to load anything + self.load_specs[request.request_id].can_load = False + return + + assert ( + num_external_tokens > 0 + and num_external_tokens + == self.load_specs[request.request_id].mooncake_cached_tokens + - self.load_specs[request.request_id].vllm_cached_tokens + ), ( + f"Mismatch in number of tokens: {num_external_tokens} vs " + f"{self.load_specs[request.request_id].mooncake_cached_tokens} - " + f"{self.load_specs[request.request_id].vllm_cached_tokens}" + f" for request {request.request_id}" + ) + + self.load_specs[request.request_id].can_load = True + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + """Attach the connector metadata to the request object. + + This function should NOT modify other fields in the scheduler_output + except the `kv_connector_metadata` field. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + + force_skip_save = self.kv_role == "kv_consumer" + + meta = MooncakeConnectorMetadata() + + for finished_req_id in scheduler_output.finished_req_ids: + self._request_trackers.pop(finished_req_id, None) + self._unfinished_requests.pop(finished_req_id, None) + + for request in scheduler_output.scheduled_new_reqs: + # Right now, we only load KV for new requests + load_spec = self.load_specs.pop(request.req_id, None) + num_tokens_to_compute = ( + request.num_computed_tokens + + scheduler_output.num_scheduled_tokens[request.req_id] + ) + request_tracker = RequestTracker.from_new_request( + request, num_tokens_to_compute + ) + self._request_trackers[request.req_id] = request_tracker + + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=load_spec, + skip_save=force_skip_save, + is_last_chunk=len(request_tracker.token_ids)>=len(request.prompt_token_ids), + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + + cached_reqs = scheduler_output.scheduled_cached_reqs + if isinstance(cached_reqs, list): + for i, req in enumerate(cached_reqs): + request_tracker = self._request_trackers[req.req_id] + request_tracker.update(req.new_token_ids, req.new_block_ids) + + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=None, + skip_save=force_skip_save, + is_last_chunk=len(request_tracker.token_ids)>=len(req.prompt_token_ids), + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + else: + for i, req_id in enumerate(cached_reqs.req_ids): + request_tracker = self._request_trackers[req_id] + num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] + req_tuple = self._unfinished_requests.get(req_id) + if req_tuple: + request = req_tuple[0] + num_current_tokens = len(request_tracker.token_ids) + new_token_ids = request.all_token_ids[ + num_current_tokens : num_current_tokens + num_new_tokens + ] + else: + raise ValueError( + f"Request {req_id} is not in _unfinished_requests, " + f"but it is scheduled to be cached" + ) + new_block_ids = cached_reqs.new_block_ids[i] + if not new_block_ids: + continue + request_tracker.update(new_token_ids, new_block_ids) + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=None, + skip_save=force_skip_save, + is_last_chunk=len(request_tracker.token_ids)>=len(request.prompt_token_ids), + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + + request_ids = [req.req_id for req in scheduler_output.scheduled_new_reqs] + for request_id, (request,block_ids) in self._unfinished_requests.items(): + if not request_id in request_ids and not request_id in cached_reqs.req_ids: + load_spec = self.load_specs.pop(request_id, None) + if not load_spec: + continue + num_tokens_to_compute = load_spec.mooncake_cached_tokens + if (num_tokens_to_compute % self._block_size != 0) and (num_tokens_to_compute == len(request.prompt_token_ids) - 1): + num_tokens_to_compute = num_tokens_to_compute + 1 + request_tracker = RequestTracker( + req_id=request_id, + token_ids=request.prompt_token_ids[:num_tokens_to_compute].copy(), + allocated_block_ids=block_ids, + num_saved_tokens=0, + ) + + self._request_trackers[request_id] = request_tracker + + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=load_spec, + skip_save=None, + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + + if self.kv_role == "kv_consumer": + return False, None + delay_free_blocks = len(block_ids) > 0 + if delay_free_blocks: + logger.info("Delaying free of %d blocks for request %s", + len(block_ids), request.request_id) + return delay_free_blocks, None + + +class MooncakeLookupClient: + def __init__(self, vllm_config: "VllmConfig"): + self.encoder = MsgpackEncoder() + self.ctx = zmq.Context() # type: ignore[attr-defined] + socket_path = get_zmq_rpc_path_mooncake(vllm_config) + self.socket = make_zmq_socket( + self.ctx, + socket_path, + zmq.REQ, # type: ignore[attr-defined] + bind=False, + ) + + def lookup(self, token_ids: torch.Tensor) -> int: + request = self.encoder.encode(token_ids) + self.socket.send_multipart(request, copy=False) + resp = self.socket.recv() + result = int.from_bytes(resp, "big") + return result + + def close(self): + self.socket.close(linger=0) + + +class MooncakeLookupServer: + def __init__( + self, + mooncake_engine: MooncakeEngine, + vllm_config: "VllmConfig", + use_layerwise: bool, + ): + self.decoder = MsgpackDecoder(torch.Tensor) + self.ctx = zmq.Context() # type: ignore[attr-defined] + socket_path = get_zmq_rpc_path_mooncake(vllm_config) + self.socket = make_zmq_socket( + self.ctx, + socket_path, + zmq.REP, # type: ignore[attr-defined] + bind=True, + ) + + self.mooncake_engine = mooncake_engine + self.running = True + + def process_request(): + while self.running: + frames = self.socket.recv_multipart(copy=False) + token_ids = self.decoder.decode(frames) + result = self.mooncake_engine.lookup(token_ids, use_layerwise) + response = result.to_bytes(4, "big") + self.socket.send(response) + + self.thread = threading.Thread(target=process_request, daemon=True) + self.thread.start() + + def close(self): + self.socket.close(linger=0) + # TODO: close the thread! diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ab8f593cf1..241725429f 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1613,13 +1613,9 @@ def execute_model( if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: hidden_states, aux_hidden_states = hidden_states - kv_connector_output = None - if finished_sending is not None or finished_recving is not None: - kv_connector_output = KVConnectorOutput( - finished_sending=finished_sending, - finished_recving=finished_recving) - else: - kv_connector_output = None + kv_connector_output = KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving) finished_sending = None finished_recving = None with ProfileExecuteDuration().capture_async("post process"): @@ -1867,8 +1863,6 @@ def kv_connector_no_forward( # For the case of no forward caused by receiving remote kv, # one round of dummy inference is necessary # to prevent hang over the collective calls. - if not finished_sending and not finished_recving: - return EMPTY_MODEL_RUNNER_OUTPUT output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output.kv_connector_output = KVConnectorOutput( From da1f99f3e20e63efc46f7518221f86c61c0b4844 Mon Sep 17 00:00:00 2001 From: fems14 <1804143737@qq.com> Date: Tue, 16 Sep 2025 11:25:26 +0800 Subject: [PATCH 02/10] mooncake store bugfix Signed-off-by: fems14 <1804143737@qq.com> --- vllm_ascend/attention/mla_v1.py | 35 +-- .../distributed/mooncake/config_data.py | 146 +++++------ .../distributed/mooncake/kv_transfer.py | 204 ++++++++++------ .../distributed/mooncake/mooncake_engine.py | 217 ++++++++-------- .../distributed/mooncake/mooncake_store.py | 60 ++--- .../mooncake/mooncake_store_connector_v1.py | 231 ++++++++++-------- 6 files changed, 437 insertions(+), 456 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 5e47814609..73420e50a9 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -980,7 +980,6 @@ def forward( assert attn_metadata.num_decodes is not None and \ attn_metadata.num_prefills is not None and \ attn_metadata.num_decode_tokens is not None - self.wait_for_kv_layer_from_connector(layer.layer_name) num_decode_tokens = attn_metadata.num_decode_tokens # Inputs and outputs may be padded for CUDA graphs output_padded = output @@ -1051,36 +1050,4 @@ def forward( is_force_scatter=self.enable_shared_expert_dp)[0] current_ms_metadata.after_comm_event.record() del o_proj_input - self.maybe_save_kv_layer_to_connector(layer_name=layer.layer_name, kv_cache_layer=kv_cache) - return output_padded - - def wait_for_kv_layer_from_connector(self, layer_name: str): - if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): - return - - connector = get_kv_transfer_group() - - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if attn_metadata is None: - return - assert isinstance(attn_metadata, AscendMLAMetadata) - connector.wait_for_layer_load(layer_name) - - def maybe_save_kv_layer_to_connector( - self, - layer_name: str, - kv_cache_layer: List[torch.Tensor], - ): - if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): - return - - connector = get_kv_transfer_group() - - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if attn_metadata is None: - return - assert isinstance(attn_metadata, AscendMLAMetadata) - connector.save_kv_layer(layer_name, kv_cache_layer, - attn_metadata) + return output_padded \ No newline at end of file diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py index 47eda15afd..a49393be45 100644 --- a/vllm_ascend/distributed/mooncake/config_data.py +++ b/vllm_ascend/distributed/mooncake/config_data.py @@ -1,17 +1,16 @@ -# Standard -from dataclasses import dataclass import hashlib -from typing import Any, Iterable, List, Optional, Tuple, Union import json import os -# Third Party +from dataclasses import dataclass +from typing import Iterable, List, Optional, Tuple, Union + +import torch from numpy import array -import torch, torch_npu -from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata -from vllm.utils import logger -from vllm.utils import cdiv +from vllm.distributed.kv_transfer.kv_connector.v1.base import \ + KVConnectorMetadata +from vllm.utils import cdiv, logger +from vllm.v1.core.sched.output import NewRequestData -# First Party @dataclass class MooncakeEngineMetadata: @@ -30,7 +29,8 @@ class MooncakeEngineMetadata: block_size: int = 128 """ whether use MLA""" use_mla: bool = False - + + @dataclass(order=True) class MooncakeEngineKey: model_name: str @@ -39,20 +39,16 @@ class MooncakeEngineKey: chunk_hash: str def __hash__(self): - return hash( - ( - self.model_name, - self.world_size, - self.worker_id, - self.chunk_hash, - ) - ) + return hash(( + self.model_name, + self.world_size, + self.worker_id, + self.chunk_hash, + )) def to_string(self): - return ( - f"{self.model_name}@{self.world_size}" - f"@{self.worker_id}@{self.chunk_hash}" - ) + return (f"{self.model_name}@{self.world_size}" + f"@{self.worker_id}@{self.chunk_hash}") def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]: """Split the key into multiple keys for each layer""" @@ -65,18 +61,8 @@ def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]: self.worker_id, self.chunk_hash, layer_id, - ) - ) - return keys - - @staticmethod - def from_string(s): - parts = s.split("@") - if len(parts) != 5: - raise ValueError(f"Invalid key string: {s}") - return MooncakeEngineKey( - parts[0], int(parts[1]), int(parts[2]), parts[3] - ) + )) + return keys def to_dict(self): # Note(Kuntai): this is used for serializing CacheEngineKey via msgpack. @@ -105,42 +91,30 @@ class LayerMooncakeEngineKey(MooncakeEngineKey): layer_id: int def __hash__(self): - return hash( - ( - self.model_name, - self.world_size, - self.worker_id, - self.chunk_hash, - self.layer_id, - ) - ) + return hash(( + self.model_name, + self.world_size, + self.worker_id, + self.chunk_hash, + self.layer_id, + )) def to_string(self): - return ( - f"{self.model_name}@{self.world_size}" - f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}" - ) - - @staticmethod - def from_string(s): - parts = s.split("@") - return LayerMooncakeEngineKey( - parts[0], - int(parts[1]), - int(parts[2]), - parts[3], - int(parts[4]), - ) + return (f"{self.model_name}@{self.world_size}" + f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}") class ChunkedTokenDatabase(): + def __init__( self, metadata: Optional[MooncakeEngineMetadata] = None, ): self.metadata = metadata - def _make_key_by_hash(self, chunk_hash: str, layer_id: Optional[int] = None): + def _make_key_by_hash(self, + chunk_hash: str, + layer_id: Optional[int] = None): assert self.metadata is not None return MooncakeEngineKey( self.metadata.model_name, @@ -159,7 +133,8 @@ def _hash( tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes() elif isinstance(tokens, list): tokens_bytes = array.array("I", tokens).tobytes() - return hashlib.sha256(prefix_hash.encode("ascii") + tokens_bytes).hexdigest() + return hashlib.sha256(prefix_hash.encode("ascii") + + tokens_bytes).hexdigest() def _chunk_tokens( self, @@ -175,7 +150,7 @@ def _chunk_tokens( shape [metadata.block_size] """ for i in range(0, len(tokens), self.metadata.block_size): - yield tokens[i : i + self.metadata.block_size] + yield tokens[i:i + self.metadata.block_size] def _prefix_hash( self, @@ -248,6 +223,7 @@ class LoadSpec: # Whether the scheduler allow us to load the tokens can_load: bool + @dataclass class SaveSpec: # Skip already saved tokens @@ -255,6 +231,7 @@ class SaveSpec: # Whether the scheduler allow us to save the tokens can_save: bool + @dataclass class RequestTracker: # Request id @@ -299,7 +276,8 @@ def from_new_request( return RequestTracker( req_id=new_request.req_id, - token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].copy(), + token_ids=new_request.prompt_token_ids[:num_tokens_to_compute]. + copy(), allocated_block_ids=unfolded_block_ids, num_saved_tokens=0, ) @@ -322,7 +300,8 @@ def update( elif isinstance(new_block_ids, list): pass else: - raise ValueError(f"Unsupported new_block_ids type {type(new_block_ids)}") + raise ValueError( + f"Unsupported new_block_ids type {type(new_block_ids)}") self.allocated_block_ids.extend(new_block_ids) @@ -342,6 +321,7 @@ class ReqMeta: load_spec: Optional[LoadSpec] = None is_last_chunk: Optional[bool] = None + @staticmethod def from_request_tracker( tracker: RequestTracker, @@ -371,24 +351,17 @@ def from_request_tracker( # 1. has already been saved before (num_saved_tokens > 0) # 2. number of unsaved tokens is not reached the chunk boundary skip_leading_tokens = tracker.num_saved_tokens - chunk_boundary = ( - cdiv(tracker.num_saved_tokens + 1, block_size) * block_size - ) - skip_save = skip_save or ( - tracker.num_saved_tokens > 0 and input_token_len < chunk_boundary - ) + chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) * + block_size if discard_partial_chunks else 0) + # Calculate number of tokens to save based on discard_partial_chunks + # setting + num_tokens_to_save = ((input_token_len // block_size * block_size) + if discard_partial_chunks else input_token_len) + skip_save = skip_save or num_tokens_to_save < chunk_boundary if skip_save and load_spec is None: return None - # Calculate number of tokens to save based on discard_partial_chunks - # setting - num_tokens_to_save = ( - (input_token_len // block_size * block_size) - if discard_partial_chunks - else input_token_len - ) - # If we need to save, update the number of saved tokens if not skip_save: tracker.num_saved_tokens = num_tokens_to_save @@ -408,7 +381,9 @@ def from_request_tracker( else: # Do not load if not in `can_load` state load_spec = None - + logger.debug( + f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}" + ) return ReqMeta( req_id=tracker.req_id, token_ids=token_ids, @@ -419,12 +394,11 @@ def from_request_tracker( ) -@dataclass class MooncakeConnectorMetadata(KVConnectorMetadata): - requests: list[ReqMeta] - def __init__(self): + def __init__(self, unfinished_request_ids): self.requests = [] + self.unfinished_request_ids = unfinished_request_ids def add_request(self, req_meta: ReqMeta) -> None: """Add a request to the metadata. @@ -466,12 +440,12 @@ def from_file(file_path: str) -> "MooncakeStoreConfig": local_buffer_size=config.get("local_buffer_size", 1073741824), protocol=config.get("protocol", "tcp"), device_name=config.get("device_name", ""), - master_server_address=config.get("master_server_address") - ) - + master_server_address=config.get("master_server_address")) + @staticmethod def load_from_env() -> "MooncakeStoreConfig": config_path = os.getenv("MOONCAKE_CONFIG_PATH") if not config_path: - raise ValueError("The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") - return MooncakeStoreConfig.from_file(config_path) + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeStoreConfig.from_file(config_path) \ No newline at end of file diff --git a/vllm_ascend/distributed/mooncake/kv_transfer.py b/vllm_ascend/distributed/mooncake/kv_transfer.py index 4da6b06463..40b0682b03 100644 --- a/vllm_ascend/distributed/mooncake/kv_transfer.py +++ b/vllm_ascend/distributed/mooncake/kv_transfer.py @@ -1,21 +1,22 @@ -import threading import queue -import torch, torch_npu -import zmq -from typing import Any, Iterable, List, Optional, Tuple, Union -from collections import defaultdict, deque -from dataclasses import dataclass +import threading from concurrent.futures import ThreadPoolExecutor -from vllm.utils import logger, get_ip, logger, make_zmq_path, make_zmq_socket -from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey, MooncakeEngineMetadata, ChunkedTokenDatabase, LayerMooncakeEngineKey, MooncakeConnectorMetadata, LasyerMultiBlockReqMeta +from typing import Any, Optional + +import torch +from vllm.utils import logger + +from vllm_ascend.distributed.mooncake.config_data import ( + ChunkedTokenDatabase, LasyerMultiBlockReqMeta) from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore -import os class KVTransferThread(threading.Thread): + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], token_database: ChunkedTokenDatabase, - block_len: list[int], block_size:int, ready_event: threading.Event, name:str): + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event, name: str): super().__init__(daemon=True, name=name) self.tp_rank = tp_rank self.tp_size = tp_size @@ -33,37 +34,42 @@ def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, # TODO(jianzs): make this configurable self.executor = ThreadPoolExecutor(max_workers=32) self.finished_requests: set[str] = set() - + def prepare_value(self, start: int, end: int, block_ids: list[int]): - addr_list=[] - size_list=[] - block_id=block_ids[start//self.block_size] - for index, base_addr in enumerate(self.kv_caches_base_addr): + addr_list = [] + size_list = [] + block_id = block_ids[start // self.block_size] + for index, base_addr in enumerate(self.kv_caches_base_addr): block_len = (self.block_len[index % 2] - if self.use_mla else self.block_len[0]) + if self.use_mla else self.block_len[0]) - addr=base_addr+block_id*block_len - length=int(block_len/self.block_size*(end-start)) + addr = base_addr + block_id * block_len + length = int(block_len / self.block_size * (end - start)) addr_list.append(addr) size_list.append(length) return addr_list, size_list, block_id - - def prepare_value_layer(self, start: int, end: int, block_ids: list[int], layer_id: int): - block_id=block_ids[start//self.block_size] + + def prepare_value_layer(self, start: int, end: int, block_ids: list[int], + layer_id: int): + block_id = block_ids[start // self.block_size] if self.use_mla: - addr_k=self.kv_caches_base_addr[layer_id*2]+block_id*self.block_len[0] - addr_v=self.kv_caches_base_addr[layer_id*2+1]+block_id*self.block_len[1] - length_k=int(self.block_len[0]/self.block_size*(end-start)) - length_v=int(self.block_len[1]/self.block_size*(end-start)) - size_list=[length_k, length_v] + addr_k = self.kv_caches_base_addr[layer_id * + 2] + block_id * self.block_len[0] + addr_v = self.kv_caches_base_addr[layer_id * 2 + + 1] + block_id * self.block_len[1] + length_k = int(self.block_len[0] / self.block_size * (end - start)) + length_v = int(self.block_len[1] / self.block_size * (end - start)) + size_list = [length_k, length_v] else: - addr_k=self.kv_caches_base_addr[layer_id*2]+block_id*self.block_len[0] - addr_v=self.kv_caches_base_addr[layer_id*2+1]+block_id*self.block_len[0] - length=int(self.block_len[0]/self.block_size*(end-start)) - size_list=[length, length] - addr_list=[addr_k,addr_v] + addr_k = self.kv_caches_base_addr[layer_id * + 2] + block_id * self.block_len[0] + addr_v = self.kv_caches_base_addr[layer_id * 2 + + 1] + block_id * self.block_len[0] + length = int(self.block_len[0] / self.block_size * (end - start)) + size_list = [length, length] + addr_list = [addr_k, addr_v] return addr_list, size_list - + def add_request( self, req_id: str, @@ -72,12 +78,12 @@ def add_request( mask: Optional[torch.Tensor] = None, is_last_chunk: Optional[bool] = None, ) -> torch.Tensor: - req=({ + req = ({ "req_id": req_id, "tokens": tokens, "block_ids": block_ids, "mask": mask, - "is_last_chunk":is_last_chunk, + "is_last_chunk": is_last_chunk, }) self.request_queue.put(req) @@ -91,7 +97,7 @@ def get_and_clear_finished_requests(self) -> set[str]: finished_requests = self.finished_requests.copy() self.finished_requests.clear() return finished_requests - + def set_finished_request(self, req_id): with self.done_task_lock: self.finished_requests.add(req_id) @@ -117,87 +123,125 @@ def _handle_request(self, req_meta: dict[str, Any]): class KVCacheStoreSendingThread(KVTransferThread): def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], token_database: ChunkedTokenDatabase, - block_len: list[int], block_size:int, ready_event: threading.Event): - super().__init__(tp_rank, tp_size, m_store, local_kv_caches_base_addr, - token_database, block_len, block_size, ready_event, name="KVCacheSendingThread") + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheSendingThread") def _handle_request(self, req_meta: dict[str, Any]): - tokens=req_meta["tokens"] - mask=req_meta["mask"] - block_ids=req_meta["block_ids"] - req_id=req_meta["req_id"] + tokens = req_meta["tokens"] + mask = req_meta["mask"] + block_ids = req_meta["block_ids"] + req_id = req_meta["req_id"] + is_last_chunk = req_meta["is_last_chunk"] torch.npu.current_stream().synchronize() - for start, end, key in self.token_database.process_tokens(tokens, mask): - addr, size, _ =self.prepare_value(start, end, block_ids) + for start, end, key in self.token_database.process_tokens( + tokens, mask): + addr, size, _ = self.prepare_value(start, end, block_ids) self.m_store.put(key, addr, size) if is_last_chunk: self.set_finished_request(req_id) self.request_queue.task_done() - - + + class KVCacheStoreRecvingThread(KVTransferThread): def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], token_database: ChunkedTokenDatabase, - block_len: list[int], block_size:int, ready_event: threading.Event): - super().__init__(tp_rank, tp_size, m_store, local_kv_caches_base_addr, - token_database, block_len, block_size, ready_event, name="KVCacheStoreRecvingThread") + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheStoreRecvingThread") def _handle_request(self, req_meta: dict[str, Any]): tokens = req_meta["tokens"] mask = req_meta["mask"] block_ids = req_meta["block_ids"] req_id = req_meta["req_id"] - for start, end, key in self.token_database.process_tokens(tokens, mask): - addr, size, _ = self.prepare_value(start, end, block_ids) + for start, end, key in self.token_database.process_tokens( + tokens, mask): + addr, size, _ = self.prepare_value(start, end, block_ids) self.m_store.get(key, addr, size) self.set_finished_request(req_id) self.request_queue.task_done() class KVCacheStoreLayerSendingThread(KVTransferThread): + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], token_database: ChunkedTokenDatabase, - block_len: list[int], block_size:int, ready_event: threading.Event, num_layers:int): - super().__init__(tp_rank, tp_size, m_store, local_kv_caches_base_addr, - token_database, block_len, block_size, ready_event, name="KVCacheStoreLayerSendingThread") + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event, + num_layers: int): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheStoreLayerSendingThread") self.final_layer_id = num_layers - 1 - - def add_request( - self, - req_meta: LasyerMultiBlockReqMeta - ) -> torch.Tensor: + + def add_request(self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: self.request_queue.put(req_meta) - + def _handle_request(self, req_meta: dict[str, Any]): #chunk torch.npu.current_stream().synchronize() for index, key in enumerate(req_meta.keys): - addr, size = self.prepare_value_layer(req_meta.starts[index], req_meta.ends[index], req_meta.block_ids, req_meta.layer_id) + addr, size = self.prepare_value_layer(req_meta.starts[index], + req_meta.ends[index], + req_meta.block_ids, + req_meta.layer_id) self.m_store.put(key, addr, size) - if req_meta.layer_id==self.final_layer_id: + if req_meta.layer_id == self.final_layer_id: self.set_finished_request(req_meta.req_id) self.request_queue.task_done() class KVCacheStoreLayerRecvingThread(KVTransferThread): + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], token_database: ChunkedTokenDatabase, - block_len: list[int], block_size:int, ready_event: threading.Event, get_event: threading.Event): - super().__init__(tp_rank, tp_size, m_store, local_kv_caches_base_addr, - token_database, block_len, block_size, ready_event, name="KVCacheStoreLayerRecvingThread") - self.get_event=get_event - - def add_request( - self, - req_meta: LasyerMultiBlockReqMeta - ) -> torch.Tensor: + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event, + get_event: threading.Event): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheStoreLayerRecvingThread") + self.get_event = get_event + + def add_request(self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: self.request_queue.put(req_meta) - + def _handle_request(self, req_meta: dict[str, Any]): #chunk for index, key in enumerate(req_meta.keys): - addr, size=self.prepare_value_layer(req_meta.starts[index], req_meta.ends[index], req_meta.block_ids, req_meta.layer_id) + addr, size = self.prepare_value_layer(req_meta.starts[index], + req_meta.ends[index], + req_meta.block_ids, + req_meta.layer_id) self.m_store.get(key, addr, size) self.request_queue.task_done() - self.get_event.set() - + self.get_event.set() \ No newline at end of file diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py index 362a80608f..f6eda3ed00 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -1,23 +1,22 @@ # Standard -from typing import Dict, Generator, List, Optional, Union import math -import asyncio -import multiprocessing -import time import threading -import queue -from dataclasses import dataclass +import time +from typing import Generator, List, Optional, Union # Third Party -import torch, torch_npu -from vllm.utils import cdiv, get_kv_cache_torch_dtype, round_down -from vllm.utils import logger -from vllm.config import ( - VllmConfig, -) -from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey, MooncakeEngineMetadata, ChunkedTokenDatabase, LayerMooncakeEngineKey, MooncakeConnectorMetadata, LasyerMultiBlockReqMeta +import torch +from vllm.config import VllmConfig +from vllm.utils import get_kv_cache_torch_dtype, logger + +from vllm_ascend.distributed.mooncake.config_data import ( + ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata, + MooncakeEngineMetadata) +from vllm_ascend.distributed.mooncake.kv_transfer import ( + KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, + KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore -from vllm_ascend.distributed.mooncake.kv_transfer import KVTransferThread, KVCacheStoreSendingThread, KVCacheStoreRecvingThread, KVCacheStoreLayerSendingThread, KVCacheStoreLayerRecvingThread + # First Party @@ -28,19 +27,15 @@ def __init__( self, vllm_config: VllmConfig, use_layerwize: bool, - skip_last_n_tokens: int, ): model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.use_mla = False - if ( - hasattr(model_config, "use_mla") - and isinstance(model_config.use_mla, bool) - and model_config.use_mla - ): + if (hasattr(model_config, "use_mla") + and isinstance(model_config.use_mla, bool) + and model_config.use_mla): self.use_mla = True - self.use_layerwise=use_layerwize - self.skip_last_n_tokens = skip_last_n_tokens + self.use_layerwise = use_layerwize self.tp_rank = parallel_config.rank self.tp_size = parallel_config.tensor_parallel_size self.kv_role = vllm_config.kv_transfer_config.kv_role @@ -52,12 +47,14 @@ def __init__( self.block_size = vllm_config.cache_config.block_size num_kv_head = model_config.get_num_kv_heads(parallel_config) head_size = model_config.get_head_size() - kv_dtype = get_kv_cache_torch_dtype(vllm_config.cache_config.cache_dtype, model_config.dtype) + kv_dtype = get_kv_cache_torch_dtype( + vllm_config.cache_config.cache_dtype, model_config.dtype) self.hidden_dim_size = num_kv_head * head_size if self.use_mla: kv_shape = (self.num_layers, 1, self.block_size, 1, head_size) else: - kv_shape = (self.num_layers, 2, self.block_size, num_kv_head, head_size) + kv_shape = (self.num_layers, 2, self.block_size, num_kv_head, + head_size) self.metadata = MooncakeEngineMetadata( model_config.model, parallel_config.world_size, @@ -71,11 +68,10 @@ def __init__( self.token_database = ChunkedTokenDatabase(self.metadata) self.m_store = Mooncakestore(parallel_config) - + self.kv_send_thread: Optional[KVTransferThread] = None self.kv_recv_thread: Optional[KVTransferThread] = None - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): _, first_kv_cache_tuple = next(iter(kv_caches.items())) first_kv_cache = first_kv_cache_tuple[0] @@ -122,51 +118,57 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for cache in cache_list: base_addr = cache.data_ptr() self.kv_caches_base_addr.append(base_addr) - + if self.use_layerwise: self.get_event = threading.Event() if self.kv_role == 'kv_producer': ready_event_sending = threading.Event() - self.kv_send_thread = KVCacheStoreLayerSendingThread(self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, self.block_len, self.block_size, ready_event_sending, self.num_layers) + self.kv_send_thread = KVCacheStoreLayerSendingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, + self.block_len, self.block_size, ready_event_sending, + self.num_layers) self.kv_send_thread.start() ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreLayerRecvingThread( self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, self.block_len, self.block_size, ready_event, self.get_event) + self.kv_caches_base_addr, self.token_database, self.block_len, + self.block_size, ready_event, self.get_event) self.kv_recv_thread.start() ready_event.wait() else: if self.kv_role == 'kv_producer': ready_event_sending = threading.Event() - self.kv_send_thread = KVCacheStoreSendingThread(self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, self.block_len, self.block_size, ready_event_sending) + self.kv_send_thread = KVCacheStoreSendingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, + self.block_len, self.block_size, ready_event_sending) self.kv_send_thread.start() ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreRecvingThread( self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, self.block_len, self.block_size, ready_event) + self.kv_caches_base_addr, self.token_database, self.block_len, + self.block_size, ready_event) self.kv_recv_thread.start() ready_event.wait() - + def start_load_kv(self, metadata: MooncakeConnectorMetadata): self.current_layer = 0 self.layerwise_retrievers = [] for request in metadata.requests: load_spec = request.load_spec - if load_spec is None or not load_spec.can_load: #load =0 + if load_spec is None or not load_spec.can_load: #load =0 continue tokens = request.token_ids req_id = request.req_id - if (load_spec.mooncake_cached_tokens % self.block_size != 0) and (load_spec.mooncake_cached_tokens == tokens.shape[0] - 1): - tokens = tokens[: request.load_spec.mooncake_cached_tokens + 1] + if (load_spec.mooncake_cached_tokens % self.block_size + != 0) and (load_spec.mooncake_cached_tokens + == tokens.shape[0] - 1): + tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1] else: - tokens = tokens[: request.load_spec.mooncake_cached_tokens] - masked_token_count = ( - request.load_spec.vllm_cached_tokens - // self.block_size - * self.block_size - ) + tokens = tokens[:request.load_spec.mooncake_cached_tokens] + masked_token_count = (request.load_spec.vllm_cached_tokens // + self.block_size * self.block_size) token_mask = torch.ones_like(tokens, dtype=torch.bool) token_mask[:masked_token_count] = False if self.use_layerwise: @@ -176,7 +178,7 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata): request.block_ids, token_mask, ) - next(layerwise_retriever) # first layer load + next(layerwise_retriever) # first layer load self.layerwise_retrievers.append(layerwise_retriever) else: self.kv_recv_thread.add_request( @@ -185,7 +187,7 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata): request.block_ids, token_mask, ) - + def wait_for_layer_load(self) -> None: """MooncakeConnector does not do layerwise saving.""" for layerwise_retriever in self.layerwise_retrievers: @@ -194,8 +196,9 @@ def wait_for_layer_load(self) -> None: assert ret_token_mask is not None num_retrieved_tokens = ret_token_mask.sum().item() logger.info(f"Retrieved {num_retrieved_tokens} tokens") - - def save_kv_layer(self, connector_metadata: MooncakeConnectorMetadata) -> None: + + def save_kv_layer(self, + connector_metadata: MooncakeConnectorMetadata) -> None: """MooncakeConnector does not save explicitly.""" if self.current_layer == 0: self.layerwise_storers = [] @@ -209,20 +212,19 @@ def save_kv_layer(self, connector_metadata: MooncakeConnectorMetadata) -> None: assert isinstance(token_ids, torch.Tensor) assert token_ids.is_cpu - # TODO: whther need to remov saveThread + # TODO: whether need to remov saveThread # no lookup, skipmask skip_leading_tokens = max( self.lookup(token_ids, self.use_layerwise), save_spec.skip_leading_tokens, ) if skip_leading_tokens == len(token_ids): + if request.is_last_chunk: + self.kv_send_thread.set_finished_request(req_id) continue # skip this request - skip_leading_tokens = ( - skip_leading_tokens - // self.block_size - * self.block_size - ) + skip_leading_tokens = (skip_leading_tokens // self.block_size * + self.block_size) store_mask = torch.ones_like(token_ids, dtype=torch.bool) store_mask[:skip_leading_tokens] = False @@ -245,7 +247,7 @@ def save_kv_layer(self, connector_metadata: MooncakeConnectorMetadata) -> None: for layerwise_storer in self.layerwise_storers: try: next(layerwise_storer) - except Exception as e: + except Exception: raise self.current_layer = self.current_layer + 1 @@ -257,7 +259,6 @@ def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata): continue token_ids = request.token_ids - # token_ids = token_ids[: -self.skip_last_n_tokens] req_id = request.req_id assert isinstance(token_ids, torch.Tensor) assert token_ids.is_cpu @@ -267,17 +268,16 @@ def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata): save_spec.skip_leading_tokens, ) if skip_leading_tokens == len(token_ids): + if request.is_last_chunk: + self.kv_send_thread.set_finished_request(req_id) continue # skip this request - skip_leading_tokens = ( - skip_leading_tokens - // self.block_size - * self.block_size - ) + skip_leading_tokens = (skip_leading_tokens // self.block_size * + self.block_size) store_mask = torch.ones_like(token_ids, dtype=torch.bool) store_mask[:skip_leading_tokens] = False - + logger.info( "Storing KV cache for %d out of %d tokens " "(skip_leading_tokens=%d) for request %s", @@ -286,14 +286,14 @@ def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata): skip_leading_tokens, request.req_id, ) - + self.kv_send_thread.add_request( - req_id, - token_ids, - request.block_ids, - store_mask, - request.is_last_chunk, - ) + req_id, + token_ids, + request.block_ids, + store_mask, + request.is_last_chunk, + ) def retrieve_layer( self, @@ -323,14 +323,15 @@ def retrieve_layer( num_required_tokens = torch.sum(mask).item() else: num_required_tokens = len(tokens) - + ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu") starts = [] ends = [] keys = [] - first_flag= True - for start, end, key in self.token_database.process_tokens(tokens, mask): + first_flag = True + for start, end, key in self.token_database.process_tokens( + tokens, mask): keys_multi_layer = key.split_layers(self.num_layers) starts.append(start) ends.append(end) @@ -339,23 +340,19 @@ def retrieve_layer( if keys: # Transpose the keys into layer major format - keys = [list(row) for row in zip(*keys, strict=False)] # [num_layer,block_num] + keys = [list(row) for row in zip(*keys, strict=False) + ] # [num_layer,block_num] for layer_id, keys_multi_chunk in enumerate(keys): if not first_flag: - is_finish=self.get_event.wait(timeout=3) #try---cache + is_finish = self.get_event.wait(timeout=3) #try---cache if not is_finish: raise SystemError("Layerwise get failed") self.get_event.clear() - req_meta=LasyerMultiBlockReqMeta( - req_id, - keys_multi_chunk, - starts, - ends, - block_ids, - layer_id - ) + req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, + starts, ends, block_ids, + layer_id) self.kv_recv_thread.add_request(req_meta) - first_flag=False + first_flag = False yield None else: # If no cache are found, we still need to yield to avoid @@ -364,11 +361,9 @@ def retrieve_layer( yield None retrieved_tokens = torch.sum(ret_mask) - logger.debug( - f"Retrieved {retrieved_tokens} " - f"out of {num_required_tokens} " - f"out of total {len(tokens)} tokens" - ) + logger.debug(f"Retrieved {retrieved_tokens} " + f"out of {num_required_tokens} " + f"out of total {len(tokens)} tokens") yield ret_mask @@ -408,42 +403,42 @@ def store_layer( starts = [] ends = [] keys = [] - for start, end, key in self.token_database.process_tokens(tokens, mask): + for start, end, key in self.token_database.process_tokens( + tokens, mask): keys_multi_layer = key.split_layers(self.num_layers) starts.append(start) ends.append(end) - keys.append(keys_multi_layer) #[block_num,layer_num] - + keys.append(keys_multi_layer) #[block_num,layer_num] + if keys: - keys = [list(row) for row in zip(*keys, strict=False)] #[layer_num,block_num] + keys = [list(row) for row in zip(*keys, strict=False) + ] #[layer_num,block_num] for layer_id, keys_multi_chunk in enumerate(keys): - req_meta=LasyerMultiBlockReqMeta( - req_id, - keys_multi_chunk, - starts, - ends, - block_ids, - layer_id - ) + req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, + starts, ends, block_ids, + layer_id) self.kv_send_thread.add_request(req_meta) yield else: for layer_id in range(self.num_layers): yield - logger.debug(f"Stored {num_stored_tokens} out of total {len(tokens)} tokens") + logger.debug( + f"Stored {num_stored_tokens} out of total {len(tokens)} tokens") def get_finished(self) -> tuple[set[str], set[str]]: done_sending = ( self.kv_send_thread. get_and_clear_finished_requests( # type: ignore[union-attr] ) if self.kv_role == 'kv_producer' else set()) - done_recving = self.kv_recv_thread.get_and_clear_finished_requests() # type: ignore[union-attr] - + done_recving = self.kv_recv_thread.get_and_clear_finished_requests( + ) # type: ignore[union-attr] + logger.debug( "Number of completed KV cache send requests: %d, receive " - "requests: %d, tp_rank:%d", len(done_sending), len(done_recving), self.tp_rank) + "requests: %d, tp_rank:%d", len(done_sending), len(done_recving), + self.tp_rank) return done_sending, done_recving - + def wait_layer_transfer_finish(self): time.sleep(10) pass @@ -465,19 +460,19 @@ def lookup( for start, end, key in self.token_database.process_tokens(tokens): try: if use_layerwise: - keys=[] + keys = [] keys_multi_layer = key.split_layers(self.num_layers) for key in keys_multi_layer: keys.append(key.to_string()) # batch is_exists - ress=self.m_store.batch_exists(keys) - res=1 + ress = self.m_store.batch_exists(keys) + res = 1 for value in ress: if value != 1: - res=0 + res = 0 break else: - res=self.m_store.exists(key) + res = self.m_store.exists(key) if res == 1: continue else: @@ -490,5 +485,5 @@ def lookup( return end def close(self) -> None: - """Close the cache engine and free all the resources""" - self.m_store.close() + """Close the cache engine and free all the resources""" + self.m_store.close() \ No newline at end of file diff --git a/vllm_ascend/distributed/mooncake/mooncake_store.py b/vllm_ascend/distributed/mooncake/mooncake_store.py index 6eaf77f6fc..ec33fe9c9c 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store.py +++ b/vllm_ascend/distributed/mooncake/mooncake_store.py @@ -1,28 +1,13 @@ # Standard -from contextlib import contextmanager -from dataclasses import dataclass -from functools import reduce -from typing import List, Optional, no_type_check -from enum import Enum -import asyncio -import json -import operator import os -import struct -import ctypes -import time -import csv -from contextlib import contextmanager + # Third Party -import torch, torch_npu from vllm.config import ParallelConfig - -# First Party +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank from vllm.utils import logger -from vllm.distributed.parallel_state import (get_dp_group, - get_tensor_model_parallel_rank, - get_tp_group) + from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey + from .config_data import MooncakeStoreConfig METADATA_BYTES_LEN = 24 @@ -30,9 +15,8 @@ class Mooncakestore(): - def __init__( - self, parallel_config: ParallelConfig - ): + + def __init__(self, parallel_config: ParallelConfig): try: from mooncake.store import MooncakeDistributedStore except ImportError as e: @@ -45,7 +29,8 @@ def __init__( dp_rank = parallel_config.data_parallel_rank_local all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None) if not all_device_ids: - device_ids_list = list(range(dp_rank * tp_size, (dp_rank + 1) * tp_size)) + device_ids_list = list( + range(dp_rank * tp_size, (dp_rank + 1) * tp_size)) else: device_ids_list = list(map(int, all_device_ids.split(','))) assert len(device_ids_list) > tp_rank @@ -57,14 +42,11 @@ def __init__( else: local_hostname = self.config.local_hostname self.store = MooncakeDistributedStore() - ret = self.store.setup( - local_hostname, - self.config.metadata_server, - self.config.global_segment_size, - self.config.local_buffer_size, - self.config.protocol, self.config.device_name, - self.config.master_server_address - ) + ret = self.store.setup(local_hostname, self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, self.config.device_name, + self.config.master_server_address) if ret != 0: msg = "Initialize mooncake failed." logger.error(msg) @@ -75,32 +57,32 @@ def set_kv_caches(self, kvcache): def exists(self, key: MooncakeEngineKey) -> bool: return self.store.is_exist(key.to_string()) == 1 - - def batch_exists(self, keys:list[str]) -> list[bool]: + + def batch_exists(self, keys: list[str]) -> list[bool]: return self.store.batch_is_exist(keys) def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): + expect_res = sum(size) key_str = key.to_string() try: res = self.store.batch_get_into_ascend(key_str, addr, size) if res[0] != expect_res: logger.error(f"Failed to get key: [{key_str}] .") - except Exception as e: + except Exception: logger.error(f"Failed to get key: [{key_str}] .") return res - def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): key_str = key.to_string() try: ret = self.store.batch_put_from_ascend(key_str, addr, size) if ret[0] != 0: logger.error(f"Failed to put key {key_str}.") - except Exception as e: + except Exception: logger.error(f"Failed to put key {key_str}.") - + return ret - + def close(self): self.store.close() - logger.info("Closed the mooncake store connection") + logger.info("Closed the mooncake store connection") \ No newline at end of file diff --git a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py index 3bde09a8cf..21aff7ee2b 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py +++ b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py @@ -1,34 +1,23 @@ - -import threading -from enum import Enum -from collections import defaultdict -from collections.abc import Iterator -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Dict, List, Tuple, Union -import msgspec -import torch -import zmq import threading +from typing import Any, Optional -from concurrent.futures import Future - +import torch import vllm.envs as envs +import zmq +from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group) -from vllm.utils import logger -from vllm.utils import make_zmq_path, make_zmq_socket, round_down, get_ip,cdiv -from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder -from vllm.v1.core.sched.output import SchedulerOutput -from vllm_ascend.distributed.mooncake.mooncake_engine import MooncakeEngine -from vllm.v1.request import Request from vllm.forward_context import ForwardContext +from vllm.utils import logger, make_zmq_socket from vllm.v1.core.kv_cache_manager import KVCacheBlocks -from vllm_ascend.distributed.mooncake.config_data import MooncakeConnectorMetadata, RequestTracker, LoadSpec, ReqMeta +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import Request +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder +from vllm_ascend.distributed.mooncake.config_data import ( + LoadSpec, MooncakeConnectorMetadata, ReqMeta, RequestTracker) +from vllm_ascend.distributed.mooncake.mooncake_engine import MooncakeEngine class MooncakeConnectorV1(KVConnectorBase_V1): @@ -37,30 +26,29 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self.kv_role = vllm_config.kv_transfer_config.kv_role - self.use_layerwise=vllm_config.kv_transfer_config.kv_connector_extra_config.get("use_layerwise", False) + self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "use_layerwise", False) self.kv_caches: dict[str, torch.Tensor] = {} self._block_size = vllm_config.cache_config.block_size - self.skip_last_n_tokens = vllm_config.kv_transfer_config.get_from_extra_config( - "skip_last_n_tokens", 1 - ) + self.sended_but_unfinished_reqs: set[str] = set() if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler = MooncakeStoreConnectorV1Scheduler(vllm_config, self.skip_last_n_tokens, self.use_layerwise) + self.connector_scheduler = MooncakeStoreConnectorV1Scheduler( + vllm_config, self.use_layerwise) else: self.connector_worker = MooncakeEngine( vllm_config, self.use_layerwise, - self.skip_last_n_tokens, ) assert self.connector_worker is not None if vllm_config.parallel_config.rank == 0: self.lookup_server = MooncakeLookupServer( - self.connector_worker, vllm_config, self.use_layerwise - ) + self.connector_worker, vllm_config, self.use_layerwise) + ############################################################ # Scheduler Side Methods ############################################################ @@ -85,7 +73,7 @@ def build_connector_meta( ) -> KVConnectorMetadata: assert self.connector_scheduler is not None return self.connector_scheduler.build_connector_meta(scheduler_output) - + def request_finished( self, request: "Request", @@ -94,7 +82,6 @@ def request_finished( assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) - ############################################################ # Worker Side Methods ############################################################ @@ -104,14 +91,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: - attn_metadata = forward_context.attn_metadata - # if attn_metadata is None: - # logger.warning("In connector.start_load_kv, but the attn_metadata is None") - # return assert self.connector_worker is not None - assert isinstance(self._get_connector_metadata(), MooncakeConnectorMetadata) + assert isinstance(self._get_connector_metadata(), + MooncakeConnectorMetadata) self.connector_worker.start_load_kv(self._get_connector_metadata()) - + def wait_for_layer_load(self, layer_name: str) -> None: """MooncakeStoreConnector does not do layerwise saving.""" if not self.use_layerwise: @@ -123,7 +107,7 @@ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, """MooncakeStoreConnector does not save explicitly.""" if not self.use_layerwise: return - + if self.kv_role == "kv_consumer": # Don't do save if the role is kv_consumer return @@ -138,50 +122,60 @@ def wait_for_save(self): if self.use_layerwise: self.connector_worker.wait_layer_transfer_finish() return - + self.connector_worker.wait_for_save(self._get_connector_metadata()) #time.sleep(1) def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: - """Get the finished recving and sending requests.""" - assert self.connector_worker is not None - return self.connector_worker.get_finished() + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + meta = self._get_connector_metadata() + done_sending, done_recving = self.connector_worker.get_finished() + sended_and_finished: set[str] = set() + for item in list(self.sended_but_unfinished_reqs): + if item not in meta.unfinished_request_ids: + sended_and_finished.add(item) + self.sended_but_unfinished_reqs.remove(item) + for item in done_sending: + if item in meta.unfinished_request_ids: + self.sended_but_unfinished_reqs.add(item) + else: + sended_and_finished.add(item) + + return sended_and_finished, done_recving def get_zmq_rpc_path_mooncake( - vllm_config: Optional["VllmConfig"] = None, -) -> str: + vllm_config: Optional["VllmConfig"] = None, ) -> str: base_url = envs.VLLM_RPC_BASE_PATH # Default to 0 if not configured rpc_port = 0 if vllm_config is not None: rpc_port = vllm_config.kv_transfer_config.get_from_extra_config( - "mooncake_rpc_port", 0 - ) + "mooncake_rpc_port", 0) logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port) return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}" class MooncakeStoreConnectorV1Scheduler: - def __init__(self, vllm_config: "VllmConfig", skip_last_n_tokens, use_layerwise): - self.client=MooncakeLookupClient(vllm_config) - self.use_layerwise=use_layerwise + + def __init__(self, vllm_config: "VllmConfig", use_layerwise): + self.client = MooncakeLookupClient(vllm_config) + self.use_layerwise = use_layerwise self.kv_role = vllm_config.kv_transfer_config.kv_role - # request_id -> (vllm cached tokes, mooncake cached tokens) + # request_id -> (vllm cached tokes, mooncake cached tokens) self.load_specs: dict[str, LoadSpec] = {} - self.skip_last_n_tokens = skip_last_n_tokens self._block_size = vllm_config.cache_config.block_size - # request_id -> full_token_ids + # request_id -> full_token_ids self._request_trackers: dict[str, RequestTracker] = {} - # Whether to discard partial chunks + # Whether to discard partial chunks self._discard_partial_chunks = ( vllm_config.kv_transfer_config.get_from_extra_config( - "discard_partial_chunks", False - ) - ) - self._unfinished_requests: dict[str, Request] = {} - + "discard_partial_chunks", True)) + self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {} + self._unfinished_request_ids: set[str] = set() + def get_num_new_matched_tokens( self, request: "Request", @@ -201,8 +195,10 @@ def get_num_new_matched_tokens( """ if self._discard_partial_chunks: - token_block_end = len(request.prompt_token_ids) // self._block_size * self._block_size - token_ids = torch.tensor(request.prompt_token_ids[:token_block_end]) + token_block_end = len(request.prompt_token_ids + ) // self._block_size * self._block_size + token_ids = torch.tensor( + request.prompt_token_ids[:token_block_end]) else: token_ids = torch.tensor(request.prompt_token_ids) @@ -220,10 +216,10 @@ def get_num_new_matched_tokens( num_external_hit_tokens, need_to_allocate, ) - + if need_to_allocate <= 0: return 0, False - + self.load_specs[request.request_id] = LoadSpec( vllm_cached_tokens=num_computed_tokens, mooncake_cached_tokens=num_external_hit_tokens, @@ -232,19 +228,22 @@ def get_num_new_matched_tokens( return need_to_allocate, not self.use_layerwise - def update_state_after_alloc(self, request: "Request", blocks:"KVCacheBlocks", num_external_tokens: int): + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): """ Update KVConnector state after temporary buffer alloc. For SharedStorageConnector, update _request_needs_load if the CacheManager this allocated blocks for us. """ - local_block_ids=[] + local_block_ids = [] if num_external_tokens > 0: local_block_ids = blocks.get_block_ids()[0] - self._unfinished_requests[request.request_id] = ( - request, local_block_ids) + self._unfinished_requests[request.request_id] = (request, + local_block_ids) + self._unfinished_request_ids.add(request.request_id) if request.request_id not in self.load_specs: # No KV tokens from external KV cache, return return @@ -255,22 +254,18 @@ def update_state_after_alloc(self, request: "Request", blocks:"KVCacheBlocks", n return assert ( - num_external_tokens > 0 - and num_external_tokens - == self.load_specs[request.request_id].mooncake_cached_tokens - - self.load_specs[request.request_id].vllm_cached_tokens - ), ( - f"Mismatch in number of tokens: {num_external_tokens} vs " + num_external_tokens > 0 and num_external_tokens + == self.load_specs[request.request_id].mooncake_cached_tokens - + self.load_specs[request.request_id].vllm_cached_tokens + ), (f"Mismatch in number of tokens: {num_external_tokens} vs " f"{self.load_specs[request.request_id].mooncake_cached_tokens} - " f"{self.load_specs[request.request_id].vllm_cached_tokens}" - f" for request {request.request_id}" - ) + f" for request {request.request_id}") self.load_specs[request.request_id].can_load = True def build_connector_meta( - self, scheduler_output: SchedulerOutput - ) -> KVConnectorMetadata: + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: """Attach the connector metadata to the request object. This function should NOT modify other fields in the scheduler_output @@ -282,101 +277,122 @@ def build_connector_meta( """ force_skip_save = self.kv_role == "kv_consumer" - - meta = MooncakeConnectorMetadata() for finished_req_id in scheduler_output.finished_req_ids: self._request_trackers.pop(finished_req_id, None) self._unfinished_requests.pop(finished_req_id, None) - + self._unfinished_request_ids.remove(finished_req_id) + + meta = MooncakeConnectorMetadata(self._unfinished_request_ids) + for request in scheduler_output.scheduled_new_reqs: # Right now, we only load KV for new requests load_spec = self.load_specs.pop(request.req_id, None) num_tokens_to_compute = ( - request.num_computed_tokens - + scheduler_output.num_scheduled_tokens[request.req_id] - ) + request.num_computed_tokens + + scheduler_output.num_scheduled_tokens[request.req_id]) request_tracker = RequestTracker.from_new_request( - request, num_tokens_to_compute - ) + request, num_tokens_to_compute) self._request_trackers[request.req_id] = request_tracker - + last_chunk_tokens_num = ((len(request.prompt_token_ids) // + self._block_size * self._block_size) + if self._discard_partial_chunks else len( + request.prompt_token_ids)) req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, load_spec=load_spec, skip_save=force_skip_save, - is_last_chunk=len(request_tracker.token_ids)>=len(request.prompt_token_ids), + is_last_chunk=len(request_tracker.token_ids) + >= last_chunk_tokens_num, discard_partial_chunks=self._discard_partial_chunks, ) if req_meta is not None: meta.add_request(req_meta) cached_reqs = scheduler_output.scheduled_cached_reqs - if isinstance(cached_reqs, list): + if isinstance(cached_reqs, list) and not force_skip_save: for i, req in enumerate(cached_reqs): request_tracker = self._request_trackers[req.req_id] request_tracker.update(req.new_token_ids, req.new_block_ids) - + last_chunk_tokens_num = ((len(req.prompt_token_ids) // + self._block_size * self._block_size) + if self._discard_partial_chunks else + len(req.prompt_token_ids)) req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, load_spec=None, skip_save=force_skip_save, - is_last_chunk=len(request_tracker.token_ids)>=len(req.prompt_token_ids), + is_last_chunk=len(request_tracker.token_ids) + >= last_chunk_tokens_num, discard_partial_chunks=self._discard_partial_chunks, ) if req_meta is not None: meta.add_request(req_meta) - else: + elif not force_skip_save: for i, req_id in enumerate(cached_reqs.req_ids): request_tracker = self._request_trackers[req_id] num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] req_tuple = self._unfinished_requests.get(req_id) - if req_tuple: + if req_tuple: request = req_tuple[0] num_current_tokens = len(request_tracker.token_ids) new_token_ids = request.all_token_ids[ - num_current_tokens : num_current_tokens + num_new_tokens - ] + num_current_tokens:num_current_tokens + num_new_tokens] else: raise ValueError( f"Request {req_id} is not in _unfinished_requests, " - f"but it is scheduled to be cached" - ) + f"but it is scheduled to be cached") new_block_ids = cached_reqs.new_block_ids[i] if not new_block_ids: continue request_tracker.update(new_token_ids, new_block_ids) + # decode not save + if len(request_tracker.token_ids) > len( + request.prompt_token_ids): + continue + + last_chunk_tokens_num = ((len(request.prompt_token_ids) // + self._block_size * self._block_size) + if self._discard_partial_chunks else + len(request.prompt_token_ids)) req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, load_spec=None, skip_save=force_skip_save, - is_last_chunk=len(request_tracker.token_ids)>=len(request.prompt_token_ids), + is_last_chunk=len(request_tracker.token_ids) + >= last_chunk_tokens_num, discard_partial_chunks=self._discard_partial_chunks, ) if req_meta is not None: meta.add_request(req_meta) - request_ids = [req.req_id for req in scheduler_output.scheduled_new_reqs] - for request_id, (request,block_ids) in self._unfinished_requests.items(): - if not request_id in request_ids and not request_id in cached_reqs.req_ids: + request_ids = [ + req.req_id for req in scheduler_output.scheduled_new_reqs + ] + for request_id, (request, + block_ids) in self._unfinished_requests.items(): + if request_id not in request_ids and request_id not in cached_reqs.req_ids: load_spec = self.load_specs.pop(request_id, None) if not load_spec: continue num_tokens_to_compute = load_spec.mooncake_cached_tokens - if (num_tokens_to_compute % self._block_size != 0) and (num_tokens_to_compute == len(request.prompt_token_ids) - 1): + if (num_tokens_to_compute % self._block_size + != 0) and (num_tokens_to_compute + == len(request.prompt_token_ids) - 1): num_tokens_to_compute = num_tokens_to_compute + 1 request_tracker = RequestTracker( req_id=request_id, - token_ids=request.prompt_token_ids[:num_tokens_to_compute].copy(), + token_ids=request.prompt_token_ids[:num_tokens_to_compute]. + copy(), allocated_block_ids=block_ids, num_saved_tokens=0, ) self._request_trackers[request_id] = request_tracker - + req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, @@ -397,9 +413,10 @@ def request_finished( Once a request is finished, determine whether request blocks should be freed now or will be sent asynchronously and freed later. """ - if self.kv_role == "kv_consumer": return False, None + if self._request_trackers[request.request_id].num_saved_tokens <= 0: + return False, None delay_free_blocks = len(block_ids) > 0 if delay_free_blocks: logger.info("Delaying free of %d blocks for request %s", @@ -408,6 +425,7 @@ def request_finished( class MooncakeLookupClient: + def __init__(self, vllm_config: "VllmConfig"): self.encoder = MsgpackEncoder() self.ctx = zmq.Context() # type: ignore[attr-defined] @@ -431,6 +449,7 @@ def close(self): class MooncakeLookupServer: + def __init__( self, mooncake_engine: MooncakeEngine, @@ -463,4 +482,4 @@ def process_request(): def close(self): self.socket.close(linger=0) - # TODO: close the thread! + # TODO: close the thread! \ No newline at end of file From 131be5f39af3fd1e23b666ed4c9c18fecbb9cebd Mon Sep 17 00:00:00 2001 From: fems14 <1804143737@qq.com> Date: Tue, 16 Sep 2025 11:32:47 +0800 Subject: [PATCH 03/10] mooncake store bugfix Signed-off-by: fems14 <1804143737@qq.com> --- vllm_ascend/attention/mla_v1.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 73420e50a9..0031513742 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar, List +from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar import torch import torch_npu @@ -12,10 +12,6 @@ from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.utils import cdiv, round_down -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group, - is_v1_kv_transfer_group) -from vllm.forward_context import ForwardContext, get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState @@ -1050,4 +1046,4 @@ def forward( is_force_scatter=self.enable_shared_expert_dp)[0] current_ms_metadata.after_comm_event.record() del o_proj_input - return output_padded \ No newline at end of file + return output_padded From 97f3ae44fa7d41f2e75638936ed57ddd5e19f475 Mon Sep 17 00:00:00 2001 From: fems14 <1804143737@qq.com> Date: Tue, 16 Sep 2025 20:25:24 +0800 Subject: [PATCH 04/10] mooncake store lint fix Signed-off-by: fems14 <1804143737@qq.com> --- vllm_ascend/distributed/mooncake/__init__.py | 0 .../distributed/mooncake/config_data.py | 10 +++------ .../distributed/mooncake/kv_transfer.py | 12 +++++++---- .../distributed/mooncake/mooncake_engine.py | 21 +++++++++---------- .../distributed/mooncake/mooncake_store.py | 2 +- .../mooncake/mooncake_store_connector_v1.py | 1 - 6 files changed, 22 insertions(+), 24 deletions(-) create mode 100644 vllm_ascend/distributed/mooncake/__init__.py diff --git a/vllm_ascend/distributed/mooncake/__init__.py b/vllm_ascend/distributed/mooncake/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py index a49393be45..90a10ed0fc 100644 --- a/vllm_ascend/distributed/mooncake/config_data.py +++ b/vllm_ascend/distributed/mooncake/config_data.py @@ -1,3 +1,4 @@ +import array import hashlib import json import os @@ -5,7 +6,6 @@ from typing import Iterable, List, Optional, Tuple, Union import torch -from numpy import array from vllm.distributed.kv_transfer.kv_connector.v1.base import \ KVConnectorMetadata from vllm.utils import cdiv, logger @@ -108,7 +108,7 @@ class ChunkedTokenDatabase(): def __init__( self, - metadata: Optional[MooncakeEngineMetadata] = None, + metadata: MooncakeEngineMetadata, ): self.metadata = metadata @@ -165,7 +165,6 @@ def process_tokens( self, tokens: Union[torch.Tensor, List[int]], mask: Optional[torch.Tensor] = None, - make_key: bool = True, ) -> Iterable[Tuple[int, int, Union[MooncakeEngineKey, str]]]: """Process the tokens and return the corresponding cache engine keys. @@ -208,10 +207,7 @@ def process_tokens( if start_idx < num_falses: continue else: - if make_key: - yield start_idx, end_idx, self._make_key_by_hash(hash_val) - else: - yield start_idx, end_idx, hash_val + yield start_idx, end_idx, self._make_key_by_hash(hash_val) @dataclass diff --git a/vllm_ascend/distributed/mooncake/kv_transfer.py b/vllm_ascend/distributed/mooncake/kv_transfer.py index 40b0682b03..89fea7c974 100644 --- a/vllm_ascend/distributed/mooncake/kv_transfer.py +++ b/vllm_ascend/distributed/mooncake/kv_transfer.py @@ -199,10 +199,12 @@ def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, name="KVCacheStoreLayerSendingThread") self.final_layer_id = num_layers - 1 - def add_request(self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: + def add_request( + self, + req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: # type: ignore self.request_queue.put(req_meta) - def _handle_request(self, req_meta: dict[str, Any]): #chunk + def _handle_request(self, req_meta: LasyerMultiBlockReqMeta): torch.npu.current_stream().synchronize() for index, key in enumerate(req_meta.keys): addr, size = self.prepare_value_layer(req_meta.starts[index], @@ -233,10 +235,12 @@ def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, name="KVCacheStoreLayerRecvingThread") self.get_event = get_event - def add_request(self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: + def add_request( + self, + req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: # type: ignore self.request_queue.put(req_meta) - def _handle_request(self, req_meta: dict[str, Any]): #chunk + def _handle_request(self, req_meta: LasyerMultiBlockReqMeta): for index, key in enumerate(req_meta.keys): addr, size = self.prepare_value_layer(req_meta.starts[index], req_meta.ends[index], diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py index f6eda3ed00..764a8ce3f6 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -17,8 +17,6 @@ KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore -# First Party - class MooncakeEngine: #The main class for the cache engine. @@ -181,7 +179,7 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata): next(layerwise_retriever) # first layer load self.layerwise_retrievers.append(layerwise_retriever) else: - self.kv_recv_thread.add_request( + self.kv_recv_thread.add_request( # type: ignore[union-attr] req_id, tokens, request.block_ids, @@ -269,7 +267,8 @@ def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata): ) if skip_leading_tokens == len(token_ids): if request.is_last_chunk: - self.kv_send_thread.set_finished_request(req_id) + self.kv_send_thread.set_finished_request( + req_id) # type: ignore[union-attr] continue # skip this request skip_leading_tokens = (skip_leading_tokens // self.block_size * @@ -287,7 +286,7 @@ def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata): request.req_id, ) - self.kv_send_thread.add_request( + self.kv_send_thread.add_request( # type: ignore[union-attr] req_id, token_ids, request.block_ids, @@ -340,8 +339,7 @@ def retrieve_layer( if keys: # Transpose the keys into layer major format - keys = [list(row) for row in zip(*keys, strict=False) - ] # [num_layer,block_num] + keys = [list(row) for row in zip(*keys)] # [num_layer,block_num] for layer_id, keys_multi_chunk in enumerate(keys): if not first_flag: is_finish = self.get_event.wait(timeout=3) #try---cache @@ -351,7 +349,8 @@ def retrieve_layer( req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, starts, ends, block_ids, layer_id) - self.kv_recv_thread.add_request(req_meta) + self.kv_recv_thread.add_request( + req_meta) # type: ignore[union-attr] first_flag = False yield None else: @@ -411,13 +410,13 @@ def store_layer( keys.append(keys_multi_layer) #[block_num,layer_num] if keys: - keys = [list(row) for row in zip(*keys, strict=False) - ] #[layer_num,block_num] + keys = [list(row) for row in zip(*keys)] #[layer_num,block_num] for layer_id, keys_multi_chunk in enumerate(keys): req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, starts, ends, block_ids, layer_id) - self.kv_send_thread.add_request(req_meta) + self.kv_send_thread.add_request( + req_meta) # type: ignore[union-attr] yield else: for layer_id in range(self.num_layers): diff --git a/vllm_ascend/distributed/mooncake/mooncake_store.py b/vllm_ascend/distributed/mooncake/mooncake_store.py index ec33fe9c9c..2383749c94 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store.py +++ b/vllm_ascend/distributed/mooncake/mooncake_store.py @@ -18,7 +18,7 @@ class Mooncakestore(): def __init__(self, parallel_config: ParallelConfig): try: - from mooncake.store import MooncakeDistributedStore + from mooncake.store import MooncakeDistributedStore # type: ignore except ImportError as e: raise ImportError( "Please install mooncake by following the instructions at " diff --git a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py index 21aff7ee2b..6254e47521 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py +++ b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py @@ -124,7 +124,6 @@ def wait_for_save(self): return self.connector_worker.wait_for_save(self._get_connector_metadata()) - #time.sleep(1) def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: From 04f4bf94f08cfd2eb01a07316e0c8bda2fb7f571 Mon Sep 17 00:00:00 2001 From: fems14 <1804143737@qq.com> Date: Tue, 16 Sep 2025 20:57:53 +0800 Subject: [PATCH 05/10] mooncake store lint fix Signed-off-by: fems14 <1804143737@qq.com> --- .../distributed/mooncake/config_data.py | 4 ++-- .../distributed/mooncake/kv_transfer.py | 12 ++++++------ .../distributed/mooncake/mooncake_engine.py | 18 +++++++++--------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py index 90a10ed0fc..abb3c9ee0d 100644 --- a/vllm_ascend/distributed/mooncake/config_data.py +++ b/vllm_ascend/distributed/mooncake/config_data.py @@ -165,7 +165,7 @@ def process_tokens( self, tokens: Union[torch.Tensor, List[int]], mask: Optional[torch.Tensor] = None, - ) -> Iterable[Tuple[int, int, Union[MooncakeEngineKey, str]]]: + ) -> Iterable[Tuple[int, int, MooncakeEngineKey]]: """Process the tokens and return the corresponding cache engine keys. :param Union[torch.Tensor, List[int]] tokens: The tokens to process. @@ -323,7 +323,7 @@ def from_request_tracker( tracker: RequestTracker, block_size: int, load_spec: Optional[LoadSpec] = None, - skip_save: bool = False, + skip_save: Optional[bool] = False, is_last_chunk: Optional[bool] = None, discard_partial_chunks: bool = True, ) -> Optional["ReqMeta"]: diff --git a/vllm_ascend/distributed/mooncake/kv_transfer.py b/vllm_ascend/distributed/mooncake/kv_transfer.py index 89fea7c974..c0adfc2f7e 100644 --- a/vllm_ascend/distributed/mooncake/kv_transfer.py +++ b/vllm_ascend/distributed/mooncake/kv_transfer.py @@ -199,12 +199,12 @@ def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, name="KVCacheStoreLayerSendingThread") self.final_layer_id = num_layers - 1 - def add_request( + def add_request( # type: ignore[override] self, - req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: # type: ignore + req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: self.request_queue.put(req_meta) - def _handle_request(self, req_meta: LasyerMultiBlockReqMeta): + def _handle_request(self, req_meta: LasyerMultiBlockReqMeta): # type: ignore[override] torch.npu.current_stream().synchronize() for index, key in enumerate(req_meta.keys): addr, size = self.prepare_value_layer(req_meta.starts[index], @@ -235,12 +235,12 @@ def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, name="KVCacheStoreLayerRecvingThread") self.get_event = get_event - def add_request( + def add_request( # type: ignore[override] self, - req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: # type: ignore + req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: self.request_queue.put(req_meta) - def _handle_request(self, req_meta: LasyerMultiBlockReqMeta): + def _handle_request(self, req_meta: LasyerMultiBlockReqMeta): # type: ignore[override] for index, key in enumerate(req_meta.keys): addr, size = self.prepare_value_layer(req_meta.starts[index], req_meta.ends[index], diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py index 764a8ce3f6..cb8d83daee 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -218,7 +218,7 @@ def save_kv_layer(self, ) if skip_leading_tokens == len(token_ids): if request.is_last_chunk: - self.kv_send_thread.set_finished_request(req_id) + self.kv_send_thread.set_finished_request(req_id) # type: ignore[union-attr] continue # skip this request skip_leading_tokens = (skip_leading_tokens // self.block_size * @@ -267,8 +267,8 @@ def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata): ) if skip_leading_tokens == len(token_ids): if request.is_last_chunk: - self.kv_send_thread.set_finished_request( - req_id) # type: ignore[union-attr] + self.kv_send_thread.set_finished_request( # type: ignore[union-attr] + req_id) continue # skip this request skip_leading_tokens = (skip_leading_tokens // self.block_size * @@ -349,8 +349,8 @@ def retrieve_layer( req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, starts, ends, block_ids, layer_id) - self.kv_recv_thread.add_request( - req_meta) # type: ignore[union-attr] + self.kv_recv_thread.add_request( # type: ignore[union-attr][call-arg] + req_meta) # type: ignore[arg-type] first_flag = False yield None else: @@ -415,8 +415,8 @@ def store_layer( req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, starts, ends, block_ids, layer_id) - self.kv_send_thread.add_request( - req_meta) # type: ignore[union-attr] + self.kv_send_thread.add_request( # type: ignore[union-attr][call-arg] + req_meta) # type: ignore[arg-type] yield else: for layer_id in range(self.num_layers): @@ -429,8 +429,8 @@ def get_finished(self) -> tuple[set[str], set[str]]: self.kv_send_thread. get_and_clear_finished_requests( # type: ignore[union-attr] ) if self.kv_role == 'kv_producer' else set()) - done_recving = self.kv_recv_thread.get_and_clear_finished_requests( - ) # type: ignore[union-attr] + done_recving = self.kv_recv_thread.get_and_clear_finished_requests( # type: ignore[union-attr] + ) logger.debug( "Number of completed KV cache send requests: %d, receive " From c9104e298a6b2971369c07055609b0ba802d51fb Mon Sep 17 00:00:00 2001 From: fems14 <1804143737@qq.com> Date: Tue, 16 Sep 2025 21:50:42 +0800 Subject: [PATCH 06/10] mooncake store lint fix Signed-off-by: fems14 <1804143737@qq.com> --- vllm_ascend/distributed/mooncake/kv_transfer.py | 14 +++++++------- .../distributed/mooncake/mooncake_engine.py | 9 +++++---- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/distributed/mooncake/kv_transfer.py b/vllm_ascend/distributed/mooncake/kv_transfer.py index c0adfc2f7e..dee5101b68 100644 --- a/vllm_ascend/distributed/mooncake/kv_transfer.py +++ b/vllm_ascend/distributed/mooncake/kv_transfer.py @@ -200,11 +200,11 @@ def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, self.final_layer_id = num_layers - 1 def add_request( # type: ignore[override] - self, - req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: + self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: self.request_queue.put(req_meta) - def _handle_request(self, req_meta: LasyerMultiBlockReqMeta): # type: ignore[override] + def _handle_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta): torch.npu.current_stream().synchronize() for index, key in enumerate(req_meta.keys): addr, size = self.prepare_value_layer(req_meta.starts[index], @@ -236,11 +236,11 @@ def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, self.get_event = get_event def add_request( # type: ignore[override] - self, - req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: + self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: self.request_queue.put(req_meta) - def _handle_request(self, req_meta: LasyerMultiBlockReqMeta): # type: ignore[override] + def _handle_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta): for index, key in enumerate(req_meta.keys): addr, size = self.prepare_value_layer(req_meta.starts[index], req_meta.ends[index], @@ -248,4 +248,4 @@ def _handle_request(self, req_meta: LasyerMultiBlockReqMeta): # type: ignore[ov req_meta.layer_id) self.m_store.get(key, addr, size) self.request_queue.task_done() - self.get_event.set() \ No newline at end of file + self.get_event.set() diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py index cb8d83daee..2482fc9969 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -218,7 +218,8 @@ def save_kv_layer(self, ) if skip_leading_tokens == len(token_ids): if request.is_last_chunk: - self.kv_send_thread.set_finished_request(req_id) # type: ignore[union-attr] + self.kv_send_thread.set_finished_request( # type: ignore[union-attr] + req_id) continue # skip this request skip_leading_tokens = (skip_leading_tokens // self.block_size * @@ -350,7 +351,7 @@ def retrieve_layer( starts, ends, block_ids, layer_id) self.kv_recv_thread.add_request( # type: ignore[union-attr][call-arg] - req_meta) # type: ignore[arg-type] + req_meta) # type: ignore[union-attr][call-arg] first_flag = False yield None else: @@ -416,7 +417,7 @@ def store_layer( starts, ends, block_ids, layer_id) self.kv_send_thread.add_request( # type: ignore[union-attr][call-arg] - req_meta) # type: ignore[arg-type] + req_meta) # type: ignore[union-attr][call-arg] yield else: for layer_id in range(self.num_layers): @@ -485,4 +486,4 @@ def lookup( def close(self) -> None: """Close the cache engine and free all the resources""" - self.m_store.close() \ No newline at end of file + self.m_store.close() From 0009434993a033a9ee31cc364a85558314bb19ad Mon Sep 17 00:00:00 2001 From: fems14 <1804143737@qq.com> Date: Wed, 17 Sep 2025 10:20:19 +0800 Subject: [PATCH 07/10] mooncake store lintfix Signed-off-by: fems14 <1804143737@qq.com> --- vllm_ascend/distributed/mooncake/mooncake_engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py index 2482fc9969..96f2949045 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -350,8 +350,8 @@ def retrieve_layer( req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, starts, ends, block_ids, layer_id) - self.kv_recv_thread.add_request( # type: ignore[union-attr][call-arg] - req_meta) # type: ignore[union-attr][call-arg] + self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg] + req_meta) # type: ignore[union-attr, call-arg, arg-type] first_flag = False yield None else: @@ -416,8 +416,8 @@ def store_layer( req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, starts, ends, block_ids, layer_id) - self.kv_send_thread.add_request( # type: ignore[union-attr][call-arg] - req_meta) # type: ignore[union-attr][call-arg] + self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg] + req_meta) # type: ignore[union-attr, call-arg, arg-type] yield else: for layer_id in range(self.num_layers): From c615c379d4e9a1a569ff288f405b2fc506fd4916 Mon Sep 17 00:00:00 2001 From: fems14 <1804143737@qq.com> Date: Wed, 17 Sep 2025 11:27:25 +0800 Subject: [PATCH 08/10] add md Signed-off-by: fems14 <1804143737@qq.com> --- ...oncake_connector_store_deployment_guide.md | 264 ++++++++++++++++++ .../distributed/mooncake/mooncake_engine.py | 2 +- 2 files changed, 265 insertions(+), 1 deletion(-) create mode 100644 vllm_ascend/distributed/mooncake/mooncake_connector_store_deployment_guide.md diff --git a/vllm_ascend/distributed/mooncake/mooncake_connector_store_deployment_guide.md b/vllm_ascend/distributed/mooncake/mooncake_connector_store_deployment_guide.md new file mode 100644 index 0000000000..8d2ec2a3f2 --- /dev/null +++ b/vllm_ascend/distributed/mooncake/mooncake_connector_store_deployment_guide.md @@ -0,0 +1,264 @@ +# MultiConnector + Mooncake Basic Scenario Verification & Pooling and Mixed Deployment Scenario Verification + +## Environmental Dependencies + +* Software: + * Python >= 3.9, < 3.12 + * CANN >= 8.2.rc1 + * PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724 + * vLLM:Mainline branch + * vLLM-Ascend:Mainline branch + * Mooncake:[AscendTransport/Mooncake at pooling-async-memcpy](https://github.com/AscendTransport/Mooncake/tree/pooling-async-memcpy) + * mooncake-transfer-engine reference documentation: https://github.com/kvcache-ai/Mooncake/blob/main/doc/zh/ascend_transport.md + +## run mooncake master + +### 1.Configure mooncake.json + +The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path where mooncake.json is located. + +``` +{ + "local_hostname": "xx.xx.xx.xx", + "metadata_server": "P2PHANDSHAKE", + "protocol": "ascend", + "device_name": "", + "master_server_address": "xx.xx.xx.xx:50088", + "global_segment_size": 30000000000 +} +``` + +**local_hostname**: Configured as the IP address of the current master node, +**metadata_server**: Configured as **P2PHANDSHAKE**, +**protocol:** Configured for Ascend to use Mooncake's HCCL communication, +**device_name**: "" +**master_server_address**: Configured with the IP and port of the master service +**global_segment_size**: Expands the kvcache size registered by the PD node to the master + +### 2. Start mooncake_master + +Under the mooncake folder: + +``` +mooncake_master --port 50088 +``` + +## multiConnector + mooncake basic scenario + +### 1.Run `prefill` Node and `decode` Node + +`prefill` Node: + +``` +bash multi_producer.sh +``` + +The content of the multi_producer.sh script: + +``` +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH +export PYTHONPATH=$PYTHONPATH:/xxxxx/vllm +export MOONCAKE_CONFIG_PATH="/xxxxxx/mooncake.json" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 +export ASCEND_TRANSPORT_PRINT=1 +# The upper boundary environment variable for memory swap logging is set to mooncake, where 1 indicates enabled and 0 indicates disabled. +export ASCEND_AGGREGATE_ENABLE=1 +# The upper-level environment variable is the switch for enabling the mooncake aggregation function, where 1 means on and 0 means off. + +python3 -m vllm.entrypoints.openai.api_server \ + --model /xxxxx/Qwen2.5-7B-Instruct \ + --port 8100 \ + --trust-remote-code \ + --enforce-eager \ + --no_enable_prefix_caching \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 10000 \ + --block-size 128 \ + --max-num-batched-tokens 4096 \ + --kv-transfer-config \ + '{ + "kv_connector": "MultiConnector", + "kv_role": "kv_producer", + "kv_connector_extra_config": { + "use_layerwise": false, + "connectors": [ + { + "kv_connector": "MooncakeConnectorV1", + "kv_role": "kv_producer", + "kv_port": "20001", + "kv_connector_extra_config": { + "prefill": { + "dp_size": 1, + "tp_size": 1 + }, + "decode": { + "dp_size": 1, + "tp_size": 1 + } + } + }, + { + "kv_connector": "MooncakeConnectorStoreV1", + "kv_role": "kv_producer", + } + ] + } +}' > p.log 2>&1 +``` + +`decode` Node: + +``` +bash multi_consumer.sh +``` + +The content of multi_consumer.sh: + +``` +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH +export PYTHONPATH=$PYTHONPATH:/xxxxx/vllm +export MOONCAKE_CONFIG_PATH="/xxxxx/mooncake.json" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=4,5,6,7 +export ASCEND_TRANSPORT_PRINT=1 +# The upper boundary environment variable for memory swap logging is set to mooncake, where 1 indicates enabled and 0 indicates disabled. +export ASCEND_AGGREGATE_ENABLE=1 +# The upper-level environment variable is the switch for enabling the mooncake aggregation function, where 1 means on and 0 means off. + +python3 -m vllm.entrypoints.openai.api_server \ + --model /xxxxx/Qwen2.5-7B-Instruct \ + --port 8200 \ + --trust-remote-code \ + --enforce-eager \ + --no_enable_prefix_caching \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 10000 \ + --block-size 128 \ + --max-num-batched-tokens 4096 \ + --kv-transfer-config \ + '{ + "kv_connector": "MultiConnector", + "kv_role": "kv_consumer", + "kv_connector_extra_config": { + "use_layerwise": false, + "connectors": [ + { + "kv_connector": "MooncakeConnectorV1", + "kv_role": "kv_consumer", + "kv_port": "20002", + "kv_connector_extra_config": { + "prefill": { + "dp_size": 1, + "tp_size": 1 + }, + "decode": { + "dp_size": 1, + "tp_size": 1 + } + } + }, + { + "kv_connector": "MooncakeConnectorStoreV1", + "kv_role": "kv_consumer", + } + ] + } + }' > d.log 2>&1 +``` + +### 2、Start proxy_server. + +``` +bash proxy.sh +``` + +proxy.sh content: +Change localhost to your actual IP address. + +``` +python vllm-ascend/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py \ + --host localhost\ + --prefiller-hosts localhost \ + --prefiller-ports 8100 \ + --decoder-hosts localhost\ + --decoder-ports 8200 \ +``` + +### 3. Run Inference + +Configure the localhost, port, and model weight path in the command to your own settings. + +Short question: + +``` +curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Hello. I have a question. The president of the United States is", "max_tokens": 200, "temperature":0.0 }' +``` + +Long question: + +``` +curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?", "max_tokens": 256, "temperature":0.0 }' +``` + +## Pooling and Mixed Deployment Scenario + +### 1、Run Mixed Department Script + +The mixed script is essentially a pure pooling scenario for the P node. + +``` +bash mixed_department.sh +``` + +Content of mixed_department.sh: + +``` +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH +export PYTHONPATH=$PYTHONPATH:/xxxxx/vllm +export MOONCAKE_CONFIG_PATH="/xxxxxx/mooncake.json" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 +export ASCEND_TRANSPORT_PRINT=1 +# The upper boundary environment variable for memory swap logging is set to mooncake, where 1 indicates enabled and 0 indicates disabled. +export ASCEND_AGGREGATE_ENABLE=1 +# The upper-level environment variable is the switch for enabling the mooncake aggregation function, where 1 means on and 0 means off. + +python3 -m vllm.entrypoints.openai.api_server \ + --model /xxxxx/Qwen2.5-7B-Instruct \ + --port 8100 \ + --trust-remote-code \ + --enforce-eager \ + --no_enable_prefix_caching \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 10000 \ + --block-size 128 \ + --max-num-batched-tokens 4096 \ + --kv-transfer-config \ + '{ + "kv_connector": "MooncakeConnectorStoreV1", + "kv_role": "kv_producer", + "kv_connector_extra_config": { + "use_layerwise": false + } +}' > mix.log 2>&1 +``` + +### 2. Run Inference + +Configure the localhost, port, and model weight path in the command to your own settings. The requests sent will only go to the port where the mixed deployment script is located, and there is no need to start a separate proxy. + +Short question: + +``` +curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Hello. I have a question. The president of the United States is", "max_tokens": 200, "temperature":0.0 }' +``` + +Long question: + +``` +curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?", "max_tokens": 256, "temperature":0.0 }' +``` \ No newline at end of file diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py index 96f2949045..53c2724b45 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -345,7 +345,7 @@ def retrieve_layer( if not first_flag: is_finish = self.get_event.wait(timeout=3) #try---cache if not is_finish: - raise SystemError("Layerwise get failed") + logger.info("Layerwise get failed") self.get_event.clear() req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, starts, ends, block_ids, From dc919aad79b679e5f73b2ae14c79043f174ddb63 Mon Sep 17 00:00:00 2001 From: fems14 <1804143737@qq.com> Date: Wed, 17 Sep 2025 11:42:09 +0800 Subject: [PATCH 09/10] add md Signed-off-by: fems14 <1804143737@qq.com> --- .../mooncake_connector_store_deployment_guide.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {vllm_ascend/distributed/mooncake => examples/disaggregated_prefill_v1}/mooncake_connector_store_deployment_guide.md (100%) diff --git a/vllm_ascend/distributed/mooncake/mooncake_connector_store_deployment_guide.md b/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md similarity index 100% rename from vllm_ascend/distributed/mooncake/mooncake_connector_store_deployment_guide.md rename to examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md From 98edcb4bbf2b4d3f6163440c3a79e16b56891096 Mon Sep 17 00:00:00 2001 From: fems14 <1804143737@qq.com> Date: Thu, 18 Sep 2025 10:19:42 +0800 Subject: [PATCH 10/10] modify md Signed-off-by: fems14 <1804143737@qq.com> --- ...oncake_connector_store_deployment_guide.md | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md b/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md index 8d2ec2a3f2..b91705aee7 100644 --- a/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md +++ b/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md @@ -1,4 +1,4 @@ -# MultiConnector + Mooncake Basic Scenario Verification & Pooling and Mixed Deployment Scenario Verification +# Mooncacke Store Deployment Guide ## Environmental Dependencies @@ -6,11 +6,11 @@ * Python >= 3.9, < 3.12 * CANN >= 8.2.rc1 * PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724 - * vLLM:Mainline branch - * vLLM-Ascend:Mainline branch - * Mooncake:[AscendTransport/Mooncake at pooling-async-memcpy](https://github.com/AscendTransport/Mooncake/tree/pooling-async-memcpy) - * mooncake-transfer-engine reference documentation: https://github.com/kvcache-ai/Mooncake/blob/main/doc/zh/ascend_transport.md - + * vLLM:main branch + * vLLM-Ascend:main branch + * Mooncake:[AscendTransport/Mooncake at pooling-async-memcpy](https://github.com/AscendTransport/Mooncake/tree/pooling-async-memcpy)(Currently available branch code, continuously updated.) + Installation and Compilation Guide:https://github.com/AscendTransport/Mooncake/tree/pooling-async-memcpy?tab=readme-ov-file#build-and-use-binaries + ## run mooncake master ### 1.Configure mooncake.json @@ -43,10 +43,12 @@ Under the mooncake folder: mooncake_master --port 50088 ``` -## multiConnector + mooncake basic scenario +## Pooling and Prefill Decode Disaggregate Scenario ### 1.Run `prefill` Node and `decode` Node +Using MultiConnector to simultaneously utilize both p2p connectors and pooled connectors. P2P performs kv_transfer, while pooling creates a larger prefix-cache. + `prefill` Node: ``` @@ -254,11 +256,11 @@ Configure the localhost, port, and model weight path in the command to your own Short question: ``` -curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Hello. I have a question. The president of the United States is", "max_tokens": 200, "temperature":0.0 }' +curl -s http://localhost:8100/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Hello. I have a question. The president of the United States is", "max_tokens": 200, "temperature":0.0 }' ``` Long question: ``` -curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?", "max_tokens": 256, "temperature":0.0 }' +curl -s http://localhost:8100/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?", "max_tokens": 256, "temperature":0.0 }' ``` \ No newline at end of file