1
1
import asyncio
2
2
from abc import ABC , abstractmethod
3
- from copy import deepcopy
4
3
from typing import AsyncGenerator , List , Mapping , Optional , Union
5
4
6
5
from vllm .beam_search import BeamSearchSequence , create_sort_beams_key_function
7
6
from vllm .config import DecodingConfig , ModelConfig
8
7
from vllm .core .scheduler import SchedulerOutputs
9
8
from vllm .inputs .data import PromptType , TokensPrompt
9
+ from vllm .inputs .preprocess import InputPreprocessor
10
10
from vllm .logger import init_logger
11
11
from vllm .lora .request import LoRARequest
12
12
from vllm .model_executor .layers .sampler import SamplerOutput
@@ -61,6 +61,7 @@ def generate(
61
61
async def beam_search (
62
62
self ,
63
63
prompt : Union [PromptType , List [int ]],
64
+ model_config : ModelConfig ,
64
65
request_id : str ,
65
66
params : BeamSearchParams ,
66
67
) -> AsyncGenerator [RequestOutput , None ]:
@@ -72,27 +73,16 @@ async def beam_search(
72
73
length_penalty = params .length_penalty
73
74
include_stop_str_in_output = params .include_stop_str_in_output
74
75
75
- tokenizer = await self .get_tokenizer (lora_request = None )
76
-
77
- if isinstance (prompt , dict ):
78
- if "prompt" in prompt :
79
- tokenized_prompt = tokenizer .encode (prompt .get ("prompt" ))
80
- multi_modal_data = prompt .get ("multi_modal_data" )
81
- mm_processor_kwargs = prompt .get ("mm_processor_kwargs" )
82
- elif "prompt_token_ids" in prompt :
83
- tokenized_prompt = prompt .get ("prompt_token_ids" )
84
- multi_modal_data = prompt .get ("multi_modal_data" )
85
- mm_processor_kwargs = prompt .get ("mm_processor_kwargs" )
86
- else :
87
- raise TypeError (
88
- "Dictionary input must be a TextPrompt or TokensPrompt" )
89
- else :
90
- tokenized_prompt = prompt if isinstance (
91
- prompt , list ) else tokenizer .encode (prompt )
92
- multi_modal_data = None
93
- mm_processor_kwargs = None
94
-
95
- tokenized_length = len (tokenized_prompt )
76
+ tokenizer = await self .get_tokenizer ()
77
+ self .input_preprocessor = InputPreprocessor (model_config ,
78
+ self .tokenizer )
79
+
80
+ (prompt_text , prompt_token_ids , multi_modal_data , mm_processor_kwargs
81
+ ) = self .input_preprocessor ._extract_prompt_components (
82
+ prompt ,
83
+ request_id = request_id ,
84
+ )
85
+ tokenized_length = len (prompt_token_ids )
96
86
97
87
sort_beams_key = create_sort_beams_key_function (
98
88
tokenizer .eos_token_id , length_penalty )
@@ -103,17 +93,18 @@ async def beam_search(
103
93
temperature = temperature ,
104
94
)
105
95
all_beams = [
106
- BeamSearchSequence (tokens = tokenized_prompt , cum_logprob = 0 )
96
+ BeamSearchSequence (tokens = prompt_token_ids ,
97
+ cum_logprob = 0 ,
98
+ multi_modal_data = multi_modal_data ,
99
+ mm_processor_kwargs = mm_processor_kwargs )
107
100
]
108
101
completed = []
109
102
110
103
for _ in range (max_tokens ):
111
104
prompts_batch = [
112
- TokensPrompt (
113
- prompt_token_ids = beam .tokens ,
114
- multi_modal_data = deepcopy (
115
- multi_modal_data ), # always the values from inputs
116
- mm_processor_kwargs = deepcopy (mm_processor_kwargs ))
105
+ TokensPrompt (prompt_token_ids = beam .tokens ,
106
+ multi_modal_data = beam .multi_modal_data ,
107
+ mm_processor_kwargs = beam .mm_processor_kwargs )
117
108
for beam in all_beams
118
109
]
119
110
@@ -148,14 +139,18 @@ async def beam_search(
148
139
else current_beam .tokens , #
149
140
cum_logprob = current_beam .cum_logprob +
150
141
logprob_obj .logprob ,
151
- finish_reason = "stop" ))
142
+ finish_reason = "stop" ,
143
+ stop_reason = tokenizer .eos_token_id ))
152
144
else :
153
145
new_beams .append (
154
146
BeamSearchSequence (
155
- tokens = current_beam .tokens + [token_id ], #
147
+ tokens = current_beam .tokens + [token_id ],
156
148
cum_logprob = current_beam .cum_logprob +
157
149
logprob_obj .logprob ,
158
- ))
150
+ multi_modal_data = current_beam .
151
+ multi_modal_data ,
152
+ mm_processor_kwargs = current_beam .
153
+ mm_processor_kwargs ))
159
154
160
155
sorted_beams = sorted (new_beams , key = sort_beams_key , reverse = True )
161
156
all_beams = sorted_beams [:beam_width ]
@@ -169,18 +164,20 @@ async def beam_search(
169
164
170
165
beam_search_output = RequestOutput (
171
166
request_id = request_id ,
172
- prompt = tokenizer . decode ( tokenized_prompt ) ,
167
+ prompt = prompt_text ,
173
168
outputs = [
174
- CompletionOutput (
175
- text = beam .text ,
176
- cumulative_logprob = beam .cum_logprob ,
177
- token_ids = beam .tokens [tokenized_length :],
178
- index = i ,
179
- logprobs = beam .cum_logprob ,
180
- ) for (i , beam ) in enumerate (best_beams )
169
+ CompletionOutput (text = beam .text ,
170
+ cumulative_logprob = beam .cum_logprob ,
171
+ token_ids = beam .tokens [tokenized_length :],
172
+ index = i ,
173
+ logprobs = beam .cum_logprob ,
174
+ finish_reason = beam .finish_reason if
175
+ beam .finish_reason is not None else "length" ,
176
+ stop_reason = beam .stop_reason )
177
+ for (i , beam ) in enumerate (best_beams )
181
178
],
182
179
finished = True ,
183
- prompt_token_ids = tokenized_prompt ,
180
+ prompt_token_ids = prompt_token_ids ,
184
181
prompt_logprobs = None )
185
182
186
183
yield beam_search_output
0 commit comments