Skip to content

Commit 00298e0

Browse files
[Bugfix] Fix bug of xformer prefill for encoder-decoder (#9026)
1 parent 89feb4c commit 00298e0

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

vllm/attention/backends/xformers.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -559,34 +559,41 @@ def forward(
559559
self.kv_cache_dtype,
560560
k_scale, v_scale)
561561

562-
if attn_type != AttentionType.ENCODER:
563-
# Decoder self-attention supports chunked prefill.
564-
# Encoder/decoder cross-attention requires no chunked
565-
# prefill (100% prefill or 100% decode tokens, no mix)
566-
num_prefill_tokens = attn_metadata.num_prefill_tokens
567-
num_decode_tokens = attn_metadata.num_decode_tokens
568-
else:
562+
if attn_type == AttentionType.ENCODER:
569563
# Encoder attention - chunked prefill is not applicable;
570564
# derive token-count from query shape & and treat them
571565
# as 100% prefill tokens
572566
assert attn_metadata.num_encoder_tokens is not None
573567
num_prefill_tokens = attn_metadata.num_encoder_tokens
568+
num_encoder_tokens = attn_metadata.num_encoder_tokens
574569
num_decode_tokens = 0
575-
576-
if attn_type == AttentionType.DECODER:
570+
elif attn_type == AttentionType.DECODER:
571+
# Decoder self-attention supports chunked prefill.
572+
num_prefill_tokens = attn_metadata.num_prefill_tokens
573+
num_encoder_tokens = attn_metadata.num_prefill_tokens
574+
num_decode_tokens = attn_metadata.num_decode_tokens
577575
# Only enforce this shape-constraint for decoder
578576
# self-attention
579577
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
580578
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
579+
else: # attn_type == AttentionType.ENCODER_DECODER
580+
# Encoder/decoder cross-attention requires no chunked
581+
# prefill (100% prefill or 100% decode tokens, no mix)
582+
num_prefill_tokens = attn_metadata.num_prefill_tokens
583+
if attn_metadata.num_encoder_tokens is not None:
584+
num_encoder_tokens = attn_metadata.num_encoder_tokens
585+
else:
586+
num_encoder_tokens = attn_metadata.num_prefill_tokens
587+
num_decode_tokens = attn_metadata.num_decode_tokens
581588

582589
output = torch.empty_like(query)
583590
# Query for decode. KV is not needed because it is already cached.
584591
decode_query = query[num_prefill_tokens:]
585592
# QKV for prefill.
586593
query = query[:num_prefill_tokens]
587594
if key is not None and value is not None:
588-
key = key[:num_prefill_tokens]
589-
value = value[:num_prefill_tokens]
595+
key = key[:num_encoder_tokens]
596+
value = value[:num_encoder_tokens]
590597

591598
assert query.shape[0] == num_prefill_tokens
592599
assert decode_query.shape[0] == num_decode_tokens

0 commit comments

Comments
 (0)