Skip to content

Commit b98cc28

Browse files
[Core][Kernels] Use FlashInfer backend for FP8 KV Cache when available. (#7798)
Co-authored-by: Simon Mo <simon.mo@hey.com>
1 parent ef9baee commit b98cc28

File tree

3 files changed

+249
-12
lines changed

3 files changed

+249
-12
lines changed

tests/kernels/test_flashinfer.py

Lines changed: 222 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,14 @@ def ref_paged_attn(
7373
@pytest.mark.parametrize("dtype", DTYPES)
7474
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
7575
@torch.inference_mode
76-
def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
77-
num_heads: Tuple[int,
78-
int], head_size: int,
79-
dtype: torch.dtype, block_size: int,
80-
soft_cap: Optional[float]) -> None:
76+
def test_flashinfer_decode_with_paged_kv(
77+
kv_lens: List[int],
78+
num_heads: Tuple[int, int],
79+
head_size: int,
80+
dtype: torch.dtype,
81+
block_size: int,
82+
soft_cap: Optional[float],
83+
) -> None:
8184
torch.set_default_device("cuda")
8285
torch.cuda.manual_seed_all(0)
8386
num_seqs = len(kv_lens)
@@ -88,6 +91,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
8891
scale = head_size**-0.5
8992

9093
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
94+
9195
key_value_cache = torch.randn(NUM_BLOCKS,
9296
2,
9397
block_size,
@@ -125,7 +129,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
125129
wrapper = flashinfer.\
126130
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
127131
use_tensor_cores=(
128-
(num_query_heads//num_kv_heads) not in (1, 2, 4, 8))
132+
(num_query_heads//num_kv_heads) > 4)
129133
)
130134
wrapper.begin_forward(kv_indptr,
131135
kv_indices,
@@ -249,3 +253,215 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
249253
soft_cap=soft_cap)
250254
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
251255
f"{torch.max(torch.abs(output - ref_output))}"
256+
257+
258+
@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]])
259+
@pytest.mark.parametrize("num_heads", [(32, 8), (6, 1)])
260+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
261+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
262+
@pytest.mark.parametrize("dtype", DTYPES)
263+
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
264+
def test_flashinfer_prefill_with_paged_fp8_kv(
265+
seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int],
266+
head_size: int, dtype: torch.dtype, block_size: int,
267+
soft_cap: Optional[float]) -> None:
268+
torch.set_default_device("cuda")
269+
torch.cuda.manual_seed_all(0)
270+
num_seqs = len(seq_lens)
271+
query_lens = [x[0] for x in seq_lens]
272+
kv_lens = [x[1] for x in seq_lens]
273+
num_query_heads = num_heads[0]
274+
num_kv_heads = num_heads[1]
275+
assert num_query_heads % num_kv_heads == 0
276+
max_kv_len = max(kv_lens)
277+
scale = head_size**-0.5
278+
279+
kv_cache_dtype = torch.float8_e4m3fn
280+
281+
query = torch.randn(sum(query_lens),
282+
num_query_heads,
283+
head_size,
284+
dtype=dtype)
285+
NUM_BLOCKS_FP8 = 2048
286+
key_value_cache = torch.randn(NUM_BLOCKS_FP8,
287+
2,
288+
block_size,
289+
num_kv_heads,
290+
head_size,
291+
dtype=dtype)
292+
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
293+
key_cache /= head_size**0.5
294+
value_cache /= head_size**0.5
295+
296+
k_scale = key_cache.amax().item() / 448.0
297+
v_scale = value_cache.amax().item() / 448.0
298+
299+
kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale],
300+
dim=1).to(kv_cache_dtype)
301+
302+
assert (kv_cache_fp8.shape == key_value_cache.shape)
303+
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
304+
block_tables = torch.randint(0,
305+
NUM_BLOCKS_FP8,
306+
(num_seqs, max_num_blocks_per_seq),
307+
dtype=torch.int32)
308+
309+
qo_indptr = [0]
310+
kv_indptr = [0]
311+
kv_indices = []
312+
kv_last_page_lens = []
313+
for i in range(num_seqs):
314+
seq_len = kv_lens[i]
315+
assert seq_len > 0
316+
num_blocks = (seq_len + block_size - 1) // block_size
317+
kv_indices.extend(block_tables[i, :num_blocks])
318+
kv_indptr.append(kv_indptr[-1] + num_blocks)
319+
kv_last_page_len = seq_len % block_size
320+
if kv_last_page_len == 0:
321+
kv_last_page_len = block_size
322+
kv_last_page_lens.append(kv_last_page_len)
323+
qo_indptr.append(qo_indptr[-1] + query_lens[i])
324+
325+
qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32)
326+
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
327+
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
328+
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
329+
330+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
331+
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
332+
workspace_buffer, "NHD")
333+
wrapper.begin_forward(
334+
qo_indptr,
335+
kv_indptr,
336+
kv_indices,
337+
kv_last_page_lens,
338+
num_query_heads,
339+
num_kv_heads,
340+
head_size,
341+
block_size,
342+
)
343+
344+
output = wrapper.forward(query,
345+
kv_cache_fp8,
346+
logits_soft_cap=soft_cap,
347+
k_scale=k_scale,
348+
v_scale=v_scale)
349+
350+
ref_output = ref_paged_attn(query=query,
351+
key_cache=key_cache.squeeze(1),
352+
value_cache=value_cache.squeeze(1),
353+
query_lens=query_lens,
354+
kv_lens=kv_lens,
355+
block_tables=block_tables,
356+
scale=scale,
357+
soft_cap=soft_cap)
358+
del query
359+
del block_tables
360+
# verify prefill fp8
361+
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
362+
f"{torch.max(torch.abs(output - ref_output))}"
363+
364+
365+
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
366+
@pytest.mark.parametrize("num_heads", [(32, 8), (64, 8), (6, 1)])
367+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
368+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
369+
@pytest.mark.parametrize("dtype", DTYPES)
370+
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
371+
@torch.inference_mode
372+
def test_flashinfer_decode_with_paged_fp8_kv(
373+
kv_lens: List[int],
374+
num_heads: Tuple[int, int],
375+
head_size: int,
376+
dtype: torch.dtype,
377+
block_size: int,
378+
soft_cap: Optional[float],
379+
) -> None:
380+
# test doesn't work for num_heads = (16,16)
381+
torch.set_default_device("cuda")
382+
torch.cuda.manual_seed_all(0)
383+
num_seqs = len(kv_lens)
384+
num_query_heads = num_heads[0]
385+
num_kv_heads = num_heads[1]
386+
assert num_query_heads % num_kv_heads == 0
387+
max_kv_len = max(kv_lens)
388+
scale = head_size**-0.5
389+
use_tensor_cores = (num_query_heads // num_kv_heads) > 4
390+
kv_cache_dtype = torch.float8_e4m3fn
391+
392+
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
393+
NUM_BLOCKS_FP8 = 2048
394+
key_value_cache = torch.randn(NUM_BLOCKS_FP8,
395+
2,
396+
block_size,
397+
num_kv_heads,
398+
head_size,
399+
dtype=dtype)
400+
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
401+
key_cache /= head_size**0.5
402+
value_cache /= head_size**0.5
403+
404+
k_scale = key_cache.amax().item() / 448.0
405+
v_scale = value_cache.amax().item() / 448.0
406+
407+
key_cache_fp8 = (key_cache / k_scale).to(kv_cache_dtype)
408+
value_cache_fp8 = (value_cache / v_scale).to(kv_cache_dtype)
409+
assert (key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1)
410+
kv_cache_fp8 = torch.cat([key_cache_fp8, value_cache_fp8], dim=1)
411+
412+
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
413+
block_tables = torch.randint(0,
414+
NUM_BLOCKS_FP8,
415+
(num_seqs, max_num_blocks_per_seq),
416+
dtype=torch.int32)
417+
418+
kv_indptr = [0]
419+
kv_indices = []
420+
kv_last_page_lens = []
421+
for i in range(num_seqs):
422+
seq_len = kv_lens[i]
423+
assert seq_len > 0
424+
num_blocks = (seq_len + block_size - 1) // block_size
425+
kv_indices.extend(block_tables[i, :num_blocks])
426+
kv_indptr.append(kv_indptr[-1] + num_blocks)
427+
kv_last_page_len = seq_len % block_size
428+
if kv_last_page_len == 0:
429+
kv_last_page_len = block_size
430+
kv_last_page_lens.append(kv_last_page_len)
431+
432+
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
433+
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
434+
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
435+
436+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
437+
wrapper = flashinfer.\
438+
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
439+
use_tensor_cores=use_tensor_cores)
440+
wrapper.begin_forward(kv_indptr,
441+
kv_indices,
442+
kv_last_page_lens,
443+
num_query_heads,
444+
num_kv_heads,
445+
head_size,
446+
block_size,
447+
"NONE",
448+
data_type=dtype)
449+
output = wrapper.forward(query,
450+
kv_cache_fp8,
451+
logits_soft_cap=soft_cap,
452+
k_scale=k_scale,
453+
v_scale=v_scale)
454+
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
455+
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
456+
457+
ref_output = ref_paged_attn(query=query,
458+
key_cache=key_cache,
459+
value_cache=value_cache,
460+
query_lens=[1] * num_seqs,
461+
kv_lens=kv_lens,
462+
block_tables=block_tables,
463+
scale=scale,
464+
soft_cap=soft_cap)
465+
# Temporary fix: Increasing the tolerance. Seems like a flashinfer issue
466+
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
467+
f"{torch.max(torch.abs(output - ref_output))}"

vllm/attention/backends/flashinfer.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@ def copy_blocks(
8383
def get_supported_head_sizes() -> List[int]:
8484
return [64, 128, 256]
8585

86+
@staticmethod
87+
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
88+
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
89+
return torch.float8_e4m3fn
90+
elif kv_cache_dtype == "fp8_e5m2":
91+
return torch.float8_e5m2
92+
else:
93+
return ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
94+
8695

8796
class FlashInferState(AttentionState):
8897

@@ -177,9 +186,9 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
177186
self._graph_decode_workspace_buffer, _indptr_buffer,
178187
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
179188
use_tensor_cores)
180-
kv_cache_dtype = get_kv_cache_torch_dtype(
181-
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
182189

190+
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
191+
self.runner.kv_cache_dtype)
183192
paged_kv_indptr_tensor_host = torch.arange(0,
184193
batch_size + 1,
185194
dtype=torch.int32)
@@ -340,7 +349,7 @@ def begin_forward(self):
340349
self.page_size,
341350
# Disable flashinfer's pos encoding and use vllm's rope.
342351
pos_encoding_mode="NONE",
343-
data_type=self.data_type)
352+
)
344353

345354
def asdict_zerocopy(self,
346355
skip_fields: Optional[Set[str]] = None
@@ -366,7 +375,8 @@ def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
366375
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
367376
# Currently chunked prefill is not supported
368377
if self.num_prefills > 0:
369-
assert self.num_decode_tokens == 0
378+
assert self.num_decode_tokens == 0, (
379+
"Chunked prefill is not supported with flashinfer yet.")
370380
return None
371381

372382
return self
@@ -578,6 +588,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
578588

579589
kv_cache_dtype = get_kv_cache_torch_dtype(
580590
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
591+
581592
return FlashInferMetadata(
582593
num_prefills=self.num_prefills,
583594
slot_mapping=slot_mapping_tensor,
@@ -661,7 +672,6 @@ def forward(
661672
if attn_metadata.num_decode_tokens > 0:
662673
assert attn_metadata.num_prefill_tokens == 0, (
663674
"Chunked prefill is not supported with flashinfer yet.")
664-
665675
if kv_cache is not None:
666676
# Use the same reshape and cache kernel as flash attention.
667677
ops.reshape_and_cache_flash(
@@ -674,6 +684,11 @@ def forward(
674684
k_scale,
675685
v_scale,
676686
)
687+
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
688+
# to process the cache in fp8
689+
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
690+
self.kv_cache_dtype)
691+
kv_cache = kv_cache.view(torch_dtype)
677692

678693
query = query.contiguous(
679694
) # Flashinfer requires query to be contiguous
@@ -711,5 +726,7 @@ def forward(
711726
query,
712727
kv_cache,
713728
sm_scale=self.scale,
714-
logits_soft_cap=self.logits_soft_cap)
729+
logits_soft_cap=self.logits_soft_cap,
730+
k_scale=k_scale,
731+
v_scale=v_scale)
715732
return output.view(num_tokens, hidden_size)

vllm/attention/selector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ def which_attn_to_use(
226226
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
227227
logger.info(
228228
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
229+
logger.warning(
230+
"Please use FlashInfer backend with FP8 KV Cache for "
231+
"better performance by set environment "
232+
"VLLM_ATTENTION_BACKEND=FLASHINFER")
229233
selected_backend = _Backend.XFORMERS
230234
elif block_size % 16 != 0:
231235
logger.info(

0 commit comments

Comments
 (0)