10
10
class NgramProposer :
11
11
12
12
def __init__ (self , vllm_config : VllmConfig ):
13
- self .vllm_config = vllm_config
13
+ # Minimum length of the n-gram to match.
14
+ self .min_n = vllm_config .speculative_config .prompt_lookup_min
15
+ # Maximum length of the n-gram to match.
16
+ self .max_n = vllm_config .speculative_config .prompt_lookup_max
17
+ # Number of tokens follow the match. If there are less than k
18
+ # tokens follow the match, we will return the maximum amount of
19
+ # tokens until the end.
20
+ self .k = vllm_config .speculative_config .num_speculative_tokens
21
+ # Trigger Numba JIT compilation for N-gram proposer.
22
+ # This usually takes less than 1 second.
23
+ self .propose (np .zeros (1024 , dtype = np .int32 ))
14
24
15
25
def propose (
16
26
self ,
17
27
context_token_ids : np .ndarray ,
18
- min_n : int ,
19
- max_n : int ,
20
- k : int ,
21
28
) -> Optional [np .ndarray ]:
22
29
"""Proposes the next sequence of tokens based on n-gram pattern
23
30
matching in the context. The function finds matches of the last n
@@ -27,17 +34,12 @@ def propose(
27
34
Args:
28
35
context_token_ids: Numpy array of token IDs representing the
29
36
context sequence.
30
- min_n: Minimum length of the n-gram to match.
31
- max_n: Maximum length of the n-gram to match.
32
- k: Number of tokens follow the match. If there are less
33
- than k tokens follow the match, we will return
34
- the maximum amount of tokens until the end.
35
-
37
+
36
38
Returns:
37
39
np.ndarray: The sequence of tokens that followed
38
40
the matched n-gram in the context.
39
41
None: If no matching n-gram pattern is found.
40
-
42
+
41
43
Example:
42
44
If context_token_ids = [1,2,3,4,2,3], min_n = 2, max_n = 3, and
43
45
k = 4:
@@ -49,8 +51,8 @@ def propose(
49
51
we only have three tokens after the match.
50
52
"""
51
53
# TODO(woosuk): Optimize this.
52
- for n in range (max_n , min_n - 1 , - 1 ):
53
- result = _find_subarray_kmp (context_token_ids , n , k )
54
+ for n in range (self . max_n , self . min_n - 1 , - 1 ):
55
+ result = _find_subarray_kmp (context_token_ids , n , self . k )
54
56
if result is not None :
55
57
return result
56
58
return None
0 commit comments