Skip to content

Commit bf8f5bb

Browse files
committed
[bugfix] fix torchair and mtp functionality
Co-authored-by: hust17yixuan <303660421@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent cd2b4b8 commit bf8f5bb

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import types
2+
from typing import Dict
23

34
import torch
45
import torch.nn as nn
@@ -186,6 +187,8 @@ def generate_token_ids(self,
186187
hidden_states: torch.Tensor = None,
187188
attn_metadata=None,
188189
aux_hidden_states: torch.Tensor = None):
190+
if attn_metadata is not None and isinstance(attn_metadata, Dict):
191+
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
189192
next_token_ids: list[int] = []
190193
for i, token_ids in enumerate(valid_sampled_token_ids):
191194
if token_ids:
@@ -382,7 +385,7 @@ def _propose(
382385
num_computed_tokens_cpu=None,
383386
seq_lens=None)
384387
attn_metadata = self.runner.attn_metadata_builder.build(
385-
common_attn_metadata, self.runner.get_model())
388+
0, common_attn_metadata, self.runner.get_model())
386389

387390
self.positions[:num_tokens] = target_positions
388391
self.hidden_states[:num_tokens] = target_hidden_states

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import math
2121
import types
22-
from typing import Optional
22+
from typing import Optional, Dict
2323

2424
import torch
2525
import torch.distributed as dist
@@ -51,7 +51,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
5151
def __init__(self, vllm_config: VllmConfig, device: torch.device):
5252
super().__init__(vllm_config, device)
5353
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
54-
vllm_config, device)
54+
None, None, vllm_config, device)
5555

5656
ascend_config = get_ascend_config()
5757
self.new_kv_cache_bytes = -1
@@ -135,7 +135,8 @@ def _generate_dummy_run_hidden_states(self, with_prefill,
135135
is_torchair_compile, input_ids,
136136
positions, attn_metadata, num_tokens,
137137
intermediate_tensors, inputs_embeds):
138-
138+
if attn_metadata is not None and isinstance(attn_metadata, Dict):
139+
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
139140
if not with_prefill:
140141
# Only mark static while compiling
141142
if is_torchair_compile:
@@ -281,6 +282,8 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
281282
input_ids, positions,
282283
intermediate_tensors,
283284
inputs_embeds):
285+
if attn_metadata is not None and isinstance(attn_metadata, Dict):
286+
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
284287
model_kwargs = {
285288
"kv_caches": self.kv_caches,
286289
"attn_metadata": attn_metadata

0 commit comments

Comments
 (0)