Skip to content

Commit 5b6d1a0

Browse files
committed
support kv cache cpu offload with connector.
Signed-off-by: lidenghui <lidenghui1110@gmail.com> Signed-off-by: AlvisGong <gwly0401@163.com> Signed-off-by: CalvinXKY <kyxiezju@163.com> Signed-off-by: AlvisGong <gwly0401@163.com>
1 parent 8dd53c8 commit 5b6d1a0

File tree

10 files changed

+990
-44
lines changed

10 files changed

+990
-44
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,11 @@ def test_mla_preprocess(self, magic_npu_fetch):
554554
self.impl.num_kv_heads = self.impl.num_heads
555555

556556
decode_res, prefill_res = self.impl._mla_preprocess(
557-
hidden_states, kv_cache, attn_metadata, need_gather_q_kv=False)
557+
"mock_layer",
558+
hidden_states,
559+
kv_cache,
560+
attn_metadata,
561+
need_gather_q_kv=False)
558562

559563
self.assertIsNotNone(decode_res)
560564
self.assertIsNotNone(prefill_res)

tests/ut/torchair/models/test_torchair_deepseek_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,4 +328,4 @@ def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config):
328328
"vllm.model_executor.model_loader.weight_utils.default_weight_loader"
329329
):
330330
loaded = model.load_weights(weights)
331-
assert loaded is not None
331+
assert loaded is not None

vllm_ascend/attention/attention_v1.py

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,53 +26,21 @@
2626
AttentionLayer, AttentionType)
2727
from vllm.attention.backends.utils import CommonAttentionState
2828
from vllm.config import VllmConfig
29-
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
30-
has_kv_transfer_group,
31-
is_v1_kv_transfer_group)
3229
from vllm.forward_context import ForwardContext, get_forward_context
3330
from vllm.utils import cdiv, direct_register_custom_op
3431
from vllm.v1.attention.backends.utils import AttentionCGSupport
3532
from vllm.v1.core.sched.output import SchedulerOutput
3633
from vllm.v1.kv_cache_interface import AttentionSpec
3734

38-
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
35+
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
36+
maybe_save_kv_layer_to_connector,
37+
wait_for_kv_layer_from_connector)
3938
from vllm_ascend.compilation.acl_graph import get_graph_params
4039
from vllm_ascend.ops.attention import vanilla_chunked_prefill
4140
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
4241
nd_to_nz_2d, nd_to_nz_spec)
4342

4443

45-
def wait_for_kv_layer_from_connector(layer_name: str):
46-
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
47-
return
48-
49-
connector = get_kv_transfer_group()
50-
51-
forward_context: ForwardContext = get_forward_context()
52-
attn_metadata = forward_context.attn_metadata
53-
if attn_metadata is None:
54-
return
55-
# TODO: assert ascendMetadata
56-
connector.wait_for_layer_load(layer_name)
57-
58-
59-
def maybe_save_kv_layer_to_connector(
60-
layer_name: str,
61-
kv_cache_layer: List[torch.Tensor],
62-
):
63-
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
64-
return
65-
66-
connector = get_kv_transfer_group()
67-
68-
forward_context: ForwardContext = get_forward_context()
69-
attn_metadata = forward_context.attn_metadata
70-
if attn_metadata is None:
71-
return
72-
# TODO: assert ascendMetadata
73-
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
74-
75-
7644
class AscendAttentionBackend(AttentionBackend):
7745
accept_output_buffer: bool = True
7846

vllm_ascend/attention/mla_v1.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from vllm_ascend.ascend_config import get_ascend_config
1717
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1818
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
19-
split_decodes_and_prefills)
19+
maybe_save_kv_layer_to_connector,
20+
split_decodes_and_prefills,
21+
wait_for_kv_layer_from_connector)
2022
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2123
from vllm_ascend.multistream.context import get_multistream_comm_context
2224
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
@@ -853,8 +855,8 @@ def _forward_decode(
853855
current_ms_metadata.before_comm_event.wait()
854856
return self._v_up_proj(attn_output)
855857

856-
def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
857-
need_gather_q_kv):
858+
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
859+
attn_metadata, need_gather_q_kv):
858860
# MLA Preprocess:
859861
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
860862
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
@@ -883,6 +885,8 @@ def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
883885
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
884886
decode_preprocess_res = None
885887
prefill_preprocess_res = None
888+
if has_prefill:
889+
wait_for_kv_layer_from_connector(layer_name)
886890
# Preprocess for decode tokens
887891
if has_decode:
888892
decode_q_c = q_c[:num_decode_tokens]
@@ -929,6 +933,7 @@ def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
929933

930934
def forward(
931935
self,
936+
layer_name,
932937
hidden_states: torch.Tensor, # query in unified attn
933938
kv_cache: Tuple[torch.Tensor],
934939
attn_metadata: M,
@@ -955,7 +960,8 @@ def forward(
955960

956961
# MLA Preprocess
957962
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
958-
hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
963+
layer_name, hidden_states, kv_cache, attn_metadata,
964+
need_gather_q_kv)
959965

960966
if decode_preprocess_res is not None:
961967
# MLA Preprocess for decoding
@@ -1013,4 +1019,8 @@ def forward(
10131019
is_force_scatter=self.enable_shared_expert_dp)[0]
10141020
current_ms_metadata.after_comm_event.record()
10151021
del o_proj_input
1022+
1023+
has_prefill = attn_metadata.num_prefills > 0
1024+
if has_prefill:
1025+
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
10161026
return output_padded

vllm_ascend/attention/utils.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from dataclasses import dataclass
2-
from typing import Any
2+
from typing import Any, List
33

44
import torch
5+
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
6+
has_kv_transfer_group,
7+
is_v1_kv_transfer_group)
8+
from vllm.forward_context import ForwardContext, get_forward_context
59

610

711
@dataclass
@@ -100,3 +104,34 @@ def split_decodes_and_prefills(
100104
num_decode_tokens = query_start_loc[first_prefill].item()
101105
num_prefill_tokens = num_tokens - num_decode_tokens
102106
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
107+
108+
109+
def wait_for_kv_layer_from_connector(layer_name: str):
110+
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
111+
return
112+
113+
connector = get_kv_transfer_group()
114+
115+
forward_context: ForwardContext = get_forward_context()
116+
attn_metadata = forward_context.attn_metadata
117+
if attn_metadata is None:
118+
return
119+
# TODO: assert ascendMetadata
120+
connector.wait_for_layer_load(layer_name)
121+
122+
123+
def maybe_save_kv_layer_to_connector(
124+
layer_name: str,
125+
kv_cache_layer: List[torch.Tensor],
126+
):
127+
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
128+
return
129+
130+
connector = get_kv_transfer_group()
131+
132+
forward_context: ForwardContext = get_forward_context()
133+
attn_metadata = forward_context.attn_metadata
134+
if attn_metadata is None:
135+
return
136+
# TODO: assert ascendMetadata
137+
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)

0 commit comments

Comments
 (0)