@@ -75,11 +75,8 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
75
75
# Maximum query length in the batch.
76
76
max_query_len : Optional [int ]
77
77
78
- # Number of query tokens for each request in the batch.
79
- # Currently, we require that all requests have the same number of query
80
- # tokens during the decoding phase. When speculavie decoding is enabled,
81
- # decode_query_len might be greater than 1. In all other cases, it is 1.
82
- decode_query_len : Optional [int ]
78
+ # Max number of query tokens among request in the batch.
79
+ max_decode_query_len : Optional [int ]
83
80
84
81
# Maximum sequence length among prefill batch. 0 if there are decoding
85
82
# requests only.
@@ -140,7 +137,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
140
137
slot_mapping = slot_mapping ,
141
138
seq_lens = self .seq_lens [:self .num_prefills ],
142
139
seq_lens_tensor = self .seq_lens_tensor [:self .num_prefills ],
143
- decode_query_len = 0 ,
140
+ max_decode_query_len = 0 ,
144
141
max_query_len = self .max_query_len ,
145
142
max_prefill_seq_len = self .max_prefill_seq_len ,
146
143
max_decode_seq_len = 0 ,
@@ -172,7 +169,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
172
169
slot_mapping = slot_mapping ,
173
170
seq_lens = None ,
174
171
seq_lens_tensor = self .seq_lens_tensor [self .num_prefills :],
175
- decode_query_len = self .decode_query_len ,
172
+ max_decode_query_len = self .max_decode_query_len ,
176
173
max_query_len = None ,
177
174
max_prefill_seq_len = 0 ,
178
175
max_decode_seq_len = self .max_decode_seq_len ,
@@ -256,9 +253,9 @@ def build(self, seq_lens: List[int], query_lens: List[int],
256
253
max_query_len = max (query_lens )
257
254
decode_query_lens = query_lens [self .num_prefills :]
258
255
if len (decode_query_lens ) > 0 :
259
- decode_query_len = max (decode_query_lens )
256
+ max_decode_query_len = max (decode_query_lens )
260
257
else :
261
- decode_query_len = 1
258
+ max_decode_query_len = 1
262
259
max_prefill_seq_len = max (self .prefill_seq_lens , default = 0 )
263
260
max_decode_seq_len = max (self .curr_seq_lens , default = 0 )
264
261
num_decode_tokens = self .num_decode_tokens
@@ -304,7 +301,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
304
301
seq_lens = seq_lens ,
305
302
seq_lens_tensor = seq_lens_tensor ,
306
303
max_query_len = max_query_len ,
307
- decode_query_len = decode_query_len ,
304
+ max_decode_query_len = max_decode_query_len ,
308
305
max_prefill_seq_len = max_prefill_seq_len ,
309
306
max_decode_seq_len = max_decode_seq_len ,
310
307
query_start_loc = query_start_loc ,
0 commit comments