Skip to content

Commit 793892f

Browse files
committed
feat(graph): Enable ACL graph capture for MLA decode
Adds support for capturing the Multi-Layer Attention (MLA) decode operation into an ACL graph. This improves performance by compiling the attention kernel for single-token decoding. Key changes include: - Implementing the graph capture logic for the MLA kernel, including workspace management and parameter updates. - Modifying the rotary embedding (RoPE) handling to use pre-allocated tensors, which is a requirement for graph capture. - Adding a `build_for_graph_capture` method to the MLA metadata builder to create dummy metadata during the graph compilation phase. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent d01fd1d commit 793892f

File tree

6 files changed

+181
-34
lines changed

6 files changed

+181
-34
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def build_for_graph_capture(
244244
self,
245245
common_attn_metadata: AscendCommonAttentionMetadata,
246246
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
247+
model: Optional[nn.Module] = None,
247248
):
248249
if attn_state == AscendAttentionState.DecodeOnly:
249250
attn_metadata = self.build(

vllm_ascend/attention/mla_v1.py

Lines changed: 95 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
MLAAttentionImpl)
1111
from vllm.config import VllmConfig, get_current_vllm_config
1212
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
13+
from vllm.forward_context import ForwardContext, get_forward_context
1314
from vllm.model_executor.layers.linear import (LinearBase,
1415
UnquantizedLinearMethod)
1516
from vllm.utils import cdiv, round_down
@@ -21,6 +22,7 @@
2122
maybe_save_kv_layer_to_connector,
2223
split_decodes_and_prefills,
2324
wait_for_kv_layer_from_connector)
25+
from vllm_ascend.compilation.acl_graph import get_graph_params
2426
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2527
from vllm_ascend.multistream.context import get_multistream_comm_context
2628
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
@@ -169,7 +171,7 @@ def split_metadata_for_multistream(
169171
class AscendMLAMetadataBuilder:
170172
# Does this backend/builder support ACL Graphs for attention (default: no).
171173
aclgraph_support: ClassVar[AttentionCGSupport] = \
172-
AttentionCGSupport.NEVER
174+
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
173175
"""
174176
NOTE: Please read the comment at the top of the file before trying to
175177
understand this class
@@ -314,6 +316,8 @@ def build(
314316
self.model_config.dtype) # type: ignore
315317
self.sin_cache = self.sin_cache.to( # type: ignore
316318
self.model_config.dtype) # type: ignore
319+
cos = common_attn_metadata.cos
320+
sin = common_attn_metadata.sin
317321

318322
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
319323
query_lens = query_seq_lens_cpu[:num_reqs]
@@ -395,9 +399,12 @@ def build(
395399
block_table = block_table[:num_decodes, ...]
396400
seq_lens_list = seq_lens.tolist()
397401

398-
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
402+
assert self.cos_cache is not None
403+
assert self.sin_cache is not None
404+
405+
cos[:num_decodes, ...] = self.cos_cache[input_positions].unsqueeze(
399406
1).unsqueeze(2)
400-
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
407+
sin[:num_decodes, ...] = self.sin_cache[input_positions].unsqueeze(
401408
1).unsqueeze(2)
402409

403410
decode_metadata = AscendMLADecodeMetadata(
@@ -408,8 +415,8 @@ def build(
408415
max_seq_lens=max_seq_lens,
409416
attn_mask=common_attn_metadata.spec_attn_mask,
410417
actual_seq_lengths_q=actual_seq_lengths_q,
411-
sin=sin,
412-
cos=cos)
418+
sin=sin[:num_decodes, ...],
419+
cos=cos[:num_decodes, ...])
413420

414421
return self.metadata_cls( # type: ignore
415422
num_actual_tokens=num_actual_tokens,
@@ -429,6 +436,26 @@ def build(
429436
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
430437
)
431438

439+
def build_for_graph_capture(
440+
self,
441+
common_attn_metadata: AscendCommonAttentionMetadata,
442+
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
443+
model: Optional[nn.Module] = None,
444+
):
445+
if attn_state == AscendAttentionState.DecodeOnly:
446+
attn_metadata = self.build(
447+
common_prefix_len=0,
448+
common_attn_metadata=common_attn_metadata,
449+
model=model,
450+
)
451+
else:
452+
raise NotImplementedError(
453+
"Currently we only support building dummy metadata for DecodeOnly state"
454+
)
455+
456+
attn_metadata.attn_state = attn_state
457+
return attn_metadata
458+
432459

433460
class DecodeMLAPreprocessResult(NamedTuple):
434461
ql_nope: Optional[torch.Tensor] = None
@@ -832,24 +859,69 @@ def _forward_decode(
832859
sparse_mode = 0
833860
spec_attn_mask = None
834861

835-
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
836-
q_nope,
837-
k_nope,
838-
k_nope,
839-
query_rope=q_pe,
840-
key_rope=k_pe,
841-
num_heads=self.num_heads,
842-
num_key_value_heads=self.num_kv_heads,
843-
input_layout=input_layout,
844-
atten_mask=spec_attn_mask,
845-
sparse_mode=sparse_mode,
846-
scale=self.scale,
847-
antiquant_mode=0,
848-
antiquant_scale=None,
849-
block_table=decode_meta.block_table,
850-
block_size=block_size,
851-
actual_seq_lengths_kv=decode_meta.seq_lens_list,
852-
actual_seq_lengths=actual_seq_lengths)
862+
common_kwargs = {
863+
'query_rope': q_pe,
864+
'key_rope': k_pe,
865+
'num_heads': self.num_heads,
866+
'num_key_value_heads': self.num_kv_heads,
867+
'input_layout': input_layout,
868+
'atten_mask': spec_attn_mask,
869+
'sparse_mode': sparse_mode,
870+
'scale': self.scale,
871+
'antiquant_mode': 0,
872+
'antiquant_scale': None,
873+
'block_table': decode_meta.block_table,
874+
'block_size': block_size,
875+
"actual_seq_lengths": actual_seq_lengths,
876+
"actual_seq_lengths_kv": decode_meta.seq_lens_list,
877+
}
878+
graph_params = get_graph_params()
879+
forward_context: ForwardContext = get_forward_context()
880+
if forward_context.capturing:
881+
stream = torch_npu.npu.current_stream()
882+
883+
event = torch.npu.ExternalEvent()
884+
event.wait(stream)
885+
event.reset(stream)
886+
graph_params.events[num_tokens].append(event)
887+
888+
workspace = graph_params.workspaces.get(num_tokens)
889+
if workspace is None:
890+
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
891+
q_nope,
892+
k_nope,
893+
k_nope,
894+
**common_kwargs)
895+
graph_params.workspaces[num_tokens] = workspace
896+
897+
attn_output = torch.empty_like(q_nope)
898+
softmax_lse = torch.empty(num_tokens,
899+
dtype=q_nope.dtype,
900+
device=q_nope.device)
901+
902+
graph_params.attn_params[num_tokens].append(
903+
(q_nope, k_nope, q_pe, k_pe, self.num_heads, self.num_kv_heads,
904+
input_layout, spec_attn_mask, sparse_mode, self.scale,
905+
decode_meta.block_table, block_size,
906+
decode_meta.seq_lens_list, actual_seq_lengths, workspace,
907+
attn_output, softmax_lse))
908+
909+
torch.npu.graph_task_group_begin(stream)
910+
torch_npu.npu_fused_infer_attention_score.out(
911+
q_nope,
912+
k_nope,
913+
k_nope,
914+
**common_kwargs,
915+
workspace=workspace,
916+
out=[attn_output, softmax_lse])
917+
handle = torch.npu.graph_task_group_end(stream)
918+
graph_params.handles[num_tokens].append(handle)
919+
else:
920+
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
921+
q_nope,
922+
k_nope,
923+
k_nope,
924+
**common_kwargs)
853925

854926
current_ms_metadata = get_multistream_comm_context()
855927
if current_ms_metadata is None:

vllm_ascend/attention/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ class AscendCommonAttentionMetadata:
6363

6464
graph_pad_size: int = -1
6565

66+
# NOTE: This is a temporary solution for rotary embedding in MLA
67+
cos: torch.Tensor = None
68+
sin: torch.Tensor = None
69+
6670

6771
def split_decodes_and_prefills(
6872
common_attn_metadata: AscendCommonAttentionMetadata,

vllm_ascend/compilation/acl_graph.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,52 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
229229
event.record(update_stream)
230230

231231

232+
def update_mla_attn_params(update_stream, forward_context, runtime_shape):
233+
graph_params = get_graph_params()
234+
# FIXME: Behold! We are using a temporary hack here to update the args
235+
# for each layer's attention op in the graph.
236+
for key, param, handle, event in zip(
237+
forward_context.attn_metadata,
238+
graph_params.attn_params[runtime_shape],
239+
graph_params.handles[runtime_shape],
240+
graph_params.events[runtime_shape],
241+
):
242+
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
243+
spec_attn_mask, sparse_mode, scale, block_table, block_size,
244+
seq_lens_list, actual_seq_lengths, workspace, attn_output,
245+
softmax_lse) = param
246+
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
247+
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
248+
len(seq_lens_list))
249+
250+
with torch.npu.stream(update_stream):
251+
torch.npu.graph_task_update_begin(update_stream, handle)
252+
253+
torch_npu.npu_fused_infer_attention_score.out(
254+
q_nope,
255+
k_nope,
256+
k_nope,
257+
query_rope=q_pe,
258+
key_rope=k_pe,
259+
num_heads=num_heads,
260+
num_key_value_heads=num_kv_heads,
261+
input_layout=input_layout,
262+
atten_mask=spec_attn_mask,
263+
sparse_mode=sparse_mode,
264+
scale=scale,
265+
antiquant_mode=0,
266+
antiquant_scale=None,
267+
block_table=block_table,
268+
block_size=block_size,
269+
actual_seq_lengths_kv=seq_lens_list,
270+
actual_seq_lengths=actual_seq_lengths,
271+
workspace=workspace,
272+
out=[attn_output, softmax_lse])
273+
torch.npu.graph_task_update_end(update_stream)
274+
275+
event.record(update_stream)
276+
277+
232278
@dataclass
233279
class GraphParams:
234280
events: dict[int, list[torch.npu.ExternalEvent]]

vllm_ascend/platform.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
211211

212212
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
213213
compilation_config.level = CompilationLevel.NO_COMPILATION
214-
# TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition
215-
# after MLA being supported
216-
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or (
217-
compilation_config.cudagraph_mode
218-
== CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None
219-
and model_config.use_mla):
214+
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
220215
logger.info(
221216
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
222217
"using only ACL Graph mode")

vllm_ascend/worker/model_runner_v1.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@
101101
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
102102
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
103103
set_graph_params,
104-
update_attn_params)
104+
update_attn_params,
105+
update_mla_attn_params)
105106
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
106107
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
107108
D2DExpertWeightLoader
@@ -358,6 +359,20 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
358359
self.slot_mapping = torch.zeros(self.max_num_tokens,
359360
dtype=torch.int32,
360361
device=self.device)
362+
# FIXME: Do not hard code 64 here! And also find a better way to
363+
# fix the MLA RoPE issue.
364+
self.cos = torch.ones(self.max_num_reqs,
365+
1,
366+
1,
367+
64,
368+
dtype=self.dtype,
369+
device=self.device)
370+
self.sin = torch.zeros(self.max_num_reqs,
371+
1,
372+
1,
373+
64,
374+
dtype=self.dtype,
375+
device=self.device)
361376

362377
self.uses_mrope = self.model_config.uses_mrope
363378
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
@@ -1508,6 +1523,8 @@ def _prepare_inputs(
15081523
max_query_len=max_num_scheduled_tokens,
15091524
graph_pad_size=self.graph_pad_size,
15101525
decode_token_per_req=self.decode_token_per_req,
1526+
cos=self.cos,
1527+
sin=self.sin,
15111528
)
15121529

15131530
if self.speculative_config and \
@@ -1537,7 +1554,7 @@ def _prepare_inputs(
15371554
attn_metadata_i = builder.build(
15381555
common_prefix_len=common_prefix_len,
15391556
common_attn_metadata=common_attn_metadata,
1540-
model=self.model,
1557+
model=self.get_model(),
15411558
**extra_attn_metadata_args)
15421559

15431560
if self.vllm_config.model_config.use_mla:
@@ -1572,8 +1589,13 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
15721589

15731590
forward_context = get_forward_context()
15741591
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
1575-
update_attn_params(self.update_stream, forward_context,
1576-
positions.shape[0])
1592+
if self.vllm_config.model_config.use_mla:
1593+
# FIXME: Try using `auto_dispatch_capture=True`
1594+
update_mla_attn_params(self.update_stream, forward_context,
1595+
positions.shape[0])
1596+
else:
1597+
update_attn_params(self.update_stream, forward_context,
1598+
positions.shape[0])
15771599

15781600
if get_forward_context().flashcomm_v1_enabled:
15791601
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
@@ -2274,8 +2296,14 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs,
22742296
block_table_tensor=block_table_tensor[:num_reqs],
22752297
slot_mapping=self.slot_mapping,
22762298
num_computed_tokens_cpu=num_computed_tokens_cpu,
2299+
positions=self.positions,
2300+
attn_mask=self.attn_mask,
2301+
spec_attn_mask=self.spec_attn_mask,
2302+
attn_state=self.attn_state,
22772303
max_query_len=max_query_len,
22782304
decode_token_per_req=self.decode_token_per_req,
2305+
cos=self.cos,
2306+
sin=self.sin,
22792307
)
22802308

22812309
for attn_group in self.attn_groups[kv_cache_group_id]:
@@ -2284,7 +2312,8 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs,
22842312
else:
22852313
builder = attn_group.get_metadata_builder()
22862314
attn_metadata_i = builder.build_for_graph_capture(
2287-
common_attn_metadata)
2315+
common_attn_metadata, AscendAttentionState.DecodeOnly,
2316+
self.get_model())
22882317
for layer_name in kv_cache_group_spec.layer_names:
22892318
attn_metadata[layer_name] = attn_metadata_i
22902319

0 commit comments

Comments
 (0)