Skip to content

Commit 1a97261

Browse files
committed
fix: fix Ascend attention metadata
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent e1079d9 commit 1a97261

File tree

4 files changed

+9
-17
lines changed

4 files changed

+9
-17
lines changed

tests/e2e/singlecard/test_aclgraph.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,10 @@ def test_models_with_aclgraph(
5353
# while running pytest
5454
if full_graph:
5555
vllm_model = LLM(model,
56-
compilation_config={
57-
"full_cuda_graph": True,
58-
"cudagraph_capture_sizes":
59-
[1, 4, 16, 64, 256]
60-
})
56+
compilation_config={
57+
"full_cuda_graph": True,
58+
"cudagraph_capture_sizes": [1, 4, 16, 64, 256]
59+
})
6160
else:
6261
vllm_model = LLM(model, max_model_len=1024)
6362
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)

vllm_ascend/attention/attention_v1.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@
3232
from vllm_ascend.attention.utils import \
3333
AscendCommonAttentionMetadata as CommonAttentionMetadata
3434
from vllm_ascend.ops.attention import vanilla_chunked_prefill
35-
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, get_graph_params, is_310p,
36-
nd_to_nz_2d, nd_to_nz_spec)
35+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
36+
get_graph_params, is_310p, nd_to_nz_2d,
37+
nd_to_nz_spec)
3738
from vllm_ascend.worker.npu_input_batch import InputBatch
3839

3940

@@ -135,7 +136,7 @@ class AscendMetadata:
135136
# tokens + new tokens (is None if it is a decoding).
136137
# (batch_size,)
137138
seq_lens: torch.Tensor = None
138-
seq_lens_list: Optional[list[int]]
139+
seq_lens_list: Optional[list[int]] = None
139140
query_start_loc: torch.Tensor = None
140141
query_lens: torch.Tensor = None
141142
# Maximum query length in the batch (None for decoding).
@@ -163,10 +164,7 @@ def reorder_batch(self, input_batch: "InputBatch",
163164
scheduler_output: "SchedulerOutput") -> bool:
164165
return False
165166

166-
def build(self,
167-
num_reqs,
168-
num_actual_tokens,
169-
max_query_len,
167+
def build(self, num_reqs, num_actual_tokens, max_query_len,
170168
common_attn_metadata: CommonAttentionMetadata):
171169

172170
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
@@ -225,7 +223,6 @@ def build_dummy_metadata(self, num_actual_tokens, num_reqs,
225223
num_reqs=num_reqs,
226224
num_actual_tokens=num_actual_tokens,
227225
max_query_len=num_scheduled_tokens.max(),
228-
common_prefix_len=0,
229226
common_attn_metadata=common_attn_metadata,
230227
)
231228
else:

vllm_ascend/attention/mla_v1.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from vllm_ascend import envs
1818
from vllm_ascend.ascend_config import get_ascend_config
1919
from vllm_ascend.attention.attention_v1 import AscendAttentionState
20-
from vllm_ascend.attention.utils import \
21-
AscendCommonAttentionMetadata as CommonAttentionMetadata
2220
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2321
from vllm_ascend.multistream.context import get_multistream_comm_context
2422
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn

vllm_ascend/worker/model_runner_v1.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
270270
self.query_lens = torch.zeros(self.max_num_reqs,
271271
dtype=torch.int32,
272272
device=self.device)
273-
# None in the first PP rank. The rest are set after load_model.
274-
self.intermediate_tensors: Optional[IntermediateTensors] = None
275273

276274
self.uses_mrope = self.model_config.uses_mrope
277275
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)

0 commit comments

Comments
 (0)