Skip to content

Commit aa4d2a9

Browse files
authored
Refactor AscendMultiHeadLatentAttention (#2826)
### What this PR does / why we need it? Register AscendMultiHeadLatentAttention as CustomOP, following vllm changes ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? CI passed with new added/existing test. - vLLM version: main - vLLM main: vllm-project/vllm@b23fb78 --------- Signed-off-by: Icey <1790571317@qq.com>
1 parent 168ad60 commit aa4d2a9

File tree

4 files changed

+170
-48
lines changed

4 files changed

+170
-48
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 27 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import torch_npu
3232
from torch import nn
3333
from transformers import PretrainedConfig
34-
from vllm.attention import Attention, AttentionMetadata
34+
from vllm.attention import AttentionMetadata
3535
from vllm.config import CacheConfig, ModelConfig, VllmConfig
3636
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
3737
get_tensor_model_parallel_world_size,
@@ -48,6 +48,7 @@
4848
RowParallelLinear,
4949
UnquantizedLinearMethod)
5050
from vllm.model_executor.layers.logits_processor import LogitsProcessor
51+
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
5152
from vllm.model_executor.layers.quantization import QuantizationConfig
5253
from vllm.model_executor.layers.rotary_embedding import get_rope
5354
from vllm.model_executor.layers.sampler import get_sampler
@@ -68,6 +69,7 @@
6869
from vllm.sequence import IntermediateTensors
6970

7071
from vllm_ascend.ascend_config import get_ascend_config
72+
from vllm_ascend.models.layers.mla import AscendMLAModules
7173
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7274
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7375
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
@@ -529,29 +531,7 @@ def __init__(
529531
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
530532
self.scaling = self.scaling * mscale * mscale
531533

532-
# In the MLA backend, kv_cache includes both k_c and
533-
# pe (i.e. decoupled position embeddings). In particular,
534-
# the concat_and_cache_mla op requires
535-
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
536-
# i.e.
537-
# kv_lora_rank + qk_rope_head_dim == head_size
538-
self.mla_attn = Attention(
539-
num_heads=self.num_local_heads,
540-
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
541-
scale=self.scaling,
542-
num_kv_heads=1,
543-
cache_config=cache_config,
544-
quant_config=quant_config,
545-
prefix=f"{prefix}.attn",
546-
use_mla=True,
547-
# MLA Args
548-
q_lora_rank=self.q_lora_rank,
549-
kv_lora_rank=self.kv_lora_rank,
550-
qk_nope_head_dim=self.qk_nope_head_dim,
551-
qk_rope_head_dim=self.qk_rope_head_dim,
552-
qk_head_dim=self.qk_head_dim,
553-
v_head_dim=self.v_head_dim,
554-
rotary_emb=self.rotary_emb,
534+
mla_modules = AscendMLAModules(
555535
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
556536
q_a_layernorm=self.q_a_layernorm
557537
if self.q_lora_rank is not None else None,
@@ -560,6 +540,28 @@ def __init__(
560540
kv_a_layernorm=self.kv_a_layernorm,
561541
kv_b_proj=self.kv_b_proj,
562542
o_proj=self.o_proj,
543+
rotary_emb=self.rotary_emb,
544+
)
545+
546+
self.mla_attn = MultiHeadLatentAttention(
547+
self.hidden_size,
548+
self.enable_shared_expert_dp,
549+
self.debug_layer_idx,
550+
self.first_k_dense_replace,
551+
self.tp_size,
552+
mla_modules,
553+
self.num_local_heads,
554+
self.scaling,
555+
self.layers,
556+
self.kv_lora_rank,
557+
self.qk_rope_head_dim,
558+
self.q_lora_rank,
559+
self.qk_nope_head_dim,
560+
self.qk_head_dim,
561+
self.v_head_dim,
562+
cache_config,
563+
quant_config,
564+
prefix,
563565
)
564566

565567
def forward(
@@ -568,30 +570,7 @@ def forward(
568570
hidden_states: torch.Tensor,
569571
kv_cache: Optional[torch.Tensor] = None,
570572
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
571-
forward_context = get_forward_context()
572-
if kv_cache is None:
573-
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
574-
num_tokens = hidden_states.shape[0]
575-
need_gather_q_kv = False
576-
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
577-
# Simulate all gather to calculate output shape
578-
num_tokens = num_tokens * self.tp_size
579-
need_gather_q_kv = True
580-
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
581-
output_shape = hidden_states.shape
582-
else:
583-
rows = num_tokens // self.tp_size
584-
if num_tokens % self.tp_size:
585-
rows += 1
586-
output_shape = (rows, hidden_states.shape[1])
587-
output = torch.empty(output_shape,
588-
dtype=hidden_states.dtype,
589-
device=hidden_states.device)
590-
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
591-
forward_context.attn_metadata,
592-
need_gather_q_kv, output)
593-
output = output.view(-1, output_shape[-1])
594-
return output
573+
return self.mla_attn(positions, hidden_states, kv_cache, attn_metadata)
595574

596575

597576
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):

vllm_ascend/models/layers/__init__.py

Whitespace-only changes.

vllm_ascend/models/layers/mla.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
5+
#
6+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
7+
# and OPT implementations in this library. It has been modified from its
8+
# original forms to accommodate minor architectural differences compared
9+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
10+
#
11+
# Licensed under the Apache License, Version 2.0 (the "License");
12+
# you may not use this file except in compliance with the License.
13+
# You may obtain a copy of the License at
14+
#
15+
# http://www.apache.org/licenses/LICENSE-2.0
16+
#
17+
# Unless required by applicable law or agreed to in writing, software
18+
# distributed under the License is distributed on an "AS IS" BASIS,
19+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20+
# See the License for the specific language governing permissions and
21+
# limitations under the License.
22+
from dataclasses import dataclass
23+
from typing import Optional
24+
25+
import torch
26+
from torch import nn
27+
from vllm.attention import Attention, AttentionMetadata
28+
from vllm.config import CacheConfig
29+
from vllm.forward_context import get_forward_context
30+
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
31+
from vllm.model_executor.layers.quantization import QuantizationConfig
32+
33+
34+
@dataclass
35+
class AscendMLAModules:
36+
q_a_proj: Optional[torch.nn.Module]
37+
q_a_layernorm: Optional[torch.nn.Module]
38+
q_proj: Optional[torch.nn.Module]
39+
kv_a_proj_with_mqa: torch.nn.Module
40+
kv_a_layernorm: torch.nn.Module
41+
kv_b_proj: torch.nn.Module
42+
o_proj: torch.nn.Module
43+
rotary_emb: torch.nn.Module
44+
45+
46+
class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
47+
48+
def __init__(
49+
self,
50+
hidden_size: int,
51+
enable_shared_expert_dp: bool,
52+
debug_layer_idx: int,
53+
first_k_dense_replace: int,
54+
tp_size: int,
55+
mla_modules: AscendMLAModules,
56+
num_local_heads: int,
57+
scaling: float,
58+
layers: int,
59+
kv_lora_rank: int,
60+
qk_rope_head_dim: int,
61+
q_lora_rank: Optional[int],
62+
qk_nope_head_dim: int,
63+
qk_head_dim: int,
64+
v_head_dim: int,
65+
cache_config: Optional[CacheConfig] = None,
66+
quant_config: Optional[QuantizationConfig] = None,
67+
prefix: str = "",
68+
) -> None:
69+
nn.Module.__init__(self)
70+
self.hidden_size = hidden_size
71+
self.enable_shared_expert_dp = enable_shared_expert_dp
72+
self.debug_layer_idx = debug_layer_idx
73+
self.first_k_dense_replace = first_k_dense_replace
74+
self.tp_size = tp_size
75+
self.num_local_heads = num_local_heads
76+
self.layers = layers
77+
self.kv_lora_rank = kv_lora_rank
78+
self.qk_rope_head_dim = qk_rope_head_dim
79+
self.q_lora_rank = q_lora_rank
80+
self.qk_nope_head_dim = qk_nope_head_dim
81+
self.qk_head_dim = qk_head_dim
82+
self.v_head_dim = v_head_dim
83+
84+
self.mla_attn = Attention(
85+
num_heads=self.num_local_heads,
86+
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
87+
scale=scaling,
88+
num_kv_heads=1,
89+
cache_config=cache_config,
90+
quant_config=quant_config,
91+
prefix=f"{prefix}.attn",
92+
use_mla=True,
93+
# MLA Args
94+
q_lora_rank=self.q_lora_rank,
95+
kv_lora_rank=self.kv_lora_rank,
96+
qk_nope_head_dim=self.qk_nope_head_dim,
97+
qk_rope_head_dim=self.qk_rope_head_dim,
98+
qk_head_dim=self.qk_head_dim,
99+
v_head_dim=self.v_head_dim,
100+
rotary_emb=mla_modules.rotary_emb,
101+
q_a_proj=mla_modules.q_a_proj,
102+
q_a_layernorm=mla_modules.q_a_layernorm,
103+
q_proj=mla_modules.q_proj,
104+
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
105+
kv_a_layernorm=mla_modules.kv_a_layernorm,
106+
kv_b_proj=mla_modules.kv_b_proj,
107+
o_proj=mla_modules.o_proj,
108+
)
109+
110+
def forward(
111+
self,
112+
positions: torch.Tensor,
113+
hidden_states: torch.Tensor,
114+
kv_cache: Optional[torch.Tensor] = None,
115+
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
116+
forward_context = get_forward_context()
117+
if kv_cache is None:
118+
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
119+
num_tokens = hidden_states.shape[0]
120+
need_gather_q_kv = False
121+
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
122+
# Simulate all gather to calculate output shape
123+
num_tokens = num_tokens * self.tp_size
124+
need_gather_q_kv = True
125+
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
126+
output_shape = hidden_states.shape
127+
else:
128+
rows = num_tokens // self.tp_size
129+
if num_tokens % self.tp_size:
130+
rows += 1
131+
output_shape = (rows, hidden_states.shape[1])
132+
output = torch.empty(output_shape,
133+
dtype=hidden_states.dtype,
134+
device=hidden_states.device)
135+
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
136+
forward_context.attn_metadata,
137+
need_gather_q_kv, output)
138+
output = output.view(-1, output_shape[-1])
139+
return output

vllm_ascend/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,10 @@ def register_ascend_customop():
529529
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
530530
CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE")
531531

532+
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
533+
CustomOp.register_oot(_decorated_op_cls=AscendMultiHeadLatentAttention,
534+
name="MultiHeadLatentAttention")
535+
532536
# NOTE: Keep this at last to ensure all custom actions are registered
533537
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
534538

0 commit comments

Comments
 (0)