Skip to content

Commit 892f0d0

Browse files
committed
chore: Refactor generate function
1 parent 10a5b94 commit 892f0d0

File tree

1 file changed

+40
-54
lines changed

1 file changed

+40
-54
lines changed

src/model.py

Lines changed: 40 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -417,44 +417,30 @@ def response_loop(self):
417417
if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL:
418418
self.ongoing_request_count -= 1
419419

420-
def create_response(self, vllm_output, prepend_input, additional_outputs):
421-
"""
422-
Parses the output from the vLLM engine into Triton
423-
response.
424-
"""
425-
prompt = ""
426-
if prepend_input:
427-
prompt = vllm_output.prompt
428-
text_outputs = [
429-
(prompt + output.text).encode("utf-8") for output in vllm_output.outputs
430-
]
431-
triton_output_tensor = pb_utils.Tensor(
432-
"text_output", np.asarray(text_outputs, dtype=self.output_dtype)
433-
)
434-
return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor])
435-
436-
def create_stream_response(
437-
self, vllm_output, previous_outputs_lengths, additional_outputs
420+
def _create_response(
421+
self, prev_request_output, request_output, prepend_input=False
438422
):
439-
"""
440-
Parses the output from the vLLM engine, extracts only newly generated
441-
text and packs it into Triton response.
442-
"""
443-
if previous_outputs_lengths is None:
444-
return self.create_response(
445-
vllm_output, prepend_input=False, additional_outputs=additional_outputs
446-
)
447-
448-
text_outputs = [
449-
(output.text[prev_output_length:]).encode("utf-8")
450-
for output, prev_output_length in zip(
451-
vllm_output.outputs, previous_outputs_lengths
452-
)
423+
# text_output
424+
prepend_prompt = ""
425+
if prev_request_output is None:
426+
# this is the first response
427+
if prepend_input:
428+
prepend_prompt = request_output.prompt
429+
prev_lens = [0] * len(request_output.outputs)
430+
else:
431+
# this is a subsequent response
432+
prev_lens = [
433+
len(prev_output.text) for prev_output in prev_request_output.outputs
434+
]
435+
text_output = [
436+
(prepend_prompt + output.text[prev_len:]).encode("utf-8")
437+
for output, prev_len in zip(request_output.outputs, prev_lens)
453438
]
454-
triton_output_tensor = pb_utils.Tensor(
455-
"text_output", np.asarray(text_outputs, dtype=self.output_dtype)
439+
text_output_tensor = pb_utils.Tensor(
440+
"text_output", np.asarray(text_output, dtype=self.output_dtype)
456441
)
457-
return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor])
442+
443+
return pb_utils.InferenceResponse(output_tensors=[text_output_tensor])
458444

459445
async def generate(self, request):
460446
"""
@@ -481,8 +467,6 @@ async def generate(self, request):
481467
sampling_params_dict = self.get_sampling_params_dict(parameters)
482468
lora_name = sampling_params_dict.pop("lora_name", None)
483469
sampling_params = SamplingParams(**sampling_params_dict)
484-
last_output = None
485-
prev_outputs = None
486470
lora_request = None
487471
if lora_name is not None:
488472
lora_id = str(self.supported_loras.index(lora_name) + 1)
@@ -494,15 +478,21 @@ async def generate(self, request):
494478
request_id, prompt, sampling_params, lora_request=lora_request
495479
)
496480

497-
async for output in response_iterator:
481+
prev_request_output = None
482+
async for request_output in response_iterator:
483+
# Cancellation state will be checked by the response loop and written to
484+
# the response state if streaming. If not streaming, cancellation state
485+
# needs to be checked here.
498486
is_cancelled = response_state["is_cancelled"]
499487
if not stream:
500488
is_cancelled = response_sender.is_cancelled()
501489
if is_cancelled:
502490
self.logger.log_info("[vllm] Cancelling the request")
503491
await self.llm_engine.abort(request_id)
504492
self.logger.log_info("[vllm] Successfully cancelled the request")
493+
505494
if stream:
495+
# Add cancelled final response to response loop.
506496
response_state["last_response_generated"] = True
507497
response = pb_utils.InferenceResponse(
508498
error=pb_utils.TritonError(
@@ -515,48 +505,44 @@ async def generate(self, request):
515505
self._response_queue.put_nowait(
516506
(response_state, response, flags)
517507
)
508+
518509
break
510+
511+
# Send each response if streaming.
519512
if stream:
520-
prev_outputs_lengths = None
521-
if prev_outputs is not None:
522-
prev_outputs_lengths = [
523-
len(prev_output.text)
524-
for prev_output in prev_outputs.outputs
525-
]
526-
response = self.create_stream_response(
527-
output, prev_outputs_lengths, additional_outputs
513+
response = self._create_response(
514+
prev_request_output, request_output
528515
)
529516
flags = 0
530-
if output.finished:
517+
if request_output.finished:
531518
response_state["last_response_generated"] = True
532519
flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
533520
decrement_ongoing_request_count = False
534521
self._response_queue.put_nowait((response_state, response, flags))
535-
prev_outputs = output
536522

537-
last_output = output
523+
prev_request_output = request_output
538524

525+
# Send the last response which contains all the outputs if not streaming.
539526
if not stream:
540527
response_sender.send(
541-
self.create_response(
542-
last_output, prepend_input, additional_outputs
543-
),
528+
self._create_response(None, request_output, prepend_input),
544529
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
545530
)
546531

547532
except Exception as e:
548533
self.logger.log_error(f"[vllm] Error generating stream: {e}")
549534
error = pb_utils.TritonError(f"Error generating stream: {e}")
550-
triton_output_tensor = pb_utils.Tensor(
535+
text_output_tensor = pb_utils.Tensor(
551536
"text_output", np.asarray(["N/A"], dtype=self.output_dtype)
552537
)
553538
response = pb_utils.InferenceResponse(
554-
output_tensors=[triton_output_tensor], error=error
539+
output_tensors=[text_output_tensor], error=error
555540
)
556541
response_sender.send(
557542
response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
558543
)
559544
raise e
545+
560546
finally:
561547
if decrement_ongoing_request_count:
562548
self.ongoing_request_count -= 1

0 commit comments

Comments
 (0)