Skip to content

Commit 5f63a39

Browse files
committed
fxi: the unit tests are updated by mocking the common metadata object instead of runner attributes. Additionally, the forward pass test now uses a forward context to manage graph capturing state.
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 2305243 commit 5f63a39

File tree

3 files changed

+43
-19
lines changed

3 files changed

+43
-19
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest.mock import MagicMock, patch
22

33
import torch
4+
from vllm.forward_context import get_forward_context, set_forward_context
45

56
from tests.ut.base import TestBase
67
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
@@ -95,12 +96,15 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
9596
0].get_device_tensor.return_value = torch.zeros((10, 10))
9697
self.mock_runner.max_num_blocks_per_req = 10
9798
self.mock_runner.query_lens = torch.tensor([3, 4])
98-
self.mock_runner.seq_lens_cpu = torch.tensor([5, 6])
9999
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
100100
self.mock_runner.device = 'cpu:0'
101101
self.mock_runner.attn_mask = torch.ones((10, 10))
102102
self.mock_runner.attn_state = AscendAttentionState.PrefillNoCache
103-
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 3, 7])
103+
104+
mock_common_attn_metadata = MagicMock()
105+
mock_common_attn_metadata.seq_lens = torch.tensor([5, 6])
106+
mock_common_attn_metadata.query_start_loc = torch.tensor([0, 3, 7])
107+
mock_common_attn_metadata.seq_lens_list = [5, 6]
104108

105109
mock_nz_tensor = MagicMock()
106110
mock_nd_to_nz_2d.return_value = mock_nz_tensor
@@ -110,6 +114,7 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
110114
num_reqs,
111115
num_actual_tokens,
112116
max_query_len,
117+
mock_common_attn_metadata,
113118
)
114119

115120
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@@ -129,12 +134,15 @@ def test_build_chunked_prefill(self, mock_ascend_attention_state,
129134
0].get_device_tensor.return_value = torch.zeros((10, 10))
130135
self.mock_runner.max_num_blocks_per_req = 10
131136
self.mock_runner.query_lens = torch.tensor([2, 3, 4])
132-
self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6])
133137
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
134138
self.mock_runner.device = 'cpu:0'
135139
self.mock_runner.attn_mask = torch.ones((15, 15))
136140
self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill
137-
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9])
141+
142+
mock_common_attn_metadata = MagicMock()
143+
mock_common_attn_metadata.seq_lens = torch.tensor([4, 5, 6])
144+
mock_common_attn_metadata.query_start_loc = torch.tensor([0, 2, 5, 9])
145+
mock_common_attn_metadata.seq_lens_list = [4, 5, 6]
138146

139147
mock_ascend_attention_state = MagicMock()
140148
mock_ascend_attention_state.PrefillNoCache = 0
@@ -143,7 +151,8 @@ def test_build_chunked_prefill(self, mock_ascend_attention_state,
143151
mock_nd_to_nz_spec.return_value = mock_nz_tensor
144152
mock_npu_format_cast.return_value = mock_nz_tensor
145153

146-
self.builder.build(num_reqs, num_actual_tokens, max_query_len)
154+
self.builder.build(num_reqs, num_actual_tokens, max_query_len,
155+
mock_common_attn_metadata)
147156

148157
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
149158
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
@@ -157,14 +166,18 @@ def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
157166
0].get_device_tensor.return_value = torch.zeros((10, 10))
158167
self.mock_runner.max_num_blocks_per_req = 10
159168
self.mock_runner.query_lens = torch.tensor([2, 3, 4])
160-
self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6])
161169
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
162170
self.mock_runner.device = 'cpu:0'
163171
self.mock_runner.attn_mask = torch.ones((15, 15))
164172
self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill
165-
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9])
166173

167-
self.builder.build(num_reqs, num_actual_tokens, max_query_len)
174+
mock_common_attn_metadata = MagicMock()
175+
mock_common_attn_metadata.seq_lens = torch.tensor([4, 5, 6])
176+
mock_common_attn_metadata.query_start_loc = torch.tensor([0, 2, 5, 9])
177+
mock_common_attn_metadata.seq_lens_list = [4, 5, 6]
178+
179+
self.builder.build(num_reqs, num_actual_tokens, max_query_len,
180+
mock_common_attn_metadata)
168181

169182

170183
class TestAscendAttentionBackendImpl(TestBase):
@@ -372,13 +385,19 @@ def test_forward_decode_only(self, mock_paged_attention,
372385
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
373386
layer = self.layer_no_quant
374387

375-
output = self.impl.forward(layer,
376-
query,
377-
key,
378-
value,
379-
kv_cache,
380-
metadata,
381-
trace_flag=False)
388+
vllm_config = MagicMock()
389+
vllm_config.parallel_config.data_parallel_size = 1
390+
with set_forward_context(attn_metadata=metadata,
391+
vllm_config=vllm_config):
392+
forward_context = get_forward_context()
393+
forward_context.capturing = False
394+
output = self.impl.forward(layer,
395+
query,
396+
key,
397+
value,
398+
kv_cache,
399+
metadata,
400+
trace_flag=False)
382401

383402
mock_paged_attention.assert_called_once()
384403
assert output.shape == (10, 8 * 64)

vllm_ascend/attention/attention_v1.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,10 @@ def reorder_batch(self, input_batch: "InputBatch",
167167
scheduler_output: "SchedulerOutput") -> bool:
168168
return False
169169

170-
def build(self, num_reqs, num_actual_tokens, max_query_len,
170+
def build(self,
171+
num_reqs,
172+
num_actual_tokens,
173+
max_query_len,
171174
common_attn_metadata: CommonAttentionMetadata,
172175
enable_dbo_across_dp: bool = False,
173176
is_only_prefill: bool = False):
@@ -179,9 +182,9 @@ def build(self, num_reqs, num_actual_tokens, max_query_len,
179182

180183
query_start_loc = common_attn_metadata.query_start_loc
181184
seq_lens = common_attn_metadata.seq_lens
182-
# TODO: Refactor these two param to common metadata in runners,
185+
# TODO: Refactor this param to common metadata in runners,
183186
# preparing for the hybrid KV groups feature
184-
query_lens = common_attn_metadata.query_lens or self.runner.query_lens
187+
query_lens = self.runner.query_lens
185188
# Since FIA for GQA is not active now, we temporarily silence it
186189
seq_lens_list = common_attn_metadata.seq_lens_list
187190

@@ -407,6 +410,8 @@ def forward(
407410
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
408411
graph_params = get_graph_params()
409412

413+
# TODO(Yizhou): Find another way to handle the
414+
# graph capturing mode. Say, GraphCaptureContext?
410415
forward_context = get_forward_context()
411416
if not forward_context.capturing:
412417
if is_310p():

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@
7979
AscendMetadata)
8080
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
8181
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
82-
from vllm_ascend.multistream.ms_split import compute_split_seq_index
8382
from vllm_ascend.attention.utils import \
8483
AscendCommonAttentionMetadata as CommonAttentionMetadata
84+
from vllm_ascend.multistream.ms_split import compute_split_seq_index
8585
from vllm_ascend.platform import NPUPlatform
8686
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
8787
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,

0 commit comments

Comments
 (0)