1
1
from unittest .mock import MagicMock , patch
2
2
3
3
import torch
4
+ from vllm .forward_context import get_forward_context , set_forward_context
4
5
5
6
from tests .ut .base import TestBase
6
7
from vllm_ascend .attention .attention_v1 import (AscendAttentionBackend ,
@@ -95,12 +96,15 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
95
96
0 ].get_device_tensor .return_value = torch .zeros ((10 , 10 ))
96
97
self .mock_runner .max_num_blocks_per_req = 10
97
98
self .mock_runner .query_lens = torch .tensor ([3 , 4 ])
98
- self .mock_runner .seq_lens_cpu = torch .tensor ([5 , 6 ])
99
99
self .mock_runner .slot_mapping_cpu = torch .tensor (range (20 ))
100
100
self .mock_runner .device = 'cpu:0'
101
101
self .mock_runner .attn_mask = torch .ones ((10 , 10 ))
102
102
self .mock_runner .attn_state = AscendAttentionState .PrefillNoCache
103
- self .mock_runner .query_start_loc_cpu = torch .tensor ([0 , 3 , 7 ])
103
+
104
+ mock_common_attn_metadata = MagicMock ()
105
+ mock_common_attn_metadata .seq_lens = torch .tensor ([5 , 6 ])
106
+ mock_common_attn_metadata .query_start_loc = torch .tensor ([0 , 3 , 7 ])
107
+ mock_common_attn_metadata .seq_lens_list = [5 , 6 ]
104
108
105
109
mock_nz_tensor = MagicMock ()
106
110
mock_nd_to_nz_2d .return_value = mock_nz_tensor
@@ -110,6 +114,7 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
110
114
num_reqs ,
111
115
num_actual_tokens ,
112
116
max_query_len ,
117
+ mock_common_attn_metadata ,
113
118
)
114
119
115
120
@patch ('vllm_ascend.attention.attention_v1.AscendMetadata' )
@@ -129,12 +134,15 @@ def test_build_chunked_prefill(self, mock_ascend_attention_state,
129
134
0 ].get_device_tensor .return_value = torch .zeros ((10 , 10 ))
130
135
self .mock_runner .max_num_blocks_per_req = 10
131
136
self .mock_runner .query_lens = torch .tensor ([2 , 3 , 4 ])
132
- self .mock_runner .seq_lens_cpu = torch .tensor ([4 , 5 , 6 ])
133
137
self .mock_runner .slot_mapping_cpu = torch .tensor (range (20 ))
134
138
self .mock_runner .device = 'cpu:0'
135
139
self .mock_runner .attn_mask = torch .ones ((15 , 15 ))
136
140
self .mock_runner .attn_state = AscendAttentionState .ChunkedPrefill
137
- self .mock_runner .query_start_loc_cpu = torch .tensor ([0 , 2 , 5 , 9 ])
141
+
142
+ mock_common_attn_metadata = MagicMock ()
143
+ mock_common_attn_metadata .seq_lens = torch .tensor ([4 , 5 , 6 ])
144
+ mock_common_attn_metadata .query_start_loc = torch .tensor ([0 , 2 , 5 , 9 ])
145
+ mock_common_attn_metadata .seq_lens_list = [4 , 5 , 6 ]
138
146
139
147
mock_ascend_attention_state = MagicMock ()
140
148
mock_ascend_attention_state .PrefillNoCache = 0
@@ -143,7 +151,8 @@ def test_build_chunked_prefill(self, mock_ascend_attention_state,
143
151
mock_nd_to_nz_spec .return_value = mock_nz_tensor
144
152
mock_npu_format_cast .return_value = mock_nz_tensor
145
153
146
- self .builder .build (num_reqs , num_actual_tokens , max_query_len )
154
+ self .builder .build (num_reqs , num_actual_tokens , max_query_len ,
155
+ mock_common_attn_metadata )
147
156
148
157
@patch ('vllm_ascend.attention.attention_v1.AscendMetadata' )
149
158
@patch ('vllm_ascend.attention.attention_v1.is_310p' , return_value = False )
@@ -157,14 +166,18 @@ def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
157
166
0 ].get_device_tensor .return_value = torch .zeros ((10 , 10 ))
158
167
self .mock_runner .max_num_blocks_per_req = 10
159
168
self .mock_runner .query_lens = torch .tensor ([2 , 3 , 4 ])
160
- self .mock_runner .seq_lens_cpu = torch .tensor ([4 , 5 , 6 ])
161
169
self .mock_runner .slot_mapping_cpu = torch .tensor (range (20 ))
162
170
self .mock_runner .device = 'cpu:0'
163
171
self .mock_runner .attn_mask = torch .ones ((15 , 15 ))
164
172
self .mock_runner .attn_state = AscendAttentionState .ChunkedPrefill
165
- self .mock_runner .query_start_loc_cpu = torch .tensor ([0 , 2 , 5 , 9 ])
166
173
167
- self .builder .build (num_reqs , num_actual_tokens , max_query_len )
174
+ mock_common_attn_metadata = MagicMock ()
175
+ mock_common_attn_metadata .seq_lens = torch .tensor ([4 , 5 , 6 ])
176
+ mock_common_attn_metadata .query_start_loc = torch .tensor ([0 , 2 , 5 , 9 ])
177
+ mock_common_attn_metadata .seq_lens_list = [4 , 5 , 6 ]
178
+
179
+ self .builder .build (num_reqs , num_actual_tokens , max_query_len ,
180
+ mock_common_attn_metadata )
168
181
169
182
170
183
class TestAscendAttentionBackendImpl (TestBase ):
@@ -372,13 +385,19 @@ def test_forward_decode_only(self, mock_paged_attention,
372
385
metadata .slot_mapping = torch .zeros (10 , dtype = torch .long )
373
386
layer = self .layer_no_quant
374
387
375
- output = self .impl .forward (layer ,
376
- query ,
377
- key ,
378
- value ,
379
- kv_cache ,
380
- metadata ,
381
- trace_flag = False )
388
+ vllm_config = MagicMock ()
389
+ vllm_config .parallel_config .data_parallel_size = 1
390
+ with set_forward_context (attn_metadata = metadata ,
391
+ vllm_config = vllm_config ):
392
+ forward_context = get_forward_context ()
393
+ forward_context .capturing = False
394
+ output = self .impl .forward (layer ,
395
+ query ,
396
+ key ,
397
+ value ,
398
+ kv_cache ,
399
+ metadata ,
400
+ trace_flag = False )
382
401
383
402
mock_paged_attention .assert_called_once ()
384
403
assert output .shape == (10 , 8 * 64 )
0 commit comments