Skip to content

Commit 7655041

Browse files
LCAIZJfems14DreamerLeaderPz1116lizy124
committed
mooncake store connector
Co-authored-by: fems14 <1804143737@qq.com> Co-authored-by: Dreamerleader <2270923832@qq.com> Co-authored-by: Pz1116 <zpbzpb123123@gmail.com> Co-authored-by: lizy124 <1950471827@qq.com> Co-authored-by: zouyida2052 <zouyida2002@gmail.com> Signed-off-by: LCAIZJ <leichao139636@163.com>
1 parent f97a64b commit 7655041

File tree

8 files changed

+1792
-10
lines changed

8 files changed

+1792
-10
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar
2+
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar, List
33

44
import torch
55
import torch_npu
@@ -12,6 +12,10 @@
1212
from vllm.model_executor.layers.linear import (LinearBase,
1313
UnquantizedLinearMethod)
1414
from vllm.utils import cdiv, round_down
15+
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
16+
has_kv_transfer_group,
17+
is_v1_kv_transfer_group)
18+
from vllm.forward_context import ForwardContext, get_forward_context
1519

1620
from vllm_ascend.ascend_config import get_ascend_config
1721
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -976,6 +980,7 @@ def forward(
976980
assert attn_metadata.num_decodes is not None and \
977981
attn_metadata.num_prefills is not None and \
978982
attn_metadata.num_decode_tokens is not None
983+
self.wait_for_kv_layer_from_connector(layer.layer_name)
979984
num_decode_tokens = attn_metadata.num_decode_tokens
980985
# Inputs and outputs may be padded for CUDA graphs
981986
output_padded = output
@@ -1046,4 +1051,36 @@ def forward(
10461051
is_force_scatter=self.enable_shared_expert_dp)[0]
10471052
current_ms_metadata.after_comm_event.record()
10481053
del o_proj_input
1054+
self.maybe_save_kv_layer_to_connector(layer_name=layer.layer_name, kv_cache_layer=kv_cache)
10491055
return output_padded
1056+
1057+
def wait_for_kv_layer_from_connector(self, layer_name: str):
1058+
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
1059+
return
1060+
1061+
connector = get_kv_transfer_group()
1062+
1063+
forward_context: ForwardContext = get_forward_context()
1064+
attn_metadata = forward_context.attn_metadata
1065+
if attn_metadata is None:
1066+
return
1067+
assert isinstance(attn_metadata, AscendMLAMetadata)
1068+
connector.wait_for_layer_load(layer_name)
1069+
1070+
def maybe_save_kv_layer_to_connector(
1071+
self,
1072+
layer_name: str,
1073+
kv_cache_layer: List[torch.Tensor],
1074+
):
1075+
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
1076+
return
1077+
1078+
connector = get_kv_transfer_group()
1079+
1080+
forward_context: ForwardContext = get_forward_context()
1081+
attn_metadata = forward_context.attn_metadata
1082+
if attn_metadata is None:
1083+
return
1084+
assert isinstance(attn_metadata, AscendMLAMetadata)
1085+
connector.save_kv_layer(layer_name, kv_cache_layer,
1086+
attn_metadata)

vllm_ascend/distributed/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,8 @@
2626
KVConnectorFactory.register_connector(
2727
"MooncakeConnectorV1", "vllm_ascend.distributed.mooncake_connector",
2828
"MooncakeConnector")
29+
30+
KVConnectorFactory.register_connector(
31+
"MooncakeConnectorStoreV1",
32+
"vllm_ascend.distributed.mooncake.mooncake_store_connector_v1",
33+
"MooncakeConnectorV1")

0 commit comments

Comments
 (0)