Skip to content

Commit e358053

Browse files
authored
[Performance] Enable chunked prefill and prefix caching together (#7753)
1 parent f508e03 commit e358053

File tree

9 files changed

+225
-27
lines changed

9 files changed

+225
-27
lines changed

tests/basic_correctness/test_chunked_prefill.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
77
Run `pytest tests/models/test_chunked_prefill.py`.
88
"""
9+
from contextlib import nullcontext
910

1011
import pytest
1112

@@ -156,3 +157,68 @@ def test_models_with_fp8_kv_cache(
156157
name_0="no_chunked_prefill",
157158
name_1="chunked_prefill",
158159
)
160+
161+
162+
@pytest.mark.parametrize("max_tokens", [16])
163+
@pytest.mark.parametrize("enforce_eager", [False])
164+
@pytest.mark.parametrize("chunk_size", [30, 32])
165+
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
166+
# NOTE: Increasing this in this suite will fail CI because we currently cannot
167+
# reset distributed env properly. Use a value > 1 just when you test.
168+
@pytest.mark.parametrize("tensor_parallel_size", [1])
169+
def test_with_prefix_caching(
170+
vllm_runner,
171+
max_tokens: int,
172+
enforce_eager: bool,
173+
chunk_size: int,
174+
use_v2_block_manager: bool,
175+
tensor_parallel_size: int,
176+
) -> None:
177+
"""
178+
Checks exact match decode with and without prefix caching
179+
with chunked prefill enabled.
180+
"""
181+
model = "meta-llama/Llama-2-7b-chat-hf"
182+
# The common prompt has 142 tokens with Llama-2 tokenizer.
183+
common_prompt = "You are a helpful AI assistant " * 20
184+
unique_prompts = [
185+
"Question", # Warmup
186+
"Question", # Fully cached
187+
"Another question", # Partial cached
188+
]
189+
full_prompts = [f"{common_prompt}\n{p}" for p in unique_prompts]
190+
191+
max_num_batched_tokens = max_num_seqs = chunk_size
192+
outputs = {} # type: ignore
193+
check_result = True
194+
for enable in (True, False):
195+
with vllm_runner(
196+
model,
197+
dtype="half",
198+
max_num_batched_tokens=max_num_batched_tokens,
199+
enable_chunked_prefill=True,
200+
enable_prefix_caching=enable,
201+
tensor_parallel_size=tensor_parallel_size,
202+
use_v2_block_manager=use_v2_block_manager,
203+
enforce_eager=enforce_eager,
204+
max_num_seqs=max_num_seqs,
205+
) as vllm_model:
206+
# It should fail when prefix caching is enable and chunk
207+
# size is not a multiple of block size (16).
208+
should_fail = chunk_size % 16 != 0 and enable
209+
check_result &= not should_fail
210+
outputs[enable] = []
211+
# Send the request one-by-one to ensure the cache is populated.
212+
with pytest.raises(ValueError) if should_fail else nullcontext():
213+
for prompt in full_prompts:
214+
outputs[enable] += vllm_model.generate_greedy([prompt],
215+
max_tokens)
216+
217+
# Check results only if we did not expect a failure.
218+
if check_result:
219+
check_outputs_equal(
220+
outputs_0_lst=outputs[False],
221+
outputs_1_lst=outputs[True],
222+
name_0="w/o prefix caching",
223+
name_1="with prefix caching",
224+
)

tests/core/test_block_manager.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,3 +595,43 @@ def test_sliding_window_multi_seq():
595595

596596
# assert all blocks are free now
597597
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
598+
599+
600+
def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill():
601+
"""When prefix cache and chunked prefill are enabled, the block manager
602+
should only mark a chunk of blocks as computed instead of all blocks.
603+
"""
604+
605+
block_size = 4
606+
num_cpu_blocks = 0
607+
num_gpu_blocks = 16
608+
block_manager = BlockSpaceManagerV1(block_size,
609+
num_gpu_blocks,
610+
num_cpu_blocks,
611+
watermark=0,
612+
enable_caching=True)
613+
614+
# Set prompt size to have num_gpu_blocks - 1 full blocks.
615+
prompt_length = block_size * num_gpu_blocks - 1
616+
617+
# Allocate (reserve) all blocks.
618+
_, seq_group = create_dummy_prompt("0",
619+
prompt_length,
620+
block_size=block_size)
621+
block_manager.allocate(seq_group)
622+
assert seq_group.seqs[0].n_blocks == num_gpu_blocks
623+
624+
# 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed.
625+
token_chunk_size = int(block_size * 2.5)
626+
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
627+
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
628+
assert len(computed_blocks) == 2
629+
630+
# Actual computed tokens.
631+
seq_group.seqs[0].data.update_num_computed_tokens(token_chunk_size)
632+
633+
# 2nd chunk: Complete 3rd block and additional 4 blocks.
634+
token_chunk_size = int(block_size * 4.5)
635+
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
636+
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
637+
assert len(computed_blocks) == 7

tests/core/test_chunked_prefill_scheduler.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,3 +562,42 @@ def test_chunked_prefill_max_seqs():
562562
assert len(get_sequence_groups(out)) == max_seqs
563563
assert not running[0].is_prefill()
564564
assert not running[1].is_prefill()
565+
566+
567+
def test_perfix_caching():
568+
"""Verify allocating full blocks when prefix caching is enabled."""
569+
block_size = 4
570+
max_seqs = 10
571+
max_model_len = 80
572+
max_num_batched_tokens = 64
573+
scheduler_config = SchedulerConfig(max_num_batched_tokens,
574+
max_seqs,
575+
max_model_len,
576+
enable_chunked_prefill=True)
577+
cache_config = CacheConfig(block_size,
578+
1.0,
579+
1,
580+
"auto",
581+
enable_prefix_caching=True)
582+
cache_config.num_cpu_blocks = 0
583+
cache_config.num_gpu_blocks = 32
584+
scheduler = Scheduler(scheduler_config, cache_config, None)
585+
running: List[SequenceGroup] = []
586+
587+
# Add seq groups to scheduler.
588+
for i in range(2):
589+
_, seq_group = create_dummy_prompt(str(i),
590+
block_size=block_size,
591+
prompt_length=50)
592+
scheduler.add_seq_group(seq_group)
593+
running.append(seq_group)
594+
595+
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
596+
assert set(get_sequence_groups(out)) == set(running)
597+
assert seq_group_meta[0].token_chunk_size == 50
598+
# Verify it is chunked. Note that although the budget is 64-50=14,
599+
# we only allocate full blocks for prefix caching, so only 4*(14//4)=12
600+
# tokens are allocated.
601+
assert seq_group_meta[1].token_chunk_size == 12
602+
assert out.num_prefill_groups == 2
603+
assert out.num_batched_tokens == 62

vllm/core/block_manager_v1.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -681,14 +681,20 @@ def access_all_blocks_in_seq(
681681
for block in block_table:
682682
block.last_accessed = access_time
683683

684-
def compute_full_blocks_in_seq(self, seq: Sequence):
684+
def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int):
685685
if seq.seq_id not in self.block_tables:
686686
return
687-
max_full_block = seq.get_len() // self.block_size - 1
687+
688+
# When chunked prefill is enabled, the computed full blocks
689+
# should be calculated based on the number of computed tokens.
690+
max_computed_tokens = (seq.data.get_num_computed_tokens() +
691+
token_chunk_size)
692+
computed_full_blocks = max_computed_tokens // self.block_size
693+
688694
block_table = self.block_tables[seq.seq_id]
689-
if max_full_block == -1:
695+
if computed_full_blocks == 0:
690696
return
691-
for i in reversed(range(max_full_block)):
697+
for i in reversed(range(computed_full_blocks)):
692698
if block_table[i].computed:
693699
break
694700
block_table[i].computed = True
@@ -718,10 +724,11 @@ def get_common_computed_block_ids(
718724
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
719725
return commonprefix([ids for ids in ids_list if ids != []])
720726

721-
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
727+
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
728+
token_chunk_size: int):
722729
if self.enable_caching:
723730
for seq in seq_group.get_seqs():
724-
self.compute_full_blocks_in_seq(seq)
731+
self.compute_full_blocks_in_seq(seq, token_chunk_size)
725732

726733
def get_prefix_cache_hit_rate(self, device: Device) -> float:
727734
if device == Device.GPU:

vllm/core/block_manager_v2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float):
290290
self._last_access_blocks_tracker.update_last_access(
291291
seq.seq_id, now)
292292

293-
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
293+
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
294+
token_chunk_size: int):
294295
# If prefix caching is enabled, mark immutable blocks as computed
295296
# right after they have been scheduled (for prefill). This assumes
296297
# the scheduler is synchronous so blocks are actually computed when

vllm/core/embedding_model_block_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def get_common_computed_block_ids(self,
8080
seq_group: List[Sequence]) -> List[int]:
8181
return []
8282

83-
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
83+
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
84+
token_chunk_size: int):
8485
pass
8586

8687
def get_prefix_cache_hit_rate(self, device: Device) -> float:

vllm/core/interfaces.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def get_common_computed_block_ids(
115115
pass
116116

117117
@abstractmethod
118-
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
118+
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
119+
token_chunk_size: int):
119120
pass
120121

121122
@abstractmethod

vllm/core/scheduler.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,7 +1226,8 @@ def schedule(
12261226
# will crash the vLLM instance / will not retry.
12271227
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
12281228
self.block_manager.mark_blocks_as_computed(
1229-
scheduled_seq_group.seq_group)
1229+
scheduled_seq_group.seq_group,
1230+
scheduled_seq_group.token_chunk_size)
12301231

12311232
self._seq_group_metadata_cache[self.next_cache_id].reset()
12321233

@@ -1457,10 +1458,27 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup,
14571458
for seq in seqs:
14581459
num_new_tokens += seq.get_num_new_tokens()
14591460
assert num_new_tokens > 0
1460-
# Chunk if a running request cannot fit in.
1461-
# If number of seq > 1, it means it is doing beam search in a
1462-
# decode phase. Do not chunk in that case.
1461+
# Chunk if a running request cannot fit in the given budget.
1462+
# If number of seq > 1, it means it is doing beam search
1463+
# in a decode phase. Do not chunk.
14631464
if enable_chunking and len(seqs) == 1:
1464-
num_new_tokens = min(num_new_tokens,
1465-
budget.remaining_token_budget())
1465+
remaining_token_budget = budget.remaining_token_budget()
1466+
if self.cache_config.enable_prefix_caching:
1467+
# When prefix caching is enabled, we always allocate
1468+
# the number of new tokens that is dividable by the block size
1469+
# to avoid partial block matching.
1470+
block_size = self.cache_config.block_size
1471+
reminder = budget.token_budget % block_size
1472+
if reminder != 0:
1473+
raise ValueError("When enabling chunked prefill and "
1474+
"prefix caching, max_num_batched_tokens "
1475+
"(chunk size) must be dividable by "
1476+
"block size, but got chunk_size "
1477+
f"({budget.token_budget}) % block_size "
1478+
f"({block_size}) = {reminder}")
1479+
if remaining_token_budget < num_new_tokens:
1480+
num_new_tokens = (remaining_token_budget //
1481+
block_size) * block_size
1482+
else:
1483+
num_new_tokens = min(num_new_tokens, remaining_token_budget)
14661484
return num_new_tokens

vllm/worker/model_runner.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -501,23 +501,48 @@ def _compute_for_prefix_cache_hit(
501501
and self.sliding_window is None
502502
and inter_data.is_prompt)
503503
inter_data.prefix_cache_hit = prefix_cache_hit
504-
if self.chunked_prefill_enabled and prefix_cache_hit:
505-
raise RuntimeError(
506-
"chunked prefill cannot be used with prefix caching now.")
507-
508-
# If prefix cache is hit, advance context length to bypass
509-
# hit blocks. Accordingly, input tokens, position and query length
510-
# have to be updated.
511-
if prefix_cache_hit:
512-
assert computed_block_nums is not None
513-
context_len = len(computed_block_nums) * self.block_size
504+
505+
if not prefix_cache_hit:
506+
return
507+
508+
assert computed_block_nums is not None
509+
# The cache hit prompt tokens in this sequence. Note that
510+
# this may be larger than the sequence length if chunked
511+
# prefill is enabled.
512+
prefix_cache_len = len(computed_block_nums) * self.block_size
513+
# The number of so far computed prompt tokens in this sequence.
514+
context_len = inter_data.context_lens[seq_idx]
515+
# The total number of prompt tokens in this sequence.
516+
# When chunked prefill is enabled, this is the token number of
517+
# computed chunks + current chunk.
518+
seq_len = inter_data.seq_lens[seq_idx]
519+
if prefix_cache_len <= context_len:
520+
# We already passed the cache hit region,
521+
# so do normal computation.
522+
pass
523+
elif context_len < prefix_cache_len < seq_len:
524+
# Partial hit. Compute the missing part.
525+
uncomputed_start = prefix_cache_len - context_len
514526
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
515-
seq_idx][context_len:]
527+
seq_idx][uncomputed_start:]
516528
inter_data.input_positions[seq_idx] = inter_data.input_positions[
517-
seq_idx][context_len:]
529+
seq_idx][uncomputed_start:]
530+
context_len = prefix_cache_len
531+
518532
inter_data.context_lens[seq_idx] = context_len
519533
inter_data.query_lens[
520534
seq_idx] = inter_data.seq_lens[seq_idx] - context_len
535+
elif seq_len <= prefix_cache_len:
536+
# Full hit. Only compute the last token to avoid
537+
# erroneous behavior. FIXME: Ideally we should directly
538+
# mark all tokens as computed in the scheduler and do not
539+
# schedule this sequence, so this case should not happen.
540+
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
541+
seq_idx][-1:]
542+
inter_data.input_positions[seq_idx] = inter_data.input_positions[
543+
seq_idx][-1:]
544+
inter_data.query_lens[seq_idx] = 1
545+
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
521546

522547
def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
523548
seq_idx: int,

0 commit comments

Comments
 (0)