Skip to content

Commit c1479c4

Browse files
sasha0552hissu-hyvarinen
authored andcommitted
[Bugfix] Fix illegal memory access error with chunked prefill, prefix caching, block manager v2 and xformers enabled together (vllm-project#9532)
Signed-off-by: sasha0552 <admin@sasha0552.org>
1 parent d3d5f4e commit c1479c4

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

tests/prefix_caching/test_prefix_caching.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,22 @@
55
import pytest
66

77
from tests.kernels.utils import override_backend_env_variable
8+
from vllm import SamplingParams, TokensPrompt
89

910
from ..models.utils import check_outputs_equal
1011

1112
MODELS = [
1213
"facebook/opt-125m",
1314
]
1415

16+
UNSTABLE_PROMPT_SEQUENCE = [
17+
([0] * 588) + ([1] * 1332) + ([2] * 30) + ([3] * 1),
18+
([0] * 588) + ([1] * 1332) + ([4] * 3) + ([5] * 50),
19+
([0] * 588) + ([1] * 1332) + ([2] * 30) + ([6] * 95),
20+
([0] * 588) + ([1] * 1332) + ([4] * 3) + ([7] * 174),
21+
([0] * 588) + ([8] * 1539),
22+
]
23+
1524

1625
@pytest.mark.parametrize("model", MODELS)
1726
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
@@ -57,3 +66,22 @@ def test_mixed_requests(
5766
name_0="hf",
5867
name_1="vllm",
5968
)
69+
70+
71+
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
72+
def test_unstable_prompt_sequence(
73+
vllm_runner,
74+
backend: str,
75+
monkeypatch,
76+
) -> None:
77+
override_backend_env_variable(monkeypatch, backend)
78+
79+
with vllm_runner(
80+
"Qwen/Qwen2.5-0.5B-Instruct",
81+
enable_chunked_prefill=True,
82+
enable_prefix_caching=True,
83+
max_model_len=4096,
84+
) as vllm_model:
85+
for prompt in UNSTABLE_PROMPT_SEQUENCE:
86+
vllm_model.generate(TokensPrompt(prompt_token_ids=prompt),
87+
SamplingParams(max_tokens=1))

vllm/attention/backends/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ def _add_seq_group(
139139
chunked_prefill_enabled: bool):
140140
is_prompt = inter_data.is_prompt
141141
block_tables = inter_data.block_tables
142-
computed_block_nums = inter_data.computed_block_nums
143142

144143
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
145144
curr_sliding_window_block) in zip(
@@ -165,10 +164,14 @@ def _add_seq_group(
165164
# NOTE: This only works for oooooooxxx style attention.
166165
block_table = []
167166
if inter_data.prefix_cache_hit:
168-
block_table = computed_block_nums
167+
block_table = block_tables[seq_id]
169168
elif ((chunked_prefill_enabled or not is_prompt)
170169
and block_tables is not None):
171-
block_table = block_tables[seq_id][-curr_sliding_window_block:]
170+
if curr_sliding_window_block == 0:
171+
block_table = block_tables[seq_id]
172+
else:
173+
block_table = block_tables[seq_id][
174+
-curr_sliding_window_block:]
172175
self.block_tables.append(block_table)
173176

174177
# Compute slot mapping.

0 commit comments

Comments
 (0)