Skip to content

[Nvidia] Integrate cudnn prefill paged attention kernel for head_dim == 128 models, like Llama family #20850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 69 additions & 4 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from typing import TYPE_CHECKING, Any, Optional

import torch

import vllm.envs as envs
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper)
from flashinfer.decode import trtllm_batch_decode_with_kv_cache

import vllm.envs as envs
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.logger import init_logger
Expand All @@ -36,6 +37,13 @@

logger = init_logger(__name__)

CUDNN_SUPPORTED_HEAD_SIZES = [128]


def is_cudnn_supported(head_dim: int):
return head_dim in CUDNN_SUPPORTED_HEAD_SIZES \
and current_platform.has_device_capability(100)


class FlashInferBackend(AttentionBackend):

Expand Down Expand Up @@ -203,6 +211,10 @@ class FlashInferMetadata:
num_prefills: int
num_prefill_tokens: int

# For cudnn prefill
max_query_len: int
actual_seq_lens_q: torch.Tensor

# For cascade attention.
use_cascade: bool
shared_qo_indptr: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -302,9 +314,13 @@ def reorder_batch(self, input_batch: InputBatch,

def _get_workspace_buffer(self):
if self._workspace_buffer is None:
if is_cudnn_supported(self.kv_cache_spec.head_size):
dtype = torch.int8
else:
dtype = torch.uint8
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
dtype=dtype,
device=self.runner.device)
return self._workspace_buffer

Expand Down Expand Up @@ -369,7 +385,8 @@ def _plan(self, attn_metadata: FlashInferMetadata):
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
if self._num_prefills > 0:
if self._num_prefills > 0 and not is_cudnn_supported(
attn_metadata.head_dim):
# Decodes are first so prefills start after the last decode
prefill_start = self._num_decodes
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
Expand Down Expand Up @@ -441,6 +458,7 @@ def build(self, common_prefix_len: int,
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
max_query_len = common_attn_metadata.max_query_len
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()

Expand Down Expand Up @@ -471,6 +489,7 @@ def build(self, common_prefix_len: int,
shared_kv_page_indices = None
shared_kv_last_page_len = None

max_seq_len = int(seq_lens.max().item())
mask = (torch.arange(block_table_tensor.size(1),
dtype=block_table_tensor.dtype,
device=block_table_tensor.device).unsqueeze(0)
Expand All @@ -487,6 +506,7 @@ def build(self, common_prefix_len: int,
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)

cache_dtype = self.runner.cache_config.cache_dtype
if cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
Expand Down Expand Up @@ -515,7 +535,9 @@ def build(self, common_prefix_len: int,
shared_kv_page_indptr=shared_kv_page_indptr,
shared_kv_page_indices=shared_kv_page_indices,
shared_kv_last_page_len=shared_kv_last_page_len,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
actual_seq_lens_q=qo_indptr[1:] - qo_indptr[:-1],
seq_lens=seq_lens,
block_table_tensor=block_table_tensor,
workspace_buffer=self._workspace_buffer,
Expand Down Expand Up @@ -681,13 +703,56 @@ def forward(
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap
or 0.0)
assert prefill_wrapper._sm_scale == self.scale

prefill_wrapper.run(
prefill_query,
kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[num_decode_tokens:],
)
elif num_prefill_tokens > 0 and FlashInferBackend.is_cudnn_supported(
attn_metadata.head_dim):
(total_num_pages, _, page_size, num_kv_heads,
head_dim) = kv_cache.shape

# Validate dimensions match expected head_dim
assert head_dim == self.head_size, (
f"KV cache head_dim {head_dim} != expected {self.head_size}")

k_cache = kv_cache[:, 0].as_strided(
(total_num_pages, num_kv_heads, page_size, head_dim), (
page_size * num_kv_heads * head_dim,
head_dim,
num_kv_heads * head_dim,
1,
))
Comment on lines +723 to +729
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The as_strided call for k_cache has an incorrect stride for the first dimension. The stride for kv_cache[:, 0]'s first dimension is page_size * num_kv_heads * head_dim, but page_size * num_kv_heads * head_dim is used. This will lead to incorrect memory access. The correct first stride should be kv_cache.stride(0).

Suggested change
k_cache = kv_cache[:, 0].as_strided(
(total_num_pages, num_kv_heads, page_size, head_dim), (
page_size * num_kv_heads * head_dim,
head_dim,
num_kv_heads * head_dim,
1,
))
k_cache = kv_cache[:, 0].as_strided(
(total_num_pages, num_kv_heads, page_size, head_dim), (
kv_cache.stride(0),
head_dim,
num_kv_heads * head_dim,
1,
))

v_cache = kv_cache[:, 1].as_strided(
(total_num_pages, num_kv_heads, page_size, head_dim), (
page_size * num_kv_heads * head_dim,
head_dim,
num_kv_heads * head_dim,
1,
))
Comment on lines +730 to +736
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to k_cache, the as_strided call for v_cache has an incorrect stride for the first dimension. It should also be kv_cache.stride(0).

Suggested change
v_cache = kv_cache[:, 1].as_strided(
(total_num_pages, num_kv_heads, page_size, head_dim), (
page_size * num_kv_heads * head_dim,
head_dim,
num_kv_heads * head_dim,
1,
))
v_cache = kv_cache[:, 1].as_strided(
(total_num_pages, num_kv_heads, page_size, head_dim), (
kv_cache.stride(0),
head_dim,
num_kv_heads * head_dim,
1,
))

output[num_decode_tokens:], _ = cudnn_batch_prefill_with_kv_cache(
q=query[num_decode_tokens:],
k_cache=k_cache,
v_cache=v_cache,
scale=self.scale,
workspace_buffer=attn_metadata.workspace_buffer,
max_token_per_sequence=attn_metadata.max_query_len,
max_sequence_kv=attn_metadata.max_seq_len,
block_tables=attn_metadata.
block_table_tensor[num_decode_tokens:],
actual_seq_lens_q=attn_metadata.
actual_seq_lens_q[num_decode_tokens:].view(-1, 1, 1, 1),
actual_seq_lens_kv=attn_metadata.seq_lens[num_decode_tokens:].
view(-1, 1, 1, 1),
causal=True,
return_lse=True,
is_cuda_graph_compatible=True,
)

if decode_wrapper := attn_metadata.decode_wrapper:
decode_query = query[:num_decode_tokens]
assert decode_query.shape[0] == num_decode_tokens
Expand Down
Loading