Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 198 additions & 34 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -67,6 +67,18 @@ def get_impl_cls() -> Type["MLAAttentionImpl"]:
@dataclass
class AscendMLAPrefillMetadata:
""" Prefill Specific Metadata for Ascend"""

@dataclass
class ChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor

attn_mask: torch.Tensor
query_lens: list[int]
seq_lens: list[int]
Expand All @@ -76,6 +88,7 @@ class AscendMLAPrefillMetadata:
block_table: torch.Tensor
max_query_len: int
max_seq_lens: int
chunked_context: Optional[ChunkedContextMetadata] = None


@dataclass
Expand Down Expand Up @@ -170,7 +183,32 @@ def __init__(self,
if metadata_cls is not None else AscendMLAMetadata # type: ignore
self.runner = runner
scheduler_config = runner.scheduler_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
model_config = runner.model_config
self.block_size = runner.block_size
self.chunked_prefill_enabled = runner.chunked_prefill_enabled
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(8 * model_config.max_model_len,
4 * scheduler_config.max_num_seqs * self.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * self.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
model_config.get_head_size()),
dtype=model_config.dtype,
device=runner.device,
)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled

Expand Down Expand Up @@ -348,6 +386,7 @@ def build(
query_start_loc = common_attn_metadata.query_start_loc

prefill_metadata = None
chunked_context_metadata = None
if self._num_prefills > 0:
reqs_start = self._num_decodes # prefill_start
tokens_start = self._num_decode_tokens
Expand All @@ -357,6 +396,41 @@ def build(
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]

context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[
reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)
max_context_chunk = round_down(
max_context_chunk, self.block_size) #待确认是否需要是block_size的倍数

assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(num_chunks,
self._num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
dim=1,
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata = \
AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
)

prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=self.runner.attn_mask,
query_lens=query_lens[tokens_start:],
Expand All @@ -367,6 +441,7 @@ def build(
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
chunked_context=chunked_context_metadata,
)

decode_metadata = None
Expand Down Expand Up @@ -556,6 +631,83 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)

def _compute_prefill_context(
self,
query: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
rope_dim: int,
attn_metadata: AscendMLAMetadata,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
):
prefill_metadata = attn_metadata.prefill
if prefill_metadata is None or prefill_metadata.chunked_context is None:
return prefix_output, prefix_lse

iters = len(prefill_metadata.chunked_context.seq_tot)
q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim]

seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]

seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([seq_len1, seq_len2])
kv_c_normed = torch.empty(toks,
kv_c_and_k_pe_cache.size(2),
latent_kv_dim,
dtype=query.dtype,
device=query.device)
k_pe = torch.empty(toks,
kv_c_and_k_pe_cache.size(2),
rope_dim,
dtype=query.dtype,
device=query.device)

torch_npu.atb.npu_paged_cache_load(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
seq_len2.to(query.device),
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
)

kv_c_normed = kv_c_normed.squeeze()
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
mask = torch.triu(
torch.ones(512, 512, device=query.device, dtype=query.dtype),
1)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=v,
mask=mask,
seqlen=seq_len,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=prefix_output,
prev_lse=prefix_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=prefix_output,
softmax_lse=prefix_lse)
return prefix_output, prefix_lse

def _forward_prefill(
self,
query: torch.Tensor,
Expand All @@ -567,45 +719,57 @@ def _forward_prefill(
assert attn_metadata.prefill is not None

num_tokens = query.size(0)
attn_output = None
attn_output = torch.empty(num_tokens,
self.num_heads,
self.v_head_dim,
dtype=query.dtype,
device=query.device)
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
if attn_metadata.attn_state in [
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding
]:
attn_output = torch.empty(num_tokens,
self.num_heads * self.v_head_dim,
dtype=query.dtype,
device=query.device)
# current requests is chunked in prefill, disable flash attention with chunked prefill
vanilla_chunked_prefill_mla(
attn_lse = torch.empty(self.num_heads,
num_tokens,
dtype=torch.float32,
device=query.device)
q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim]
mask = torch.triu(
torch.ones(512, 512, device=query.device, dtype=query.dtype),
1) # 512: mask only support 512
if attn_metadata.num_prefills > 1:
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
1)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=value,
mask=mask,
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
dtype=torch.int32),
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
query=query,
kv_cache=kv_c_and_k_pe_cache,
block_tables=attn_metadata.prefill.block_table,
query_lens=attn_metadata.prefill.query_lens,
context_lens=attn_metadata.prefill.context_lens,
kv_b_proj=self.kv_b_proj,
max_query_len=attn_metadata.prefill.max_query_len,
max_context_len=attn_metadata.prefill.max_seq_lens,
nope_dim=self.qk_nope_head_dim,
rope_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
scale=self.scale,
alibi_slopes=None,
causal=True)
softmax_lse=attn_lse)
attn_output, attn_lse = self._compute_prefill_context( \
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)

elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
attn_output = torch.empty(num_tokens,
self.num_heads,
self.v_head_dim,
dtype=query.dtype,
device=query.device)
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
key = torch.cat((k_nope, k_pe), dim=-1)
torch_npu._npu_flash_attention(
query=query,
key=key,
Expand Down
Loading
Loading