2525# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
2727import asyncio
28+ import gc
2829import json
2930import os
31+ import queue
3032import threading
3133from typing import Dict , List
3234
@@ -113,13 +115,19 @@ def initialize(self, args):
113115 # Counter to keep track of ongoing request counts
114116 self .ongoing_request_count = 0
115117
118+ # Starting the response thread. It allows vLLM to keep making progress while
119+ # response sender(s) are sending responses to server frontend.
120+ self ._response_queue = queue .Queue ()
121+ self ._response_thread = threading .Thread (target = self .response_loop )
122+ self ._response_thread .start ()
123+
116124 # Starting asyncio event loop to process the received requests asynchronously.
117125 self ._loop = asyncio .get_event_loop ()
118- self ._loop_thread = threading .Thread (
126+ self ._event_thread = threading .Thread (
119127 target = self .engine_loop , args = (self ._loop ,)
120128 )
121129 self ._shutdown_event = asyncio .Event ()
122- self ._loop_thread .start ()
130+ self ._event_thread .start ()
123131
124132 def init_engine (self ):
125133 # Currently, Triton needs to use decoupled policy for asynchronously
@@ -273,6 +281,27 @@ def get_sampling_params_dict(self, params_json):
273281
274282 return params_dict
275283
284+ def response_loop (self ):
285+ while True :
286+ item = self ._response_queue .get ()
287+ # To signal shutdown a None item will be added to the queue.
288+ if item is None :
289+ break
290+ response_sender , response , response_flag = item
291+ del item
292+ try :
293+ response_sender .send (response , response_flag )
294+ except Exception as e :
295+ self .logger .log_error (
296+ f"An error occurred while sending a response: { e } "
297+ )
298+ finally :
299+ if response_flag == pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL :
300+ self .ongoing_request_count -= 1
301+ del response_sender
302+ if self .ongoing_request_count == 0 :
303+ gc .collect ()
304+
276305 def create_response (self , vllm_output , prepend_input ):
277306 """
278307 Parses the output from the vLLM engine into Triton
@@ -314,6 +343,7 @@ async def generate(self, request):
314343 """
315344 response_sender = request .get_response_sender ()
316345 self .ongoing_request_count += 1
346+ decrement_ongoing_request_count = True
317347 try :
318348 request_id = random_uuid ()
319349 prompt = pb_utils .get_input_tensor_by_name (
@@ -368,9 +398,11 @@ async def generate(self, request):
368398 lora_local_path = self .lora_repository [lora_name ]
369399 lora_request = LoRARequest (lora_id , lora_int_id , lora_local_path )
370400
371- async for output in self .llm_engine .generate (
372- prompt , sampling_params , request_id , lora_request = lora_request
373- ):
401+ response_iterator = await self .llm_engine .add_request (
402+ request_id , prompt , sampling_params , lora_request = lora_request
403+ )
404+
405+ async for output in response_iterator :
374406 if response_sender .is_cancelled ():
375407 self .logger .log_info ("[vllm] Cancelling the request" )
376408 await self .llm_engine .abort (request_id )
@@ -383,15 +415,12 @@ async def generate(self, request):
383415 len (prev_output .text )
384416 for prev_output in prev_outputs .outputs
385417 ]
418+ response = self .create_stream_response (output , prev_outputs_lengths )
419+ flags = 0
386420 if output .finished :
387- response_sender .send (
388- self .create_stream_response (output , prev_outputs_lengths ),
389- flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL ,
390- )
391- else :
392- response_sender .send (
393- self .create_stream_response (output , prev_outputs_lengths )
394- )
421+ flags = pb_utils .TRITONSERVER_RESPONSE_COMPLETE_FINAL
422+ decrement_ongoing_request_count = False
423+ self ._response_queue .put_nowait ((response_sender , response , flags ))
395424 prev_outputs = output
396425
397426 last_output = output
@@ -403,7 +432,7 @@ async def generate(self, request):
403432 )
404433
405434 except Exception as e :
406- self .logger .log_info (f"[vllm] Error generating stream: { e } " )
435+ self .logger .log_error (f"[vllm] Error generating stream: { e } " )
407436 error = pb_utils .TritonError (f"Error generating stream: { e } " )
408437 triton_output_tensor = pb_utils .Tensor (
409438 "text_output" , np .asarray (["N/A" ], dtype = self .output_dtype )
@@ -416,7 +445,11 @@ async def generate(self, request):
416445 )
417446 raise e
418447 finally :
419- self .ongoing_request_count -= 1
448+ if decrement_ongoing_request_count :
449+ self .ongoing_request_count -= 1
450+ del response_sender
451+ if self .ongoing_request_count == 0 :
452+ gc .collect ()
420453
421454 def verify_loras (self , request ):
422455 # We will check if the requested lora exists here, if not we will send a
@@ -483,6 +516,14 @@ def finalize(self):
483516 """
484517 self .logger .log_info ("[vllm] Issuing finalize to vllm backend" )
485518 self ._shutdown_event .set ()
486- if self ._loop_thread is not None :
487- self ._loop_thread .join ()
488- self ._loop_thread = None
519+
520+ # Shutdown the event thread.
521+ if self ._event_thread is not None :
522+ self ._event_thread .join ()
523+ self ._event_thread = None
524+
525+ # Shutdown the response thread.
526+ self ._response_queue .put (None )
527+ if self ._response_thread is not None :
528+ self ._response_thread .join ()
529+ self ._response_thread = None
0 commit comments