-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
[Nvidia] Integrate cudnn prefill paged attention kernel for head_dim == 128 models, like Llama family #20850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
elfiegg
wants to merge
2
commits into
vllm-project:main
Choose a base branch
from
elfiegg:llama-cudnn
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+69
−4
Draft
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -7,12 +7,13 @@ | |||||||||||||||||||||||||||||
from typing import TYPE_CHECKING, Any, Optional | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
import vllm.envs as envs | ||||||||||||||||||||||||||||||
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, | ||||||||||||||||||||||||||||||
BatchPrefillWithPagedKVCacheWrapper, | ||||||||||||||||||||||||||||||
MultiLevelCascadeAttentionWrapper) | ||||||||||||||||||||||||||||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
import vllm.envs as envs | ||||||||||||||||||||||||||||||
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache | ||||||||||||||||||||||||||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||||||||||||||||||||||||||||||
AttentionType) | ||||||||||||||||||||||||||||||
from vllm.logger import init_logger | ||||||||||||||||||||||||||||||
|
@@ -36,6 +37,13 @@ | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
logger = init_logger(__name__) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
CUDNN_SUPPORTED_HEAD_SIZES = [128] | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def is_cudnn_supported(head_dim: int): | ||||||||||||||||||||||||||||||
return head_dim in CUDNN_SUPPORTED_HEAD_SIZES \ | ||||||||||||||||||||||||||||||
and current_platform.has_device_capability(100) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
class FlashInferBackend(AttentionBackend): | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -203,6 +211,10 @@ class FlashInferMetadata: | |||||||||||||||||||||||||||||
num_prefills: int | ||||||||||||||||||||||||||||||
num_prefill_tokens: int | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# For cudnn prefill | ||||||||||||||||||||||||||||||
max_query_len: int | ||||||||||||||||||||||||||||||
actual_seq_lens_q: torch.Tensor | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# For cascade attention. | ||||||||||||||||||||||||||||||
use_cascade: bool | ||||||||||||||||||||||||||||||
shared_qo_indptr: Optional[torch.Tensor] = None | ||||||||||||||||||||||||||||||
|
@@ -302,9 +314,13 @@ def reorder_batch(self, input_batch: InputBatch, | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _get_workspace_buffer(self): | ||||||||||||||||||||||||||||||
if self._workspace_buffer is None: | ||||||||||||||||||||||||||||||
if is_cudnn_supported(self.kv_cache_spec.head_size): | ||||||||||||||||||||||||||||||
dtype = torch.int8 | ||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||
dtype = torch.uint8 | ||||||||||||||||||||||||||||||
self._workspace_buffer = torch.empty( | ||||||||||||||||||||||||||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE, | ||||||||||||||||||||||||||||||
dtype=torch.uint8, | ||||||||||||||||||||||||||||||
dtype=dtype, | ||||||||||||||||||||||||||||||
device=self.runner.device) | ||||||||||||||||||||||||||||||
return self._workspace_buffer | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -369,7 +385,8 @@ def _plan(self, attn_metadata: FlashInferMetadata): | |||||||||||||||||||||||||||||
# Regular attention (common case). | ||||||||||||||||||||||||||||||
# Decodes are at the front and prefills are at the back, | ||||||||||||||||||||||||||||||
# according to reorder_batch() | ||||||||||||||||||||||||||||||
if self._num_prefills > 0: | ||||||||||||||||||||||||||||||
if self._num_prefills > 0 and not is_cudnn_supported( | ||||||||||||||||||||||||||||||
attn_metadata.head_dim): | ||||||||||||||||||||||||||||||
# Decodes are first so prefills start after the last decode | ||||||||||||||||||||||||||||||
prefill_start = self._num_decodes | ||||||||||||||||||||||||||||||
attn_metadata.prefill_wrapper = self._get_prefill_wrapper() | ||||||||||||||||||||||||||||||
|
@@ -441,6 +458,7 @@ def build(self, common_prefix_len: int, | |||||||||||||||||||||||||||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) | ||||||||||||||||||||||||||||||
seq_lens = common_attn_metadata.seq_lens | ||||||||||||||||||||||||||||||
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs] | ||||||||||||||||||||||||||||||
max_query_len = common_attn_metadata.max_query_len | ||||||||||||||||||||||||||||||
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to( | ||||||||||||||||||||||||||||||
self.runner.device, non_blocking=True).long() | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -471,6 +489,7 @@ def build(self, common_prefix_len: int, | |||||||||||||||||||||||||||||
shared_kv_page_indices = None | ||||||||||||||||||||||||||||||
shared_kv_last_page_len = None | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
max_seq_len = int(seq_lens.max().item()) | ||||||||||||||||||||||||||||||
mask = (torch.arange(block_table_tensor.size(1), | ||||||||||||||||||||||||||||||
dtype=block_table_tensor.dtype, | ||||||||||||||||||||||||||||||
device=block_table_tensor.device).unsqueeze(0) | ||||||||||||||||||||||||||||||
|
@@ -487,6 +506,7 @@ def build(self, common_prefix_len: int, | |||||||||||||||||||||||||||||
paged_kv_last_page_len = seq_lens % page_size | ||||||||||||||||||||||||||||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, | ||||||||||||||||||||||||||||||
page_size, paged_kv_last_page_len) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
cache_dtype = self.runner.cache_config.cache_dtype | ||||||||||||||||||||||||||||||
if cache_dtype.startswith("fp8"): | ||||||||||||||||||||||||||||||
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( | ||||||||||||||||||||||||||||||
|
@@ -515,7 +535,9 @@ def build(self, common_prefix_len: int, | |||||||||||||||||||||||||||||
shared_kv_page_indptr=shared_kv_page_indptr, | ||||||||||||||||||||||||||||||
shared_kv_page_indices=shared_kv_page_indices, | ||||||||||||||||||||||||||||||
shared_kv_last_page_len=shared_kv_last_page_len, | ||||||||||||||||||||||||||||||
max_query_len=max_query_len, | ||||||||||||||||||||||||||||||
max_seq_len=max_seq_len, | ||||||||||||||||||||||||||||||
actual_seq_lens_q=qo_indptr[1:] - qo_indptr[:-1], | ||||||||||||||||||||||||||||||
seq_lens=seq_lens, | ||||||||||||||||||||||||||||||
block_table_tensor=block_table_tensor, | ||||||||||||||||||||||||||||||
workspace_buffer=self._workspace_buffer, | ||||||||||||||||||||||||||||||
|
@@ -681,13 +703,56 @@ def forward( | |||||||||||||||||||||||||||||
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap | ||||||||||||||||||||||||||||||
or 0.0) | ||||||||||||||||||||||||||||||
assert prefill_wrapper._sm_scale == self.scale | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
prefill_wrapper.run( | ||||||||||||||||||||||||||||||
prefill_query, | ||||||||||||||||||||||||||||||
kv_cache.permute(*stride_order), | ||||||||||||||||||||||||||||||
k_scale=layer._k_scale_float, | ||||||||||||||||||||||||||||||
v_scale=layer._v_scale_float, | ||||||||||||||||||||||||||||||
out=output[num_decode_tokens:], | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
elif num_prefill_tokens > 0 and FlashInferBackend.is_cudnn_supported( | ||||||||||||||||||||||||||||||
attn_metadata.head_dim): | ||||||||||||||||||||||||||||||
(total_num_pages, _, page_size, num_kv_heads, | ||||||||||||||||||||||||||||||
head_dim) = kv_cache.shape | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# Validate dimensions match expected head_dim | ||||||||||||||||||||||||||||||
assert head_dim == self.head_size, ( | ||||||||||||||||||||||||||||||
f"KV cache head_dim {head_dim} != expected {self.head_size}") | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
k_cache = kv_cache[:, 0].as_strided( | ||||||||||||||||||||||||||||||
(total_num_pages, num_kv_heads, page_size, head_dim), ( | ||||||||||||||||||||||||||||||
page_size * num_kv_heads * head_dim, | ||||||||||||||||||||||||||||||
head_dim, | ||||||||||||||||||||||||||||||
num_kv_heads * head_dim, | ||||||||||||||||||||||||||||||
1, | ||||||||||||||||||||||||||||||
)) | ||||||||||||||||||||||||||||||
v_cache = kv_cache[:, 1].as_strided( | ||||||||||||||||||||||||||||||
(total_num_pages, num_kv_heads, page_size, head_dim), ( | ||||||||||||||||||||||||||||||
page_size * num_kv_heads * head_dim, | ||||||||||||||||||||||||||||||
head_dim, | ||||||||||||||||||||||||||||||
num_kv_heads * head_dim, | ||||||||||||||||||||||||||||||
1, | ||||||||||||||||||||||||||||||
)) | ||||||||||||||||||||||||||||||
Comment on lines
+730
to
+736
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to
Suggested change
|
||||||||||||||||||||||||||||||
output[num_decode_tokens:], _ = cudnn_batch_prefill_with_kv_cache( | ||||||||||||||||||||||||||||||
q=query[num_decode_tokens:], | ||||||||||||||||||||||||||||||
k_cache=k_cache, | ||||||||||||||||||||||||||||||
v_cache=v_cache, | ||||||||||||||||||||||||||||||
scale=self.scale, | ||||||||||||||||||||||||||||||
workspace_buffer=attn_metadata.workspace_buffer, | ||||||||||||||||||||||||||||||
max_token_per_sequence=attn_metadata.max_query_len, | ||||||||||||||||||||||||||||||
max_sequence_kv=attn_metadata.max_seq_len, | ||||||||||||||||||||||||||||||
block_tables=attn_metadata. | ||||||||||||||||||||||||||||||
block_table_tensor[num_decode_tokens:], | ||||||||||||||||||||||||||||||
actual_seq_lens_q=attn_metadata. | ||||||||||||||||||||||||||||||
actual_seq_lens_q[num_decode_tokens:].view(-1, 1, 1, 1), | ||||||||||||||||||||||||||||||
actual_seq_lens_kv=attn_metadata.seq_lens[num_decode_tokens:]. | ||||||||||||||||||||||||||||||
view(-1, 1, 1, 1), | ||||||||||||||||||||||||||||||
causal=True, | ||||||||||||||||||||||||||||||
return_lse=True, | ||||||||||||||||||||||||||||||
is_cuda_graph_compatible=True, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
if decode_wrapper := attn_metadata.decode_wrapper: | ||||||||||||||||||||||||||||||
decode_query = query[:num_decode_tokens] | ||||||||||||||||||||||||||||||
assert decode_query.shape[0] == num_decode_tokens | ||||||||||||||||||||||||||||||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
as_strided
call fork_cache
has an incorrect stride for the first dimension. The stride forkv_cache[:, 0]
's first dimension ispage_size * num_kv_heads * head_dim
, butpage_size * num_kv_heads * head_dim
is used. This will lead to incorrect memory access. The correct first stride should bekv_cache.stride(0)
.