Skip to content

Commit 63375f0

Browse files
authored
[V1][Spec Decode] Update N-gram Proposer Interface (#15750)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 70ad3f9 commit 63375f0

File tree

2 files changed

+16
-18
lines changed

2 files changed

+16
-18
lines changed

vllm/v1/spec_decode/ngram_proposer.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,21 @@
1010
class NgramProposer:
1111

1212
def __init__(self, vllm_config: VllmConfig):
13-
self.vllm_config = vllm_config
13+
# Minimum length of the n-gram to match.
14+
self.min_n = vllm_config.speculative_config.prompt_lookup_min
15+
# Maximum length of the n-gram to match.
16+
self.max_n = vllm_config.speculative_config.prompt_lookup_max
17+
# Number of tokens follow the match. If there are less than k
18+
# tokens follow the match, we will return the maximum amount of
19+
# tokens until the end.
20+
self.k = vllm_config.speculative_config.num_speculative_tokens
21+
# Trigger Numba JIT compilation for N-gram proposer.
22+
# This usually takes less than 1 second.
23+
self.propose(np.zeros(1024, dtype=np.int32))
1424

1525
def propose(
1626
self,
1727
context_token_ids: np.ndarray,
18-
min_n: int,
19-
max_n: int,
20-
k: int,
2128
) -> Optional[np.ndarray]:
2229
"""Proposes the next sequence of tokens based on n-gram pattern
2330
matching in the context. The function finds matches of the last n
@@ -27,17 +34,12 @@ def propose(
2734
Args:
2835
context_token_ids: Numpy array of token IDs representing the
2936
context sequence.
30-
min_n: Minimum length of the n-gram to match.
31-
max_n: Maximum length of the n-gram to match.
32-
k: Number of tokens follow the match. If there are less
33-
than k tokens follow the match, we will return
34-
the maximum amount of tokens until the end.
35-
37+
3638
Returns:
3739
np.ndarray: The sequence of tokens that followed
3840
the matched n-gram in the context.
3941
None: If no matching n-gram pattern is found.
40-
42+
4143
Example:
4244
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
4345
k = 4:
@@ -49,8 +51,8 @@ def propose(
4951
we only have three tokens after the match.
5052
"""
5153
# TODO(woosuk): Optimize this.
52-
for n in range(max_n, min_n - 1, -1):
53-
result = _find_subarray_kmp(context_token_ids, n, k)
54+
for n in range(self.max_n, self.min_n - 1, -1):
55+
result = _find_subarray_kmp(context_token_ids, n, self.k)
5456
if result is not None:
5557
return result
5658
return None

vllm/v1/worker/gpu_model_runner.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -1246,11 +1246,7 @@ def generate_draft_token_ids(
12461246
end_idx = start_idx + num_sampled_ids
12471247
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
12481248
drafter_output = self.drafter.propose(
1249-
self.input_batch.token_ids_cpu[i, :end_idx],
1250-
self.speculative_config.prompt_lookup_min,
1251-
self.speculative_config.prompt_lookup_max,
1252-
self.speculative_config.num_speculative_tokens,
1253-
)
1249+
self.input_batch.token_ids_cpu[i, :end_idx])
12541250
if drafter_output is None or len(drafter_output) == 0:
12551251
draft_token_ids.append([])
12561252
else:

0 commit comments

Comments
 (0)