18
18
from vllm .config import VllmConfig
19
19
from vllm .logger import init_logger
20
20
from vllm .platforms import current_platform
21
+ from vllm .utils import cdiv
21
22
from vllm .v1 .attention .backends .flash_attn import use_cascade_attention
22
23
from vllm .v1 .attention .backends .utils import (
23
24
AttentionMetadataBuilder , CommonAttentionMetadata , PerLayerParameters ,
@@ -158,7 +159,7 @@ class FlashInferMetadata:
158
159
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
159
160
# the batch, used to index into subquery. E.g., if the subquery length
160
161
# is [4, 6], it is [0, 4, 10].
161
- qo_indptr : torch .Tensor
162
+ qo_indptr_cpu : torch .Tensor
162
163
# An example for paged_kv_indices, paged_kv_indptr:
163
164
# request 1, page indices [0, 5, 8]
164
165
# request 2, page indices [1, 6, 7]
@@ -167,13 +168,13 @@ class FlashInferMetadata:
167
168
# [0, 5, 8, 1, 6, 7, 3, 4]
168
169
# paged_kv_indptr is used to index into paged_kv_indices:
169
170
# [0, 3, 6, 8]
170
- # The indptr of the paged kv cache, shape: [batch_size + 1]
171
- paged_kv_indptr : torch .Tensor
172
- # The page indices of the paged kv cache
171
+ # The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
172
+ paged_kv_indptr_cpu : torch .Tensor
173
+ # The page indices of the paged kv cache (on device for plan)
173
174
paged_kv_indices : torch .Tensor
174
175
# The number of entries in the last page of each request in
175
- # the paged kv cache, shape: [batch_size]
176
- paged_kv_last_page_len : torch .Tensor
176
+ # the paged kv cache, shape: [batch_size] (CPU for plan)
177
+ paged_kv_last_page_len_cpu : torch .Tensor
177
178
# The number of query/output heads
178
179
num_qo_heads : int
179
180
# The number of key/value heads
@@ -201,22 +202,17 @@ class FlashInferMetadata:
201
202
num_prefills : int
202
203
num_prefill_tokens : int
203
204
204
- # For cascade attention.
205
+ # For cascade attention (CPU for planning) .
205
206
use_cascade : bool
206
- shared_qo_indptr : Optional [torch .Tensor ] = None
207
- shared_kv_page_indptr : Optional [torch .Tensor ] = None
208
- shared_kv_page_indices : Optional [torch .Tensor ] = None
209
- shared_kv_last_page_len : Optional [torch .Tensor ] = None
207
+ shared_qo_indptr_cpu : Optional [torch .Tensor ] = None
208
+ shared_kv_page_indptr_cpu : Optional [torch .Tensor ] = None
209
+ shared_kv_page_indices_cpu : Optional [torch .Tensor ] = None
210
+ shared_kv_last_page_len_cpu : Optional [torch .Tensor ] = None
210
211
211
212
prefill_wrapper : Optional [BatchPrefillWithPagedKVCacheWrapper ] = None
212
213
decode_wrapper : Optional [BatchDecodeWithPagedKVCacheWrapper ] = None
213
214
cascade_wrapper : Optional [MultiLevelCascadeAttentionWrapper ] = None
214
215
215
- @property
216
- def query_start_loc (self ):
217
- # The GPUModelRunner expects to be able to access this property.
218
- return self .qo_indptr
219
-
220
216
def __post_init__ (self ):
221
217
if self .head_dim is not None :
222
218
FlashInferBackend .validate_head_size (self .head_dim )
@@ -238,6 +234,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
238
234
self .vllm_config = vllm_config
239
235
self .cache_config = vllm_config .cache_config
240
236
self .kv_cache_spec = kv_cache_spec
237
+ max_num_blocks_per_request = cdiv (
238
+ vllm_config .model_config .max_model_len ,
239
+ self .kv_cache_spec .block_size )
240
+ self .block_table_arange = torch .arange (max_num_blocks_per_request ,
241
+ dtype = torch .int32 ,
242
+ device = self .device )
241
243
242
244
def reorder_batch (self , input_batch : InputBatch ,
243
245
scheduler_output : SchedulerOutput ) -> bool :
@@ -285,21 +287,25 @@ def _plan(self, num_prefills: int, num_decodes: int,
285
287
if self .global_hyperparameters is None :
286
288
self .global_hyperparameters = infer_global_hyperparameters (
287
289
get_per_layer_parameters (self .vllm_config , FlashInferImpl ))
290
+
288
291
if attn_metadata .use_cascade :
289
292
attn_metadata .cascade_wrapper = self ._get_cascade_wrapper ()
290
293
attn_metadata .cascade_wrapper .plan (
291
- [attn_metadata .shared_qo_indptr , attn_metadata .qo_indptr ],
292
294
[
293
- attn_metadata .shared_kv_page_indptr ,
294
- attn_metadata .paged_kv_indptr
295
+ attn_metadata .shared_qo_indptr_cpu ,
296
+ attn_metadata .qo_indptr_cpu
297
+ ],
298
+ [
299
+ attn_metadata .shared_kv_page_indptr_cpu ,
300
+ attn_metadata .paged_kv_indptr_cpu
295
301
],
296
302
[
297
- attn_metadata .shared_kv_page_indices ,
303
+ attn_metadata .shared_kv_page_indices_cpu ,
298
304
attn_metadata .paged_kv_indices
299
305
],
300
306
[
301
- attn_metadata .shared_kv_last_page_len ,
302
- attn_metadata .paged_kv_last_page_len
307
+ attn_metadata .shared_kv_last_page_len_cpu ,
308
+ attn_metadata .paged_kv_last_page_len_cpu
303
309
],
304
310
attn_metadata .num_qo_heads ,
305
311
attn_metadata .num_kv_heads ,
@@ -320,22 +326,22 @@ def _plan(self, num_prefills: int, num_decodes: int,
320
326
# Decodes are first so prefills start after the last decode
321
327
prefill_start = num_decodes
322
328
attn_metadata .prefill_wrapper = self ._get_prefill_wrapper ()
323
- assert attn_metadata .qo_indptr [prefill_start :].shape [
329
+ assert attn_metadata .qo_indptr_cpu [prefill_start :].shape [
324
330
0 ] == num_prefills + 1
325
- assert attn_metadata .paged_kv_indptr [prefill_start :].shape [
331
+ assert attn_metadata .paged_kv_indptr_cpu [prefill_start :].shape [
326
332
0 ] == num_prefills + 1
327
- assert attn_metadata .paged_kv_last_page_len [
333
+ assert attn_metadata .paged_kv_last_page_len_cpu [
328
334
prefill_start :].shape [0 ] == num_prefills
329
335
# Since prefill_wrapper.run() will be called with
330
336
# query[num_decode_tokens:] we need to adjust the qo_indptr
331
337
# to be relative to the start of the prefill queries.
332
- qo_indptr = attn_metadata .qo_indptr [
333
- prefill_start :] - attn_metadata .qo_indptr [prefill_start ]
338
+ qo_indptr_cpu = attn_metadata .qo_indptr_cpu [
339
+ prefill_start :] - attn_metadata .qo_indptr_cpu [prefill_start ]
334
340
attn_metadata .prefill_wrapper .plan (
335
- qo_indptr ,
336
- attn_metadata .paged_kv_indptr [prefill_start :],
341
+ qo_indptr_cpu ,
342
+ attn_metadata .paged_kv_indptr_cpu [prefill_start :],
337
343
attn_metadata .paged_kv_indices ,
338
- attn_metadata .paged_kv_last_page_len [prefill_start :],
344
+ attn_metadata .paged_kv_last_page_len_cpu [prefill_start :],
339
345
attn_metadata .num_qo_heads ,
340
346
attn_metadata .num_kv_heads ,
341
347
attn_metadata .head_dim ,
@@ -357,9 +363,9 @@ def _plan(self, num_prefills: int, num_decodes: int,
357
363
attn_metadata .num_qo_heads , attn_metadata .num_kv_heads ,
358
364
attn_metadata .head_dim ):
359
365
attn_metadata .decode_wrapper .plan (
360
- attn_metadata .paged_kv_indptr [:num_decodes + 1 ],
366
+ attn_metadata .paged_kv_indptr_cpu [:num_decodes + 1 ],
361
367
attn_metadata .paged_kv_indices ,
362
- attn_metadata .paged_kv_last_page_len [:num_decodes ],
368
+ attn_metadata .paged_kv_last_page_len_cpu [:num_decodes ],
363
369
attn_metadata .num_qo_heads ,
364
370
attn_metadata .num_kv_heads ,
365
371
attn_metadata .head_dim ,
@@ -383,55 +389,58 @@ def build(self,
383
389
split_decodes_and_prefills (common_attn_metadata )
384
390
385
391
page_size = self .kv_cache_spec .block_size
386
- device = self .device
387
- qo_indptr = common_attn_metadata .query_start_loc
388
392
max_seq_len = common_attn_metadata .seq_lens_cpu .max ()
389
393
seq_lens = common_attn_metadata .seq_lens
394
+ seq_lens_cpu = common_attn_metadata .seq_lens_cpu
390
395
block_table_tensor = common_attn_metadata .block_table_tensor
391
396
392
- block_table_bounds = (seq_lens + page_size - 1 ) // page_size
397
+ block_table_bounds_cpu = (seq_lens_cpu + page_size - 1 ) // page_size
393
398
394
399
use_cascade = common_prefix_len > 0
395
400
if use_cascade :
396
401
# Grab the blocks of the shared prefix from the first request.
397
402
assert common_prefix_len % page_size == 0
398
403
num_common_kv_blocks = common_prefix_len // page_size
399
- shared_qo_indptr = torch .tensor ([0 , num_actual_tokens ],
400
- dtype = torch .int32 ,
401
- device = device )
402
- shared_kv_page_indptr = torch .tensor ([0 , num_common_kv_blocks ],
403
- dtype = torch .int32 ,
404
- device = device )
405
- shared_kv_page_indices = block_table_tensor [
404
+
405
+ # Create CPU versions directly for cascade (no GPU versions needed)
406
+ shared_qo_indptr_cpu = torch .tensor ([0 , num_actual_tokens ],
407
+ dtype = torch .int32 ,
408
+ device = 'cpu' )
409
+ shared_kv_page_indptr_cpu = torch .tensor ([0 , num_common_kv_blocks ],
410
+ dtype = torch .int32 ,
411
+ device = 'cpu' )
412
+ shared_kv_page_indices_cpu = block_table_tensor [
406
413
0 , :num_common_kv_blocks ]
407
- shared_kv_last_page_len = torch .tensor ([page_size ],
408
- dtype = torch .int32 ,
409
- device = device )
414
+ shared_kv_last_page_len_cpu = torch .tensor ([page_size ],
415
+ dtype = torch .int32 ,
416
+ device = 'cpu' )
417
+
410
418
# Remove the blocks of the shared prefix from all requests.
411
419
block_table_tensor = block_table_tensor [:, num_common_kv_blocks :]
412
- block_table_bounds -= num_common_kv_blocks
420
+ block_table_bounds_cpu -= num_common_kv_blocks
413
421
else :
414
- shared_qo_indptr = None
415
- shared_kv_page_indptr = None
416
- shared_kv_page_indices = None
417
- shared_kv_last_page_len = None
418
-
419
- mask = (torch .arange (block_table_tensor .size (1 ),
420
- dtype = block_table_tensor .dtype ,
421
- device = block_table_tensor .device ).unsqueeze (0 )
422
+ shared_qo_indptr_cpu = None
423
+ shared_kv_page_indptr_cpu = None
424
+ shared_kv_page_indices_cpu = None
425
+ shared_kv_last_page_len_cpu = None
426
+
427
+ max_num_blocks = block_table_bounds_cpu .max ()
428
+ block_table_bounds = block_table_bounds_cpu .to (self .device ,
429
+ non_blocking = True )
430
+ mask = (self .block_table_arange [:max_num_blocks ].unsqueeze (0 )
422
431
< block_table_bounds .unsqueeze (1 ))
423
- paged_kv_indices = block_table_tensor [mask ]
424
-
425
- paged_kv_indptr = torch .cat ([
426
- torch .zeros ( 1 ,
427
- dtype = block_table_bounds . dtype ,
428
- device = block_table_bounds . device ),
429
- block_table_bounds . cumsum ( dim = 0 , dtype = torch .int32 )
430
- ])
431
-
432
- paged_kv_last_page_len = seq_lens % page_size
433
- paged_kv_last_page_len = torch . where ( paged_kv_last_page_len == 0 ,
434
- page_size , paged_kv_last_page_len )
432
+ paged_kv_indices = block_table_tensor [:, : max_num_blocks ][ mask ]
433
+
434
+ paged_kv_indptr_cpu = torch .zeros ( len ( block_table_bounds_cpu ) + 1 ,
435
+ dtype = torch .int32 ,
436
+ device = 'cpu' )
437
+ paged_kv_indptr_cpu [ 1 :] = block_table_bounds_cpu . cumsum (
438
+ dim = 0 , dtype = torch .int32 )
439
+
440
+ paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
441
+ paged_kv_last_page_len_cpu = torch . where (
442
+ paged_kv_last_page_len_cpu == 0 , page_size ,
443
+ paged_kv_last_page_len_cpu )
435
444
cache_dtype = self .cache_config .cache_dtype
436
445
if cache_dtype .startswith ("fp8" ):
437
446
kv_cache_dtype = FlashInferBackend .get_fp8_dtype_for_flashinfer (
@@ -440,10 +449,10 @@ def build(self,
440
449
kv_cache_dtype = self .kv_cache_spec .dtype
441
450
attn_metadata = FlashInferMetadata (
442
451
num_actual_tokens = num_actual_tokens ,
443
- qo_indptr = qo_indptr ,
444
- paged_kv_indptr = paged_kv_indptr ,
452
+ qo_indptr_cpu = common_attn_metadata . query_start_loc_cpu ,
453
+ paged_kv_indptr_cpu = paged_kv_indptr_cpu ,
445
454
paged_kv_indices = paged_kv_indices ,
446
- paged_kv_last_page_len = paged_kv_last_page_len ,
455
+ paged_kv_last_page_len_cpu = paged_kv_last_page_len_cpu ,
447
456
num_qo_heads = self .vllm_config .model_config .get_num_attention_heads (
448
457
self .vllm_config .parallel_config ),
449
458
num_kv_heads = self .kv_cache_spec .num_kv_heads ,
@@ -457,14 +466,14 @@ def build(self,
457
466
num_prefills = num_prefills ,
458
467
num_prefill_tokens = num_prefill_tokens ,
459
468
use_cascade = use_cascade ,
460
- shared_qo_indptr = shared_qo_indptr ,
461
- shared_kv_page_indptr = shared_kv_page_indptr ,
462
- shared_kv_page_indices = shared_kv_page_indices ,
463
- shared_kv_last_page_len = shared_kv_last_page_len ,
469
+ shared_qo_indptr_cpu = shared_qo_indptr_cpu ,
470
+ shared_kv_page_indptr_cpu = shared_kv_page_indptr_cpu ,
471
+ shared_kv_page_indices_cpu = shared_kv_page_indices_cpu ,
472
+ shared_kv_last_page_len_cpu = shared_kv_last_page_len_cpu ,
464
473
max_seq_len = max_seq_len ,
465
474
seq_lens = seq_lens ,
466
475
block_table_tensor = block_table_tensor ,
467
- workspace_buffer = self ._workspace_buffer ,
476
+ workspace_buffer = self ._get_workspace_buffer () ,
468
477
)
469
478
470
479
self ._plan (num_prefills , num_decodes , attn_metadata )
0 commit comments