Skip to content

Commit 516d0ad

Browse files
alexm-redhatLeiWang1999
authored andcommitted
[Performance] Optimize e2e overheads: Reduce python allocations (vllm-project#7162)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 6fe663e commit 516d0ad

File tree

11 files changed

+550
-125
lines changed

11 files changed

+550
-125
lines changed

vllm/attention/backends/flash_attn.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,11 @@ def _add_seq_group(
259259
block_table = block_tables[seq_id]
260260
elif ((chunked_prefill_enabled or not is_prompt)
261261
and block_tables is not None):
262-
block_table = block_tables[seq_id][-curr_sliding_window_block:]
262+
if curr_sliding_window_block == 0:
263+
block_table = block_tables[seq_id]
264+
else:
265+
block_table = block_tables[seq_id][
266+
-curr_sliding_window_block:]
263267
self.block_tables.append(block_table)
264268

265269
# Compute slot mapping.

vllm/attention/backends/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,21 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
6868
# tokens are masked and the slot mapping will be
6969
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
7070
block_table = block_tables[seq_id]
71-
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
72-
for i in range(max(start_idx, context_len), seq_len):
71+
72+
def add_slot(i):
7373
block_number = block_table[i // block_size]
7474
block_offset = i % block_size
7575
slot = block_number * block_size + block_offset
7676
slot_mapping.append(slot)
7777

78+
if start_idx == 0 and (seq_len - context_len) == 1:
79+
# Optimization for common-case of decoding next token
80+
add_slot(seq_len - 1)
81+
else:
82+
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
83+
for i in range(max(start_idx, context_len), seq_len):
84+
add_slot(i)
85+
7886

7987
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
8088

vllm/block.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Token blocks."""
2-
from typing import List
2+
from typing import List, Optional
33

44
from vllm.utils import Device
55

@@ -37,5 +37,47 @@ def __repr__(self) -> str:
3737
f'computed={self.computed})')
3838

3939

40-
# Mapping: logical block number -> physical block.
41-
BlockTable = List[PhysicalTokenBlock]
40+
class BlockTable:
41+
"""Holds a list of blocks with caching of their associated block_ids
42+
"""
43+
44+
def __init__(self, blocks: Optional[List[PhysicalTokenBlock]] = None):
45+
self._blocks: List[PhysicalTokenBlock] = []
46+
self._block_ids: List[int] = []
47+
48+
if blocks is not None:
49+
for block in blocks:
50+
self.append(block)
51+
52+
def append(self, block: PhysicalTokenBlock):
53+
self._blocks.append(block)
54+
self._block_ids.append(block.block_number)
55+
56+
def __len__(self) -> int:
57+
return len(self._blocks)
58+
59+
def __getitem__(self, key):
60+
return self._blocks[key]
61+
62+
def __setitem__(self, key, value):
63+
if isinstance(key, slice):
64+
blocks = value
65+
self._blocks[key] = blocks
66+
self._block_ids[key] = [b.block_number for b in blocks]
67+
else:
68+
block = value
69+
self._blocks[key] = block
70+
self._block_ids[key] = block.block_number
71+
72+
def reset(self):
73+
self._blocks = []
74+
self._block_ids = []
75+
76+
def copy(self) -> "BlockTable":
77+
return BlockTable(self._blocks)
78+
79+
def list(self) -> List[PhysicalTokenBlock]:
80+
return self._blocks
81+
82+
def ids(self) -> List[int]:
83+
return self._block_ids

vllm/core/block_manager_v1.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def __init__(
170170
self.num_blocks = num_blocks
171171

172172
# Initialize the free blocks.
173-
self.free_blocks: BlockTable = []
173+
self.free_blocks: List[PhysicalTokenBlock] = []
174174
for i in range(num_blocks):
175175
block = PhysicalTokenBlock(device=device,
176176
block_number=i,
@@ -256,6 +256,7 @@ def __init__(
256256
Device.CPU, block_size, num_cpu_blocks)
257257
# Mapping: seq_id -> BlockTable.
258258
self.block_tables: Dict[int, BlockTable] = {}
259+
259260
# Mapping: req_id -> BlockTable
260261
# Note that each SequenceGroup has a unique
261262
# request ID
@@ -299,7 +300,7 @@ def _allocate_sequence(self, \
299300
# Allocate new physical token blocks that will store the prompt tokens.
300301
num_prompt_blocks = seq.n_blocks
301302

302-
block_table: BlockTable = []
303+
block_table: BlockTable = BlockTable()
303304
for logical_idx in range(num_prompt_blocks):
304305
if (self.block_sliding_window is not None
305306
and logical_idx >= self.block_sliding_window):
@@ -326,15 +327,19 @@ def allocate(self, seq_group: SequenceGroup) -> None:
326327
#
327328
# NOTE: Here we assume that all sequences in the group have the same
328329
# decoder prompt.
329-
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
330+
wait_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
331+
seq = wait_seqs[0]
330332
block_table: BlockTable = \
331333
self._allocate_sequence(seq,
332334
seq_group.num_seqs(),
333335
is_encoder_decoder)
334336

335337
# Assign the self-attention block tables for each sequence.
336-
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
337-
self.block_tables[seq.seq_id] = block_table.copy()
338+
if len(wait_seqs) == 1:
339+
self.block_tables[wait_seqs[0].seq_id] = block_table
340+
else:
341+
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
342+
self.block_tables[seq.seq_id] = block_table.copy()
338343

339344
# Allocate encoder sequence
340345
if is_encoder_decoder:
@@ -476,6 +481,7 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
476481
return
477482
src_block_table = self.block_tables[parent_seq.seq_id]
478483
self.block_tables[child_seq.seq_id] = src_block_table.copy()
484+
479485
# When using a sliding window, blocks will be eventually reused.
480486
# In this case the block tables will contain repeated blocks.
481487
# When forking, we must make sure that each block's `ref_count`
@@ -527,7 +533,7 @@ def _swap_block_table(
527533
dest_allocator: BlockAllocatorBase,
528534
mapping: Dict[PhysicalTokenBlock,
529535
PhysicalTokenBlock]) -> BlockTable:
530-
new_block_table = []
536+
new_block_table: BlockTable = BlockTable()
531537

532538
for from_block in block_table:
533539
if from_block in mapping:
@@ -553,8 +559,7 @@ def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
553559
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
554560
self.block_tables[seq.seq_id] = \
555561
self._swap_block_table(self.block_tables[seq.seq_id],
556-
self.cpu_allocator,
557-
self.gpu_allocator,
562+
self.cpu_allocator, self.gpu_allocator,
558563
mapping)
559564

560565
if seq_group.is_encoder_decoder():
@@ -580,8 +585,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
580585
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
581586
self.block_tables[seq.seq_id] = \
582587
self._swap_block_table(self.block_tables[seq.seq_id],
583-
self.gpu_allocator,
584-
self.cpu_allocator,
588+
self.gpu_allocator, self.cpu_allocator,
585589
mapping)
586590

587591
if seq_group.is_encoder_decoder():
@@ -636,8 +640,7 @@ def reset(self) -> None:
636640
self.cross_block_tables.clear()
637641

638642
def get_block_table(self, seq: Sequence) -> List[int]:
639-
block_table = self.block_tables[seq.seq_id]
640-
return [block.block_number for block in block_table]
643+
return self.block_tables[seq.seq_id].ids()
641644

642645
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
643646
block_table = self.cross_block_tables[seq_group.request_id]

0 commit comments

Comments
 (0)