|
1 | 1 | from dataclasses import dataclass
|
2 |
| -from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar |
| 2 | +from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar, List |
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | import torch_npu
|
|
12 | 12 | from vllm.model_executor.layers.linear import (LinearBase,
|
13 | 13 | UnquantizedLinearMethod)
|
14 | 14 | from vllm.utils import cdiv, round_down
|
| 15 | +from vllm.distributed.kv_transfer import (get_kv_transfer_group, |
| 16 | + has_kv_transfer_group, |
| 17 | + is_v1_kv_transfer_group) |
| 18 | +from vllm.forward_context import ForwardContext, get_forward_context |
15 | 19 |
|
16 | 20 | from vllm_ascend.ascend_config import get_ascend_config
|
17 | 21 | from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
@@ -976,6 +980,7 @@ def forward(
|
976 | 980 | assert attn_metadata.num_decodes is not None and \
|
977 | 981 | attn_metadata.num_prefills is not None and \
|
978 | 982 | attn_metadata.num_decode_tokens is not None
|
| 983 | + self.wait_for_kv_layer_from_connector(layer.layer_name) |
979 | 984 | num_decode_tokens = attn_metadata.num_decode_tokens
|
980 | 985 | # Inputs and outputs may be padded for CUDA graphs
|
981 | 986 | output_padded = output
|
@@ -1046,4 +1051,36 @@ def forward(
|
1046 | 1051 | is_force_scatter=self.enable_shared_expert_dp)[0]
|
1047 | 1052 | current_ms_metadata.after_comm_event.record()
|
1048 | 1053 | del o_proj_input
|
| 1054 | + self.maybe_save_kv_layer_to_connector(layer_name=layer.layer_name, kv_cache_layer=kv_cache) |
1049 | 1055 | return output_padded
|
| 1056 | + |
| 1057 | + def wait_for_kv_layer_from_connector(self, layer_name: str): |
| 1058 | + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): |
| 1059 | + return |
| 1060 | + |
| 1061 | + connector = get_kv_transfer_group() |
| 1062 | + |
| 1063 | + forward_context: ForwardContext = get_forward_context() |
| 1064 | + attn_metadata = forward_context.attn_metadata |
| 1065 | + if attn_metadata is None: |
| 1066 | + return |
| 1067 | + assert isinstance(attn_metadata, AscendMLAMetadata) |
| 1068 | + connector.wait_for_layer_load(layer_name) |
| 1069 | + |
| 1070 | + def maybe_save_kv_layer_to_connector( |
| 1071 | + self, |
| 1072 | + layer_name: str, |
| 1073 | + kv_cache_layer: List[torch.Tensor], |
| 1074 | + ): |
| 1075 | + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): |
| 1076 | + return |
| 1077 | + |
| 1078 | + connector = get_kv_transfer_group() |
| 1079 | + |
| 1080 | + forward_context: ForwardContext = get_forward_context() |
| 1081 | + attn_metadata = forward_context.attn_metadata |
| 1082 | + if attn_metadata is None: |
| 1083 | + return |
| 1084 | + assert isinstance(attn_metadata, AscendMLAMetadata) |
| 1085 | + connector.save_kv_layer(layer_name, kv_cache_layer, |
| 1086 | + attn_metadata) |
0 commit comments