48
48
49
49
50
50
class TritonPythonModel :
51
+ @classmethod
52
+ def auto_complete_config (cls , auto_complete_model_config ):
53
+ # Add inputs/outputs to the model config.
54
+ cls ._auto_complete_inputs_and_outputs (auto_complete_model_config )
55
+
56
+ # We need to use decoupled transaction policy for saturating
57
+ # vLLM engine for max throughtput.
58
+ # TODO [DLIS:5233]: Allow asynchronous execution to lift this
59
+ # restriction for cases there is exactly a single response to
60
+ # a single request.
61
+ auto_complete_model_config .set_model_transaction_policy (dict (decoupled = True ))
62
+
63
+ # Disabling batching in Triton, let vLLM handle the batching on its own.
64
+ auto_complete_model_config .set_max_batch_size (0 )
65
+
66
+ return auto_complete_model_config
67
+
51
68
@staticmethod
52
- def auto_complete_config (auto_complete_model_config ):
69
+ def _auto_complete_inputs_and_outputs (auto_complete_model_config ):
70
+ # Inputs/Outputs expected by the backend.
53
71
inputs = [
54
72
{"name" : "text_input" , "data_type" : "TYPE_STRING" , "dims" : [1 ]},
55
73
{
@@ -70,10 +88,33 @@ def auto_complete_config(auto_complete_model_config):
70
88
"dims" : [1 ],
71
89
"optional" : True ,
72
90
},
91
+ {
92
+ "name" : "output_finish_reason" ,
93
+ "data_type" : "TYPE_BOOL" ,
94
+ "dims" : [1 ],
95
+ "optional" : True ,
96
+ },
97
+ {
98
+ "name" : "output_cumulative_logprob" ,
99
+ "data_type" : "TYPE_BOOL" ,
100
+ "dims" : [1 ],
101
+ "optional" : True ,
102
+ },
103
+ {
104
+ "name" : "output_num_token_ids" ,
105
+ "data_type" : "TYPE_BOOL" ,
106
+ "dims" : [1 ],
107
+ "optional" : True ,
108
+ },
109
+ ]
110
+ outputs = [
111
+ {"name" : "text_output" , "data_type" : "TYPE_STRING" , "dims" : [- 1 ]},
112
+ {"name" : "finish_reason" , "data_type" : "TYPE_STRING" , "dims" : [- 1 ]},
113
+ {"name" : "cumulative_logprob" , "data_type" : "TYPE_FP32" , "dims" : [- 1 ]},
114
+ {"name" : "num_token_ids" , "data_type" : "TYPE_UINT32" , "dims" : [- 1 ]},
73
115
]
74
- outputs = [{"name" : "text_output" , "data_type" : "TYPE_STRING" , "dims" : [- 1 ]}]
75
116
76
- # Store the model configuration as a dictionary .
117
+ # Collect input and output names from the provided model config .
77
118
config = auto_complete_model_config .as_dict ()
78
119
input_names = []
79
120
output_names = []
@@ -82,26 +123,14 @@ def auto_complete_config(auto_complete_model_config):
82
123
for output in config ["output" ]:
83
124
output_names .append (output ["name" ])
84
125
85
- # Add only missing inputs and output to the model configuration .
126
+ # Add missing inputs and outputs to the model config .
86
127
for input in inputs :
87
128
if input ["name" ] not in input_names :
88
129
auto_complete_model_config .add_input (input )
89
130
for output in outputs :
90
131
if output ["name" ] not in output_names :
91
132
auto_complete_model_config .add_output (output )
92
133
93
- # We need to use decoupled transaction policy for saturating
94
- # vLLM engine for max throughtput.
95
- # TODO [DLIS:5233]: Allow asynchronous execution to lift this
96
- # restriction for cases there is exactly a single response to
97
- # a single request.
98
- auto_complete_model_config .set_model_transaction_policy (dict (decoupled = True ))
99
-
100
- # Disabling batching in Triton, let vLLM handle the batching on its own.
101
- auto_complete_model_config .set_max_batch_size (0 )
102
-
103
- return auto_complete_model_config
104
-
105
134
def initialize (self , args ):
106
135
self .args = args
107
136
self .logger = pb_utils .Logger
@@ -278,6 +307,63 @@ async def await_shutdown(self):
278
307
279
308
self .logger .log_info ("[vllm] Shutdown complete" )
280
309
310
+ def _get_input_tensors (self , request ):
311
+ # prompt
312
+ prompt = pb_utils .get_input_tensor_by_name (request , "text_input" ).as_numpy ()[0 ]
313
+ if isinstance (prompt , bytes ):
314
+ prompt = prompt .decode ("utf-8" )
315
+
316
+ # stream
317
+ stream = pb_utils .get_input_tensor_by_name (request , "stream" )
318
+ if stream :
319
+ stream = stream .as_numpy ()[0 ]
320
+ else :
321
+ stream = False
322
+
323
+ # prepend_input / exclude_input_in_output
324
+ prepend_input = pb_utils .get_input_tensor_by_name (
325
+ request , "exclude_input_in_output"
326
+ )
327
+ if prepend_input :
328
+ # When `exclude_input_in_output` is False, we want to prepend input prompt
329
+ # to output, thus prepend_input should be True, and vice versa.
330
+ prepend_input = not prepend_input .as_numpy ()[0 ]
331
+ elif prepend_input is None and stream :
332
+ prepend_input = False
333
+ else :
334
+ prepend_input = True
335
+ if prepend_input and stream :
336
+ raise ValueError (
337
+ "When streaming, `exclude_input_in_output` = False is not allowed."
338
+ )
339
+
340
+ # parameters / sampling_parameters
341
+ # An alternative mechanism to receive serialized parameters as an input tensor,
342
+ # because request parameters are not yet supported via BLS.
343
+ sampling_parameters = pb_utils .get_input_tensor_by_name (
344
+ request , "sampling_parameters"
345
+ )
346
+ if sampling_parameters :
347
+ parameters = sampling_parameters .as_numpy ()[0 ].decode ("utf-8" )
348
+ else :
349
+ parameters = request .parameters ()
350
+
351
+ # output_finish_reason, output_cumulative_logprob, output_num_token_ids
352
+ additional_outputs = {
353
+ "output_finish_reason" : None ,
354
+ "output_cumulative_logprob" : None ,
355
+ "output_num_token_ids" : None ,
356
+ }
357
+ for tensor_name in additional_outputs .keys ():
358
+ tensor = pb_utils .get_input_tensor_by_name (request , tensor_name )
359
+ if tensor :
360
+ tensor = bool (tensor .as_numpy ()[0 ])
361
+ else :
362
+ tensor = False
363
+ additional_outputs [tensor_name ] = tensor
364
+
365
+ return prompt , stream , prepend_input , parameters , additional_outputs
366
+
281
367
def get_sampling_params_dict (self , params_json ):
282
368
"""
283
369
This functions parses the dictionary values into their
@@ -331,7 +417,7 @@ def response_loop(self):
331
417
if response_flag == pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL :
332
418
self .ongoing_request_count -= 1
333
419
334
- def create_response (self , vllm_output , prepend_input ):
420
+ def create_response (self , vllm_output , prepend_input , additional_outputs ):
335
421
"""
336
422
Parses the output from the vLLM engine into Triton
337
423
response.
@@ -347,13 +433,17 @@ def create_response(self, vllm_output, prepend_input):
347
433
)
348
434
return pb_utils .InferenceResponse (output_tensors = [triton_output_tensor ])
349
435
350
- def create_stream_response (self , vllm_output , previous_outputs_lengths ):
436
+ def create_stream_response (
437
+ self , vllm_output , previous_outputs_lengths , additional_outputs
438
+ ):
351
439
"""
352
440
Parses the output from the vLLM engine, extracts only newly generated
353
441
text and packs it into Triton response.
354
442
"""
355
443
if previous_outputs_lengths is None :
356
- return self .create_response (vllm_output , prepend_input = False )
444
+ return self .create_response (
445
+ vllm_output , prepend_input = False , additional_outputs = additional_outputs
446
+ )
357
447
358
448
text_outputs = [
359
449
(output .text [prev_output_length :]).encode ("utf-8" )
@@ -380,45 +470,13 @@ async def generate(self, request):
380
470
decrement_ongoing_request_count = True
381
471
try :
382
472
request_id = random_uuid ()
383
- prompt = pb_utils .get_input_tensor_by_name (
384
- request , "text_input"
385
- ).as_numpy ()[0 ]
386
- if isinstance (prompt , bytes ):
387
- prompt = prompt .decode ("utf-8" )
388
- stream = pb_utils .get_input_tensor_by_name (request , "stream" )
389
- if stream :
390
- stream = stream .as_numpy ()[0 ]
391
- else :
392
- stream = False
393
- prepend_input = pb_utils .get_input_tensor_by_name (
394
- request , "exclude_input_in_output"
395
- )
396
- if prepend_input :
397
- # When `exclude_input_in_output` is False, we want to prepend
398
- # input prompt to output, thus prepend_input should be True,
399
- # and vice versa.
400
- prepend_input = not prepend_input .as_numpy ()[0 ]
401
- elif prepend_input is None and stream :
402
- prepend_input = False
403
- else :
404
- prepend_input = True
405
-
406
- if prepend_input and stream :
407
- raise ValueError (
408
- "When streaming, `exclude_input_in_output` = False is not allowed."
409
- )
410
-
411
- # Request parameters are not yet supported via
412
- # BLS. Provide an optional mechanism to receive serialized
413
- # parameters as an input tensor until support is added
414
-
415
- parameters_input_tensor = pb_utils .get_input_tensor_by_name (
416
- request , "sampling_parameters"
417
- )
418
- if parameters_input_tensor :
419
- parameters = parameters_input_tensor .as_numpy ()[0 ].decode ("utf-8" )
420
- else :
421
- parameters = request .parameters ()
473
+ (
474
+ prompt ,
475
+ stream ,
476
+ prepend_input ,
477
+ parameters ,
478
+ additional_outputs ,
479
+ ) = self ._get_input_tensors (request )
422
480
423
481
sampling_params_dict = self .get_sampling_params_dict (parameters )
424
482
lora_name = sampling_params_dict .pop ("lora_name" , None )
@@ -465,7 +523,9 @@ async def generate(self, request):
465
523
len (prev_output .text )
466
524
for prev_output in prev_outputs .outputs
467
525
]
468
- response = self .create_stream_response (output , prev_outputs_lengths )
526
+ response = self .create_stream_response (
527
+ output , prev_outputs_lengths , additional_outputs
528
+ )
469
529
flags = 0
470
530
if output .finished :
471
531
response_state ["last_response_generated" ] = True
@@ -478,7 +538,9 @@ async def generate(self, request):
478
538
479
539
if not stream :
480
540
response_sender .send (
481
- self .create_response (last_output , prepend_input ),
541
+ self .create_response (
542
+ last_output , prepend_input , additional_outputs
543
+ ),
482
544
flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL ,
483
545
)
484
546
0 commit comments