@@ -559,34 +559,41 @@ def forward(
559
559
self .kv_cache_dtype ,
560
560
k_scale , v_scale )
561
561
562
- if attn_type != AttentionType .ENCODER :
563
- # Decoder self-attention supports chunked prefill.
564
- # Encoder/decoder cross-attention requires no chunked
565
- # prefill (100% prefill or 100% decode tokens, no mix)
566
- num_prefill_tokens = attn_metadata .num_prefill_tokens
567
- num_decode_tokens = attn_metadata .num_decode_tokens
568
- else :
562
+ if attn_type == AttentionType .ENCODER :
569
563
# Encoder attention - chunked prefill is not applicable;
570
564
# derive token-count from query shape & and treat them
571
565
# as 100% prefill tokens
572
566
assert attn_metadata .num_encoder_tokens is not None
573
567
num_prefill_tokens = attn_metadata .num_encoder_tokens
568
+ num_encoder_tokens = attn_metadata .num_encoder_tokens
574
569
num_decode_tokens = 0
575
-
576
- if attn_type == AttentionType .DECODER :
570
+ elif attn_type == AttentionType .DECODER :
571
+ # Decoder self-attention supports chunked prefill.
572
+ num_prefill_tokens = attn_metadata .num_prefill_tokens
573
+ num_encoder_tokens = attn_metadata .num_prefill_tokens
574
+ num_decode_tokens = attn_metadata .num_decode_tokens
577
575
# Only enforce this shape-constraint for decoder
578
576
# self-attention
579
577
assert key .shape [0 ] == num_prefill_tokens + num_decode_tokens
580
578
assert value .shape [0 ] == num_prefill_tokens + num_decode_tokens
579
+ else : # attn_type == AttentionType.ENCODER_DECODER
580
+ # Encoder/decoder cross-attention requires no chunked
581
+ # prefill (100% prefill or 100% decode tokens, no mix)
582
+ num_prefill_tokens = attn_metadata .num_prefill_tokens
583
+ if attn_metadata .num_encoder_tokens is not None :
584
+ num_encoder_tokens = attn_metadata .num_encoder_tokens
585
+ else :
586
+ num_encoder_tokens = attn_metadata .num_prefill_tokens
587
+ num_decode_tokens = attn_metadata .num_decode_tokens
581
588
582
589
output = torch .empty_like (query )
583
590
# Query for decode. KV is not needed because it is already cached.
584
591
decode_query = query [num_prefill_tokens :]
585
592
# QKV for prefill.
586
593
query = query [:num_prefill_tokens ]
587
594
if key is not None and value is not None :
588
- key = key [:num_prefill_tokens ]
589
- value = value [:num_prefill_tokens ]
595
+ key = key [:num_encoder_tokens ]
596
+ value = value [:num_encoder_tokens ]
590
597
591
598
assert query .shape [0 ] == num_prefill_tokens
592
599
assert decode_query .shape [0 ] == num_decode_tokens
0 commit comments