Skip to content

Commit c586b55

Browse files
authored
[TPU] Optimize kv cache update kernel (#20415)
Signed-off-by: Yifei Teng <tengyifei88@gmail.com>
1 parent 33d5600 commit c586b55

File tree

3 files changed

+63
-16
lines changed

3 files changed

+63
-16
lines changed

vllm/utils/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,13 @@ def next_power_of_2(n) -> int:
947947
return 1 << (n - 1).bit_length()
948948

949949

950+
def prev_power_of_2(n: int) -> int:
951+
"""The previous power of 2 (inclusive)"""
952+
if n <= 0:
953+
return 0
954+
return 1 << (n.bit_length() - 1)
955+
956+
950957
def round_up(x: int, y: int) -> int:
951958
return ((x + y - 1) // y) * y
952959

vllm/v1/attention/backends/pallas.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,9 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
324324
page_size: int,
325325
num_slices_per_block: int) -> torch.Tensor:
326326
return kv_cache
327+
328+
329+
def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
330+
kv_cache_dtype: torch.dtype) -> int:
331+
"""Returns the size in bytes of one page of the KV cache."""
332+
return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize

vllm/v1/worker/tpu_model_runner.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@
3131
from vllm.multimodal.utils import group_mm_inputs_by_modality
3232
from vllm.sequence import IntermediateTensors
3333
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
34-
is_pin_memory_available)
34+
is_pin_memory_available, prev_power_of_2)
3535
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
36-
PallasMetadata)
36+
PallasMetadata,
37+
get_page_size_bytes)
3738
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
3839
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
3940
KVCacheConfig, KVCacheSpec,
@@ -56,8 +57,6 @@
5657
INVALID_TOKEN_ID = -1
5758
# Smallest output size
5859
MIN_NUM_SEQS = 8
59-
# Block size used for kv cache updating kernel
60-
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8
6160

6261

6362
#########################################################
@@ -139,7 +138,11 @@ def __init__(
139138
self.pin_memory = is_pin_memory_available()
140139
self.dtype = self.model_config.dtype
141140
if cache_config.cache_dtype == "auto":
142-
self.kv_cache_dtype = self.dtype
141+
model_dtype = self.dtype
142+
if isinstance(model_dtype, str):
143+
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
144+
else:
145+
self.kv_cache_dtype = model_dtype
143146
else:
144147
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
145148
cache_config.cache_dtype]
@@ -192,6 +195,14 @@ def __init__(
192195
self.max_num_encoder_input_tokens = encoder_compute_budget
193196
self.encoder_cache_size = encoder_cache_size
194197

198+
self._num_slices_per_kv_cache_update_block = \
199+
_get_num_slices_per_kv_cache_update_block(get_page_size_bytes(
200+
block_size=self.block_size,
201+
num_kv_heads=self.num_kv_heads,
202+
head_size=self.head_size,
203+
kv_cache_dtype=self.kv_cache_dtype,
204+
))
205+
195206
# Lazy initialization
196207
self.model: nn.Module # Set after load_model
197208
self.kv_caches: list[torch.Tensor] = []
@@ -719,7 +730,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
719730
num_kv_update_slices = slot_mapping_metadata.shape[0]
720731
padded_num_slices = _get_padded_num_kv_cache_update_slices(
721732
padded_total_num_scheduled_tokens, self.max_num_reqs,
722-
self.block_size)
733+
self.block_size, self._num_slices_per_kv_cache_update_block)
723734
slot_mapping_metadata = np.pad(
724735
slot_mapping_metadata,
725736
[[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
@@ -750,8 +761,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
750761
num_kv_update_slices=torch.tensor([num_kv_update_slices],
751762
dtype=torch.int32,
752763
device=self.device),
753-
num_slices_per_kv_cache_update_block=
754-
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
764+
num_slices_per_kv_cache_update_block=self.
765+
_num_slices_per_kv_cache_update_block,
755766
)
756767
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
757768
# request in the batch. While we should not sample any token from this
@@ -1197,7 +1208,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
11971208
position_ids = torch.zeros(num_tokens,
11981209
dtype=torch.int32).to(self.device)
11991210
padded_num_slices = _get_padded_num_kv_cache_update_slices(
1200-
num_tokens, self.max_num_reqs, self.block_size)
1211+
num_tokens, self.max_num_reqs, self.block_size,
1212+
self._num_slices_per_kv_cache_update_block)
12011213
num_kv_update_slices = torch.tensor([padded_num_slices],
12021214
dtype=torch.int32).to(self.device)
12031215
slot_mapping = torch.zeros((3, padded_num_slices),
@@ -1220,8 +1232,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int,
12201232
query_start_loc=query_start_loc,
12211233
num_seqs=num_seqs,
12221234
num_kv_update_slices=num_kv_update_slices,
1223-
num_slices_per_kv_cache_update_block=
1224-
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
1235+
num_slices_per_kv_cache_update_block=self.
1236+
_num_slices_per_kv_cache_update_block,
12251237
)
12261238

12271239
if self.is_multimodal_model:
@@ -1826,19 +1838,41 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
18261838
return paddings[index]
18271839

18281840

1829-
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
1830-
page_size: int) -> int:
1841+
def _get_padded_num_kv_cache_update_slices(
1842+
num_tokens: int, max_num_reqs: int, page_size: int,
1843+
num_slices_per_kv_cache_update_block: int) -> int:
18311844
"""Calculates the padded number of KV cache update slices to avoid
18321845
recompilation."""
18331846
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
18341847
padded_num_slices = min(padded_num_slices, num_tokens)
18351848
padded_num_slices = (
1836-
padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1
1837-
) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \
1838-
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
1849+
padded_num_slices + num_slices_per_kv_cache_update_block - 1
1850+
) // num_slices_per_kv_cache_update_block * \
1851+
num_slices_per_kv_cache_update_block
18391852
return padded_num_slices
18401853

18411854

1855+
def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
1856+
"""Find the optimum number of slices to copy per Pallas program instance.
1857+
1858+
Increasing the number of slices copied in one instance of the kernel program
1859+
will increase HBM bandwidth utilization via more in-flight DMAs.
1860+
1861+
However, it will also use more VMEM, and experimentally, we observed
1862+
performance regression at 128 slices on v6e, likely due to running
1863+
out of scalar registers. Thus this function will limit the number of
1864+
slices to 64.
1865+
"""
1866+
# Conservative VMEM usage limit: 32 MiB
1867+
vmem_limit = 32 * 1024 * 1024
1868+
num_slices_per_block = vmem_limit // page_size_bytes
1869+
assert num_slices_per_block > 0, "Number of slices should be positive"
1870+
num_slices_per_block = prev_power_of_2(num_slices_per_block)
1871+
if num_slices_per_block > 64:
1872+
num_slices_per_block = 64
1873+
return num_slices_per_block
1874+
1875+
18421876
def replace_set_lora(model):
18431877

18441878
def _tpu_set_lora(

0 commit comments

Comments
 (0)