@@ -130,6 +130,24 @@ class FlashAttentionMetadata:
130
130
prefix_scheduler_metadata : Optional [torch .Tensor ] = None
131
131
max_num_splits : int = 0
132
132
133
+ # Begin encoder attn & enc/dec cross-attn fields...
134
+
135
+ # (batch_size + 1,). The cumulative sequence lengths of the encoder
136
+ # sequences in the batch, used to index into sequence. E.g., if the sequence
137
+ # length is [4, 6], it is [0, 4, 10].
138
+ encoder_seq_start_loc : Optional [torch .Tensor ] = None
139
+ # Maximum sequence length among encoder sequences
140
+ max_encoder_seq_len : Optional [int ] = None
141
+ cross_slot_mapping : Optional [torch .Tensor ] = None
142
+
143
+ @property
144
+ def is_all_encoder_attn_metadata_set (self ) -> bool :
145
+ """
146
+ All attention metadata required for encoder attention is set.
147
+ """
148
+ return (self .encoder_seq_start_loc is not None
149
+ and self .max_encoder_seq_len is not None )
150
+
133
151
134
152
def _get_sliding_window_configs (
135
153
vllm_config : VllmConfig ) -> set [Optional [tuple [int , int ]]]:
@@ -207,7 +225,13 @@ def build(self,
207
225
num_reqs = common_attn_metadata .num_reqs
208
226
num_actual_tokens = common_attn_metadata .num_actual_tokens
209
227
max_query_len = common_attn_metadata .max_query_len
210
- max_seq_len = int (common_attn_metadata .seq_lens_cpu .max ())
228
+
229
+ if (common_attn_metadata .cross_slot_mapping is not None
230
+ and common_attn_metadata .max_encoder_seq_len is not None ):
231
+ # ENCODER_DECODER cross-attention
232
+ max_seq_len = common_attn_metadata .max_encoder_seq_len
233
+ else :
234
+ max_seq_len = int (common_attn_metadata .seq_lens_cpu .max ())
211
235
query_start_loc = common_attn_metadata .query_start_loc
212
236
seq_lens = common_attn_metadata .seq_lens
213
237
seq_lens_cpu = common_attn_metadata .seq_lens_cpu
@@ -326,6 +350,10 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
326
350
suffix_kv_lens = suffix_kv_lens ,
327
351
prefix_scheduler_metadata = prefix_scheduler_metadata ,
328
352
max_num_splits = max_num_splits ,
353
+ # Encoder/cross-attention fields
354
+ encoder_seq_start_loc = common_attn_metadata .encoder_seq_start_loc ,
355
+ max_encoder_seq_len = common_attn_metadata .max_encoder_seq_len ,
356
+ cross_slot_mapping = common_attn_metadata .cross_slot_mapping ,
329
357
)
330
358
return attn_metadata
331
359
@@ -375,17 +403,31 @@ def __init__(
375
403
376
404
FlashAttentionBackend .validate_head_size (head_size )
377
405
378
- if attn_type != AttentionType .DECODER :
379
- raise NotImplementedError ("Encoder self-attention and "
380
- "encoder/decoder cross-attention "
381
- "are not implemented for "
382
- "FlashAttentionImpl" )
406
+ self .attn_type = attn_type
383
407
self .vllm_flash_attn_version = get_flash_attn_version ()
384
408
if is_quantized_kv_cache (self .kv_cache_dtype ) \
385
409
and not flash_attn_supports_fp8 ():
386
410
raise NotImplementedError (
387
411
"FlashAttention does not support fp8 kv-cache on this device." )
388
412
413
+ @staticmethod
414
+ def _get_causal_option (attn_type : str ) -> bool :
415
+ """
416
+ Determine whether the given attention type is suitable for causal
417
+ attention mechanisms.
418
+
419
+ Args:
420
+ attn_type (AttentionType): The type of attention being evaluated
421
+
422
+ Returns:
423
+ bool: Returns `True` if the attention type is suitable for causal
424
+ attention (i.e., not encoder, encoder-only, or encoder-decoder),
425
+ otherwise returns `False`.
426
+ """
427
+ return not (attn_type == AttentionType .ENCODER
428
+ or attn_type == AttentionType .ENCODER_ONLY
429
+ or attn_type == AttentionType .ENCODER_DECODER )
430
+
389
431
def forward (
390
432
self ,
391
433
layer : torch .nn .Module ,
@@ -422,6 +464,14 @@ def forward(
422
464
# Profiling run.
423
465
return output
424
466
467
+ # Validate attention metadata based on attention type
468
+ attn_type = self .attn_type
469
+ if (attn_type in (AttentionType .ENCODER , AttentionType .ENCODER_DECODER ,
470
+ AttentionType .ENCODER_ONLY )
471
+ and (not attn_metadata .is_all_encoder_attn_metadata_set )):
472
+ raise AttributeError ("Encoder attention requires setting "
473
+ "encoder metadata attributes." )
474
+
425
475
# IMPORTANT!
426
476
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
427
477
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
@@ -432,22 +482,40 @@ def forward(
432
482
# performance to make sure it does not introduce any overhead.
433
483
434
484
num_actual_tokens = attn_metadata .num_actual_tokens
485
+
486
+ # Handle encoder attention differently - no KV cache needed
487
+ if attn_type == AttentionType .ENCODER :
488
+ # For encoder attention,
489
+ # we use direct Q, K, V tensors without caching
490
+ return self ._forward_encoder_attention (query [:num_actual_tokens ],
491
+ key [:num_actual_tokens ],
492
+ value [:num_actual_tokens ],
493
+ output [:num_actual_tokens ],
494
+ attn_metadata , layer )
495
+
496
+ # For decoder and cross-attention, use KV cache as before
435
497
key_cache , value_cache = kv_cache .unbind (0 )
436
498
437
- if self .kv_sharing_target_layer_name is None :
499
+ if (self .kv_sharing_target_layer_name is None and (key is not None )
500
+ and (value is not None )):
438
501
# Reshape the input keys and values and store them in the cache.
439
502
# Skip this if sharing KV cache with an earlier attention layer.
440
503
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
441
504
# not padded. However, we don't need to do key[:num_actual_tokens]
442
505
# and value[:num_actual_tokens] because the reshape_and_cache_flash
443
506
# op uses the slot_mapping's shape to determine the number of
444
507
# actual tokens.
508
+ if attn_type == AttentionType .ENCODER_DECODER :
509
+ updated_slot_mapping = attn_metadata .cross_slot_mapping
510
+ else :
511
+ updated_slot_mapping = attn_metadata .slot_mapping
512
+
445
513
reshape_and_cache_flash (
446
514
key ,
447
515
value ,
448
516
key_cache ,
449
517
value_cache ,
450
- attn_metadata . slot_mapping ,
518
+ updated_slot_mapping ,
451
519
self .kv_cache_dtype ,
452
520
layer ._k_scale ,
453
521
layer ._v_scale ,
@@ -471,7 +539,7 @@ def forward(
471
539
block_table = attn_metadata .block_table
472
540
scheduler_metadata = attn_metadata .scheduler_metadata
473
541
474
- descale_shape = (cu_seqlens_q .shape [0 ] - 1 , key . shape [ 1 ] )
542
+ descale_shape = (cu_seqlens_q .shape [0 ] - 1 , self . num_kv_heads )
475
543
476
544
flash_attn_varlen_func (
477
545
q = query [:num_actual_tokens ],
@@ -483,7 +551,7 @@ def forward(
483
551
seqused_k = seqused_k ,
484
552
max_seqlen_k = max_seqlen_k ,
485
553
softmax_scale = self .scale ,
486
- causal = True ,
554
+ causal = FlashAttentionImpl . _get_causal_option ( attn_type ) ,
487
555
alibi_slopes = self .alibi_slopes ,
488
556
window_size = self .sliding_window ,
489
557
block_table = block_table ,
@@ -518,12 +586,86 @@ def forward(
518
586
fa_version = self .vllm_flash_attn_version ,
519
587
prefix_scheduler_metadata = attn_metadata .prefix_scheduler_metadata ,
520
588
suffix_scheduler_metadata = attn_metadata .scheduler_metadata ,
521
- q_descale = layer ._q_scale ,
522
- k_descale = layer ._k_scale ,
523
- v_descale = layer ._v_scale ,
589
+ q_descale = layer ._q_scale . expand ( descale_shape ) ,
590
+ k_descale = layer ._k_scale . expand ( descale_shape ) ,
591
+ v_descale = layer ._v_scale . expand ( descale_shape ) ,
524
592
)
525
593
return output
526
594
595
+ def _forward_encoder_attention (
596
+ self ,
597
+ query : torch .Tensor ,
598
+ key : torch .Tensor ,
599
+ value : torch .Tensor ,
600
+ output : torch .Tensor ,
601
+ attn_metadata : FlashAttentionMetadata ,
602
+ layer : torch .nn .Module ,
603
+ ) -> torch .Tensor :
604
+ """Forward pass for encoder attention without KV cache.
605
+
606
+ Args:
607
+ query: shape = [num_encoder_tokens, num_heads, head_size]
608
+ key: shape = [num_encoder_tokens, num_kv_heads, head_size]
609
+ value: shape = [num_encoder_tokens, num_kv_heads, head_size]
610
+ output: shape = [num_encoder_tokens, num_heads, head_size]
611
+ attn_metadata: Encoder attention metadata
612
+ layer: The attention layer
613
+ """
614
+ # For encoder attention, process FP8 quantization if needed
615
+ if self .kv_cache_dtype .startswith ("fp8" ):
616
+ num_tokens , num_heads , head_size = query .shape
617
+ query , _ = ops .scaled_fp8_quant (
618
+ query .reshape (
619
+ (num_tokens , num_heads * head_size )).contiguous (),
620
+ layer ._q_scale )
621
+ query = query .reshape ((num_tokens , num_heads , head_size ))
622
+
623
+ num_kv_tokens , num_kv_heads , head_size = key .shape
624
+ key , _ = ops .scaled_fp8_quant (
625
+ key .reshape (
626
+ (num_kv_tokens , num_kv_heads * head_size )).contiguous (),
627
+ layer ._k_scale )
628
+ key = key .reshape ((num_kv_tokens , num_kv_heads , head_size ))
629
+
630
+ value , _ = ops .scaled_fp8_quant (
631
+ value .reshape (
632
+ (num_kv_tokens , num_kv_heads * head_size )).contiguous (),
633
+ layer ._v_scale )
634
+ value = value .reshape ((num_kv_tokens , num_kv_heads , head_size ))
635
+
636
+ # Use encoder-specific metadata for sequence information
637
+ cu_seqlens_q = attn_metadata .encoder_seq_start_loc
638
+ cu_seqlens_k = attn_metadata .encoder_seq_start_loc
639
+ max_seqlen_q = attn_metadata .max_encoder_seq_len
640
+ max_seqlen_k = attn_metadata .max_encoder_seq_len
641
+
642
+ descale_shape = (
643
+ cu_seqlens_q .shape [0 ] - 1 , # type: ignore[union-attr]
644
+ self .num_kv_heads )
645
+
646
+ # Call flash attention directly on Q, K, V tensors
647
+ flash_attn_varlen_func (
648
+ q = query ,
649
+ k = key ,
650
+ v = value ,
651
+ out = output ,
652
+ cu_seqlens_q = cu_seqlens_q ,
653
+ cu_seqlens_k = cu_seqlens_k ,
654
+ max_seqlen_q = max_seqlen_q ,
655
+ max_seqlen_k = max_seqlen_k ,
656
+ softmax_scale = self .scale ,
657
+ causal = False , # Encoder attention is bidirectional
658
+ alibi_slopes = self .alibi_slopes ,
659
+ window_size = self .sliding_window ,
660
+ softcap = self .logits_soft_cap ,
661
+ fa_version = self .vllm_flash_attn_version ,
662
+ q_descale = layer ._q_scale .expand (descale_shape ),
663
+ k_descale = layer ._k_scale .expand (descale_shape ),
664
+ v_descale = layer ._v_scale .expand (descale_shape ),
665
+ )
666
+
667
+ return output
668
+
527
669
528
670
def use_cascade_attention (
529
671
common_prefix_len : int ,
0 commit comments