Skip to content

Commit 1a8a46a

Browse files
committed
support kv cache cpu offload with connector
Signed-off-by: lidenghui <lidenghui1110@gmail.com> Signed-off-by: AlvisGong <gwly0401@163.com>
1 parent 53ecd89 commit 1a8a46a

File tree

12 files changed

+2329
-66
lines changed

12 files changed

+2329
-66
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,11 @@ def test_mla_preprocess(self, magic_npu_fetch):
554554
self.impl.num_kv_heads = self.impl.num_heads
555555

556556
decode_res, prefill_res = self.impl._mla_preprocess(
557-
hidden_states, kv_cache, attn_metadata, need_gather_q_kv=False)
557+
"mock_layer",
558+
hidden_states,
559+
kv_cache,
560+
attn_metadata,
561+
need_gather_q_kv=False)
558562

559563
self.assertIsNotNone(decode_res)
560564
self.assertIsNotNone(prefill_res)

tests/ut/distributed/cpu_offload/test_cpu_kv_cache_manager.py

Lines changed: 440 additions & 0 deletions
Large diffs are not rendered by default.

tests/ut/distributed/cpu_offload/test_cpu_offload_connector.py

Lines changed: 897 additions & 0 deletions
Large diffs are not rendered by default.

tests/ut/torchair/models/test_torchair_deepseek_v2.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222
from vllm.distributed.parallel_state import GroupCoordinator
2323

2424
from vllm_ascend.torchair.models.torchair_deepseek_v2 import (
25-
TorchairDeepseekV2DecoderLayer, TorchairDeepseekV2ForCausalLM,
26-
TorchairDeepseekV2MergedReplicatedLinear, TorchairDeepseekV2MLAAttention,
27-
TorchairDeepseekV2MLP, TorchairDeepseekV2MoE,
28-
TorchairDeepseekV2RowParallelLinear,
25+
TorchairDeepseekV2DecoderLayer, TorchairDeepseekV2MergedReplicatedLinear,
26+
TorchairDeepseekV2MLAAttention, TorchairDeepseekV2MLP,
27+
TorchairDeepseekV2MoE, TorchairDeepseekV2RowParallelLinear,
2928
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
3029
TorchairDeepseekV2SiluAndMul)
3130

@@ -310,22 +309,3 @@ def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
310309
model_config=vllm_config.model_config,
311310
quant_config=None)
312311
assert isinstance(layer.mlp, TorchairDeepseekV2MLP)
313-
314-
315-
def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config):
316-
model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config)
317-
318-
input_ids = torch.randint(0, 10000, (2, 4))
319-
positions = torch.arange(4).repeat(2, 1)
320-
with patch.object(model.model,
321-
"forward",
322-
return_value=torch.randn(2, 4, 128)):
323-
output = model(input_ids, positions)
324-
assert output.shape == (2, 4, 128)
325-
326-
weights = [("model.embed_tokens.weight", torch.randn(10000, 128))]
327-
with patch(
328-
"vllm.model_executor.model_loader.weight_utils.default_weight_loader"
329-
):
330-
loaded = model.load_weights(weights)
331-
assert loaded is not None

vllm_ascend/attention/attention_v1.py

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,51 +26,19 @@
2626
AttentionLayer, AttentionType)
2727
from vllm.attention.backends.utils import CommonAttentionState
2828
from vllm.config import VllmConfig
29-
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
30-
has_kv_transfer_group,
31-
is_v1_kv_transfer_group)
3229
from vllm.forward_context import ForwardContext, get_forward_context
3330
from vllm.utils import cdiv, direct_register_custom_op
3431
from vllm.v1.core.sched.output import SchedulerOutput
3532
from vllm.v1.kv_cache_interface import AttentionSpec
3633

37-
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
34+
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
35+
maybe_save_kv_layer_to_connector,
36+
wait_for_kv_layer_from_connector)
3837
from vllm_ascend.ops.attention import vanilla_chunked_prefill
3938
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
4039
nd_to_nz_2d, nd_to_nz_spec)
4140

4241

43-
def wait_for_kv_layer_from_connector(layer_name: str):
44-
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
45-
return
46-
47-
connector = get_kv_transfer_group()
48-
49-
forward_context: ForwardContext = get_forward_context()
50-
attn_metadata = forward_context.attn_metadata
51-
if attn_metadata is None:
52-
return
53-
# TODO: assert ascendMetadata
54-
connector.wait_for_layer_load(layer_name)
55-
56-
57-
def maybe_save_kv_layer_to_connector(
58-
layer_name: str,
59-
kv_cache_layer: List[torch.Tensor],
60-
):
61-
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
62-
return
63-
64-
connector = get_kv_transfer_group()
65-
66-
forward_context: ForwardContext = get_forward_context()
67-
attn_metadata = forward_context.attn_metadata
68-
if attn_metadata is None:
69-
return
70-
# TODO: assert ascendMetadata
71-
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
72-
73-
7442
class AscendAttentionBackend(AttentionBackend):
7543
accept_output_buffer: bool = True
7644

vllm_ascend/attention/mla_v1.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from vllm_ascend.ascend_config import get_ascend_config
1717
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1818
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
19-
split_decodes_and_prefills)
19+
maybe_save_kv_layer_to_connector,
20+
split_decodes_and_prefills,
21+
wait_for_kv_layer_from_connector)
2022
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2123
from vllm_ascend.multistream.context import get_multistream_comm_context
2224
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
@@ -857,8 +859,8 @@ def _forward_decode(
857859
current_ms_metadata.before_comm_event.wait()
858860
return self._v_up_proj(attn_output)
859861

860-
def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
861-
need_gather_q_kv):
862+
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
863+
attn_metadata, need_gather_q_kv):
862864
# MLA Preprocess:
863865
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
864866
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
@@ -887,6 +889,8 @@ def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
887889
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
888890
decode_preprocess_res = None
889891
prefill_preprocess_res = None
892+
if has_prefill:
893+
wait_for_kv_layer_from_connector(layer_name)
890894
# Preprocess for decode tokens
891895
if has_decode:
892896
decode_q_c = q_c[:num_decode_tokens]
@@ -933,6 +937,7 @@ def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
933937

934938
def forward(
935939
self,
940+
layer_name,
936941
hidden_states: torch.Tensor, # query in unified attn
937942
kv_cache: Tuple[torch.Tensor],
938943
attn_metadata: M,
@@ -959,7 +964,8 @@ def forward(
959964

960965
# MLA Preprocess
961966
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
962-
hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
967+
layer_name, hidden_states, kv_cache, attn_metadata,
968+
need_gather_q_kv)
963969

964970
if decode_preprocess_res is not None:
965971
# MLA Preprocess for decoding
@@ -1017,4 +1023,8 @@ def forward(
10171023
is_force_scatter=self.enable_shared_expert_dp)[0]
10181024
current_ms_metadata.after_comm_event.record()
10191025
del o_proj_input
1026+
1027+
has_prefill = attn_metadata.num_prefills > 0
1028+
if has_prefill:
1029+
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
10201030
return output_padded

vllm_ascend/attention/utils.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from dataclasses import dataclass
2-
from typing import Any
2+
from typing import Any, List
33

44
import torch
5+
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
6+
has_kv_transfer_group,
7+
is_v1_kv_transfer_group)
8+
from vllm.forward_context import ForwardContext, get_forward_context
59

610

711
@dataclass
@@ -100,3 +104,34 @@ def split_decodes_and_prefills(
100104
num_decode_tokens = query_start_loc[first_prefill].item()
101105
num_prefill_tokens = num_tokens - num_decode_tokens
102106
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
107+
108+
109+
def wait_for_kv_layer_from_connector(layer_name: str):
110+
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
111+
return
112+
113+
connector = get_kv_transfer_group()
114+
115+
forward_context: ForwardContext = get_forward_context()
116+
attn_metadata = forward_context.attn_metadata
117+
if attn_metadata is None:
118+
return
119+
# TODO: assert ascendMetadata
120+
connector.wait_for_layer_load(layer_name)
121+
122+
123+
def maybe_save_kv_layer_to_connector(
124+
layer_name: str,
125+
kv_cache_layer: List[torch.Tensor],
126+
):
127+
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
128+
return
129+
130+
connector = get_kv_transfer_group()
131+
132+
forward_context: ForwardContext = get_forward_context()
133+
attn_metadata = forward_context.attn_metadata
134+
if attn_metadata is None:
135+
return
136+
# TODO: assert ascendMetadata
137+
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)

0 commit comments

Comments
 (0)