Skip to content

Commit 577bfbb

Browse files
update BeamSequence, prompt preprocess and adding stop_reason
Signed-off-by: qishuai <ferdinandzhong@gmail.com>
1 parent ae0f24e commit 577bfbb

File tree

4 files changed

+52
-47
lines changed

4 files changed

+52
-47
lines changed

vllm/beam_search.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from dataclasses import dataclass
2-
from typing import List, Optional
2+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
3+
4+
if TYPE_CHECKING:
5+
from vllm.multimodal import MultiModalDataDict
36

47

58
@dataclass
@@ -14,6 +17,9 @@ class BeamSearchSequence:
1417
cum_logprob: float = 0.0
1518
text: Optional[str] = None
1619
finish_reason: Optional[str] = None
20+
stop_reason: Union[int, str, None] = None
21+
multi_modal_data: Optional["MultiModalDataDict"] = None
22+
mm_processor_kwargs: Optional[Dict[str, Any]] = None
1723

1824

1925
@dataclass

vllm/engine/protocol.py

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import asyncio
22
from abc import ABC, abstractmethod
3-
from copy import deepcopy
43
from typing import AsyncGenerator, List, Mapping, Optional, Union
54

65
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
76
from vllm.config import DecodingConfig, ModelConfig
87
from vllm.core.scheduler import SchedulerOutputs
98
from vllm.inputs.data import PromptType, TokensPrompt
9+
from vllm.inputs.preprocess import InputPreprocessor
1010
from vllm.logger import init_logger
1111
from vllm.lora.request import LoRARequest
1212
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -61,6 +61,7 @@ def generate(
6161
async def beam_search(
6262
self,
6363
prompt: Union[PromptType, List[int]],
64+
model_config: ModelConfig,
6465
request_id: str,
6566
params: BeamSearchParams,
6667
) -> AsyncGenerator[RequestOutput, None]:
@@ -72,27 +73,16 @@ async def beam_search(
7273
length_penalty = params.length_penalty
7374
include_stop_str_in_output = params.include_stop_str_in_output
7475

75-
tokenizer = await self.get_tokenizer(lora_request=None)
76-
77-
if isinstance(prompt, dict):
78-
if "prompt" in prompt:
79-
tokenized_prompt = tokenizer.encode(prompt.get("prompt"))
80-
multi_modal_data = prompt.get("multi_modal_data")
81-
mm_processor_kwargs = prompt.get("mm_processor_kwargs")
82-
elif "prompt_token_ids" in prompt:
83-
tokenized_prompt = prompt.get("prompt_token_ids")
84-
multi_modal_data = prompt.get("multi_modal_data")
85-
mm_processor_kwargs = prompt.get("mm_processor_kwargs")
86-
else:
87-
raise TypeError(
88-
"Dictionary input must be a TextPrompt or TokensPrompt")
89-
else:
90-
tokenized_prompt = prompt if isinstance(
91-
prompt, list) else tokenizer.encode(prompt)
92-
multi_modal_data = None
93-
mm_processor_kwargs = None
94-
95-
tokenized_length = len(tokenized_prompt)
76+
tokenizer = await self.get_tokenizer()
77+
self.input_preprocessor = InputPreprocessor(model_config,
78+
self.tokenizer)
79+
80+
(prompt_text, prompt_token_ids, multi_modal_data, mm_processor_kwargs
81+
) = self.input_preprocessor._extract_prompt_components(
82+
prompt,
83+
request_id=request_id,
84+
)
85+
tokenized_length = len(prompt_token_ids)
9686

9787
sort_beams_key = create_sort_beams_key_function(
9888
tokenizer.eos_token_id, length_penalty)
@@ -103,17 +93,18 @@ async def beam_search(
10393
temperature=temperature,
10494
)
10595
all_beams = [
106-
BeamSearchSequence(tokens=tokenized_prompt, cum_logprob=0)
96+
BeamSearchSequence(tokens=prompt_token_ids,
97+
cum_logprob=0,
98+
multi_modal_data=multi_modal_data,
99+
mm_processor_kwargs=mm_processor_kwargs)
107100
]
108101
completed = []
109102

110103
for _ in range(max_tokens):
111104
prompts_batch = [
112-
TokensPrompt(
113-
prompt_token_ids=beam.tokens,
114-
multi_modal_data=deepcopy(
115-
multi_modal_data), # always the values from inputs
116-
mm_processor_kwargs=deepcopy(mm_processor_kwargs))
105+
TokensPrompt(prompt_token_ids=beam.tokens,
106+
multi_modal_data=beam.multi_modal_data,
107+
mm_processor_kwargs=beam.mm_processor_kwargs)
117108
for beam in all_beams
118109
]
119110

@@ -148,14 +139,18 @@ async def beam_search(
148139
else current_beam.tokens, #
149140
cum_logprob=current_beam.cum_logprob +
150141
logprob_obj.logprob,
151-
finish_reason="stop"))
142+
finish_reason="stop",
143+
stop_reason=tokenizer.eos_token_id))
152144
else:
153145
new_beams.append(
154146
BeamSearchSequence(
155-
tokens=current_beam.tokens + [token_id], #
147+
tokens=current_beam.tokens + [token_id],
156148
cum_logprob=current_beam.cum_logprob +
157149
logprob_obj.logprob,
158-
))
150+
multi_modal_data=current_beam.
151+
multi_modal_data,
152+
mm_processor_kwargs=current_beam.
153+
mm_processor_kwargs))
159154

160155
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
161156
all_beams = sorted_beams[:beam_width]
@@ -169,18 +164,20 @@ async def beam_search(
169164

170165
beam_search_output = RequestOutput(
171166
request_id=request_id,
172-
prompt=tokenizer.decode(tokenized_prompt),
167+
prompt=prompt_text,
173168
outputs=[
174-
CompletionOutput(
175-
text=beam.text,
176-
cumulative_logprob=beam.cum_logprob,
177-
token_ids=beam.tokens[tokenized_length:],
178-
index=i,
179-
logprobs=beam.cum_logprob,
180-
) for (i, beam) in enumerate(best_beams)
169+
CompletionOutput(text=beam.text,
170+
cumulative_logprob=beam.cum_logprob,
171+
token_ids=beam.tokens[tokenized_length:],
172+
index=i,
173+
logprobs=beam.cum_logprob,
174+
finish_reason=beam.finish_reason if
175+
beam.finish_reason is not None else "length",
176+
stop_reason=beam.stop_reason)
177+
for (i, beam) in enumerate(best_beams)
181178
],
182179
finished=True,
183-
prompt_token_ids=tokenized_prompt,
180+
prompt_token_ids=prompt_token_ids,
184181
prompt_logprobs=None)
185182

186183
yield beam_search_output

vllm/entrypoints/openai/serving_chat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,10 @@ async def create_chat_completion(
236236

237237
if isinstance(sampling_params, BeamSearchParams):
238238
result_generator = self.engine_client.beam_search(
239-
engine_inputs,
240-
request_id,
241-
sampling_params,
239+
prompt=engine_inputs,
240+
model_config=self.model_config,
241+
request_id=request_id,
242+
params=sampling_params,
242243
)
243244
else:
244245
result_generator = self.engine_client.generate(

vllm/entrypoints/openai/serving_completion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,10 @@ async def create_completion(
150150

151151
if isinstance(sampling_params, BeamSearchParams):
152152
generator = self.engine_client.beam_search(
153-
prompt_inputs,
154-
request_id_item,
155-
sampling_params,
153+
prompt=prompt_inputs,
154+
model_config=self.model_config,
155+
request_id=request_id,
156+
params=sampling_params,
156157
)
157158
else:
158159
generator = self.engine_client.generate(

0 commit comments

Comments
 (0)