Skip to content

Commit 61b8cea

Browse files
[Attention] Optimize FlashInfer MetadataBuilder Build call (#21137)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 526078a commit 61b8cea

File tree

3 files changed

+94
-78
lines changed

3 files changed

+94
-78
lines changed

tests/v1/attention/test_attention_backends.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
create_vllm_config,
1212
get_attention_backend)
1313
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
14-
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
14+
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
15+
set_kv_cache_layout)
1516
from vllm.v1.kv_cache_interface import FullAttentionSpec
1617

1718
BACKENDS_TO_TEST = [
@@ -212,7 +213,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
212213

213214
from vllm.v1.attention.backends.flashinfer import PerLayerParameters
214215

215-
def mock_get_per_layer_parameters(vllm_config):
216+
def mock_get_per_layer_parameters(vllm_config, impl_cls):
216217
# Return mock parameters for a single layer
217218
head_size = vllm_config.model_config.get_head_size()
218219
return {
@@ -297,7 +298,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
297298
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
298299
"""
299300
batch_spec = BATCH_SPECS[batch_spec_name]
300-
vllm_config = create_vllm_config(model_name=model)
301+
vllm_config = create_vllm_config(model_name=model,
302+
max_model_len=max(batch_spec.seq_lens))
301303
device = torch.device("cuda:0")
302304

303305
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
@@ -419,6 +421,11 @@ def test_backend_correctness(batch_spec_name: str, model: str):
419421
if backend_name == _Backend.FLASHINFER_VLLM_V1:
420422
kv_cache_for_backend = kv_cache.transpose(0, 1)
421423

424+
# For FlashInfer default to HND layout and
425+
kv_cache_for_backend = kv_cache_for_backend.transpose(
426+
2, 3).contiguous().transpose(2, 3)
427+
set_kv_cache_layout("HND")
428+
422429
backend_output = run_attention_backend(backend_name, kv_cache_spec,
423430
vllm_config, device,
424431
common_attn_metadata,

tests/v1/attention/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def create_common_attn_metadata(
6666
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
6767

6868
# Create block table (random for testing)
69-
max_blocks = max(batch_spec.seq_lens) // block_size + 1
69+
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
7070
block_table_tensor = torch.randint(0,
7171
max_block_idx,
7272
(batch_spec.batch_size, max_blocks),

vllm/v1/attention/backends/flashinfer.py

Lines changed: 83 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.config import VllmConfig
1919
from vllm.logger import init_logger
2020
from vllm.platforms import current_platform
21+
from vllm.utils import cdiv
2122
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
2223
from vllm.v1.attention.backends.utils import (
2324
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
@@ -158,7 +159,7 @@ class FlashInferMetadata:
158159
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
159160
# the batch, used to index into subquery. E.g., if the subquery length
160161
# is [4, 6], it is [0, 4, 10].
161-
qo_indptr: torch.Tensor
162+
qo_indptr_cpu: torch.Tensor
162163
# An example for paged_kv_indices, paged_kv_indptr:
163164
# request 1, page indices [0, 5, 8]
164165
# request 2, page indices [1, 6, 7]
@@ -167,13 +168,13 @@ class FlashInferMetadata:
167168
# [0, 5, 8, 1, 6, 7, 3, 4]
168169
# paged_kv_indptr is used to index into paged_kv_indices:
169170
# [0, 3, 6, 8]
170-
# The indptr of the paged kv cache, shape: [batch_size + 1]
171-
paged_kv_indptr: torch.Tensor
172-
# The page indices of the paged kv cache
171+
# The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
172+
paged_kv_indptr_cpu: torch.Tensor
173+
# The page indices of the paged kv cache (on device for plan)
173174
paged_kv_indices: torch.Tensor
174175
# The number of entries in the last page of each request in
175-
# the paged kv cache, shape: [batch_size]
176-
paged_kv_last_page_len: torch.Tensor
176+
# the paged kv cache, shape: [batch_size] (CPU for plan)
177+
paged_kv_last_page_len_cpu: torch.Tensor
177178
# The number of query/output heads
178179
num_qo_heads: int
179180
# The number of key/value heads
@@ -201,22 +202,17 @@ class FlashInferMetadata:
201202
num_prefills: int
202203
num_prefill_tokens: int
203204

204-
# For cascade attention.
205+
# For cascade attention (CPU for planning).
205206
use_cascade: bool
206-
shared_qo_indptr: Optional[torch.Tensor] = None
207-
shared_kv_page_indptr: Optional[torch.Tensor] = None
208-
shared_kv_page_indices: Optional[torch.Tensor] = None
209-
shared_kv_last_page_len: Optional[torch.Tensor] = None
207+
shared_qo_indptr_cpu: Optional[torch.Tensor] = None
208+
shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None
209+
shared_kv_page_indices_cpu: Optional[torch.Tensor] = None
210+
shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None
210211

211212
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
212213
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
213214
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
214215

215-
@property
216-
def query_start_loc(self):
217-
# The GPUModelRunner expects to be able to access this property.
218-
return self.qo_indptr
219-
220216
def __post_init__(self):
221217
if self.head_dim is not None:
222218
FlashInferBackend.validate_head_size(self.head_dim)
@@ -238,6 +234,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
238234
self.vllm_config = vllm_config
239235
self.cache_config = vllm_config.cache_config
240236
self.kv_cache_spec = kv_cache_spec
237+
max_num_blocks_per_request = cdiv(
238+
vllm_config.model_config.max_model_len,
239+
self.kv_cache_spec.block_size)
240+
self.block_table_arange = torch.arange(max_num_blocks_per_request,
241+
dtype=torch.int32,
242+
device=self.device)
241243

242244
def reorder_batch(self, input_batch: InputBatch,
243245
scheduler_output: SchedulerOutput) -> bool:
@@ -285,21 +287,25 @@ def _plan(self, num_prefills: int, num_decodes: int,
285287
if self.global_hyperparameters is None:
286288
self.global_hyperparameters = infer_global_hyperparameters(
287289
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
290+
288291
if attn_metadata.use_cascade:
289292
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
290293
attn_metadata.cascade_wrapper.plan(
291-
[attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr],
292294
[
293-
attn_metadata.shared_kv_page_indptr,
294-
attn_metadata.paged_kv_indptr
295+
attn_metadata.shared_qo_indptr_cpu,
296+
attn_metadata.qo_indptr_cpu
297+
],
298+
[
299+
attn_metadata.shared_kv_page_indptr_cpu,
300+
attn_metadata.paged_kv_indptr_cpu
295301
],
296302
[
297-
attn_metadata.shared_kv_page_indices,
303+
attn_metadata.shared_kv_page_indices_cpu,
298304
attn_metadata.paged_kv_indices
299305
],
300306
[
301-
attn_metadata.shared_kv_last_page_len,
302-
attn_metadata.paged_kv_last_page_len
307+
attn_metadata.shared_kv_last_page_len_cpu,
308+
attn_metadata.paged_kv_last_page_len_cpu
303309
],
304310
attn_metadata.num_qo_heads,
305311
attn_metadata.num_kv_heads,
@@ -320,22 +326,22 @@ def _plan(self, num_prefills: int, num_decodes: int,
320326
# Decodes are first so prefills start after the last decode
321327
prefill_start = num_decodes
322328
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
323-
assert attn_metadata.qo_indptr[prefill_start:].shape[
329+
assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[
324330
0] == num_prefills + 1
325-
assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
331+
assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[
326332
0] == num_prefills + 1
327-
assert attn_metadata.paged_kv_last_page_len[
333+
assert attn_metadata.paged_kv_last_page_len_cpu[
328334
prefill_start:].shape[0] == num_prefills
329335
# Since prefill_wrapper.run() will be called with
330336
# query[num_decode_tokens:] we need to adjust the qo_indptr
331337
# to be relative to the start of the prefill queries.
332-
qo_indptr = attn_metadata.qo_indptr[
333-
prefill_start:] - attn_metadata.qo_indptr[prefill_start]
338+
qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
339+
prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
334340
attn_metadata.prefill_wrapper.plan(
335-
qo_indptr,
336-
attn_metadata.paged_kv_indptr[prefill_start:],
341+
qo_indptr_cpu,
342+
attn_metadata.paged_kv_indptr_cpu[prefill_start:],
337343
attn_metadata.paged_kv_indices,
338-
attn_metadata.paged_kv_last_page_len[prefill_start:],
344+
attn_metadata.paged_kv_last_page_len_cpu[prefill_start:],
339345
attn_metadata.num_qo_heads,
340346
attn_metadata.num_kv_heads,
341347
attn_metadata.head_dim,
@@ -357,9 +363,9 @@ def _plan(self, num_prefills: int, num_decodes: int,
357363
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
358364
attn_metadata.head_dim):
359365
attn_metadata.decode_wrapper.plan(
360-
attn_metadata.paged_kv_indptr[:num_decodes + 1],
366+
attn_metadata.paged_kv_indptr_cpu[:num_decodes + 1],
361367
attn_metadata.paged_kv_indices,
362-
attn_metadata.paged_kv_last_page_len[:num_decodes],
368+
attn_metadata.paged_kv_last_page_len_cpu[:num_decodes],
363369
attn_metadata.num_qo_heads,
364370
attn_metadata.num_kv_heads,
365371
attn_metadata.head_dim,
@@ -383,55 +389,58 @@ def build(self,
383389
split_decodes_and_prefills(common_attn_metadata)
384390

385391
page_size = self.kv_cache_spec.block_size
386-
device = self.device
387-
qo_indptr = common_attn_metadata.query_start_loc
388392
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
389393
seq_lens = common_attn_metadata.seq_lens
394+
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
390395
block_table_tensor = common_attn_metadata.block_table_tensor
391396

392-
block_table_bounds = (seq_lens + page_size - 1) // page_size
397+
block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size
393398

394399
use_cascade = common_prefix_len > 0
395400
if use_cascade:
396401
# Grab the blocks of the shared prefix from the first request.
397402
assert common_prefix_len % page_size == 0
398403
num_common_kv_blocks = common_prefix_len // page_size
399-
shared_qo_indptr = torch.tensor([0, num_actual_tokens],
400-
dtype=torch.int32,
401-
device=device)
402-
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
403-
dtype=torch.int32,
404-
device=device)
405-
shared_kv_page_indices = block_table_tensor[
404+
405+
# Create CPU versions directly for cascade (no GPU versions needed)
406+
shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens],
407+
dtype=torch.int32,
408+
device='cpu')
409+
shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks],
410+
dtype=torch.int32,
411+
device='cpu')
412+
shared_kv_page_indices_cpu = block_table_tensor[
406413
0, :num_common_kv_blocks]
407-
shared_kv_last_page_len = torch.tensor([page_size],
408-
dtype=torch.int32,
409-
device=device)
414+
shared_kv_last_page_len_cpu = torch.tensor([page_size],
415+
dtype=torch.int32,
416+
device='cpu')
417+
410418
# Remove the blocks of the shared prefix from all requests.
411419
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
412-
block_table_bounds -= num_common_kv_blocks
420+
block_table_bounds_cpu -= num_common_kv_blocks
413421
else:
414-
shared_qo_indptr = None
415-
shared_kv_page_indptr = None
416-
shared_kv_page_indices = None
417-
shared_kv_last_page_len = None
418-
419-
mask = (torch.arange(block_table_tensor.size(1),
420-
dtype=block_table_tensor.dtype,
421-
device=block_table_tensor.device).unsqueeze(0)
422+
shared_qo_indptr_cpu = None
423+
shared_kv_page_indptr_cpu = None
424+
shared_kv_page_indices_cpu = None
425+
shared_kv_last_page_len_cpu = None
426+
427+
max_num_blocks = block_table_bounds_cpu.max()
428+
block_table_bounds = block_table_bounds_cpu.to(self.device,
429+
non_blocking=True)
430+
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
422431
< block_table_bounds.unsqueeze(1))
423-
paged_kv_indices = block_table_tensor[mask]
424-
425-
paged_kv_indptr = torch.cat([
426-
torch.zeros(1,
427-
dtype=block_table_bounds.dtype,
428-
device=block_table_bounds.device),
429-
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
430-
])
431-
432-
paged_kv_last_page_len = seq_lens % page_size
433-
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
434-
page_size, paged_kv_last_page_len)
432+
paged_kv_indices = block_table_tensor[:, :max_num_blocks][mask]
433+
434+
paged_kv_indptr_cpu = torch.zeros(len(block_table_bounds_cpu) + 1,
435+
dtype=torch.int32,
436+
device='cpu')
437+
paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum(
438+
dim=0, dtype=torch.int32)
439+
440+
paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
441+
paged_kv_last_page_len_cpu = torch.where(
442+
paged_kv_last_page_len_cpu == 0, page_size,
443+
paged_kv_last_page_len_cpu)
435444
cache_dtype = self.cache_config.cache_dtype
436445
if cache_dtype.startswith("fp8"):
437446
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
@@ -440,10 +449,10 @@ def build(self,
440449
kv_cache_dtype = self.kv_cache_spec.dtype
441450
attn_metadata = FlashInferMetadata(
442451
num_actual_tokens=num_actual_tokens,
443-
qo_indptr=qo_indptr,
444-
paged_kv_indptr=paged_kv_indptr,
452+
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
453+
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
445454
paged_kv_indices=paged_kv_indices,
446-
paged_kv_last_page_len=paged_kv_last_page_len,
455+
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
447456
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
448457
self.vllm_config.parallel_config),
449458
num_kv_heads=self.kv_cache_spec.num_kv_heads,
@@ -457,14 +466,14 @@ def build(self,
457466
num_prefills=num_prefills,
458467
num_prefill_tokens=num_prefill_tokens,
459468
use_cascade=use_cascade,
460-
shared_qo_indptr=shared_qo_indptr,
461-
shared_kv_page_indptr=shared_kv_page_indptr,
462-
shared_kv_page_indices=shared_kv_page_indices,
463-
shared_kv_last_page_len=shared_kv_last_page_len,
469+
shared_qo_indptr_cpu=shared_qo_indptr_cpu,
470+
shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
471+
shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
472+
shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
464473
max_seq_len=max_seq_len,
465474
seq_lens=seq_lens,
466475
block_table_tensor=block_table_tensor,
467-
workspace_buffer=self._workspace_buffer,
476+
workspace_buffer=self._get_workspace_buffer(),
468477
)
469478

470479
self._plan(num_prefills, num_decodes, attn_metadata)

0 commit comments

Comments
 (0)