Skip to content

Commit 89feb4c

Browse files
[SpecDec] Remove Batch Expansion (2/3) (#9298)
1 parent ec10cb8 commit 89feb4c

File tree

8 files changed

+122
-70
lines changed

8 files changed

+122
-70
lines changed

tests/spec_decode/test_scorer.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import random
2+
from typing import List
3+
14
import pytest
25
import torch
36

@@ -10,31 +13,45 @@
1013
from .utils import create_batch, create_worker
1114

1215

13-
def create_proposal(batch_size: int, propose_len: int, vocab_size: int,
16+
def create_proposal(propose_lens: List[int], vocab_size: int,
1417
device: str) -> SpeculativeProposals:
15-
proposal_probs = torch.rand((batch_size, propose_len, vocab_size),
18+
batch_size = len(propose_lens)
19+
max_propose_len = max(propose_lens)
20+
proposal_probs = torch.rand((batch_size, max_propose_len, vocab_size),
1621
device=device)
17-
proposal_token_ids = torch.argmax(proposal_probs, dim=-1)
18-
proposal_lens = torch.tensor([propose_len] * batch_size, device=device)
22+
23+
proposal_token_ids = torch.full((batch_size, max_propose_len),
24+
fill_value=-1,
25+
device=device)
26+
for i in range(batch_size):
27+
proposal_token_ids[i][:propose_lens[i]] = torch.argmax(
28+
proposal_probs[i][:propose_lens[i]], dim=-1)
29+
30+
propose_lens = torch.tensor(propose_lens, device=device)
1931
return SpeculativeProposals(proposal_token_ids, proposal_probs,
20-
proposal_lens)
32+
propose_lens)
2133

2234

2335
def assert_score_equal(score1: SpeculativeScores,
2436
score2: SpeculativeScores) -> None:
2537
assert torch.allclose(score1.probs, score2.probs)
2638
assert torch.allclose(score1.logprobs, score2.logprobs)
27-
assert torch.equal(score1.token_ids, score2.token_ids)
39+
assert torch.equal(
40+
score1.token_ids,
41+
score2.token_ids), f"{score1.token_ids}, {score2.token_ids}"
2842

2943

3044
@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
3145
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
32-
@pytest.mark.parametrize('propose_len', [1, 3, 5])
46+
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
47+
@pytest.mark.parametrize('mixed_propose_len', [True])
3348
@pytest.mark.parametrize('device', ['cuda'])
34-
def test_scoroer(model_name: str, batch_size: int, propose_len: int,
35-
device: str) -> None:
49+
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
50+
mixed_propose_len: bool, device: str) -> None:
3651
"""
37-
Compare the batch expansion scorer and mqa scorer return the same score
52+
Compare the batch expansion scorer and mqa scorer return the same score.
53+
We test for both queries with the same propose length and different
54+
propose length.
3855
"""
3956
seed = 0
4057
block_size = 32
@@ -46,13 +63,22 @@ def test_scoroer(model_name: str, batch_size: int, propose_len: int,
4663
should_modify_greedy_probs_inplace = True
4764

4865
vocab_size = scorer_worker.vocab_size
49-
proposals = create_proposal(batch_size, propose_len, vocab_size, device)
66+
67+
if not mixed_propose_len:
68+
propose_lens = [max_propose_len] * batch_size
69+
else:
70+
non_zero_cnt = random.randint(0, batch_size)
71+
propose_lens = [max_propose_len
72+
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
73+
random.shuffle(propose_lens)
74+
75+
proposals = create_proposal(propose_lens, vocab_size, device)
5076
seq_group_metadatalist, _, _ = create_batch(batch_size,
51-
propose_len,
77+
max_propose_len,
5278
block_size=block_size,
5379
num_gpu_blocks=num_gpu_blocks)
5480
requests = ExecuteModelRequest(seq_group_metadatalist,
55-
num_lookahead_slots=propose_len)
81+
num_lookahead_slots=max_propose_len)
5682

5783
batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
5884
vocab_size)

vllm/attention/backends/blocksparse_attn.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
186186
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
187187
use_cuda_graph: bool
188188

189-
# Number of query tokens for each request in the batch.
190-
# Currently, we require that all requests have the same number of query
191-
# tokens during the decoding phase. When speculavie decoding is enabled,
192-
# decode_query_len might be greater than 1. In all other cases, it is 1.
193-
decode_query_len: Optional[int] = None
189+
# Max number of query tokens for among request in the batch.
190+
max_decode_query_len: Optional[int] = None
194191

195192
_cached_prefill_metadata: Optional[
196193
"BlocksparseFlashAttentionMetadata"] = None

vllm/attention/backends/flash_attn.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,8 @@ class FlashAttentionMetadata(AttentionMetadata):
111111
# Maximum query length in the batch.
112112
max_query_len: Optional[int]
113113

114-
# Number of query tokens for each request in the batch.
115-
# Currently, we require that all requests have the same number of query
116-
# tokens during the decoding phase. When speculavie decoding is enabled,
117-
# decode_query_len might be greater than 1. In all other cases, it is 1.
118-
decode_query_len: Optional[int]
114+
# Max number of query tokens among request in the batch.
115+
max_decode_query_len: Optional[int]
119116

120117
# Maximum sequence length among prefill batch. 0 if there are decoding
121118
# requests only.
@@ -173,9 +170,9 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
173170
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
174171
seq_lens=self.seq_lens[:self.num_prefills],
175172
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
176-
decode_query_len=0,
177173
max_query_len=self.max_query_len,
178174
max_prefill_seq_len=self.max_prefill_seq_len,
175+
max_decode_query_len=0,
179176
max_decode_seq_len=0,
180177
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
181178
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
@@ -202,12 +199,14 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
202199
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
203200
seq_lens=None,
204201
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
205-
decode_query_len=self.decode_query_len,
202+
max_decode_query_len=self.max_decode_query_len,
206203
max_query_len=self.max_query_len,
207204
max_prefill_seq_len=0,
208205
max_decode_seq_len=self.max_decode_seq_len,
209-
query_start_loc=None,
210-
seq_start_loc=None,
206+
query_start_loc=self.query_start_loc[self.num_prefills:]
207+
if self.query_start_loc is not None else None,
208+
seq_start_loc=self.seq_start_loc[self.num_prefills:]
209+
if self.seq_start_loc is not None else None,
211210
context_lens_tensor=None,
212211
block_tables=self.block_tables[self.num_prefills:],
213212
use_cuda_graph=self.use_cuda_graph,
@@ -413,9 +412,9 @@ def build(self, seq_lens: List[int], query_lens: List[int],
413412
max_query_len = max(query_lens)
414413
decode_query_lens = query_lens[self.num_prefills:]
415414
if len(decode_query_lens) > 0:
416-
decode_query_len = max(decode_query_lens)
415+
max_decode_query_len = max(decode_query_lens)
417416
else:
418-
decode_query_len = 1
417+
max_decode_query_len = 1
419418
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
420419
max_decode_seq_len = max(self.curr_seq_lens, default=0)
421420
num_decode_tokens = self.num_decode_tokens
@@ -468,7 +467,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
468467
seq_lens=seq_lens,
469468
seq_lens_tensor=seq_lens_tensor,
470469
max_query_len=max_query_len,
471-
decode_query_len=decode_query_len,
470+
max_decode_query_len=max_decode_query_len,
472471
max_prefill_seq_len=max_prefill_seq_len,
473472
max_decode_seq_len=max_decode_seq_len,
474473
query_start_loc=query_start_loc,
@@ -714,20 +713,37 @@ def unified_flash_attention(
714713

715714
if decode_meta := attn_metadata.decode_metadata:
716715
# Decoding run.
717-
_, num_head, head_dim = decode_query.shape
718-
decode_query = decode_query.reshape(-1, decode_meta.decode_query_len,
719-
num_head, head_dim)
720-
decode_output = flash_attn_with_kvcache(
721-
q=decode_query,
722-
k_cache=key_cache,
723-
v_cache=value_cache,
724-
block_table=decode_meta.block_tables,
725-
cache_seqlens=decode_meta.seq_lens_tensor,
726-
softmax_scale=softmax_scale,
727-
causal=True,
728-
alibi_slopes=alibi_slopes,
729-
softcap=logits_soft_cap,
730-
).squeeze(1)
716+
# Use flash_attn_varlen_func kernel for speculative decoding
717+
# because different queries might have different lengths.
718+
assert decode_meta.max_decode_query_len is not None
719+
if decode_meta.max_decode_query_len > 1:
720+
decode_output = flash_attn_varlen_func(
721+
q=decode_query,
722+
k=key_cache,
723+
v=value_cache,
724+
cu_seqlens_q=decode_meta.query_start_loc,
725+
max_seqlen_q=decode_meta.max_decode_query_len,
726+
cu_seqlens_k=decode_meta.seq_start_loc,
727+
max_seqlen_k=decode_meta.max_decode_seq_len,
728+
softmax_scale=softmax_scale,
729+
causal=True,
730+
alibi_slopes=alibi_slopes,
731+
softcap=logits_soft_cap,
732+
block_table=decode_meta.block_tables,
733+
)
734+
else:
735+
# Use flash_attn_with_kvcache for normal decoding.
736+
decode_output = flash_attn_with_kvcache(
737+
q=decode_query.unsqueeze(1),
738+
k_cache=key_cache,
739+
v_cache=value_cache,
740+
block_table=decode_meta.block_tables,
741+
cache_seqlens=decode_meta.seq_lens_tensor,
742+
softmax_scale=softmax_scale,
743+
causal=True,
744+
alibi_slopes=alibi_slopes,
745+
softcap=logits_soft_cap,
746+
).squeeze(1)
731747

732748
if prefill_output is None:
733749
assert decode_output is not None
@@ -739,7 +755,6 @@ def unified_flash_attention(
739755
# Chunked prefill does not work with speculative decoding.
740756
# Therefore, the query length for decode should be 1 in chunked prefill.
741757
assert decode_meta is not None
742-
assert decode_meta.decode_query_len == 1
743758
decode_output = decode_output.squeeze(1)
744759
output = torch.cat([prefill_output, decode_output], dim=0)
745760
return output.view(num_tokens, hidden_size)

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
121121
# so far).
122122
context_lens_tensor: Optional[torch.Tensor]
123123

124-
# Number of query tokens for each request in the batch.
125-
# Currently, we require that all requests have the same number of query
126-
# tokens during the decoding phase. When speculavie decoding is enabled,
127-
# decode_query_len might be greater than 1. In all other cases, it is 1.
128-
decode_query_len: Optional[int] = None
124+
# Max number of query tokens among request in the batch.
125+
max_decode_query_len: Optional[int] = None
129126

130127
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
131128
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None

vllm/attention/backends/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def graph_capture_get_metadata_for_batch(
313313
seq_lens=None,
314314
seq_lens_tensor=self._graph_seq_lens[:batch_size],
315315
max_query_len=1,
316-
decode_query_len=1,
316+
max_decode_query_len=1,
317317
max_prefill_seq_len=0,
318318
max_decode_seq_len=self.runner.max_seq_len_to_capture,
319319
query_start_loc=None,

vllm/attention/backends/xformers.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
118118
# Maximum query length in the batch. None for decoding.
119119
max_query_len: Optional[int] = None
120120

121-
# Number of query tokens for each request in the batch.
122-
# Currently, we require that all requests have the same number of query
123-
# tokens during the decoding phase. When speculavie decoding is enabled,
124-
# decode_query_len might be greater than 1. In all other cases, it is 1.
125-
decode_query_len: Optional[int] = None
121+
# Max number of query tokens among request in the batch.
122+
max_decode_query_len: Optional[int] = None
126123

127124
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
128125
# the batch, used to index into subquery. E.g., if the subquery length

vllm/spec_decode/mqa_scorer.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def score_proposals(
1818
target_seq_id_start = max(
1919
get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
2020
all_proposal_tokens = proposals.proposal_token_ids.tolist()
21+
all_proposal_lengths = proposals.proposal_lens.tolist()
2122
for i, seq_group_metadata in enumerate(
2223
execute_model_req.seq_group_metadata_list):
2324
seq_data_dict = seq_group_metadata.seq_data
@@ -27,7 +28,8 @@ def score_proposals(
2728
seq_data: SequenceData = seq_data_dict[seq_id]
2829
prompt_token_ids = seq_data.get_prompt_token_ids()
2930
output_token_ids = seq_data.get_output_token_ids()
30-
proposal_token_ids = all_proposal_tokens[i]
31+
proposal_token_ids = all_proposal_tokens[
32+
i][:all_proposal_lengths[i]]
3133
new_output_token_ids = [*output_token_ids, *proposal_token_ids]
3234

3335
target_seq_id = target_seq_id_start + i
@@ -62,18 +64,42 @@ def score_proposals(
6264

6365
target_sampler_output = target_sampler_output[0]
6466

65-
bs, k = proposals.proposal_token_ids.shape
66-
all_tokens = target_sampler_output.sampled_token_ids.reshape(bs, k + 1)
67-
68-
all_probs = target_sampler_output.sampled_token_probs.reshape(
69-
bs, k + 1, self._vocab_size)
70-
all_logprobs = target_sampler_output.logprobs.reshape(
71-
bs, k + 1, self._vocab_size)
67+
k = execute_model_req.num_lookahead_slots
68+
bs = len(execute_model_req.seq_group_metadata_list)
69+
target_token_ids = target_sampler_output.sampled_token_ids
70+
target_probs = target_sampler_output.sampled_token_probs
71+
target_logprobs = target_sampler_output.logprobs
72+
# If all requests have the same number of query tokens, we can avoid
73+
# the for loop to build output for better performance.
74+
if min(all_proposal_lengths) == k:
75+
bs, _ = proposals.proposal_token_ids.shape
76+
all_tokens = target_token_ids.reshape(bs, k + 1)
77+
all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
78+
all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
79+
else:
80+
all_tokens = target_token_ids.new_full(size=(bs, k + 1),
81+
fill_value=-1)
82+
all_probs = target_probs.new_zeros(*all_tokens.shape,
83+
self._vocab_size)
84+
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
85+
fill_value=-float("inf"))
86+
target_token_ids = target_token_ids.flatten()
87+
start_loc = 0
88+
for i, proposed_len in enumerate(all_proposal_lengths):
89+
output_len = proposed_len + 1
90+
end_loc = start_loc + output_len
91+
all_tokens[
92+
i, :output_len] = target_token_ids[start_loc:end_loc]
93+
all_probs[i, :output_len] = target_probs[start_loc:end_loc]
94+
all_logprobs[
95+
i, :output_len] = target_logprobs[start_loc:end_loc]
96+
start_loc = end_loc
7297

7398
hidden_states = None
7499
if target_sampler_output.hidden_states is not None:
75100
hidden_states = target_sampler_output.hidden_states.reshape(
76101
bs, (k + 1), -1)
102+
77103
return SpeculativeScores(probs=all_probs,
78104
token_ids=all_tokens,
79105
logprobs=all_logprobs,

vllm/spec_decode/spec_decode_worker.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,6 @@ def create_worker(
190190
"[Speculative Decoding] Disabling MQA scorer as the "
191191
"MQA is only available with flash attn backend.")
192192

193-
if ngram_prompt_lookup_max > 0:
194-
disable_mqa_scorer = True
195-
logger.info(
196-
"[Speculative Decoding] Disabling MQA scorer as the "
197-
"NGramWorker does not support MQA scorer.")
198-
199193
if "model_config" in draft_worker_kwargs and \
200194
draft_worker_kwargs["model_config"].max_model_len < \
201195
scorer_worker.model_config.max_model_len:

0 commit comments

Comments
 (0)