@@ -68,32 +68,42 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
68
68
69
69
namespace vllm {
70
70
71
- // Grid: (num_layers , num_pairs)
71
+ // Grid: (layer_or_pair_idx , num_pairs)
72
72
template <typename scalar_t >
73
- __global__ void copy_blocks_kernel ( int64_t * key_cache_ptrs,
74
- int64_t * value_cache_ptrs ,
75
- const int64_t * __restrict__ block_mapping ,
76
- const int numel_per_block) {
77
- const int layer_idx = blockIdx .x ;
73
+ __global__ void unified_copy_blocks_kernel (
74
+ int64_t * src_key_cache_ptrs, int64_t * src_value_cache_ptrs ,
75
+ int64_t * dst_key_cache_ptrs, int64_t * dst_value_cache_ptrs ,
76
+ const int64_t * __restrict__ block_mapping, const int numel_per_block) {
77
+ const int layer_or_pair_idx = blockIdx .x ;
78
78
const int pair_idx = blockIdx .y ;
79
79
80
- scalar_t * key_cache = reinterpret_cast <scalar_t *>(key_cache_ptrs[layer_idx]);
81
- scalar_t * value_cache =
82
- reinterpret_cast <scalar_t *>(value_cache_ptrs[layer_idx]);
80
+ scalar_t * src_key_cache =
81
+ reinterpret_cast <scalar_t *>(src_key_cache_ptrs[layer_or_pair_idx]);
82
+ scalar_t * src_value_cache =
83
+ reinterpret_cast <scalar_t *>(src_value_cache_ptrs[layer_or_pair_idx]);
84
+ scalar_t * dst_key_cache =
85
+ reinterpret_cast <scalar_t *>(dst_key_cache_ptrs[layer_or_pair_idx]);
86
+ scalar_t * dst_value_cache =
87
+ reinterpret_cast <scalar_t *>(dst_value_cache_ptrs[layer_or_pair_idx]);
88
+
83
89
int64_t src_block_number = block_mapping[2 * pair_idx];
84
90
int64_t dst_block_number = block_mapping[2 * pair_idx + 1 ];
85
91
86
92
const int64_t src_block_offset = src_block_number * numel_per_block;
87
93
const int64_t dst_block_offset = dst_block_number * numel_per_block;
94
+
95
+ // Copy key cache from source to destination
88
96
for (int i = threadIdx .x ; i < numel_per_block; i += blockDim .x ) {
89
97
int64_t src_offset = src_block_offset + i;
90
98
int64_t dst_offset = dst_block_offset + i;
91
- key_cache [dst_offset] = key_cache [src_offset];
99
+ dst_key_cache [dst_offset] = src_key_cache [src_offset];
92
100
}
101
+
102
+ // Copy value cache from source to destination
93
103
for (int i = threadIdx .x ; i < numel_per_block; i += blockDim .x ) {
94
104
int64_t src_offset = src_block_offset + i;
95
105
int64_t dst_offset = dst_block_offset + i;
96
- value_cache [dst_offset] = value_cache [src_offset];
106
+ dst_value_cache [dst_offset] = src_value_cache [src_offset];
97
107
}
98
108
}
99
109
@@ -117,58 +127,108 @@ __global__ void copy_blocks_mla_kernel(
117
127
118
128
} // namespace vllm
119
129
120
- // Note: the key_caches and value_caches vectors are constant but
121
- // not the Tensors they contain. The vectors need to be const refs
122
- // in order to satisfy pytorch's C++ operator registration code.
123
- void copy_blocks (std::vector<torch::Tensor> const & key_caches,
124
- std::vector<torch::Tensor> const & value_caches,
125
- const torch::Tensor& block_mapping) {
126
- int num_layers = key_caches.size ();
127
- TORCH_CHECK (num_layers == value_caches.size ());
128
- if (num_layers == 0 ) {
130
+ // Unified implementation function for both copy_blocks and
131
+ // copy_blocks_between_caches
132
+ void copy_blocks_impl (std::vector<torch::Tensor> const & src_key_caches,
133
+ std::vector<torch::Tensor> const & src_value_caches,
134
+ std::vector<torch::Tensor> const & dst_key_caches,
135
+ std::vector<torch::Tensor> const & dst_value_caches,
136
+ const torch::Tensor& block_mapping) {
137
+ int num_src_dst_pairs = src_key_caches.size ();
138
+ TORCH_CHECK (num_src_dst_pairs == src_value_caches.size ());
139
+ TORCH_CHECK (num_src_dst_pairs == dst_key_caches.size ());
140
+ TORCH_CHECK (num_src_dst_pairs == dst_value_caches.size ());
141
+
142
+ if (num_src_dst_pairs == 0 ) {
129
143
return ;
130
144
}
131
- torch::Device cache_device = key_caches[0 ].device ();
145
+
146
+ torch::Device cache_device = src_key_caches[0 ].device ();
132
147
TORCH_CHECK (cache_device.is_cuda ());
133
148
134
- // Create data structures for the kernel.
135
- // Create an array of pointers to the key and value caches.
136
- int64_t key_cache_ptrs[num_layers];
137
- int64_t value_cache_ptrs[num_layers];
138
- for (int layer_idx = 0 ; layer_idx < num_layers; ++layer_idx) {
139
- key_cache_ptrs[layer_idx] =
140
- reinterpret_cast <int64_t >(key_caches[layer_idx].data_ptr ());
141
- value_cache_ptrs[layer_idx] =
142
- reinterpret_cast <int64_t >(value_caches[layer_idx].data_ptr ());
149
+ // Create arrays of pointers to the source and destination key and value
150
+ // caches
151
+ int64_t src_key_cache_ptrs[num_src_dst_pairs];
152
+ int64_t src_value_cache_ptrs[num_src_dst_pairs];
153
+ int64_t dst_key_cache_ptrs[num_src_dst_pairs];
154
+ int64_t dst_value_cache_ptrs[num_src_dst_pairs];
155
+
156
+ for (int pair_idx = 0 ; pair_idx < num_src_dst_pairs; ++pair_idx) {
157
+ src_key_cache_ptrs[pair_idx] =
158
+ reinterpret_cast <int64_t >(src_key_caches[pair_idx].data_ptr ());
159
+ src_value_cache_ptrs[pair_idx] =
160
+ reinterpret_cast <int64_t >(src_value_caches[pair_idx].data_ptr ());
161
+ dst_key_cache_ptrs[pair_idx] =
162
+ reinterpret_cast <int64_t >(dst_key_caches[pair_idx].data_ptr ());
163
+ dst_value_cache_ptrs[pair_idx] =
164
+ reinterpret_cast <int64_t >(dst_value_caches[pair_idx].data_ptr ());
143
165
}
144
166
145
167
// block_mapping is a 2D tensor with shape (num_pairs, 2).
146
168
int num_pairs = block_mapping.size (0 );
147
169
148
- // Move the data structures to the GPU.
149
- // NOTE: This synchronizes the CPU and GPU.
150
- torch::Tensor key_cache_ptrs_tensor =
151
- torch::from_blob (key_cache_ptrs, {num_layers}, torch::kInt64 )
170
+ // Move the data structures to the GPU
171
+ torch::Tensor src_key_cache_ptrs_tensor =
172
+ torch::from_blob (src_key_cache_ptrs, {num_src_dst_pairs}, torch::kInt64 )
173
+ .to (cache_device);
174
+ torch::Tensor src_value_cache_ptrs_tensor =
175
+ torch::from_blob (src_value_cache_ptrs, {num_src_dst_pairs}, torch::kInt64 )
176
+ .to (cache_device);
177
+ torch::Tensor dst_key_cache_ptrs_tensor =
178
+ torch::from_blob (dst_key_cache_ptrs, {num_src_dst_pairs}, torch::kInt64 )
152
179
.to (cache_device);
153
- torch::Tensor value_cache_ptrs_tensor =
154
- torch::from_blob (value_cache_ptrs , {num_layers }, torch::kInt64 )
180
+ torch::Tensor dst_value_cache_ptrs_tensor =
181
+ torch::from_blob (dst_value_cache_ptrs , {num_src_dst_pairs }, torch::kInt64 )
155
182
.to (cache_device);
156
183
157
- // Launch the kernel.
158
- const int numel_per_block = key_caches [0 ][0 ].numel ();
159
- dim3 grid (num_layers , num_pairs);
184
+ // Launch the kernel
185
+ const int numel_per_block = src_key_caches [0 ][0 ].numel ();
186
+ dim3 grid (num_src_dst_pairs , num_pairs);
160
187
dim3 block (std::min (1024 , numel_per_block));
161
188
const at::cuda::OptionalCUDAGuard device_guard (cache_device);
162
189
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
190
+
163
191
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES (
164
- key_caches[0 ].scalar_type (), " copy_blocks_kernel" , ([&] {
165
- vllm::copy_blocks_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
166
- key_cache_ptrs_tensor.data_ptr <int64_t >(),
167
- value_cache_ptrs_tensor.data_ptr <int64_t >(),
192
+ src_key_caches[0 ].scalar_type (), " unified_copy_blocks_kernel" , ([&] {
193
+ vllm::unified_copy_blocks_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
194
+ src_key_cache_ptrs_tensor.data_ptr <int64_t >(),
195
+ src_value_cache_ptrs_tensor.data_ptr <int64_t >(),
196
+ dst_key_cache_ptrs_tensor.data_ptr <int64_t >(),
197
+ dst_value_cache_ptrs_tensor.data_ptr <int64_t >(),
168
198
block_mapping.data_ptr <int64_t >(), numel_per_block);
169
199
}));
170
200
}
171
201
202
+ // Note: the key_caches and value_caches vectors are constant but
203
+ // not the Tensors they contain. The vectors need to be const refs
204
+ // in order to satisfy pytorch's C++ operator registration code.
205
+ void copy_blocks (std::vector<torch::Tensor> const & key_caches,
206
+ std::vector<torch::Tensor> const & value_caches,
207
+ const torch::Tensor& block_mapping) {
208
+ int num_layers = key_caches.size ();
209
+ TORCH_CHECK (num_layers == value_caches.size ());
210
+ if (num_layers == 0 ) {
211
+ return ;
212
+ }
213
+
214
+ // Call the unified implementation with the same caches for both source and
215
+ // destination
216
+ copy_blocks_impl (key_caches, value_caches, key_caches, value_caches,
217
+ block_mapping);
218
+ }
219
+
220
+ // Function to copy blocks between different layers
221
+ void copy_blocks_between_layers (
222
+ std::vector<torch::Tensor> const & src_key_caches,
223
+ std::vector<torch::Tensor> const & src_value_caches,
224
+ std::vector<torch::Tensor> const & dst_key_caches,
225
+ std::vector<torch::Tensor> const & dst_value_caches,
226
+ const torch::Tensor& block_mapping) {
227
+ // Call the unified implementation with separate source and destination caches
228
+ copy_blocks_impl (src_key_caches, src_value_caches, dst_key_caches,
229
+ dst_value_caches, block_mapping);
230
+ }
231
+
172
232
// copy blocks kernel for MLA (assumes a joint KV-cache)
173
233
void copy_blocks_mla (std::vector<torch::Tensor> const & kv_caches,
174
234
const torch::Tensor& block_mapping) {
0 commit comments