Skip to content

Commit ab78534

Browse files
committed
[Feature] Add support for custom DeepSeek modeling in ACL Graph mode
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 6193ba6 commit ab78534

File tree

4 files changed

+216
-127
lines changed

4 files changed

+216
-127
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 173 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
MLAAttentionImpl)
1010
from vllm.attention.backends.utils import PAD_SLOT_ID
1111
from vllm.config import get_current_vllm_config
12+
from vllm.forward_context import ForwardContext, get_forward_context
13+
from vllm.utils import direct_register_custom_op
1214
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1315
LinearBase, RowParallelLinear,
1416
UnquantizedLinearMethod)
@@ -669,130 +671,180 @@ def forward(
669671
kv_cache: torch.Tensor,
670672
attn_metadata: M,
671673
output: Optional[torch.Tensor] = None,
674+
trace_flag: bool = True,
672675
) -> torch.Tensor:
673676
assert output is not None, "Output tensor must be provided."
674-
if attn_metadata is None:
675-
# Profiling run.
676-
return output
677-
self.running_in_graph = self.enable_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly
678-
num_actual_toks = attn_metadata.num_actual_tokens
679-
if k_pe is None and not self.running_in_graph:
680-
kv_c, k_pe = self.kv_a_proj_with_mqa(
681-
hidden_states_or_kv_c_normed)[0].split(
682-
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
683-
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
677+
if trace_flag:
678+
torch.ops.vllm.unified_ascend_mla_attention_with_output(
679+
query=hidden_states_or_q_c,
680+
key=hidden_states_or_kv_c_normed,
681+
value=k_pe,
682+
output=output,
683+
layer_name=layer.layer_name)
684684
else:
685-
kv_c_normed = hidden_states_or_kv_c_normed
686-
assert attn_metadata.num_decodes is not None and \
687-
attn_metadata.num_prefills is not None and \
688-
attn_metadata.num_decode_tokens is not None
689-
has_decode = attn_metadata.num_decodes > 0
690-
has_prefill = attn_metadata.num_prefills > 0
691-
num_decode_tokens = attn_metadata.num_decode_tokens
692-
if not self.running_in_graph:
693-
# Inputs and outputs may be padded for CUDA graphs
694-
output_padded = output
695-
output = output[:num_actual_toks, ...]
696-
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
697-
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
698-
if not self.running_in_graph:
699-
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
700-
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
701-
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
702-
k_pe = k_pe[:num_actual_toks, ...]
703-
k_pe = k_pe.unsqueeze(1)
704-
decode_k_pe = k_pe[:num_decode_tokens]
705-
prefill_k_pe = k_pe[num_decode_tokens:]
706-
else:
707-
decode_hs_or_q_c = hidden_states_or_q_c
708-
if has_decode:
709-
decode_k_nope = None
710-
assert attn_metadata.decode is not None
711-
decode_ql_nope, decode_q_pe = \
712-
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
713-
if self.running_in_graph:
714-
seq_len = self.rotary_emb.max_position_embeddings
715-
cos = self.rotary_emb.cos_cached[:seq_len].to(
716-
dtype=decode_q_pe.dtype)
717-
sin = self.rotary_emb.sin_cached[:seq_len].to(
718-
dtype=decode_q_pe.dtype)
719-
cos = cos[attn_metadata.decode.input_positions]
720-
sin = sin[attn_metadata.decode.input_positions]
721-
cos = cos[:, None, None, :]
722-
sin = sin[:, None, None, :]
723-
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
724-
decode_k_pe, decode_k_nope = self.exec_kv(
725-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
726-
attn_metadata.slot_mapping)
685+
if attn_metadata is None:
686+
# Profiling run.
687+
return output
688+
self.running_in_graph = self.enable_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly
689+
num_actual_toks = attn_metadata.num_actual_tokens
690+
if k_pe is None and not self.running_in_graph:
691+
kv_c, k_pe = self.kv_a_proj_with_mqa(
692+
hidden_states_or_kv_c_normed)[0].split(
693+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
694+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
695+
else:
696+
kv_c_normed = hidden_states_or_kv_c_normed
697+
assert attn_metadata.num_decodes is not None and \
698+
attn_metadata.num_prefills is not None and \
699+
attn_metadata.num_decode_tokens is not None
700+
has_decode = attn_metadata.num_decodes > 0
701+
has_prefill = attn_metadata.num_prefills > 0
702+
num_decode_tokens = attn_metadata.num_decode_tokens
703+
if not self.running_in_graph:
704+
# Inputs and outputs may be padded for CUDA graphs
705+
output_padded = output
706+
output = output[:num_actual_toks, ...]
707+
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
708+
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
709+
if not self.running_in_graph:
710+
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
711+
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
712+
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
713+
k_pe = k_pe[:num_actual_toks, ...]
714+
k_pe = k_pe.unsqueeze(1)
715+
decode_k_pe = k_pe[:num_decode_tokens]
716+
prefill_k_pe = k_pe[num_decode_tokens:]
727717
else:
728-
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
729-
attn_metadata.decode.input_positions,
730-
decode_q_pe.contiguous(),
731-
decode_k_pe,
732-
max_seq_len=attn_metadata.decode.max_seq_lens)
733-
if has_prefill:
734-
assert attn_metadata.prefill is not None
735-
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
736-
.view(-1, self.num_heads, self.qk_head_dim)
737-
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
738-
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
718+
decode_hs_or_q_c = hidden_states_or_q_c
719+
if has_decode:
720+
decode_k_nope = None
721+
assert attn_metadata.decode is not None
722+
decode_ql_nope, decode_q_pe = \
723+
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
724+
if self.running_in_graph:
725+
seq_len = self.rotary_emb.max_position_embeddings
726+
cos = self.rotary_emb.cos_cached[:seq_len].to(
727+
dtype=decode_q_pe.dtype)
728+
sin = self.rotary_emb.sin_cached[:seq_len].to(
729+
dtype=decode_q_pe.dtype)
730+
cos = cos[attn_metadata.decode.input_positions]
731+
sin = sin[attn_metadata.decode.input_positions]
732+
cos = cos[:, None, None, :]
733+
sin = sin[:, None, None, :]
734+
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
735+
decode_k_pe, decode_k_nope = self.exec_kv(
736+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
737+
attn_metadata.slot_mapping)
738+
else:
739+
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
740+
attn_metadata.decode.input_positions,
741+
decode_q_pe.contiguous(),
742+
decode_k_pe,
743+
max_seq_len=attn_metadata.decode.max_seq_lens)
744+
if has_prefill:
745+
assert attn_metadata.prefill is not None
746+
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
747+
.view(-1, self.num_heads, self.qk_head_dim)
748+
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
749+
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
750+
if self.enable_graph_mode:
751+
num_tokens = prefill_hs_or_q_c.shape[0]
752+
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
753+
-1)
754+
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
755+
# NOTE: When scaling not specified
756+
ori_q_pe_shape, ori_k_pe_shape = prefill_q_pe.shape, prefill_k_pe.shape
757+
prefill_q_pe = prefill_q_pe.reshape(num_tokens, -1)
758+
prefill_k_pe = prefill_k_pe.reshape(num_tokens, -1)
759+
prefill_q_pe, prefill_k_pe = self.rotary_emb(
760+
attn_metadata.prefill.input_positions, prefill_q_pe,
761+
prefill_k_pe)
762+
prefill_q_pe = prefill_q_pe.view(ori_q_pe_shape)
763+
prefill_k_pe = prefill_k_pe.view(ori_k_pe_shape)
764+
else:
765+
prefill_q_pe, prefill_k_pe = self.rotary_emb(
766+
attn_metadata.prefill.input_positions, prefill_q_pe,
767+
prefill_k_pe)
768+
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
769+
else:
770+
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
771+
attn_metadata.prefill.input_positions,
772+
prefill_q_pe.contiguous(),
773+
prefill_k_pe,
774+
max_seq_len=attn_metadata.prefill.max_seq_lens)
739775
if self.enable_graph_mode:
740-
num_tokens = prefill_hs_or_q_c.shape[0]
741-
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
742-
-1)
743-
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
744-
# NOTE: When scaling not specified
745-
ori_q_pe_shape, ori_k_pe_shape = prefill_q_pe.shape, prefill_k_pe.shape
746-
prefill_q_pe = prefill_q_pe.reshape(num_tokens, -1)
747-
prefill_k_pe = prefill_k_pe.reshape(num_tokens, -1)
748-
prefill_q_pe, prefill_k_pe = self.rotary_emb(
749-
attn_metadata.prefill.input_positions, prefill_q_pe,
750-
prefill_k_pe)
751-
prefill_q_pe = prefill_q_pe.view(ori_q_pe_shape)
752-
prefill_k_pe = prefill_k_pe.view(ori_k_pe_shape)
776+
if len(kv_cache) > 0 and kv_cache[0].numel(
777+
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
778+
slots = attn_metadata.slot_mapping
779+
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
780+
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
781+
num_tokens, self.num_kv_heads, -1),
782+
value=prefill_k_pe,
783+
key_cache=kv_cache[0],
784+
value_cache=kv_cache[1],
785+
slot_indices=slots)
786+
elif kv_cache.numel() > 0:
787+
key = torch.cat([
788+
kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]),
789+
k_pe
790+
],
791+
dim=2)
792+
torch_npu._npu_reshape_and_cache_siso(
793+
key=key,
794+
key_cache=kv_cache,
795+
slot_indices=attn_metadata.slot_mapping.flatten())
796+
if has_prefill:
797+
output[num_decode_tokens:] = self._forward_prefill(
798+
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
799+
attn_metadata)
800+
if has_decode:
801+
if self.running_in_graph:
802+
return self._forward_decode(decode_ql_nope, decode_q_pe,
803+
decode_k_nope, decode_k_pe,
804+
kv_cache, attn_metadata)
753805
else:
754-
prefill_q_pe, prefill_k_pe = self.rotary_emb(
755-
attn_metadata.prefill.input_positions, prefill_q_pe,
756-
prefill_k_pe)
757-
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
758-
else:
759-
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
760-
attn_metadata.prefill.input_positions,
761-
prefill_q_pe.contiguous(),
762-
prefill_k_pe,
763-
max_seq_len=attn_metadata.prefill.max_seq_lens)
764-
if self.enable_graph_mode:
765-
if len(kv_cache) > 0 and kv_cache[0].numel(
766-
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
767-
slots = attn_metadata.slot_mapping
768-
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
769-
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
770-
num_tokens, self.num_kv_heads, -1),
771-
value=prefill_k_pe,
772-
key_cache=kv_cache[0],
773-
value_cache=kv_cache[1],
774-
slot_indices=slots)
775-
elif kv_cache.numel() > 0:
776-
key = torch.cat([
777-
kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]),
778-
k_pe
779-
],
780-
dim=2)
781-
torch_npu._npu_reshape_and_cache_siso(
782-
key=key,
783-
key_cache=kv_cache,
784-
slot_indices=attn_metadata.slot_mapping.flatten())
785-
if has_prefill:
786-
output[num_decode_tokens:] = self._forward_prefill(
787-
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
788-
attn_metadata)
789-
if has_decode:
790-
if self.running_in_graph:
791-
return self._forward_decode(decode_ql_nope, decode_q_pe,
792-
decode_k_nope, decode_k_pe,
793-
kv_cache, attn_metadata)
794-
else:
795-
output[:num_decode_tokens] = self._forward_decode(
796-
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
797-
kv_cache, attn_metadata)
798-
return output_padded
806+
output[:num_decode_tokens] = self._forward_decode(
807+
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe,
808+
kv_cache, attn_metadata)
809+
return output_padded
810+
811+
812+
def unified_ascend_mla_attention_with_output(
813+
query: torch.Tensor,
814+
key: torch.Tensor,
815+
value: torch.Tensor,
816+
output: torch.Tensor,
817+
layer_name: str,
818+
) -> None:
819+
forward_context: ForwardContext = get_forward_context()
820+
attn_metadata = forward_context.attn_metadata
821+
self = forward_context.no_compile_layers[layer_name]
822+
kv_cache = self.kv_cache[forward_context.virtual_engine]
823+
self.impl.forward(self,
824+
query,
825+
key,
826+
value,
827+
kv_cache,
828+
attn_metadata,
829+
output,
830+
trace_flag=False)
831+
return
832+
833+
834+
def unified_mla_attention_with_output_fake(
835+
query: torch.Tensor,
836+
key: torch.Tensor,
837+
value: torch.Tensor,
838+
output: torch.Tensor,
839+
layer_name: str,
840+
) -> None:
841+
return
842+
843+
844+
direct_register_custom_op(
845+
op_name="unified_ascend_mla_attention_with_output",
846+
op_func=unified_ascend_mla_attention_with_output,
847+
mutates_args=["output"],
848+
fake_impl=unified_mla_attention_with_output_fake,
849+
dispatch_key="PrivateUse1",
850+
)

vllm_ascend/models/deepseek_v2.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torch import nn
3636
from transformers import PretrainedConfig
3737
from vllm.attention import Attention, AttentionMetadata
38+
from vllm.compilation.decorators import support_torch_compile
3839
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
3940
get_current_vllm_config)
4041
from vllm.distributed import (get_dp_group, get_pp_group,
@@ -207,13 +208,47 @@ def __init__(
207208
self.dp_size = get_dp_group().world_size
208209
batch_size = vllm_config.scheduler_config.max_num_seqs
209210
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", '0')) == 1
211+
additional_config = vllm_config.additional_config
212+
self.enable_graph_mode = False
213+
if additional_config:
214+
self.enable_graph_mode = additional_config.get(
215+
"enable_graph_mode", False)
210216

211217
params_dtype = torch.get_default_dtype()
212218
self.final_hidden_states = torch.zeros(
213219
[batch_size, config.hidden_size], dtype=params_dtype, device="npu")
214220
self.tp_group = get_tp_group().device_group
215221

216222
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
223+
if self.enable_graph_mode:
224+
return self._forward(hidden_states)
225+
else:
226+
return self._forward_eager(hidden_states)
227+
228+
def _forward_eager(self, hidden_states: torch.Tensor) -> torch.Tensor:
229+
num_tokens, hidden_dim = hidden_states.shape
230+
hidden_states = hidden_states.view(-1, hidden_dim)
231+
232+
if self.n_shared_experts is not None:
233+
shared_output = self.shared_experts(hidden_states)
234+
235+
# router_logits: (num_tokens, n_experts)
236+
router_logits, _ = self.gate(hidden_states)
237+
final_hidden_states = self.experts(
238+
hidden_states=hidden_states,
239+
router_logits=router_logits) * self.routed_scaling_factor
240+
241+
# NOTE(Yizhou): Quite strange that the order of these two operations
242+
# is reversed in the original vLLM code
243+
if self.tp_size > 1:
244+
final_hidden_states = tensor_model_parallel_all_reduce(
245+
final_hidden_states)
246+
if shared_output is not None:
247+
final_hidden_states = final_hidden_states + shared_output
248+
249+
return final_hidden_states.view(num_tokens, hidden_dim)
250+
251+
def _forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
217252
attn_metadata = get_forward_context().attn_metadata
218253
if attn_metadata is None:
219254
# for profile run
@@ -538,6 +573,7 @@ def forward(
538573
return hidden_states, residual
539574

540575

576+
@support_torch_compile
541577
class CustomDeepseekV2Model(nn.Module):
542578

543579
fall_back_to_pt_during_load = False

0 commit comments

Comments
 (0)