Skip to content

Commit 10a5b94

Browse files
committed
Add additional outputs and their input switches to auto complete
* [WIP] Add additional outputs to auto complete * [WIP] Use individual input tensor to control per additional output * [WIP] Parse additional output flags from request
1 parent b71088a commit 10a5b94

File tree

1 file changed

+122
-60
lines changed

1 file changed

+122
-60
lines changed

src/model.py

Lines changed: 122 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,26 @@
4848

4949

5050
class TritonPythonModel:
51+
@classmethod
52+
def auto_complete_config(cls, auto_complete_model_config):
53+
# Add inputs/outputs to the model config.
54+
cls._auto_complete_inputs_and_outputs(auto_complete_model_config)
55+
56+
# We need to use decoupled transaction policy for saturating
57+
# vLLM engine for max throughtput.
58+
# TODO [DLIS:5233]: Allow asynchronous execution to lift this
59+
# restriction for cases there is exactly a single response to
60+
# a single request.
61+
auto_complete_model_config.set_model_transaction_policy(dict(decoupled=True))
62+
63+
# Disabling batching in Triton, let vLLM handle the batching on its own.
64+
auto_complete_model_config.set_max_batch_size(0)
65+
66+
return auto_complete_model_config
67+
5168
@staticmethod
52-
def auto_complete_config(auto_complete_model_config):
69+
def _auto_complete_inputs_and_outputs(auto_complete_model_config):
70+
# Inputs/Outputs expected by the backend.
5371
inputs = [
5472
{"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]},
5573
{
@@ -70,10 +88,33 @@ def auto_complete_config(auto_complete_model_config):
7088
"dims": [1],
7189
"optional": True,
7290
},
91+
{
92+
"name": "output_finish_reason",
93+
"data_type": "TYPE_BOOL",
94+
"dims": [1],
95+
"optional": True,
96+
},
97+
{
98+
"name": "output_cumulative_logprob",
99+
"data_type": "TYPE_BOOL",
100+
"dims": [1],
101+
"optional": True,
102+
},
103+
{
104+
"name": "output_num_token_ids",
105+
"data_type": "TYPE_BOOL",
106+
"dims": [1],
107+
"optional": True,
108+
},
109+
]
110+
outputs = [
111+
{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]},
112+
{"name": "finish_reason", "data_type": "TYPE_STRING", "dims": [-1]},
113+
{"name": "cumulative_logprob", "data_type": "TYPE_FP32", "dims": [-1]},
114+
{"name": "num_token_ids", "data_type": "TYPE_UINT32", "dims": [-1]},
73115
]
74-
outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}]
75116

76-
# Store the model configuration as a dictionary.
117+
# Collect input and output names from the provided model config.
77118
config = auto_complete_model_config.as_dict()
78119
input_names = []
79120
output_names = []
@@ -82,26 +123,14 @@ def auto_complete_config(auto_complete_model_config):
82123
for output in config["output"]:
83124
output_names.append(output["name"])
84125

85-
# Add only missing inputs and output to the model configuration.
126+
# Add missing inputs and outputs to the model config.
86127
for input in inputs:
87128
if input["name"] not in input_names:
88129
auto_complete_model_config.add_input(input)
89130
for output in outputs:
90131
if output["name"] not in output_names:
91132
auto_complete_model_config.add_output(output)
92133

93-
# We need to use decoupled transaction policy for saturating
94-
# vLLM engine for max throughtput.
95-
# TODO [DLIS:5233]: Allow asynchronous execution to lift this
96-
# restriction for cases there is exactly a single response to
97-
# a single request.
98-
auto_complete_model_config.set_model_transaction_policy(dict(decoupled=True))
99-
100-
# Disabling batching in Triton, let vLLM handle the batching on its own.
101-
auto_complete_model_config.set_max_batch_size(0)
102-
103-
return auto_complete_model_config
104-
105134
def initialize(self, args):
106135
self.args = args
107136
self.logger = pb_utils.Logger
@@ -278,6 +307,63 @@ async def await_shutdown(self):
278307

279308
self.logger.log_info("[vllm] Shutdown complete")
280309

310+
def _get_input_tensors(self, request):
311+
# prompt
312+
prompt = pb_utils.get_input_tensor_by_name(request, "text_input").as_numpy()[0]
313+
if isinstance(prompt, bytes):
314+
prompt = prompt.decode("utf-8")
315+
316+
# stream
317+
stream = pb_utils.get_input_tensor_by_name(request, "stream")
318+
if stream:
319+
stream = stream.as_numpy()[0]
320+
else:
321+
stream = False
322+
323+
# prepend_input / exclude_input_in_output
324+
prepend_input = pb_utils.get_input_tensor_by_name(
325+
request, "exclude_input_in_output"
326+
)
327+
if prepend_input:
328+
# When `exclude_input_in_output` is False, we want to prepend input prompt
329+
# to output, thus prepend_input should be True, and vice versa.
330+
prepend_input = not prepend_input.as_numpy()[0]
331+
elif prepend_input is None and stream:
332+
prepend_input = False
333+
else:
334+
prepend_input = True
335+
if prepend_input and stream:
336+
raise ValueError(
337+
"When streaming, `exclude_input_in_output` = False is not allowed."
338+
)
339+
340+
# parameters / sampling_parameters
341+
# An alternative mechanism to receive serialized parameters as an input tensor,
342+
# because request parameters are not yet supported via BLS.
343+
sampling_parameters = pb_utils.get_input_tensor_by_name(
344+
request, "sampling_parameters"
345+
)
346+
if sampling_parameters:
347+
parameters = sampling_parameters.as_numpy()[0].decode("utf-8")
348+
else:
349+
parameters = request.parameters()
350+
351+
# output_finish_reason, output_cumulative_logprob, output_num_token_ids
352+
additional_outputs = {
353+
"output_finish_reason": None,
354+
"output_cumulative_logprob": None,
355+
"output_num_token_ids": None,
356+
}
357+
for tensor_name in additional_outputs.keys():
358+
tensor = pb_utils.get_input_tensor_by_name(request, tensor_name)
359+
if tensor:
360+
tensor = bool(tensor.as_numpy()[0])
361+
else:
362+
tensor = False
363+
additional_outputs[tensor_name] = tensor
364+
365+
return prompt, stream, prepend_input, parameters, additional_outputs
366+
281367
def get_sampling_params_dict(self, params_json):
282368
"""
283369
This functions parses the dictionary values into their
@@ -331,7 +417,7 @@ def response_loop(self):
331417
if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL:
332418
self.ongoing_request_count -= 1
333419

334-
def create_response(self, vllm_output, prepend_input):
420+
def create_response(self, vllm_output, prepend_input, additional_outputs):
335421
"""
336422
Parses the output from the vLLM engine into Triton
337423
response.
@@ -347,13 +433,17 @@ def create_response(self, vllm_output, prepend_input):
347433
)
348434
return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor])
349435

350-
def create_stream_response(self, vllm_output, previous_outputs_lengths):
436+
def create_stream_response(
437+
self, vllm_output, previous_outputs_lengths, additional_outputs
438+
):
351439
"""
352440
Parses the output from the vLLM engine, extracts only newly generated
353441
text and packs it into Triton response.
354442
"""
355443
if previous_outputs_lengths is None:
356-
return self.create_response(vllm_output, prepend_input=False)
444+
return self.create_response(
445+
vllm_output, prepend_input=False, additional_outputs=additional_outputs
446+
)
357447

358448
text_outputs = [
359449
(output.text[prev_output_length:]).encode("utf-8")
@@ -380,45 +470,13 @@ async def generate(self, request):
380470
decrement_ongoing_request_count = True
381471
try:
382472
request_id = random_uuid()
383-
prompt = pb_utils.get_input_tensor_by_name(
384-
request, "text_input"
385-
).as_numpy()[0]
386-
if isinstance(prompt, bytes):
387-
prompt = prompt.decode("utf-8")
388-
stream = pb_utils.get_input_tensor_by_name(request, "stream")
389-
if stream:
390-
stream = stream.as_numpy()[0]
391-
else:
392-
stream = False
393-
prepend_input = pb_utils.get_input_tensor_by_name(
394-
request, "exclude_input_in_output"
395-
)
396-
if prepend_input:
397-
# When `exclude_input_in_output` is False, we want to prepend
398-
# input prompt to output, thus prepend_input should be True,
399-
# and vice versa.
400-
prepend_input = not prepend_input.as_numpy()[0]
401-
elif prepend_input is None and stream:
402-
prepend_input = False
403-
else:
404-
prepend_input = True
405-
406-
if prepend_input and stream:
407-
raise ValueError(
408-
"When streaming, `exclude_input_in_output` = False is not allowed."
409-
)
410-
411-
# Request parameters are not yet supported via
412-
# BLS. Provide an optional mechanism to receive serialized
413-
# parameters as an input tensor until support is added
414-
415-
parameters_input_tensor = pb_utils.get_input_tensor_by_name(
416-
request, "sampling_parameters"
417-
)
418-
if parameters_input_tensor:
419-
parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8")
420-
else:
421-
parameters = request.parameters()
473+
(
474+
prompt,
475+
stream,
476+
prepend_input,
477+
parameters,
478+
additional_outputs,
479+
) = self._get_input_tensors(request)
422480

423481
sampling_params_dict = self.get_sampling_params_dict(parameters)
424482
lora_name = sampling_params_dict.pop("lora_name", None)
@@ -465,7 +523,9 @@ async def generate(self, request):
465523
len(prev_output.text)
466524
for prev_output in prev_outputs.outputs
467525
]
468-
response = self.create_stream_response(output, prev_outputs_lengths)
526+
response = self.create_stream_response(
527+
output, prev_outputs_lengths, additional_outputs
528+
)
469529
flags = 0
470530
if output.finished:
471531
response_state["last_response_generated"] = True
@@ -478,7 +538,9 @@ async def generate(self, request):
478538

479539
if not stream:
480540
response_sender.send(
481-
self.create_response(last_output, prepend_input),
541+
self.create_response(
542+
last_output, prepend_input, additional_outputs
543+
),
482544
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
483545
)
484546

0 commit comments

Comments
 (0)