Skip to content

Commit 9bbf32b

Browse files
author
lwq
committed
format code
Signed-off-by: lwq <liwenquan5@huawei.com>
1 parent 2295536 commit 9bbf32b

File tree

4 files changed

+70
-76
lines changed

4 files changed

+70
-76
lines changed

vllm_ascend/ascend_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ def __init__(self, vllm_config):
4949
"enable_shared_expert_dp", False
5050
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
5151
self.enable_mla_prefetch = additional_config.get(
52-
"enable_mla_prefetch", True
53-
)
52+
"enable_mla_prefetch", True)
5453
self.enable_kv_nz = additional_config.get("enable_kv_nz", False)
5554

5655

vllm_ascend/attention/mla_v1.py

Lines changed: 59 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2121
from vllm_ascend.multistream.context import get_multistream_comm_context
2222
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
23-
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, npu_stream_switch, npu_wait_tensor)
23+
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
24+
npu_stream_switch, npu_wait_tensor)
2425
from vllm_ascend.utils import npu_prefetch
2526
from vllm_ascend.worker.npu_input_batch import InputBatch
2627

@@ -668,35 +669,35 @@ def _forward_prefill(
668669
dtype=q_nope.dtype,
669670
device=q_nope.device)
670671
attn_lse = torch.empty(self.num_heads,
671-
num_tokens,
672-
dtype=torch.float32,
673-
device=q_nope.device)
672+
num_tokens,
673+
dtype=torch.float32,
674+
device=q_nope.device)
674675
self.prefill_mask = torch.triu(
675676
torch.ones(512, 512, device=q_nope.device, dtype=q_nope.dtype),
676677
1) # 512: mask only support 512
677678
if attn_metadata.num_prefills > 1:
678-
self.prefill_mask = self.prefill_mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
679-
1)
680-
torch_npu.atb.npu_ring_mla(
681-
q_nope=q_nope,
682-
q_rope=q_pe,
683-
k_nope=k_nope,
684-
k_rope=k_pe,
685-
value=value,
686-
mask=self.prefill_mask,
687-
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
688-
dtype=torch.int32),
689-
head_num=self.num_heads,
690-
kv_head_num=self.num_heads,
691-
pre_out=None,
692-
prev_lse=None,
693-
qk_scale=self.scale,
694-
kernel_type="kernel_type_high_precision",
695-
mask_type="mask_type_triu",
696-
input_layout="type_bsnd",
697-
calc_type="calc_type_first_ring",
698-
output=attn_output,
699-
softmax_lse=attn_lse)
679+
self.prefill_mask = self.prefill_mask.unsqueeze(0).repeat(
680+
attn_metadata.num_prefills, 1, 1)
681+
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
682+
q_rope=q_pe,
683+
k_nope=k_nope,
684+
k_rope=k_pe,
685+
value=value,
686+
mask=self.prefill_mask,
687+
seqlen=torch.tensor(
688+
attn_metadata.prefill.query_lens,
689+
dtype=torch.int32),
690+
head_num=self.num_heads,
691+
kv_head_num=self.num_heads,
692+
pre_out=None,
693+
prev_lse=None,
694+
qk_scale=self.scale,
695+
kernel_type="kernel_type_high_precision",
696+
mask_type="mask_type_triu",
697+
input_layout="type_bsnd",
698+
calc_type="calc_type_first_ring",
699+
output=attn_output,
700+
softmax_lse=attn_lse)
700701
attn_output, attn_lse = self._compute_prefill_context( \
701702
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
702703

@@ -716,7 +717,8 @@ def exec_kv_decode(
716717
N = self.num_kv_heads
717718
S = 1
718719
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
719-
kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
720+
kv_no_split = kv_no_split.view(
721+
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
720722
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
721723
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
722724
kv_no_split,
@@ -743,7 +745,8 @@ def exec_kv_prefill(
743745
N = self.num_kv_heads
744746
S = 1
745747
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
746-
kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
748+
kv_no_split = kv_no_split.view(
749+
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
747750
cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
748751
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
749752
kv_no_split,
@@ -788,15 +791,15 @@ def _forward_decode(
788791
actual_seq_lengths = None
789792
if self.enable_kv_nz:
790793
k_nope = k_nope.view(-1, self.num_kv_heads,
791-
self.kv_lora_rank // 16, block_size, 16)
794+
self.kv_lora_rank // 16, block_size, 16)
792795
k_pe = k_pe.view(-1, self.num_kv_heads,
793-
self.qk_rope_head_dim // 16, block_size, 16)
796+
self.qk_rope_head_dim // 16, block_size, 16)
794797
input_layout = "BSND"
795798
else:
796799
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
797-
self.kv_lora_rank)
800+
self.kv_lora_rank)
798801
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
799-
self.qk_rope_head_dim)
802+
self.qk_rope_head_dim)
800803
input_layout = "BNSD"
801804

802805
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
@@ -846,13 +849,8 @@ def _forward_decode(
846849
current_ms_metadata.before_comm_event.wait()
847850
return self._v_up_proj(attn_output)
848851

849-
def _mla_preprocess(
850-
self,
851-
hidden_states,
852-
kv_cache,
853-
attn_metadata,
854-
need_gather_q_kv
855-
):
852+
def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
853+
need_gather_q_kv):
856854
# MLA Preprocess:
857855
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
858856
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
@@ -877,8 +875,7 @@ def _mla_preprocess(
877875
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
878876
# Process for shared_expert_dp
879877
if need_gather_q_kv:
880-
q_c = get_tp_group().all_gather(
881-
q_c, 0)
878+
q_c = get_tp_group().all_gather(q_c, 0)
882879
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
883880
decode_preprocess_res = None
884881
prefill_preprocess_res = None
@@ -893,33 +890,37 @@ def _mla_preprocess(
893890
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens]
894891
decode_kv_no_split = kv_no_split[:num_decode_tokens]
895892
decode_k_pe, decode_k_nope = self.exec_kv_decode(
896-
decode_kv_no_split, cos, sin, kv_cache,
897-
decode_slots)
893+
decode_kv_no_split, cos, sin, kv_cache, decode_slots)
898894
decode_preprocess_res = DecodeMLAPreprocessResult(
899895
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe)
900896
# Preprocess for prefill tokens
901897
if has_prefill:
902-
prefill_kv_no_split = kv_no_split[num_decode_tokens:num_actual_tokens]
898+
prefill_kv_no_split = kv_no_split[
899+
num_decode_tokens:num_actual_tokens]
903900
prefill_q_c = q_c[num_decode_tokens:num_actual_tokens]
904901
prefill_q = self.q_proj(prefill_q_c)[0]\
905902
.view(-1, self.num_heads, self.qk_head_dim)
906903
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
907904
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
908905
cos = attn_metadata.prefill.cos
909906
sin = attn_metadata.prefill.sin
910-
prefill_slots = attn_metadata.slot_mapping[num_decode_tokens:num_actual_tokens]
907+
prefill_slots = attn_metadata.slot_mapping[
908+
num_decode_tokens:num_actual_tokens]
911909
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
912910
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
913-
prefill_kv_no_split, cos, sin, kv_cache,
914-
prefill_slots)
915-
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0], self.num_kv_heads,
916-
-1)
917-
prefill_k_nope, prefill_value = self.kv_b_proj(prefill_k_c_normed)[0].view(
918-
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
919-
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
920-
prefill_k_pe = prefill_k_pe.expand((*prefill_k_nope.shape[:-1], -1))
911+
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
912+
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
913+
self.num_kv_heads, -1)
914+
prefill_k_nope, prefill_value = self.kv_b_proj(
915+
prefill_k_c_normed)[0].view(
916+
-1, self.num_heads,
917+
self.qk_nope_head_dim + self.v_head_dim).split(
918+
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
919+
prefill_k_pe = prefill_k_pe.expand(
920+
(*prefill_k_nope.shape[:-1], -1))
921921
prefill_preprocess_res = PrefillMLAPreprocessResult(
922-
prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value)
922+
prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe,
923+
prefill_value)
923924
return decode_preprocess_res, prefill_preprocess_res
924925

925926
def forward(
@@ -972,13 +973,10 @@ def forward(
972973
# FIX: aicore move should be also placed on the comm stream in dbo,
973974
# otherwise it may affect the accuracy
974975
# TODO: use an elegant way to overlap
975-
output_prefill = self._forward_prefill(prefill_preprocess_res.q_nope,
976-
prefill_preprocess_res.q_pe,
977-
prefill_preprocess_res.k_nope,
978-
prefill_preprocess_res.k_pe,
979-
prefill_preprocess_res.value,
980-
kv_cache,
981-
attn_metadata)
976+
output_prefill = self._forward_prefill(
977+
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
978+
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
979+
prefill_preprocess_res.value, kv_cache, attn_metadata)
982980
current_ms_metadata = get_multistream_comm_context()
983981
if current_ms_metadata is not None:
984982
with torch.npu.stream(current_ms_metadata.comm_stream):

vllm_ascend/models/deepseek_v2.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,8 @@ def __init__(
565565
v_head_dim=self.v_head_dim,
566566
rotary_emb=self.rotary_emb,
567567
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
568-
q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None,
568+
q_a_layernorm=self.q_a_layernorm
569+
if self.q_lora_rank is not None else None,
569570
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
570571
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
571572
kv_a_layernorm=self.kv_a_layernorm,
@@ -598,11 +599,9 @@ def forward(
598599
output = torch.empty(output_shape,
599600
dtype=hidden_states.dtype,
600601
device=hidden_states.device)
601-
output = self.mla_attn.impl.forward(hidden_states,
602-
kv_cache,
602+
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
603603
forward_context.attn_metadata,
604-
need_gather_q_kv,
605-
output)
604+
need_gather_q_kv, output)
606605
output = output.view(-1, output_shape[-1])
607606
return output
608607

@@ -735,8 +734,7 @@ def forward(
735734
hidden_states, residual)
736735

737736
if isinstance(self.mlp, CustomDeepseekV2MoE):
738-
hidden_states = self.mlp(hidden_states,
739-
attn_metadata)
737+
hidden_states = self.mlp(hidden_states, attn_metadata)
740738
else:
741739
hidden_states = self.mlp(hidden_states)
742740

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,11 @@ def __init__(
4848
device=self.runner.device)
4949
self.torchair_compiled_model = None # type: ignore
5050
self.torchair_compiled_models = {} # type: ignore
51-
self.reserved_mc2_mask = torch.zeros(
52-
512,
53-
dtype=torch.bool,
54-
device=self.runner.device
55-
)
56-
self.torchair_graph_enabled = get_ascend_config().torchair_graph_config.enabled
51+
self.reserved_mc2_mask = torch.zeros(512,
52+
dtype=torch.bool,
53+
device=self.runner.device)
54+
self.torchair_graph_enabled = get_ascend_config(
55+
).torchair_graph_config.enabled
5756

5857
@staticmethod
5958
def prepare_inputs(

0 commit comments

Comments
 (0)