Skip to content

Commit 408a8ea

Browse files
author
lwq
committed
fix rebase bugs
Signed-off-by: lwq <liwenquan5@huawei.com>
1 parent 8a4479a commit 408a8ea

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import torch
66
import torch_npu
7+
from torch import nn
78
from vllm.attention.backends.abstract import (AttentionBackend,
89
AttentionMetadata,
910
MLAAttentionImpl)
@@ -21,6 +22,7 @@
2122
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2223
from vllm_ascend.multistream.context import get_multistream_comm_context
2324
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
25+
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, npu_stream_switch, npu_wait_tensor)
2426
from vllm_ascend.utils import npu_prefetch
2527
from vllm_ascend.worker.npu_input_batch import InputBatch
2628

@@ -211,6 +213,7 @@ def __init__(self,
211213
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
212214
self.cos_cache = None
213215
self.sin_cache = None
216+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
214217

215218
def reorder_batch(self, input_batch: "InputBatch",
216219
scheduler_output: "SchedulerOutput") -> bool:
@@ -230,10 +233,8 @@ def reorder_batch(self, input_batch: "InputBatch",
230233
# We treat spec decoding as decode.
231234
if num_tokens - num_spec_tokens == 1:
232235
decodes.append(i)
233-
num_decode_tokens += num_tokens
234236
else:
235237
prefills.append(i)
236-
num_prefill_tokens += num_tokens
237238

238239
# We hope that this is fairly minimal since decodes
239240
# should be around for a number of iterations so hopefully they are

0 commit comments

Comments
 (0)