Skip to content

Commit 58ee481

Browse files
committed
Add additional outputs to response
1 parent 892f0d0 commit 58ee481

File tree

1 file changed

+62
-6
lines changed

1 file changed

+62
-6
lines changed

src/model.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,10 @@ def response_loop(self):
418418
self.ongoing_request_count -= 1
419419

420420
def _create_response(
421-
self, prev_request_output, request_output, prepend_input=False
421+
self, prev_request_output, request_output, prepend_input, additional_outputs
422422
):
423+
output_tensors = []
424+
423425
# text_output
424426
prepend_prompt = ""
425427
if prev_request_output is None:
@@ -436,11 +438,57 @@ def _create_response(
436438
(prepend_prompt + output.text[prev_len:]).encode("utf-8")
437439
for output, prev_len in zip(request_output.outputs, prev_lens)
438440
]
439-
text_output_tensor = pb_utils.Tensor(
440-
"text_output", np.asarray(text_output, dtype=self.output_dtype)
441+
output_tensors.append(
442+
pb_utils.Tensor(
443+
"text_output", np.asarray(text_output, dtype=self.output_dtype)
444+
)
441445
)
442446

443-
return pb_utils.InferenceResponse(output_tensors=[text_output_tensor])
447+
# finish_reason
448+
if additional_outputs["output_finish_reason"]:
449+
finish_reason = [
450+
str(output.finish_reason) for output in request_output.outputs
451+
]
452+
output_tensors.append(
453+
pb_utils.Tensor(
454+
"finish_reason", np.asarray(finish_reason, dtype=np.object_)
455+
)
456+
)
457+
458+
# cumulative_logprob
459+
if additional_outputs["output_cumulative_logprob"]:
460+
cumulative_logprob = [
461+
output.cumulative_logprob for output in request_output.outputs
462+
]
463+
output_tensors.append(
464+
pb_utils.Tensor(
465+
"cumulative_logprob",
466+
np.asarray(cumulative_logprob, dtype=np.float32),
467+
)
468+
)
469+
470+
# num_token_ids
471+
if additional_outputs["output_num_token_ids"]:
472+
if prev_request_output is None:
473+
# this is the first response
474+
prev_lens = [0] * len(request_output.outputs)
475+
else:
476+
# this is a subsequent response
477+
prev_lens = [
478+
len(prev_output.token_ids)
479+
for prev_output in prev_request_output.outputs
480+
]
481+
num_token_ids = [
482+
(len(output.token_ids) - prev_len)
483+
for output, prev_len in zip(request_output.outputs, prev_lens)
484+
]
485+
output_tensors.append(
486+
pb_utils.Tensor(
487+
"num_token_ids", np.asarray(num_token_ids, dtype=np.uint32)
488+
)
489+
)
490+
491+
return pb_utils.InferenceResponse(output_tensors=output_tensors)
444492

445493
async def generate(self, request):
446494
"""
@@ -511,7 +559,10 @@ async def generate(self, request):
511559
# Send each response if streaming.
512560
if stream:
513561
response = self._create_response(
514-
prev_request_output, request_output
562+
prev_request_output,
563+
request_output,
564+
prepend_input=False,
565+
additional_outputs=additional_outputs,
515566
)
516567
flags = 0
517568
if request_output.finished:
@@ -525,7 +576,12 @@ async def generate(self, request):
525576
# Send the last response which contains all the outputs if not streaming.
526577
if not stream:
527578
response_sender.send(
528-
self._create_response(None, request_output, prepend_input),
579+
self._create_response(
580+
prev_request_output=None,
581+
request_output=request_output,
582+
prepend_input=prepend_input,
583+
additional_outputs=additional_outputs,
584+
),
529585
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
530586
)
531587

0 commit comments

Comments
 (0)