From 5a3c37211bea0b1930b46e32dc311e7c75ee8117 Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Tue, 15 Jul 2025 12:38:14 +0530 Subject: [PATCH 1/4] Enable usage for tensorrt-llm backend --- .../openai_frontend/engine/triton_engine.py | 18 --- .../openai_frontend/engine/utils/triton.py | 9 +- python/openai/tests/test_chat_completions.py | 18 +-- python/openai/tests/test_completions.py | 14 +- python/openai/tests/test_openai_client.py | 123 +++++------------- 5 files changed, 44 insertions(+), 138 deletions(-) diff --git a/python/openai/openai_frontend/engine/triton_engine.py b/python/openai/openai_frontend/engine/triton_engine.py index 499cc623e7..134500995b 100644 --- a/python/openai/openai_frontend/engine/triton_engine.py +++ b/python/openai/openai_frontend/engine/triton_engine.py @@ -695,15 +695,6 @@ def _validate_chat_request( if request.stream_options and not request.stream: raise Exception("`stream_options` can only be used when `stream` is True") - if ( - request.stream_options - and request.stream_options.include_usage - and metadata.backend != "vllm" - ): - raise Exception( - "`stream_options.include_usage` is currently only supported for the vLLM backend" - ) - def _verify_chat_tool_call_settings(self, request: CreateChatCompletionRequest): if ( request.tool_choice @@ -844,15 +835,6 @@ def _validate_completion_request( if request.stream_options and not request.stream: raise Exception("`stream_options` can only be used when `stream` is True") - if ( - request.stream_options - and request.stream_options.include_usage - and metadata.backend != "vllm" - ): - raise Exception( - "`stream_options.include_usage` is currently only supported for the vLLM backend" - ) - def _should_stream_with_auto_tool_parsing( self, request: CreateChatCompletionRequest ): diff --git a/python/openai/openai_frontend/engine/utils/triton.py b/python/openai/openai_frontend/engine/utils/triton.py index 3104c49911..32fe1a33a3 100644 --- a/python/openai/openai_frontend/engine/utils/triton.py +++ b/python/openai/openai_frontend/engine/utils/triton.py @@ -177,6 +177,10 @@ def _create_trtllm_inference_request( if guided_json is not None: inputs["guided_decoding_guide_type"] = [["json_schema"]] inputs["guided_decoding_guide"] = [[guided_json]] + + inputs["return_num_input_tokens"] = np.bool_([[True]]) + inputs["return_num_output_tokens"] = np.bool_([[True]]) + # FIXME: TRT-LLM doesn't currently support runtime changes of 'echo' and it # is configured at model load time, so we don't handle it here for now. return model.create_request(inputs=inputs) @@ -265,11 +269,6 @@ def _get_usage_from_response( """ Extracts token usage statistics from a Triton inference response. """ - # TODO: Remove this check once TRT-LLM backend supports both "num_input_tokens" - # and "num_output_tokens", and also update the test cases accordingly. - if backend != "vllm": - return None - prompt_tokens = None completion_tokens = None diff --git a/python/openai/tests/test_chat_completions.py b/python/openai/tests/test_chat_completions.py index 5402be451d..3939d9a7a8 100644 --- a/python/openai/tests/test_chat_completions.py +++ b/python/openai/tests/test_chat_completions.py @@ -41,9 +41,7 @@ class TestChatCompletions: def client(self, fastapi_client_class_scope): yield fastapi_client_class_scope - def test_chat_completions_defaults( - self, client, model: str, messages: List[dict], backend: str - ): + def test_chat_completions_defaults(self, client, model: str, messages: List[dict]): response = client.post( "/v1/chat/completions", json={"model": model, "messages": messages}, @@ -55,10 +53,7 @@ def test_chat_completions_defaults( assert message["role"] == "assistant" usage = response.json().get("usage") - if backend == "vllm": - assert usage is not None - else: - assert usage is None + assert usage is not None def test_chat_completions_system_prompt(self, client, model: str): # NOTE: Currently just sanity check that there are no issues when a @@ -536,14 +531,7 @@ def test_request_logprobs(self): def test_request_logit_bias(self): pass - def test_usage_response( - self, client, model: str, messages: List[dict], backend: str - ): - if backend != "vllm": - pytest.skip( - "Usage reporting is currently available only for the vLLM backend." - ) - + def test_usage_response(self, client, model: str, messages: List[dict]): response = client.post( "/v1/chat/completions", json={"model": model, "messages": messages}, diff --git a/python/openai/tests/test_completions.py b/python/openai/tests/test_completions.py index 9ec3ffe7f7..ecba399398 100644 --- a/python/openai/tests/test_completions.py +++ b/python/openai/tests/test_completions.py @@ -35,7 +35,7 @@ class TestCompletions: def client(self, fastapi_client_class_scope): yield fastapi_client_class_scope - def test_completions_defaults(self, client, model: str, prompt: str, backend: str): + def test_completions_defaults(self, client, model: str, prompt: str): response = client.post( "/v1/completions", json={"model": model, "prompt": prompt}, @@ -48,10 +48,7 @@ def test_completions_defaults(self, client, model: str, prompt: str, backend: st assert response.json()["choices"][0]["text"].strip() usage = response.json().get("usage") - if backend == "vllm": - assert usage is not None - else: - assert usage is None + assert usage is not None @pytest.mark.parametrize( "sampling_parameter, value", @@ -371,12 +368,7 @@ def test_lora(self): def test_multi_lora(self): pass - def test_usage_response(self, client, model: str, prompt: str, backend: str): - if backend != "vllm": - pytest.skip( - "Usage reporting is currently available only for the vLLM backend." - ) - + def test_usage_response(self, client, model: str, prompt: str): response = client.post( "/v1/completions", json={"model": model, "prompt": prompt}, diff --git a/python/openai/tests/test_openai_client.py b/python/openai/tests/test_openai_client.py index 1a1001329b..5ffcbe4f1d 100644 --- a/python/openai/tests/test_openai_client.py +++ b/python/openai/tests/test_openai_client.py @@ -49,7 +49,7 @@ def test_openai_client_models(self, client: openai.OpenAI, backend: str): raise Exception(f"Unexpected backend {backend=}") def test_openai_client_completion( - self, client: openai.OpenAI, model: str, prompt: str, backend: str + self, client: openai.OpenAI, model: str, prompt: str ): completion = client.completions.create( prompt=prompt, @@ -61,19 +61,16 @@ def test_openai_client_completion( assert completion.choices[0].finish_reason == "stop" usage = completion.usage - if backend == "vllm": - assert usage is not None - assert isinstance(usage.prompt_tokens, int) - assert isinstance(usage.completion_tokens, int) - assert isinstance(usage.total_tokens, int) - assert usage.prompt_tokens > 0 - assert usage.completion_tokens > 0 - assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens - else: - assert usage is None + assert usage is not None + assert isinstance(usage.prompt_tokens, int) + assert isinstance(usage.completion_tokens, int) + assert isinstance(usage.total_tokens, int) + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens def test_openai_client_chat_completion( - self, client: openai.OpenAI, model: str, messages: List[dict], backend: str + self, client: openai.OpenAI, model: str, messages: List[dict] ): chat_completion = client.chat.completions.create( messages=messages, @@ -85,16 +82,13 @@ def test_openai_client_chat_completion( assert chat_completion.choices[0].finish_reason == "stop" usage = chat_completion.usage - if backend == "vllm": - assert usage is not None - assert isinstance(usage.prompt_tokens, int) - assert isinstance(usage.completion_tokens, int) - assert isinstance(usage.total_tokens, int) - assert usage.prompt_tokens > 0 - assert usage.completion_tokens > 0 - assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens - else: - assert usage is None + assert usage is not None + assert isinstance(usage.prompt_tokens, int) + assert isinstance(usage.completion_tokens, int) + assert isinstance(usage.total_tokens, int) + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens @pytest.mark.parametrize("echo", [False, True]) def test_openai_client_completion_echo( @@ -141,7 +135,7 @@ async def test_openai_client_models(self, client: openai.AsyncOpenAI, backend: s @pytest.mark.asyncio async def test_openai_client_completion( - self, client: openai.AsyncOpenAI, model: str, prompt: str, backend: str + self, client: openai.AsyncOpenAI, model: str, prompt: str ): completion = await client.completions.create( prompt=prompt, @@ -153,20 +147,17 @@ async def test_openai_client_completion( assert completion.choices[0].finish_reason == "stop" usage = completion.usage - if backend == "vllm": - assert usage is not None - assert isinstance(usage.prompt_tokens, int) - assert isinstance(usage.completion_tokens, int) - assert isinstance(usage.total_tokens, int) - assert usage.prompt_tokens > 0 - assert usage.completion_tokens > 0 - assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens - else: - assert usage is None + assert usage is not None + assert isinstance(usage.prompt_tokens, int) + assert isinstance(usage.completion_tokens, int) + assert isinstance(usage.total_tokens, int) + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens @pytest.mark.asyncio async def test_openai_client_chat_completion( - self, client: openai.AsyncOpenAI, model: str, messages: List[dict], backend: str + self, client: openai.AsyncOpenAI, model: str, messages: List[dict] ): chat_completion = await client.chat.completions.create( messages=messages, @@ -177,16 +168,13 @@ async def test_openai_client_chat_completion( assert chat_completion.choices[0].finish_reason == "stop" usage = chat_completion.usage - if backend == "vllm": - assert usage is not None - assert isinstance(usage.prompt_tokens, int) - assert isinstance(usage.completion_tokens, int) - assert isinstance(usage.total_tokens, int) - assert usage.prompt_tokens > 0 - assert usage.completion_tokens > 0 - assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens - else: - assert usage is None + assert usage is not None + assert isinstance(usage.prompt_tokens, int) + assert isinstance(usage.completion_tokens, int) + assert isinstance(usage.total_tokens, int) + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens print(f"Chat completion results: {chat_completion}") @@ -300,13 +288,8 @@ async def test_chat_streaming( @pytest.mark.asyncio async def test_chat_streaming_usage_option( - self, client: openai.AsyncOpenAI, model: str, messages: List[dict], backend: str + self, client: openai.AsyncOpenAI, model: str, messages: List[dict] ): - if backend != "vllm": - pytest.skip( - "Usage reporting is currently available only for the vLLM backend." - ) - seed = 0 temperature = 0.0 max_tokens = 16 @@ -397,13 +380,8 @@ async def test_chat_streaming_usage_option( @pytest.mark.asyncio async def test_completion_streaming_usage_option( - self, client: openai.AsyncOpenAI, model: str, prompt: str, backend: str + self, client: openai.AsyncOpenAI, model: str, prompt: str ): - if backend != "vllm": - pytest.skip( - "Usage reporting is currently available only for the vLLM backend." - ) - seed = 0 temperature = 0.0 max_tokens = 16 @@ -509,36 +487,3 @@ async def test_stream_options_without_streaming( stream_options={"include_usage": True}, ) assert "`stream_options` can only be used when `stream` is True" in str(e.value) - - @pytest.mark.asyncio - async def test_streaming_usage_unsupported_backend( - self, client: openai.AsyncOpenAI, model: str, messages: List[dict], backend: str - ): - if backend == "vllm": - pytest.skip( - "This test is for backends that do not support usage reporting." - ) - - with pytest.raises(openai.BadRequestError) as e: - await client.completions.create( - model=model, - prompt="Test prompt", - stream=True, - stream_options={"include_usage": True}, - ) - assert ( - "`stream_options.include_usage` is currently only supported for the vLLM backend" - in str(e.value) - ) - - with pytest.raises(openai.BadRequestError) as e: - await client.chat.completions.create( - model=model, - messages=messages, - stream=True, - stream_options={"include_usage": True}, - ) - assert ( - "`stream_options.include_usage` is currently only supported for the vLLM backend" - in str(e.value) - ) From c7fbe73582650fe0c431fad4d0e290caa3efb3bb Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Tue, 30 Sep 2025 20:07:31 +0530 Subject: [PATCH 2/4] Update tests to use ensemble model --- python/openai/README.md | 18 +++++++++++------- .../openai_frontend/engine/utils/triton.py | 10 ++++++++++ python/openai/tests/conftest.py | 2 +- python/openai/tests/test_openai_client.py | 4 ++-- qa/L0_openai/generate_engine.py | 5 +++-- qa/L0_openai/test.sh | 8 ++++++-- 6 files changed, 33 insertions(+), 14 deletions(-) diff --git a/python/openai/README.md b/python/openai/README.md index f5ea906a86..d31fb749aa 100644 --- a/python/openai/README.md +++ b/python/openai/README.md @@ -341,8 +341,8 @@ python3 openai_frontend/main.py --model-repository path/to/models --tokenizer me - Note the use of `jq` is optional, but provides a nicely formatted output for JSON responses. ```bash # MODEL should be the client-facing model name in your model repository for a pipeline like TRT-LLM. -# For example, this could also be "ensemble", or something like "gpt2" if generated from Triton CLI -MODEL="tensorrt_llm_bls" +# For example, this could also be "tensorrt_llm_bls", or something like "gpt2" if generated from Triton CLI +MODEL="ensemble" curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/json' -d '{ "model": "'${MODEL}'", "messages": [{"role": "user", "content": "Say this is a test!"}] @@ -354,13 +354,13 @@ curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/ ```json { - "id": "cmpl-704c758c-8a84-11ef-b106-107c6149ca79", + "id": "cmpl-1e8ef90b-9def-11f0-8b68-89e7c3fd7d95", "choices": [ { "finish_reason": "stop", "index": 0, "message": { - "content": "It looks like you're testing the system!", + "content": "It looks like you're ready to see if I'm functioning properly. What would", "tool_calls": null, "role": "assistant", "function_call": null @@ -368,11 +368,15 @@ curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/ "logprobs": null } ], - "created": 1728948689, - "model": "llama-3-8b-instruct", + "created": 1759231078, + "model": "ensemble", "system_fingerprint": null, "object": "chat.completion", - "usage": null + "usage": { + "completion_tokens": 16, + "prompt_tokens": 42, + "total_tokens": 58 + } } ``` diff --git a/python/openai/openai_frontend/engine/utils/triton.py b/python/openai/openai_frontend/engine/utils/triton.py index 30fb349ade..6b0196439c 100644 --- a/python/openai/openai_frontend/engine/utils/triton.py +++ b/python/openai/openai_frontend/engine/utils/triton.py @@ -284,12 +284,22 @@ def _get_usage_from_response( input_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_uint32) ) prompt_tokens = prompt_tokens_ptr[0] + elif input_token_tensor.data_type == tritonserver.DataType.INT32: + prompt_tokens_ptr = ctypes.cast( + input_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_int32) + ) + prompt_tokens = prompt_tokens_ptr[0] if output_token_tensor.data_type == tritonserver.DataType.UINT32: completion_tokens_ptr = ctypes.cast( output_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_uint32) ) completion_tokens = completion_tokens_ptr[0] + elif output_token_tensor.data_type == tritonserver.DataType.INT32: + completion_tokens_ptr = ctypes.cast( + output_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_int32) + ) + completion_tokens = completion_tokens_ptr[0] if prompt_tokens is not None and completion_tokens is not None: total_tokens = prompt_tokens + completion_tokens diff --git a/python/openai/tests/conftest.py b/python/openai/tests/conftest.py index 50ba0de4ed..5bb781792e 100644 --- a/python/openai/tests/conftest.py +++ b/python/openai/tests/conftest.py @@ -51,7 +51,7 @@ def infer_test_environment(tool_call_parser): import tensorrt_llm as _ backend = "tensorrtllm" - model = "tensorrt_llm_bls" + model = "ensemble" return backend, model except ImportError: print("No tensorrt_llm installation found.") diff --git a/python/openai/tests/test_openai_client.py b/python/openai/tests/test_openai_client.py index 5ffcbe4f1d..8f24cef96d 100644 --- a/python/openai/tests/test_openai_client.py +++ b/python/openai/tests/test_openai_client.py @@ -40,7 +40,7 @@ def test_openai_client_models(self, client: openai.OpenAI, backend: str): models = list(client.models.list()) print(f"Models: {models}") if backend == "tensorrtllm": - # tensorrt_llm_bls + + # ensemble + # preprocess -> tensorrt_llm -> postprocess assert len(models) == 4 elif backend == "vllm": @@ -125,7 +125,7 @@ async def test_openai_client_models(self, client: openai.AsyncOpenAI, backend: s models = [model async for model in async_models] print(f"Models: {models}") if backend == "tensorrtllm": - # tensorrt_llm_bls + + # ensemble + # preprocess -> tensorrt_llm -> postprocess assert len(models) == 4 elif backend == "vllm": diff --git a/qa/L0_openai/generate_engine.py b/qa/L0_openai/generate_engine.py index b71e084b03..07a2dfb29d 100644 --- a/qa/L0_openai/generate_engine.py +++ b/qa/L0_openai/generate_engine.py @@ -25,7 +25,8 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from argparse import ArgumentParser -from tensorrt_llm import LLM, BuildConfig +from tensorrt_llm import BuildConfig +from tensorrt_llm._tensorrt_engine import LLM from tensorrt_llm.plugin import PluginConfig @@ -53,4 +54,4 @@ def generate_model_engine(model: str, engines_path: str): FLAGS = parser.parse_args() generate_model_engine(FLAGS.model, FLAGS.engine_path) - print(f"model {FLAGS.model}'s engine has been saved to {FLAGS.engine_path}") + print(f"model {FLAGS.model}'s engine has been saved to {FLAGS.engine_path}") \ No newline at end of file diff --git a/qa/L0_openai/test.sh b/qa/L0_openai/test.sh index 9e098a4140..f351b61299 100755 --- a/qa/L0_openai/test.sh +++ b/qa/L0_openai/test.sh @@ -60,8 +60,11 @@ function prepare_tensorrtllm() { mkdir -p ${MODEL_REPO} cp /app/all_models/inflight_batcher_llm/* "${MODEL_REPO}" -r + + # TODO: # Ensemble model is not needed for the test - rm -rf ${MODEL_REPO}/ensemble + #rm -rf ${MODEL_REPO}/ensemble + rm -rf ${MODEL_REPO}/tensorrt_llm_bls # 1. Generate the model's trt engines python3 ../generate_engine.py --model "${MODEL}" --engine_path "${ENGINE_PATH}" @@ -70,8 +73,9 @@ function prepare_tensorrtllm() { FILL_TEMPLATE="/app/tools/fill_template.py" python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/preprocessing/config.pbtxt tokenizer_dir:${ENGINE_PATH},triton_max_batch_size:64,preprocessing_instance_count:1,max_queue_size:0 python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/postprocessing/config.pbtxt tokenizer_dir:${ENGINE_PATH},triton_max_batch_size:64,postprocessing_instance_count:1 - python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,bls_instance_count:1,accumulate_tokens:False,logits_datatype:TYPE_FP32 + #python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,bls_instance_count:1,accumulate_tokens:False,logits_datatype:TYPE_FP32 python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:64,decoupled_mode:True,max_beam_width:1,engine_dir:${ENGINE_PATH},batching_strategy:inflight_fused_batching,max_queue_size:0,max_queue_delay_microseconds:1000,encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32,exclude_input_in_output:True + python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/ensemble/config.pbtxt triton_max_batch_size:64,logits_datatype:TYPE_FP32 } function pre_test() { From b1d50aca703f502dfceb70b30cb01ab86d637e5c Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Mon, 6 Oct 2025 15:01:40 +0530 Subject: [PATCH 3/4] Undo test changes --- python/openai/README.md | 22 ++++++++++------------ python/openai/tests/conftest.py | 2 +- python/openai/tests/test_openai_client.py | 4 ++-- qa/L0_openai/generate_engine.py | 2 +- qa/L0_openai/test.sh | 8 ++------ 5 files changed, 16 insertions(+), 22 deletions(-) diff --git a/python/openai/README.md b/python/openai/README.md index d31fb749aa..796596e531 100644 --- a/python/openai/README.md +++ b/python/openai/README.md @@ -51,7 +51,7 @@ docker run -it --net=host --gpus all --rm \ -v ${HOME}/.cache/huggingface:/root/.cache/huggingface \ -e HF_TOKEN \ - nvcr.io/nvidia/tritonserver:25.08-vllm-python-py3 + nvcr.io/nvidia/tritonserver:25.09-vllm-python-py3 ``` 2. Launch the OpenAI-compatible Triton Inference Server: @@ -341,8 +341,8 @@ python3 openai_frontend/main.py --model-repository path/to/models --tokenizer me - Note the use of `jq` is optional, but provides a nicely formatted output for JSON responses. ```bash # MODEL should be the client-facing model name in your model repository for a pipeline like TRT-LLM. -# For example, this could also be "tensorrt_llm_bls", or something like "gpt2" if generated from Triton CLI -MODEL="ensemble" +# For example, this could also be "ensemble", or something like "gpt2" if generated from Triton CLI +MODEL="tensorrt_llm_bls" curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/json' -d '{ "model": "'${MODEL}'", "messages": [{"role": "user", "content": "Say this is a test!"}] @@ -354,13 +354,13 @@ curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/ ```json { - "id": "cmpl-1e8ef90b-9def-11f0-8b68-89e7c3fd7d95", + "id": "cmpl-704c758c-8a84-11ef-b106-107c6149ca79", "choices": [ { "finish_reason": "stop", "index": 0, "message": { - "content": "It looks like you're ready to see if I'm functioning properly. What would", + "content": "It looks like you're testing the system!", "tool_calls": null, "role": "assistant", "function_call": null @@ -368,15 +368,11 @@ curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/ "logprobs": null } ], - "created": 1759231078, - "model": "ensemble", + "created": 1728948689, + "model": "llama-3-8b-instruct", "system_fingerprint": null, "object": "chat.completion", - "usage": { - "completion_tokens": 16, - "prompt_tokens": 42, - "total_tokens": 58 - } + "usage": null } ``` @@ -693,3 +689,5 @@ curl -H "api-key: my-secret-key" \ # Multiple APIs in single argument with shared authentication --openai-restricted-api "inference,model-repository shared-key shared-secret" ``` + +#### Add a note about usage metrics limitation diff --git a/python/openai/tests/conftest.py b/python/openai/tests/conftest.py index 5bb781792e..50ba0de4ed 100644 --- a/python/openai/tests/conftest.py +++ b/python/openai/tests/conftest.py @@ -51,7 +51,7 @@ def infer_test_environment(tool_call_parser): import tensorrt_llm as _ backend = "tensorrtllm" - model = "ensemble" + model = "tensorrt_llm_bls" return backend, model except ImportError: print("No tensorrt_llm installation found.") diff --git a/python/openai/tests/test_openai_client.py b/python/openai/tests/test_openai_client.py index 8f24cef96d..5ffcbe4f1d 100644 --- a/python/openai/tests/test_openai_client.py +++ b/python/openai/tests/test_openai_client.py @@ -40,7 +40,7 @@ def test_openai_client_models(self, client: openai.OpenAI, backend: str): models = list(client.models.list()) print(f"Models: {models}") if backend == "tensorrtllm": - # ensemble + + # tensorrt_llm_bls + # preprocess -> tensorrt_llm -> postprocess assert len(models) == 4 elif backend == "vllm": @@ -125,7 +125,7 @@ async def test_openai_client_models(self, client: openai.AsyncOpenAI, backend: s models = [model async for model in async_models] print(f"Models: {models}") if backend == "tensorrtllm": - # ensemble + + # tensorrt_llm_bls + # preprocess -> tensorrt_llm -> postprocess assert len(models) == 4 elif backend == "vllm": diff --git a/qa/L0_openai/generate_engine.py b/qa/L0_openai/generate_engine.py index 07a2dfb29d..83ea35a88d 100644 --- a/qa/L0_openai/generate_engine.py +++ b/qa/L0_openai/generate_engine.py @@ -54,4 +54,4 @@ def generate_model_engine(model: str, engines_path: str): FLAGS = parser.parse_args() generate_model_engine(FLAGS.model, FLAGS.engine_path) - print(f"model {FLAGS.model}'s engine has been saved to {FLAGS.engine_path}") \ No newline at end of file + print(f"model {FLAGS.model}'s engine has been saved to {FLAGS.engine_path}") diff --git a/qa/L0_openai/test.sh b/qa/L0_openai/test.sh index f351b61299..9e098a4140 100755 --- a/qa/L0_openai/test.sh +++ b/qa/L0_openai/test.sh @@ -60,11 +60,8 @@ function prepare_tensorrtllm() { mkdir -p ${MODEL_REPO} cp /app/all_models/inflight_batcher_llm/* "${MODEL_REPO}" -r - - # TODO: # Ensemble model is not needed for the test - #rm -rf ${MODEL_REPO}/ensemble - rm -rf ${MODEL_REPO}/tensorrt_llm_bls + rm -rf ${MODEL_REPO}/ensemble # 1. Generate the model's trt engines python3 ../generate_engine.py --model "${MODEL}" --engine_path "${ENGINE_PATH}" @@ -73,9 +70,8 @@ function prepare_tensorrtllm() { FILL_TEMPLATE="/app/tools/fill_template.py" python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/preprocessing/config.pbtxt tokenizer_dir:${ENGINE_PATH},triton_max_batch_size:64,preprocessing_instance_count:1,max_queue_size:0 python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/postprocessing/config.pbtxt tokenizer_dir:${ENGINE_PATH},triton_max_batch_size:64,postprocessing_instance_count:1 - #python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,bls_instance_count:1,accumulate_tokens:False,logits_datatype:TYPE_FP32 + python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,bls_instance_count:1,accumulate_tokens:False,logits_datatype:TYPE_FP32 python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:64,decoupled_mode:True,max_beam_width:1,engine_dir:${ENGINE_PATH},batching_strategy:inflight_fused_batching,max_queue_size:0,max_queue_delay_microseconds:1000,encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32,exclude_input_in_output:True - python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/ensemble/config.pbtxt triton_max_batch_size:64,logits_datatype:TYPE_FP32 } function pre_test() { From 5874e421eba776164abccdceec17ab79c1b48500 Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Tue, 28 Oct 2025 23:15:47 +0530 Subject: [PATCH 4/4] Update --- qa/L0_openai/test.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qa/L0_openai/test.sh b/qa/L0_openai/test.sh index 9e098a4140..16288efea1 100755 --- a/qa/L0_openai/test.sh +++ b/qa/L0_openai/test.sh @@ -70,8 +70,8 @@ function prepare_tensorrtllm() { FILL_TEMPLATE="/app/tools/fill_template.py" python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/preprocessing/config.pbtxt tokenizer_dir:${ENGINE_PATH},triton_max_batch_size:64,preprocessing_instance_count:1,max_queue_size:0 python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/postprocessing/config.pbtxt tokenizer_dir:${ENGINE_PATH},triton_max_batch_size:64,postprocessing_instance_count:1 - python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,bls_instance_count:1,accumulate_tokens:False,logits_datatype:TYPE_FP32 - python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:64,decoupled_mode:True,max_beam_width:1,engine_dir:${ENGINE_PATH},batching_strategy:inflight_fused_batching,max_queue_size:0,max_queue_delay_microseconds:1000,encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32,exclude_input_in_output:True + python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:64,decoupled_mode:True,bls_instance_count:1,accumulate_tokens:False,logits_datatype:TYPE_FP32,prompt_embedding_table_data_type:TYPE_FP16 + python3 ${FILL_TEMPLATE} -i ${MODEL_REPO}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:64,decoupled_mode:True,max_beam_width:1,engine_dir:${ENGINE_PATH},batching_strategy:inflight_fused_batching,max_queue_size:0,max_queue_delay_microseconds:1000,encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32,exclude_input_in_output:True,prompt_embedding_table_data_type:TYPE_FP16 } function pre_test() {