Skip to content

Commit 9f7bcb1

Browse files
sasha0552weilong.yu
authored andcommitted
[Bugfix] Fix illegal memory access error with chunked prefill, prefix caching, block manager v2 and xformers enabled together (vllm-project#9532)
Signed-off-by: sasha0552 <admin@sasha0552.org>
1 parent 338bed9 commit 9f7bcb1

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

tests/prefix_caching/test_prefix_caching.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,22 @@
55
import pytest
66

77
from tests.kernels.utils import override_backend_env_variable
8+
from vllm import SamplingParams, TokensPrompt
89

910
from ..models.utils import check_outputs_equal
1011

1112
MODELS = [
1213
"facebook/opt-125m",
1314
]
1415

16+
UNSTABLE_PROMPT_SEQUENCE = [
17+
([0] * 588) + ([1] * 1332) + ([2] * 30) + ([3] * 1),
18+
([0] * 588) + ([1] * 1332) + ([4] * 3) + ([5] * 50),
19+
([0] * 588) + ([1] * 1332) + ([2] * 30) + ([6] * 95),
20+
([0] * 588) + ([1] * 1332) + ([4] * 3) + ([7] * 174),
21+
([0] * 588) + ([8] * 1539),
22+
]
23+
1524

1625
@pytest.mark.parametrize("model", MODELS)
1726
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
@@ -57,3 +66,22 @@ def test_mixed_requests(
5766
name_0="hf",
5867
name_1="vllm",
5968
)
69+
70+
71+
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
72+
def test_unstable_prompt_sequence(
73+
vllm_runner,
74+
backend: str,
75+
monkeypatch,
76+
) -> None:
77+
override_backend_env_variable(monkeypatch, backend)
78+
79+
with vllm_runner(
80+
"Qwen/Qwen2.5-0.5B-Instruct",
81+
enable_chunked_prefill=True,
82+
enable_prefix_caching=True,
83+
max_model_len=4096,
84+
) as vllm_model:
85+
for prompt in UNSTABLE_PROMPT_SEQUENCE:
86+
vllm_model.generate(TokensPrompt(prompt_token_ids=prompt),
87+
SamplingParams(max_tokens=1))

vllm/attention/backends/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ def _add_seq_group(
138138
chunked_prefill_enabled: bool):
139139
is_prompt = inter_data.is_prompt
140140
block_tables = inter_data.block_tables
141-
computed_block_nums = inter_data.computed_block_nums
142141

143142
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
144143
curr_sliding_window_block) in zip(
@@ -164,10 +163,14 @@ def _add_seq_group(
164163
# NOTE: This only works for oooooooxxx style attention.
165164
block_table = []
166165
if inter_data.prefix_cache_hit:
167-
block_table = computed_block_nums
166+
block_table = block_tables[seq_id]
168167
elif ((chunked_prefill_enabled or not is_prompt)
169168
and block_tables is not None):
170-
block_table = block_tables[seq_id][-curr_sliding_window_block:]
169+
if curr_sliding_window_block == 0:
170+
block_table = block_tables[seq_id]
171+
else:
172+
block_table = block_tables[seq_id][
173+
-curr_sliding_window_block:]
171174
self.block_tables.append(block_table)
172175

173176
# Compute slot mapping.

0 commit comments

Comments
 (0)