31
31
from vllm .multimodal .utils import group_mm_inputs_by_modality
32
32
from vllm .sequence import IntermediateTensors
33
33
from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , LayerBlockType , cdiv ,
34
- is_pin_memory_available )
34
+ is_pin_memory_available , prev_power_of_2 )
35
35
from vllm .v1 .attention .backends .pallas import (PallasAttentionBackend ,
36
- PallasMetadata )
36
+ PallasMetadata ,
37
+ get_page_size_bytes )
37
38
from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
38
39
from vllm .v1 .kv_cache_interface import (AttentionSpec , FullAttentionSpec ,
39
40
KVCacheConfig , KVCacheSpec ,
56
57
INVALID_TOKEN_ID = - 1
57
58
# Smallest output size
58
59
MIN_NUM_SEQS = 8
59
- # Block size used for kv cache updating kernel
60
- NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8
61
60
62
61
63
62
#########################################################
@@ -139,7 +138,11 @@ def __init__(
139
138
self .pin_memory = is_pin_memory_available ()
140
139
self .dtype = self .model_config .dtype
141
140
if cache_config .cache_dtype == "auto" :
142
- self .kv_cache_dtype = self .dtype
141
+ model_dtype = self .dtype
142
+ if isinstance (model_dtype , str ):
143
+ self .kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE [model_dtype ]
144
+ else :
145
+ self .kv_cache_dtype = model_dtype
143
146
else :
144
147
self .kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE [
145
148
cache_config .cache_dtype ]
@@ -192,6 +195,14 @@ def __init__(
192
195
self .max_num_encoder_input_tokens = encoder_compute_budget
193
196
self .encoder_cache_size = encoder_cache_size
194
197
198
+ self ._num_slices_per_kv_cache_update_block = \
199
+ _get_num_slices_per_kv_cache_update_block (get_page_size_bytes (
200
+ block_size = self .block_size ,
201
+ num_kv_heads = self .num_kv_heads ,
202
+ head_size = self .head_size ,
203
+ kv_cache_dtype = self .kv_cache_dtype ,
204
+ ))
205
+
195
206
# Lazy initialization
196
207
self .model : nn .Module # Set after load_model
197
208
self .kv_caches : list [torch .Tensor ] = []
@@ -719,7 +730,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
719
730
num_kv_update_slices = slot_mapping_metadata .shape [0 ]
720
731
padded_num_slices = _get_padded_num_kv_cache_update_slices (
721
732
padded_total_num_scheduled_tokens , self .max_num_reqs ,
722
- self .block_size )
733
+ self .block_size , self . _num_slices_per_kv_cache_update_block )
723
734
slot_mapping_metadata = np .pad (
724
735
slot_mapping_metadata ,
725
736
[[0 , padded_num_slices - len (slot_mapping_metadata )], [0 , 0 ]],
@@ -750,8 +761,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
750
761
num_kv_update_slices = torch .tensor ([num_kv_update_slices ],
751
762
dtype = torch .int32 ,
752
763
device = self .device ),
753
- num_slices_per_kv_cache_update_block =
754
- NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK ,
764
+ num_slices_per_kv_cache_update_block = self .
765
+ _num_slices_per_kv_cache_update_block ,
755
766
)
756
767
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
757
768
# request in the batch. While we should not sample any token from this
@@ -1197,7 +1208,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
1197
1208
position_ids = torch .zeros (num_tokens ,
1198
1209
dtype = torch .int32 ).to (self .device )
1199
1210
padded_num_slices = _get_padded_num_kv_cache_update_slices (
1200
- num_tokens , self .max_num_reqs , self .block_size )
1211
+ num_tokens , self .max_num_reqs , self .block_size ,
1212
+ self ._num_slices_per_kv_cache_update_block )
1201
1213
num_kv_update_slices = torch .tensor ([padded_num_slices ],
1202
1214
dtype = torch .int32 ).to (self .device )
1203
1215
slot_mapping = torch .zeros ((3 , padded_num_slices ),
@@ -1220,8 +1232,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
1220
1232
query_start_loc = query_start_loc ,
1221
1233
num_seqs = num_seqs ,
1222
1234
num_kv_update_slices = num_kv_update_slices ,
1223
- num_slices_per_kv_cache_update_block =
1224
- NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK ,
1235
+ num_slices_per_kv_cache_update_block = self .
1236
+ _num_slices_per_kv_cache_update_block ,
1225
1237
)
1226
1238
1227
1239
if self .is_multimodal_model :
@@ -1826,19 +1838,41 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
1826
1838
return paddings [index ]
1827
1839
1828
1840
1829
- def _get_padded_num_kv_cache_update_slices (num_tokens : int , max_num_reqs : int ,
1830
- page_size : int ) -> int :
1841
+ def _get_padded_num_kv_cache_update_slices (
1842
+ num_tokens : int , max_num_reqs : int , page_size : int ,
1843
+ num_slices_per_kv_cache_update_block : int ) -> int :
1831
1844
"""Calculates the padded number of KV cache update slices to avoid
1832
1845
recompilation."""
1833
1846
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
1834
1847
padded_num_slices = min (padded_num_slices , num_tokens )
1835
1848
padded_num_slices = (
1836
- padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1
1837
- ) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \
1838
- NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
1849
+ padded_num_slices + num_slices_per_kv_cache_update_block - 1
1850
+ ) // num_slices_per_kv_cache_update_block * \
1851
+ num_slices_per_kv_cache_update_block
1839
1852
return padded_num_slices
1840
1853
1841
1854
1855
+ def _get_num_slices_per_kv_cache_update_block (page_size_bytes : int ) -> int :
1856
+ """Find the optimum number of slices to copy per Pallas program instance.
1857
+
1858
+ Increasing the number of slices copied in one instance of the kernel program
1859
+ will increase HBM bandwidth utilization via more in-flight DMAs.
1860
+
1861
+ However, it will also use more VMEM, and experimentally, we observed
1862
+ performance regression at 128 slices on v6e, likely due to running
1863
+ out of scalar registers. Thus this function will limit the number of
1864
+ slices to 64.
1865
+ """
1866
+ # Conservative VMEM usage limit: 32 MiB
1867
+ vmem_limit = 32 * 1024 * 1024
1868
+ num_slices_per_block = vmem_limit // page_size_bytes
1869
+ assert num_slices_per_block > 0 , "Number of slices should be positive"
1870
+ num_slices_per_block = prev_power_of_2 (num_slices_per_block )
1871
+ if num_slices_per_block > 64 :
1872
+ num_slices_per_block = 64
1873
+ return num_slices_per_block
1874
+
1875
+
1842
1876
def replace_set_lora (model ):
1843
1877
1844
1878
def _tpu_set_lora (
0 commit comments