Skip to content

Commit e77dbaf

Browse files
committed
Add kv copy kernel for between layers
1 parent a766a66 commit e77dbaf

File tree

9 files changed

+231
-105
lines changed

9 files changed

+231
-105
lines changed

csrc/cache.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
1515
std::vector<torch::Tensor> const& value_caches,
1616
const torch::Tensor& block_mapping);
1717

18+
void copy_blocks_between_layers(
19+
std::vector<torch::Tensor> const& src_key_caches,
20+
std::vector<torch::Tensor> const& src_value_caches,
21+
std::vector<torch::Tensor> const& dst_key_caches,
22+
std::vector<torch::Tensor> const& dst_value_caches,
23+
const torch::Tensor& block_mapping);
24+
1825
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
1926
const torch::Tensor& block_mapping);
2027

@@ -45,4 +52,4 @@ void gather_cache(
4552
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
4653
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
4754
torch::Tensor const& cu_seq_lens, // [BATCH+1]
48-
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
55+
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);

csrc/cache_kernels.cu

Lines changed: 103 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -68,32 +68,42 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
6868

6969
namespace vllm {
7070

71-
// Grid: (num_layers, num_pairs)
71+
// Grid: (layer_or_pair_idx, num_pairs)
7272
template <typename scalar_t>
73-
__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
74-
int64_t* value_cache_ptrs,
75-
const int64_t* __restrict__ block_mapping,
76-
const int numel_per_block) {
77-
const int layer_idx = blockIdx.x;
73+
__global__ void unified_copy_blocks_kernel(
74+
int64_t* src_key_cache_ptrs, int64_t* src_value_cache_ptrs,
75+
int64_t* dst_key_cache_ptrs, int64_t* dst_value_cache_ptrs,
76+
const int64_t* __restrict__ block_mapping, const int numel_per_block) {
77+
const int layer_or_pair_idx = blockIdx.x;
7878
const int pair_idx = blockIdx.y;
7979

80-
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
81-
scalar_t* value_cache =
82-
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
80+
scalar_t* src_key_cache =
81+
reinterpret_cast<scalar_t*>(src_key_cache_ptrs[layer_or_pair_idx]);
82+
scalar_t* src_value_cache =
83+
reinterpret_cast<scalar_t*>(src_value_cache_ptrs[layer_or_pair_idx]);
84+
scalar_t* dst_key_cache =
85+
reinterpret_cast<scalar_t*>(dst_key_cache_ptrs[layer_or_pair_idx]);
86+
scalar_t* dst_value_cache =
87+
reinterpret_cast<scalar_t*>(dst_value_cache_ptrs[layer_or_pair_idx]);
88+
8389
int64_t src_block_number = block_mapping[2 * pair_idx];
8490
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
8591

8692
const int64_t src_block_offset = src_block_number * numel_per_block;
8793
const int64_t dst_block_offset = dst_block_number * numel_per_block;
94+
95+
// Copy key cache from source to destination
8896
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
8997
int64_t src_offset = src_block_offset + i;
9098
int64_t dst_offset = dst_block_offset + i;
91-
key_cache[dst_offset] = key_cache[src_offset];
99+
dst_key_cache[dst_offset] = src_key_cache[src_offset];
92100
}
101+
102+
// Copy value cache from source to destination
93103
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
94104
int64_t src_offset = src_block_offset + i;
95105
int64_t dst_offset = dst_block_offset + i;
96-
value_cache[dst_offset] = value_cache[src_offset];
106+
dst_value_cache[dst_offset] = src_value_cache[src_offset];
97107
}
98108
}
99109

@@ -117,58 +127,108 @@ __global__ void copy_blocks_mla_kernel(
117127

118128
} // namespace vllm
119129

120-
// Note: the key_caches and value_caches vectors are constant but
121-
// not the Tensors they contain. The vectors need to be const refs
122-
// in order to satisfy pytorch's C++ operator registration code.
123-
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
124-
std::vector<torch::Tensor> const& value_caches,
125-
const torch::Tensor& block_mapping) {
126-
int num_layers = key_caches.size();
127-
TORCH_CHECK(num_layers == value_caches.size());
128-
if (num_layers == 0) {
130+
// Unified implementation function for both copy_blocks and
131+
// copy_blocks_between_caches
132+
void copy_blocks_impl(std::vector<torch::Tensor> const& src_key_caches,
133+
std::vector<torch::Tensor> const& src_value_caches,
134+
std::vector<torch::Tensor> const& dst_key_caches,
135+
std::vector<torch::Tensor> const& dst_value_caches,
136+
const torch::Tensor& block_mapping) {
137+
int num_src_dst_pairs = src_key_caches.size();
138+
TORCH_CHECK(num_src_dst_pairs == src_value_caches.size());
139+
TORCH_CHECK(num_src_dst_pairs == dst_key_caches.size());
140+
TORCH_CHECK(num_src_dst_pairs == dst_value_caches.size());
141+
142+
if (num_src_dst_pairs == 0) {
129143
return;
130144
}
131-
torch::Device cache_device = key_caches[0].device();
145+
146+
torch::Device cache_device = src_key_caches[0].device();
132147
TORCH_CHECK(cache_device.is_cuda());
133148

134-
// Create data structures for the kernel.
135-
// Create an array of pointers to the key and value caches.
136-
int64_t key_cache_ptrs[num_layers];
137-
int64_t value_cache_ptrs[num_layers];
138-
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
139-
key_cache_ptrs[layer_idx] =
140-
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
141-
value_cache_ptrs[layer_idx] =
142-
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
149+
// Create arrays of pointers to the source and destination key and value
150+
// caches
151+
int64_t src_key_cache_ptrs[num_src_dst_pairs];
152+
int64_t src_value_cache_ptrs[num_src_dst_pairs];
153+
int64_t dst_key_cache_ptrs[num_src_dst_pairs];
154+
int64_t dst_value_cache_ptrs[num_src_dst_pairs];
155+
156+
for (int pair_idx = 0; pair_idx < num_src_dst_pairs; ++pair_idx) {
157+
src_key_cache_ptrs[pair_idx] =
158+
reinterpret_cast<int64_t>(src_key_caches[pair_idx].data_ptr());
159+
src_value_cache_ptrs[pair_idx] =
160+
reinterpret_cast<int64_t>(src_value_caches[pair_idx].data_ptr());
161+
dst_key_cache_ptrs[pair_idx] =
162+
reinterpret_cast<int64_t>(dst_key_caches[pair_idx].data_ptr());
163+
dst_value_cache_ptrs[pair_idx] =
164+
reinterpret_cast<int64_t>(dst_value_caches[pair_idx].data_ptr());
143165
}
144166

145167
// block_mapping is a 2D tensor with shape (num_pairs, 2).
146168
int num_pairs = block_mapping.size(0);
147169

148-
// Move the data structures to the GPU.
149-
// NOTE: This synchronizes the CPU and GPU.
150-
torch::Tensor key_cache_ptrs_tensor =
151-
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
170+
// Move the data structures to the GPU
171+
torch::Tensor src_key_cache_ptrs_tensor =
172+
torch::from_blob(src_key_cache_ptrs, {num_src_dst_pairs}, torch::kInt64)
173+
.to(cache_device);
174+
torch::Tensor src_value_cache_ptrs_tensor =
175+
torch::from_blob(src_value_cache_ptrs, {num_src_dst_pairs}, torch::kInt64)
176+
.to(cache_device);
177+
torch::Tensor dst_key_cache_ptrs_tensor =
178+
torch::from_blob(dst_key_cache_ptrs, {num_src_dst_pairs}, torch::kInt64)
152179
.to(cache_device);
153-
torch::Tensor value_cache_ptrs_tensor =
154-
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
180+
torch::Tensor dst_value_cache_ptrs_tensor =
181+
torch::from_blob(dst_value_cache_ptrs, {num_src_dst_pairs}, torch::kInt64)
155182
.to(cache_device);
156183

157-
// Launch the kernel.
158-
const int numel_per_block = key_caches[0][0].numel();
159-
dim3 grid(num_layers, num_pairs);
184+
// Launch the kernel
185+
const int numel_per_block = src_key_caches[0][0].numel();
186+
dim3 grid(num_src_dst_pairs, num_pairs);
160187
dim3 block(std::min(1024, numel_per_block));
161188
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
162189
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
190+
163191
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
164-
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
165-
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
166-
key_cache_ptrs_tensor.data_ptr<int64_t>(),
167-
value_cache_ptrs_tensor.data_ptr<int64_t>(),
192+
src_key_caches[0].scalar_type(), "unified_copy_blocks_kernel", ([&] {
193+
vllm::unified_copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
194+
src_key_cache_ptrs_tensor.data_ptr<int64_t>(),
195+
src_value_cache_ptrs_tensor.data_ptr<int64_t>(),
196+
dst_key_cache_ptrs_tensor.data_ptr<int64_t>(),
197+
dst_value_cache_ptrs_tensor.data_ptr<int64_t>(),
168198
block_mapping.data_ptr<int64_t>(), numel_per_block);
169199
}));
170200
}
171201

202+
// Note: the key_caches and value_caches vectors are constant but
203+
// not the Tensors they contain. The vectors need to be const refs
204+
// in order to satisfy pytorch's C++ operator registration code.
205+
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
206+
std::vector<torch::Tensor> const& value_caches,
207+
const torch::Tensor& block_mapping) {
208+
int num_layers = key_caches.size();
209+
TORCH_CHECK(num_layers == value_caches.size());
210+
if (num_layers == 0) {
211+
return;
212+
}
213+
214+
// Call the unified implementation with the same caches for both source and
215+
// destination
216+
copy_blocks_impl(key_caches, value_caches, key_caches, value_caches,
217+
block_mapping);
218+
}
219+
220+
// Function to copy blocks between different layers
221+
void copy_blocks_between_layers(
222+
std::vector<torch::Tensor> const& src_key_caches,
223+
std::vector<torch::Tensor> const& src_value_caches,
224+
std::vector<torch::Tensor> const& dst_key_caches,
225+
std::vector<torch::Tensor> const& dst_value_caches,
226+
const torch::Tensor& block_mapping) {
227+
// Call the unified implementation with separate source and destination caches
228+
copy_blocks_impl(src_key_caches, src_value_caches, dst_key_caches,
229+
dst_value_caches, block_mapping);
230+
}
231+
172232
// copy blocks kernel for MLA (assumes a joint KV-cache)
173233
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
174234
const torch::Tensor& block_mapping) {

csrc/torch_bindings.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
660660
"copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()");
661661
cache_ops.impl("copy_blocks_mla", torch::kCUDA, &copy_blocks_mla);
662662

663+
// Copy blocks between different caches
664+
cache_ops.def(
665+
"copy_blocks_between_layers(Tensor(a!)[] src_key_caches, Tensor(b!)[] "
666+
"src_value_caches, "
667+
"Tensor(c!)[] dst_key_caches, Tensor(d!)[] dst_value_caches, "
668+
"Tensor block_mapping) -> ()");
669+
cache_ops.impl("copy_blocks_between_layers", torch::kCUDA,
670+
&copy_blocks_between_layers);
671+
663672
// Reshape the key and value tensors and cache them.
664673
cache_ops.def(
665674
"reshape_and_cache(Tensor key, Tensor value,"

examples/offline_inference/spec_decode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def parse_args():
7474
action="store_false",
7575
help="Disable prefill token shift (default: enabled)",
7676
)
77-
parser.add_argument("--target_kv_layer_copy_from", type=int, default=-1)
77+
parser.add_argument("--target-kv-layer-copy-from", type=int, default=-1)
7878
parser.add_argument(
79-
"--draft_kv_layer_copy_to",
79+
"--draft-kv-layer-copy-to",
8080
type=str,
8181
default="",
8282
help="comma separated list of layer indices to copy to",

tests/kernels/attention/test_cache.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,40 @@ def test_copy_blocks(
117117
cloned_value_caches):
118118
torch.testing.assert_close(value_cache, cloned_value_cache)
119119

120+
# Test copy_blocks_between_layers
121+
num_source_layers = num_layers // 4
122+
source_layers = random.sample(range(num_layers), num_source_layers)
123+
target_layers = random.sample(range(num_layers), num_source_layers)
124+
125+
# Get source and target key/value caches using list comprehension
126+
src_key_caches = [key_caches[i] for i in source_layers]
127+
src_value_caches = [value_caches[i] for i in source_layers]
128+
dst_key_caches = [key_caches[i] for i in target_layers]
129+
dst_value_caches = [value_caches[i] for i in target_layers]
130+
131+
opcheck(torch.ops._C_cache_ops.copy_blocks_between_layers,
132+
(src_key_caches, src_value_caches, dst_key_caches,
133+
dst_value_caches, block_mapping_tensor),
134+
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
135+
cond=(head_size == HEAD_SIZES[0]))
136+
ops.copy_blocks_between_layers(src_key_caches, src_value_caches,
137+
dst_key_caches, dst_value_caches,
138+
block_mapping_tensor)
139+
# Run the reference implementation for copy_blocks_between_layers
140+
for src, dst in block_mapping:
141+
for src_layer, dst_layer in zip(source_layers, target_layers):
142+
cloned_key_caches[dst_layer][dst].copy_(
143+
cloned_key_caches[src_layer][src])
144+
cloned_value_caches[dst_layer][dst].copy_(
145+
cloned_value_caches[src_layer][src])
146+
147+
# Compare the results for copy_blocks_between_layers
148+
for src_layer, dst_layer in zip(source_layers, target_layers):
149+
torch.testing.assert_close(key_caches[dst_layer],
150+
cloned_key_caches[dst_layer])
151+
torch.testing.assert_close(value_caches[dst_layer],
152+
cloned_value_caches[dst_layer])
153+
120154

121155
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
122156
@pytest.mark.parametrize("num_heads", NUM_HEADS)

vllm/_custom_ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,6 +1655,28 @@ def copy_blocks_mla(kv_caches: list[torch.Tensor],
16551655
torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping)
16561656

16571657

1658+
def copy_blocks_between_layers(src_key_caches: list[torch.Tensor],
1659+
src_value_caches: list[torch.Tensor],
1660+
dst_key_caches: list[torch.Tensor],
1661+
dst_value_caches: list[torch.Tensor],
1662+
block_mapping: torch.Tensor) -> None:
1663+
"""Copy blocks between different key-value caches across model layers.
1664+
1665+
Args:
1666+
src_key_caches: List of source key cache tensors.
1667+
src_value_caches: List of source value cache tensors.
1668+
dst_key_caches: List of destination key cache tensors.
1669+
dst_value_caches: List of destination value cache tensors.
1670+
block_mapping: Tensor of shape (num_blocks, 2) containing pairs of
1671+
(src_block_idx, dst_block_idx) to copy.
1672+
"""
1673+
torch.ops._C_cache_ops.copy_blocks_between_layers(src_key_caches,
1674+
src_value_caches,
1675+
dst_key_caches,
1676+
dst_value_caches,
1677+
block_mapping)
1678+
1679+
16581680
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
16591681
block_mapping: torch.Tensor) -> None:
16601682
torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)

vllm/v1/spec_decode/eagle.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from vllm.model_executor.model_loader import get_model
1717
from vllm.model_executor.models import supports_multimodal
1818
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
19-
from vllm.utils import is_pin_memory_available
2019
from vllm.model_executor.models.utils import extract_layer_index
20+
from vllm.utils import is_pin_memory_available
2121
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
2222
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
2323
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -109,8 +109,7 @@ def _prepare_adjusted_tensors(
109109
block_table: torch.Tensor,
110110
batch_size: int,
111111
num_tokens: int,
112-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int,
113-
torch.Tensor]:
112+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:
114113
"""
115114
Prepare adjusted tensors for different request types
116115
(partial prefill, full prefill, full decode).
@@ -130,7 +129,7 @@ def _prepare_adjusted_tensors(
130129
131130
Returns:
132131
tuple: (target_positions, target_hidden_states, target_slot_mapping,
133-
cu_num_tokens, current_pos, partial_prefill_mask)
132+
cu_num_tokens, current_pos)
134133
135134
Algorithm design:
136135
- Suppose target tokens are [1,2,3,...N], next token is N+1
@@ -358,7 +357,6 @@ def _prepare_adjusted_tensors(
358357
target_slot_mapping,
359358
cu_num_tokens,
360359
current_pos,
361-
partial_prefill_mask,
362360
)
363361

364362
def propose(
@@ -411,7 +409,6 @@ def propose(
411409
target_slot_mapping,
412410
query_start_loc,
413411
num_tokens,
414-
partial_prefill_mask,
415412
) = self._prepare_adjusted_tensors(
416413
target_token_ids,
417414
target_positions,
@@ -452,19 +449,17 @@ def propose(
452449
max_num_blocks_per_req = block_table.shape[1]
453450
segment_indices = torch.arange(len(target_positions),
454451
device=target_positions.device)
455-
segment_indices = (
456-
segment_indices.unsqueeze(0)
457-
>= common_attn_metadata.query_start_loc[:-1].unsqueeze(1)).sum(
458-
dim=0) - 1
452+
segment_indices = (segment_indices.unsqueeze(0)
453+
>= common_attn_metadata.query_start_loc[:-1]
454+
.unsqueeze(1)).sum(dim=0) - 1
459455
# Calculate the block table indices
460456
block_table_indices = (
461457
target_positions // self.block_size +
462458
segment_indices * max_num_blocks_per_req)
463459
block_numbers = block_table.flatten()[block_table_indices]
464460
block_offsets = target_positions % self.block_size
465461
common_attn_metadata.slot_mapping = (
466-
block_numbers * self.block_size + block_offsets
467-
)
462+
block_numbers * self.block_size + block_offsets)
468463

469464
# Use the original last token indices
470465
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1

0 commit comments

Comments
 (0)