25
25
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
26
27
27
import asyncio
28
+ import gc
28
29
import json
29
30
import os
31
+ import queue
30
32
import threading
31
33
from typing import Dict, List
32
34
@@ -113,13 +115,19 @@ def initialize(self, args):
113
115
# Counter to keep track of ongoing request counts
114
116
self.ongoing_request_count = 0
115
117
118
+ # Starting the response thread. It allows vLLM to keep making progress while
119
+ # response sender(s) are sending responses to server frontend.
120
+ self._response_queue = queue.Queue()
121
+ self._response_thread = threading.Thread(target=self.response_loop)
122
+ self._response_thread.start()
123
+
116
124
# Starting asyncio event loop to process the received requests asynchronously.
117
125
self._loop = asyncio.get_event_loop()
118
- self._loop_thread = threading.Thread(
126
+ self._event_thread = threading.Thread(
119
127
target=self.engine_loop, args=(self._loop,)
120
128
)
121
129
self._shutdown_event = asyncio.Event()
122
- self._loop_thread .start()
130
+ self._event_thread .start()
123
131
124
132
def init_engine(self):
125
133
# Currently, Triton needs to use decoupled policy for asynchronously
@@ -273,6 +281,27 @@ def get_sampling_params_dict(self, params_json):
273
281
274
282
return params_dict
275
283
284
+ def response_loop(self):
285
+ while True:
286
+ item = self._response_queue.get()
287
+ # To signal shutdown a None item will be added to the queue.
288
+ if item is None:
289
+ break
290
+ response_sender, response, response_flag = item
291
+ del item
292
+ try:
293
+ response_sender.send(response, response_flag)
294
+ except Exception as e:
295
+ self.logger.log_error(
296
+ f"An error occurred while sending a response: {e}"
297
+ )
298
+ finally:
299
+ if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL:
300
+ self.ongoing_request_count -= 1
301
+ del response_sender
302
+ if self.ongoing_request_count == 0:
303
+ gc.collect()
304
+
276
305
def create_response(self, vllm_output, prepend_input):
277
306
"""
278
307
Parses the output from the vLLM engine into Triton
@@ -314,6 +343,7 @@ async def generate(self, request):
314
343
"""
315
344
response_sender = request.get_response_sender()
316
345
self.ongoing_request_count += 1
346
+ decrement_ongoing_request_count = True
317
347
try:
318
348
request_id = random_uuid()
319
349
prompt = pb_utils.get_input_tensor_by_name(
@@ -368,9 +398,11 @@ async def generate(self, request):
368
398
lora_local_path = self.lora_repository[lora_name]
369
399
lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path)
370
400
371
- async for output in self.llm_engine.generate(
372
- prompt, sampling_params, request_id, lora_request=lora_request
373
- ):
401
+ response_iterator = await self.llm_engine.add_request(
402
+ request_id, prompt, sampling_params, lora_request=lora_request
403
+ )
404
+
405
+ async for output in response_iterator:
374
406
if response_sender.is_cancelled():
375
407
self.logger.log_info("[vllm] Cancelling the request")
376
408
await self.llm_engine.abort(request_id)
@@ -383,15 +415,12 @@ async def generate(self, request):
383
415
len(prev_output.text)
384
416
for prev_output in prev_outputs.outputs
385
417
]
418
+ response = self.create_stream_response(output, prev_outputs_lengths)
419
+ flags = 0
386
420
if output.finished:
387
- response_sender.send(
388
- self.create_stream_response(output, prev_outputs_lengths),
389
- flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
390
- )
391
- else:
392
- response_sender.send(
393
- self.create_stream_response(output, prev_outputs_lengths)
394
- )
421
+ flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
422
+ decrement_ongoing_request_count = False
423
+ self._response_queue.put_nowait((response_sender, response, flags))
395
424
prev_outputs = output
396
425
397
426
last_output = output
@@ -403,7 +432,7 @@ async def generate(self, request):
403
432
)
404
433
405
434
except Exception as e:
406
- self.logger.log_info (f"[vllm] Error generating stream: {e}")
435
+ self.logger.log_error (f"[vllm] Error generating stream: {e}")
407
436
error = pb_utils.TritonError(f"Error generating stream: {e}")
408
437
triton_output_tensor = pb_utils.Tensor(
409
438
"text_output", np.asarray(["N/A"], dtype=self.output_dtype)
@@ -416,7 +445,11 @@ async def generate(self, request):
416
445
)
417
446
raise e
418
447
finally:
419
- self.ongoing_request_count -= 1
448
+ if decrement_ongoing_request_count:
449
+ self.ongoing_request_count -= 1
450
+ del response_sender
451
+ if self.ongoing_request_count == 0:
452
+ gc.collect()
420
453
421
454
def verify_loras(self, request):
422
455
# We will check if the requested lora exists here, if not we will send a
@@ -483,6 +516,14 @@ def finalize(self):
483
516
"""
484
517
self.logger.log_info("[vllm] Issuing finalize to vllm backend")
485
518
self._shutdown_event.set()
486
- if self._loop_thread is not None:
487
- self._loop_thread.join()
488
- self._loop_thread = None
519
+
520
+ # Shutdown the event thread.
521
+ if self._event_thread is not None:
522
+ self._event_thread.join()
523
+ self._event_thread = None
524
+
525
+ # Shutdown the response thread.
526
+ self._response_queue.put(None)
527
+ if self._response_thread is not None:
528
+ self._response_thread.join()
529
+ self._response_thread = None
0 commit comments