75
75
from vllm .v1 .cudagraph_dispatcher import CudagraphDispatcher
76
76
# yapf conflicts with isort for this block
77
77
# yapf: disable
78
- from vllm .v1 .kv_cache_interface import (AttentionSpec , FullAttentionSpec ,
79
- KVCacheConfig , KVCacheGroupSpec ,
80
- KVCacheSpec , MambaSpec )
78
+ from vllm .v1 .kv_cache_interface import (AttentionSpec ,
79
+ EncoderOnlyAttentionSpec ,
80
+ FullAttentionSpec , KVCacheConfig ,
81
+ KVCacheGroupSpec , KVCacheSpec ,
82
+ MambaSpec )
81
83
# yapf: enable
82
84
from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , AsyncModelRunnerOutput ,
83
85
DraftTokenIds , LogprobsTensors , ModelRunnerOutput )
@@ -317,10 +319,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
317
319
if torch .version .cann .startswith ("8.3" ):
318
320
self .attn_mask_builder = AttentionMaskBuilder (
319
321
self .scheduler_config .max_num_batched_tokens , self .dtype ,
320
- self .device )
322
+ self .device , self . model_config . runner_type == "generate" )
321
323
else :
322
324
self .attn_mask_builder = AttentionMaskBuilder (
323
- self .model_config .max_model_len , self .dtype )
325
+ self .model_config .max_model_len ,
326
+ self .dtype ,
327
+ tril = self .model_config .runner_type == "generate" )
324
328
325
329
# Set up speculative decoding.
326
330
self .spec_attn_mask = None
@@ -1477,14 +1481,29 @@ def _prepare_inputs(
1477
1481
# in the same group share the same metadata.
1478
1482
for kv_cache_group_id , kv_cache_group_spec in enumerate (
1479
1483
self .kv_cache_config .kv_cache_groups ):
1480
- blk_table = self .input_batch .block_table [kv_cache_group_id ]
1481
- blk_table_tensor = blk_table .get_device_tensor ()
1482
- slot_mapping = blk_table .slot_mapping_cpu [:
1483
- total_num_scheduled_tokens ]
1484
- self .slot_mapping [:total_num_scheduled_tokens ].copy_ (
1485
- slot_mapping [:total_num_scheduled_tokens ],
1486
- non_blocking = True ,
1487
- )
1484
+ if isinstance (kv_cache_group_spec .kv_cache_spec ,
1485
+ EncoderOnlyAttentionSpec ):
1486
+ # Encoder-only layers do not have KV cache, so we need to
1487
+ # create a dummy block table and slot mapping for them.
1488
+ blk_table_tensor = torch .zeros (
1489
+ (num_reqs , 1 ),
1490
+ dtype = torch .int32 ,
1491
+ device = self .device ,
1492
+ )
1493
+ slot_mapping = torch .zeros (
1494
+ (total_num_scheduled_tokens , ),
1495
+ dtype = torch .int64 ,
1496
+ device = self .device ,
1497
+ )
1498
+ else :
1499
+ blk_table = self .input_batch .block_table [kv_cache_group_id ]
1500
+ blk_table_tensor = blk_table .get_device_tensor ()
1501
+ slot_mapping = blk_table .slot_mapping_cpu [:
1502
+ total_num_scheduled_tokens ]
1503
+ self .slot_mapping [:total_num_scheduled_tokens ].copy_ (
1504
+ slot_mapping [:total_num_scheduled_tokens ],
1505
+ non_blocking = True ,
1506
+ )
1488
1507
1489
1508
# Make AscendCommonAttentionMetadata
1490
1509
common_attn_metadata = AscendCommonAttentionMetadata (
@@ -1533,6 +1552,11 @@ def _prepare_inputs(
1533
1552
common_prefix_len = common_prefix_len ,
1534
1553
common_attn_metadata = common_attn_metadata ,
1535
1554
** extra_attn_metadata_args )
1555
+ elif self .model_config .runner_type == "pooling" :
1556
+ attn_metadata_i = builder .build (
1557
+ common_prefix_len = common_prefix_len ,
1558
+ common_attn_metadata = common_attn_metadata ,
1559
+ ** extra_attn_metadata_args )
1536
1560
else :
1537
1561
attn_metadata_i = builder .build (
1538
1562
common_prefix_len = common_prefix_len ,
@@ -2639,6 +2663,33 @@ def _convert_torch_format(self, tensor):
2639
2663
tensor = torch_npu .npu_format_cast (tensor , ACL_FORMAT )
2640
2664
return tensor
2641
2665
2666
+ def may_add_encoder_only_layers_to_kv_cache_config (self ) -> None :
2667
+ """
2668
+ Add encoder-only layers to the KV cache config.
2669
+ """
2670
+ block_size = self .vllm_config .cache_config .block_size
2671
+ use_mla = self .vllm_config .model_config .use_mla
2672
+ encoder_only_attn_specs : dict [AttentionSpec ,
2673
+ list [str ]] = defaultdict (list )
2674
+ attn_layers = get_layers_from_vllm_config (self .vllm_config , Attention )
2675
+ for layer_name , attn_module in attn_layers .items ():
2676
+ if attn_module .attn_type == AttentionType .ENCODER_ONLY :
2677
+ attn_spec : AttentionSpec = EncoderOnlyAttentionSpec (
2678
+ block_size = block_size ,
2679
+ num_kv_heads = attn_module .num_kv_heads ,
2680
+ head_size = attn_module .head_size ,
2681
+ dtype = self .kv_cache_dtype ,
2682
+ use_mla = use_mla )
2683
+ encoder_only_attn_specs [attn_spec ].append (layer_name )
2684
+ self .runner_only_attn_layers .add (layer_name )
2685
+ if len (encoder_only_attn_specs ) > 0 :
2686
+ assert len (
2687
+ encoder_only_attn_specs
2688
+ ) == 1 , "Only support one encoder-only attention spec now"
2689
+ spec , layer_names = encoder_only_attn_specs .popitem ()
2690
+ self .kv_cache_config .kv_cache_groups .append (
2691
+ KVCacheGroupSpec (layer_names = layer_names , kv_cache_spec = spec ))
2692
+
2642
2693
def initialize_kv_cache (self , kv_cache_config : KVCacheConfig ) -> None :
2643
2694
"""
2644
2695
Initialize KV cache based on `kv_cache_config`.
@@ -2648,9 +2699,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
2648
2699
"""
2649
2700
kv_cache_config = deepcopy (kv_cache_config )
2650
2701
self .kv_cache_config = kv_cache_config
2702
+ self .may_reinitialize_input_batch (kv_cache_config )
2703
+ self .may_add_encoder_only_layers_to_kv_cache_config ()
2651
2704
self .initialize_attn_backend (kv_cache_config )
2652
2705
self .use_hybrid_blocks = (len (self .attn_groups ) > 1 )
2653
- self .may_reinitialize_input_batch (kv_cache_config )
2654
2706
2655
2707
if self .model_config .is_deepseek_mla :
2656
2708
kv_caches = self .initialize_kv_cache_tensors_deepseek (
0 commit comments