@@ -417,44 +417,30 @@ def response_loop(self):
417
417
if response_flag == pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL :
418
418
self .ongoing_request_count -= 1
419
419
420
- def create_response (self , vllm_output , prepend_input , additional_outputs ):
421
- """
422
- Parses the output from the vLLM engine into Triton
423
- response.
424
- """
425
- prompt = ""
426
- if prepend_input :
427
- prompt = vllm_output .prompt
428
- text_outputs = [
429
- (prompt + output .text ).encode ("utf-8" ) for output in vllm_output .outputs
430
- ]
431
- triton_output_tensor = pb_utils .Tensor (
432
- "text_output" , np .asarray (text_outputs , dtype = self .output_dtype )
433
- )
434
- return pb_utils .InferenceResponse (output_tensors = [triton_output_tensor ])
435
-
436
- def create_stream_response (
437
- self , vllm_output , previous_outputs_lengths , additional_outputs
420
+ def _create_response (
421
+ self , prev_request_output , request_output , prepend_input = False
438
422
):
439
- """
440
- Parses the output from the vLLM engine, extracts only newly generated
441
- text and packs it into Triton response.
442
- """
443
- if previous_outputs_lengths is None :
444
- return self .create_response (
445
- vllm_output , prepend_input = False , additional_outputs = additional_outputs
446
- )
447
-
448
- text_outputs = [
449
- (output .text [prev_output_length :]).encode ("utf-8" )
450
- for output , prev_output_length in zip (
451
- vllm_output .outputs , previous_outputs_lengths
452
- )
423
+ # text_output
424
+ prepend_prompt = ""
425
+ if prev_request_output is None :
426
+ # this is the first response
427
+ if prepend_input :
428
+ prepend_prompt = request_output .prompt
429
+ prev_lens = [0 ] * len (request_output .outputs )
430
+ else :
431
+ # this is a subsequent response
432
+ prev_lens = [
433
+ len (prev_output .text ) for prev_output in prev_request_output .outputs
434
+ ]
435
+ text_output = [
436
+ (prepend_prompt + output .text [prev_len :]).encode ("utf-8" )
437
+ for output , prev_len in zip (request_output .outputs , prev_lens )
453
438
]
454
- triton_output_tensor = pb_utils .Tensor (
455
- "text_output" , np .asarray (text_outputs , dtype = self .output_dtype )
439
+ text_output_tensor = pb_utils .Tensor (
440
+ "text_output" , np .asarray (text_output , dtype = self .output_dtype )
456
441
)
457
- return pb_utils .InferenceResponse (output_tensors = [triton_output_tensor ])
442
+
443
+ return pb_utils .InferenceResponse (output_tensors = [text_output_tensor ])
458
444
459
445
async def generate (self , request ):
460
446
"""
@@ -481,8 +467,6 @@ async def generate(self, request):
481
467
sampling_params_dict = self .get_sampling_params_dict (parameters )
482
468
lora_name = sampling_params_dict .pop ("lora_name" , None )
483
469
sampling_params = SamplingParams (** sampling_params_dict )
484
- last_output = None
485
- prev_outputs = None
486
470
lora_request = None
487
471
if lora_name is not None :
488
472
lora_id = str (self .supported_loras .index (lora_name ) + 1 )
@@ -494,15 +478,21 @@ async def generate(self, request):
494
478
request_id , prompt , sampling_params , lora_request = lora_request
495
479
)
496
480
497
- async for output in response_iterator :
481
+ prev_request_output = None
482
+ async for request_output in response_iterator :
483
+ # Cancellation state will be checked by the response loop and written to
484
+ # the response state if streaming. If not streaming, cancellation state
485
+ # needs to be checked here.
498
486
is_cancelled = response_state ["is_cancelled" ]
499
487
if not stream :
500
488
is_cancelled = response_sender .is_cancelled ()
501
489
if is_cancelled :
502
490
self .logger .log_info ("[vllm] Cancelling the request" )
503
491
await self .llm_engine .abort (request_id )
504
492
self .logger .log_info ("[vllm] Successfully cancelled the request" )
493
+
505
494
if stream :
495
+ # Add cancelled final response to response loop.
506
496
response_state ["last_response_generated" ] = True
507
497
response = pb_utils .InferenceResponse (
508
498
error = pb_utils .TritonError (
@@ -515,48 +505,44 @@ async def generate(self, request):
515
505
self ._response_queue .put_nowait (
516
506
(response_state , response , flags )
517
507
)
508
+
518
509
break
510
+
511
+ # Send each response if streaming.
519
512
if stream :
520
- prev_outputs_lengths = None
521
- if prev_outputs is not None :
522
- prev_outputs_lengths = [
523
- len (prev_output .text )
524
- for prev_output in prev_outputs .outputs
525
- ]
526
- response = self .create_stream_response (
527
- output , prev_outputs_lengths , additional_outputs
513
+ response = self ._create_response (
514
+ prev_request_output , request_output
528
515
)
529
516
flags = 0
530
- if output .finished :
517
+ if request_output .finished :
531
518
response_state ["last_response_generated" ] = True
532
519
flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL
533
520
decrement_ongoing_request_count = False
534
521
self ._response_queue .put_nowait ((response_state , response , flags ))
535
- prev_outputs = output
536
522
537
- last_output = output
523
+ prev_request_output = request_output
538
524
525
+ # Send the last response which contains all the outputs if not streaming.
539
526
if not stream :
540
527
response_sender .send (
541
- self .create_response (
542
- last_output , prepend_input , additional_outputs
543
- ),
528
+ self ._create_response (None , request_output , prepend_input ),
544
529
flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL ,
545
530
)
546
531
547
532
except Exception as e :
548
533
self .logger .log_error (f"[vllm] Error generating stream: { e } " )
549
534
error = pb_utils .TritonError (f"Error generating stream: { e } " )
550
- triton_output_tensor = pb_utils .Tensor (
535
+ text_output_tensor = pb_utils .Tensor (
551
536
"text_output" , np .asarray (["N/A" ], dtype = self .output_dtype )
552
537
)
553
538
response = pb_utils .InferenceResponse (
554
- output_tensors = [triton_output_tensor ], error = error
539
+ output_tensors = [text_output_tensor ], error = error
555
540
)
556
541
response_sender .send (
557
542
response , flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL
558
543
)
559
544
raise e
545
+
560
546
finally :
561
547
if decrement_ongoing_request_count :
562
548
self .ongoing_request_count -= 1
0 commit comments