Skip to content

Commit 8893572

Browse files
committed
optimize reject sampler in greedy situation
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 875a86c commit 8893572

File tree

2 files changed

+80
-58
lines changed

2 files changed

+80
-58
lines changed

tests/e2e/singlecard/sample/test_rejection_sampler.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ def test_perfect_match(rejection_sampler):
7777

7878
metadata = create_sampling_metadata(all_greedy=True)
7979
logits = create_logits_tensor(output_tokens)
80-
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
81-
device=logits.device)
80+
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
81+
device=logits.device,
82+
dtype=torch.int32)
8283
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
8384
device=logits.device)
8485

@@ -102,8 +103,9 @@ def test_early_mismatch(rejection_sampler):
102103

103104
metadata = create_sampling_metadata(all_greedy=True)
104105
logits = create_logits_tensor(output_tokens)
105-
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
106-
device=logits.device)
106+
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
107+
device=logits.device,
108+
dtype=torch.int32)
107109
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
108110
device=logits.device)
109111

@@ -131,7 +133,9 @@ def test_multiple_sequences(rejection_sampler):
131133
metadata = create_sampling_metadata(all_greedy=True)
132134
logits = create_logits_tensor(output_tokens)
133135
bonus_token_tensor = torch.tensor(
134-
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
136+
[output_tokens[0][-1], output_tokens[1][-1]],
137+
device=logits.device,
138+
dtype=torch.int32).unsqueeze(1)
135139
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
136140
device=logits.device)
137141

@@ -155,8 +159,9 @@ def test_single_token_sequence(rejection_sampler):
155159

156160
metadata = create_sampling_metadata(all_greedy=True)
157161
logits = create_logits_tensor(output_tokens)
158-
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
159-
device=logits.device)
162+
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
163+
device=logits.device,
164+
dtype=torch.int32)
160165
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
161166
device=logits.device)
162167

@@ -178,8 +183,9 @@ def test_empty_sequence(rejection_sampler):
178183

179184
metadata = create_sampling_metadata(all_greedy=True)
180185
logits = create_logits_tensor(output_tokens)
181-
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
182-
device=logits.device)
186+
bonus_token_tensor = torch.tensor([[output_tokens[0][-1]]],
187+
device=logits.device,
188+
dtype=torch.int32)
183189
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
184190
device=logits.device)
185191

@@ -203,7 +209,9 @@ def test_multiple_mismatches(rejection_sampler):
203209
metadata = create_sampling_metadata(all_greedy=True)
204210
logits = create_logits_tensor(output_tokens)
205211
bonus_token_tensor = torch.tensor(
206-
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
212+
[output_tokens[0][-1], output_tokens[1][-1]],
213+
device=logits.device,
214+
dtype=torch.int32).unsqueeze(1)
207215
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
208216
device=logits.device)
209217

@@ -237,7 +245,8 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
237245
metadata = create_sampling_metadata(all_greedy=True)
238246
logits = create_logits_tensor(output_tokens)
239247
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
240-
device=logits.device)
248+
device=logits.device,
249+
dtype=torch.int32).unsqueeze(1)
241250
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
242251
device=logits.device)
243252

vllm_ascend/sample/rejection_sampler.py

Lines changed: 60 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -147,16 +147,24 @@ def rejection_sample(
147147
if not sampling_metadata.all_random:
148148
# Rejection sampling for greedy sampling requests.
149149
target_argmax = target_probs.argmax(dim=-1)
150-
rejection_greedy_sample_pytorch(
151-
output_token_ids,
152-
cu_num_draft_tokens,
153-
draft_token_ids,
154-
target_argmax,
155-
bonus_token_ids,
156-
is_greedy,
157-
max_spec_len,
158-
# num_warps=1,
159-
)
150+
if min(num_draft_tokens) == 1 and max(
151+
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
152+
rejection_greedy_sample_spec_len_1_pytorch(
153+
output_token_ids,
154+
draft_token_ids,
155+
target_argmax,
156+
bonus_token_ids,
157+
)
158+
else:
159+
rejection_greedy_sample_pytorch(
160+
output_token_ids,
161+
cu_num_draft_tokens,
162+
draft_token_ids,
163+
target_argmax,
164+
bonus_token_ids,
165+
max_spec_len,
166+
is_greedy,
167+
)
160168
if sampling_metadata.all_greedy:
161169
return output_token_ids
162170

@@ -284,47 +292,52 @@ def sample_recovered_tokens(
284292
return recovered_token_ids
285293

286294

295+
def rejection_greedy_sample_spec_len_1_pytorch(
296+
output_token_ids, # [batch_size, 2]
297+
draft_token_ids, # [num_tokens]
298+
target_argmax, # [num_tokens]
299+
bonus_token_ids, # [batch_size]
300+
):
301+
batch_size = output_token_ids.size(0)
302+
num_tokens = draft_token_ids.size(0)
303+
assert batch_size == num_tokens
304+
accept_req_mask = draft_token_ids == target_argmax
305+
output_token_ids[:, 0] = target_argmax
306+
bonus_token_ids = bonus_token_ids.squeeze(1)
307+
output_token_ids[accept_req_mask, 1] = bonus_token_ids[accept_req_mask]
308+
309+
287310
def rejection_greedy_sample_pytorch(
288-
output_token_ids, # [batch_size, max_spec_len + 1]
289-
cu_num_draft_tokens, # [batch_size]
290-
draft_token_ids, # [num_tokens]
291-
target_argmax, # [num_tokens]
292-
bonus_token_ids, # [batch_size]
293-
is_greedy=None, # [batch_size] or None
294-
max_spec_len=None,
311+
output_token_ids, # [batch_size, max_spec_len + 1]
312+
cu_num_draft_tokens, # [batch_size]
313+
draft_token_ids, # [num_tokens]
314+
target_argmax, # [num_tokens]
315+
bonus_token_ids, # [batch_size]
316+
max_spec_len, # int
317+
is_greedy=None, # [batch_size] or None
295318
):
296319
batch_size = output_token_ids.shape[0]
297-
320+
device = output_token_ids.device
298321
if is_greedy is None:
299-
is_greedy = torch.ones(batch_size,
300-
dtype=torch.bool,
301-
device=output_token_ids.device)
302-
303-
for req_idx in range(batch_size):
304-
if not is_greedy[req_idx]:
305-
continue
306-
307-
if req_idx == 0:
308-
start_idx = 0
309-
else:
310-
start_idx = cu_num_draft_tokens[req_idx - 1].item()
311-
end_idx = cu_num_draft_tokens[req_idx].item()
312-
num_draft_tokens = end_idx - start_idx
313-
314-
rejected = False
315-
for pos in range(num_draft_tokens):
316-
if not rejected:
317-
draft_token_id = draft_token_ids[start_idx + pos].item()
318-
target_argmax_id = target_argmax[start_idx + pos].item()
319-
320-
output_token_ids[req_idx, pos] = target_argmax_id
321-
322-
if draft_token_id != target_argmax_id:
323-
rejected = True
324-
325-
if not rejected:
326-
bonus_token_id = bonus_token_ids[req_idx].item()
327-
output_token_ids[req_idx, num_draft_tokens] = bonus_token_id
322+
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
323+
draft_token_mask = draft_token_ids == target_argmax
324+
pos_ids = torch.arange(0, max_spec_len + 1,
325+
device=device).view(1, -1).expand(batch_size, -1)
326+
pos_mask = pos_ids < cu_num_draft_tokens.view(-1, 1)
327+
output_token_mask = torch.zeros([batch_size, max_spec_len + 1],
328+
dtype=torch.bool,
329+
device=device)
330+
output_token_mask[pos_mask] = draft_token_mask
331+
output_token_mask = torch.cumprod(output_token_mask,
332+
dim=1) # [batch_size, max_spec_len + 1]
333+
extra_accept_id = torch.max(
334+
pos_ids * output_token_mask, dim=1, keepdim=True) + 1
335+
output_token_mask[extra_accept_id] = True
336+
output_token_mask *= is_greedy.view(-1, 1)
337+
output_token_ids[pos_ids] = draft_token_ids
338+
output_token_ids[:, -1] = bonus_token_ids
339+
output_token_ids = output_token_ids * output_token_mask
340+
return output_token_ids
328341

329342

330343
def rejection_random_sample_pytorch(

0 commit comments

Comments
 (0)