Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/openai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
docker run -it --net=host --gpus all --rm \
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
-e HF_TOKEN \
nvcr.io/nvidia/tritonserver:25.08-vllm-python-py3
nvcr.io/nvidia/tritonserver:25.09-vllm-python-py3
```

2. Launch the OpenAI-compatible Triton Inference Server:
Expand Down Expand Up @@ -689,3 +689,5 @@ curl -H "api-key: my-secret-key" \
# Multiple APIs in single argument with shared authentication
--openai-restricted-api "inference,model-repository shared-key shared-secret"
```

#### Add a note about usage metrics limitation
18 changes: 0 additions & 18 deletions python/openai/openai_frontend/engine/triton_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,15 +695,6 @@ def _validate_chat_request(
if request.stream_options and not request.stream:
raise Exception("`stream_options` can only be used when `stream` is True")

if (
request.stream_options
and request.stream_options.include_usage
and metadata.backend != "vllm"
):
raise Exception(
"`stream_options.include_usage` is currently only supported for the vLLM backend"
)

def _verify_chat_tool_call_settings(self, request: CreateChatCompletionRequest):
if (
request.tool_choice
Expand Down Expand Up @@ -844,15 +835,6 @@ def _validate_completion_request(
if request.stream_options and not request.stream:
raise Exception("`stream_options` can only be used when `stream` is True")

if (
request.stream_options
and request.stream_options.include_usage
and metadata.backend != "vllm"
):
raise Exception(
"`stream_options.include_usage` is currently only supported for the vLLM backend"
)

def _should_stream_with_auto_tool_parsing(
self, request: CreateChatCompletionRequest
):
Expand Down
19 changes: 14 additions & 5 deletions python/openai/openai_frontend/engine/utils/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def _create_trtllm_inference_request(
if guided_json is not None:
inputs["guided_decoding_guide_type"] = [["json_schema"]]
inputs["guided_decoding_guide"] = [[guided_json]]

inputs["return_num_input_tokens"] = np.bool_([[True]])
inputs["return_num_output_tokens"] = np.bool_([[True]])

# FIXME: TRT-LLM doesn't currently support runtime changes of 'echo' and it
# is configured at model load time, so we don't handle it here for now.
return model.create_request(inputs=inputs)
Expand Down Expand Up @@ -265,11 +269,6 @@ def _get_usage_from_response(
"""
Extracts token usage statistics from a Triton inference response.
"""
# TODO: Remove this check once TRT-LLM backend supports both "num_input_tokens"
# and "num_output_tokens", and also update the test cases accordingly.
if backend != "vllm":
return None

prompt_tokens = None
completion_tokens = None

Expand All @@ -285,12 +284,22 @@ def _get_usage_from_response(
input_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_uint32)
)
prompt_tokens = prompt_tokens_ptr[0]
elif input_token_tensor.data_type == tritonserver.DataType.INT32:
prompt_tokens_ptr = ctypes.cast(
input_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_int32)
)
prompt_tokens = prompt_tokens_ptr[0]

if output_token_tensor.data_type == tritonserver.DataType.UINT32:
completion_tokens_ptr = ctypes.cast(
output_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_uint32)
)
completion_tokens = completion_tokens_ptr[0]
elif output_token_tensor.data_type == tritonserver.DataType.INT32:
completion_tokens_ptr = ctypes.cast(
output_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_int32)
)
completion_tokens = completion_tokens_ptr[0]

if prompt_tokens is not None and completion_tokens is not None:
total_tokens = prompt_tokens + completion_tokens
Expand Down
18 changes: 3 additions & 15 deletions python/openai/tests/test_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ class TestChatCompletions:
def client(self, fastapi_client_class_scope):
yield fastapi_client_class_scope

def test_chat_completions_defaults(
self, client, model: str, messages: List[dict], backend: str
):
def test_chat_completions_defaults(self, client, model: str, messages: List[dict]):
response = client.post(
"/v1/chat/completions",
json={"model": model, "messages": messages},
Expand All @@ -55,10 +53,7 @@ def test_chat_completions_defaults(
assert message["role"] == "assistant"

usage = response.json().get("usage")
if backend == "vllm":
assert usage is not None
else:
assert usage is None
assert usage is not None

def test_chat_completions_system_prompt(self, client, model: str):
# NOTE: Currently just sanity check that there are no issues when a
Expand Down Expand Up @@ -536,14 +531,7 @@ def test_request_logprobs(self):
def test_request_logit_bias(self):
pass

def test_usage_response(
self, client, model: str, messages: List[dict], backend: str
):
if backend != "vllm":
pytest.skip(
"Usage reporting is currently available only for the vLLM backend."
)

def test_usage_response(self, client, model: str, messages: List[dict]):
response = client.post(
"/v1/chat/completions",
json={"model": model, "messages": messages},
Expand Down
14 changes: 3 additions & 11 deletions python/openai/tests/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class TestCompletions:
def client(self, fastapi_client_class_scope):
yield fastapi_client_class_scope

def test_completions_defaults(self, client, model: str, prompt: str, backend: str):
def test_completions_defaults(self, client, model: str, prompt: str):
response = client.post(
"/v1/completions",
json={"model": model, "prompt": prompt},
Expand All @@ -48,10 +48,7 @@ def test_completions_defaults(self, client, model: str, prompt: str, backend: st
assert response.json()["choices"][0]["text"].strip()

usage = response.json().get("usage")
if backend == "vllm":
assert usage is not None
else:
assert usage is None
assert usage is not None

@pytest.mark.parametrize(
"sampling_parameter, value",
Expand Down Expand Up @@ -371,12 +368,7 @@ def test_lora(self):
def test_multi_lora(self):
pass

def test_usage_response(self, client, model: str, prompt: str, backend: str):
if backend != "vllm":
pytest.skip(
"Usage reporting is currently available only for the vLLM backend."
)

def test_usage_response(self, client, model: str, prompt: str):
response = client.post(
"/v1/completions",
json={"model": model, "prompt": prompt},
Expand Down
123 changes: 34 additions & 89 deletions python/openai/tests/test_openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_openai_client_models(self, client: openai.OpenAI, backend: str):
raise Exception(f"Unexpected backend {backend=}")

def test_openai_client_completion(
self, client: openai.OpenAI, model: str, prompt: str, backend: str
self, client: openai.OpenAI, model: str, prompt: str
):
completion = client.completions.create(
prompt=prompt,
Expand All @@ -61,19 +61,16 @@ def test_openai_client_completion(
assert completion.choices[0].finish_reason == "stop"

usage = completion.usage
if backend == "vllm":
assert usage is not None
assert isinstance(usage.prompt_tokens, int)
assert isinstance(usage.completion_tokens, int)
assert isinstance(usage.total_tokens, int)
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0
assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens
else:
assert usage is None
assert usage is not None
assert isinstance(usage.prompt_tokens, int)
assert isinstance(usage.completion_tokens, int)
assert isinstance(usage.total_tokens, int)
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0
assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens

def test_openai_client_chat_completion(
self, client: openai.OpenAI, model: str, messages: List[dict], backend: str
self, client: openai.OpenAI, model: str, messages: List[dict]
):
chat_completion = client.chat.completions.create(
messages=messages,
Expand All @@ -85,16 +82,13 @@ def test_openai_client_chat_completion(
assert chat_completion.choices[0].finish_reason == "stop"

usage = chat_completion.usage
if backend == "vllm":
assert usage is not None
assert isinstance(usage.prompt_tokens, int)
assert isinstance(usage.completion_tokens, int)
assert isinstance(usage.total_tokens, int)
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0
assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens
else:
assert usage is None
assert usage is not None
assert isinstance(usage.prompt_tokens, int)
assert isinstance(usage.completion_tokens, int)
assert isinstance(usage.total_tokens, int)
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0
assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens

@pytest.mark.parametrize("echo", [False, True])
def test_openai_client_completion_echo(
Expand Down Expand Up @@ -141,7 +135,7 @@ async def test_openai_client_models(self, client: openai.AsyncOpenAI, backend: s

@pytest.mark.asyncio
async def test_openai_client_completion(
self, client: openai.AsyncOpenAI, model: str, prompt: str, backend: str
self, client: openai.AsyncOpenAI, model: str, prompt: str
):
completion = await client.completions.create(
prompt=prompt,
Expand All @@ -153,20 +147,17 @@ async def test_openai_client_completion(
assert completion.choices[0].finish_reason == "stop"

usage = completion.usage
if backend == "vllm":
assert usage is not None
assert isinstance(usage.prompt_tokens, int)
assert isinstance(usage.completion_tokens, int)
assert isinstance(usage.total_tokens, int)
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0
assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens
else:
assert usage is None
assert usage is not None
assert isinstance(usage.prompt_tokens, int)
assert isinstance(usage.completion_tokens, int)
assert isinstance(usage.total_tokens, int)
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0
assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens

@pytest.mark.asyncio
async def test_openai_client_chat_completion(
self, client: openai.AsyncOpenAI, model: str, messages: List[dict], backend: str
self, client: openai.AsyncOpenAI, model: str, messages: List[dict]
):
chat_completion = await client.chat.completions.create(
messages=messages,
Expand All @@ -177,16 +168,13 @@ async def test_openai_client_chat_completion(
assert chat_completion.choices[0].finish_reason == "stop"

usage = chat_completion.usage
if backend == "vllm":
assert usage is not None
assert isinstance(usage.prompt_tokens, int)
assert isinstance(usage.completion_tokens, int)
assert isinstance(usage.total_tokens, int)
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0
assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens
else:
assert usage is None
assert usage is not None
assert isinstance(usage.prompt_tokens, int)
assert isinstance(usage.completion_tokens, int)
assert isinstance(usage.total_tokens, int)
assert usage.prompt_tokens > 0
assert usage.completion_tokens > 0
assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens

print(f"Chat completion results: {chat_completion}")

Expand Down Expand Up @@ -300,13 +288,8 @@ async def test_chat_streaming(

@pytest.mark.asyncio
async def test_chat_streaming_usage_option(
self, client: openai.AsyncOpenAI, model: str, messages: List[dict], backend: str
self, client: openai.AsyncOpenAI, model: str, messages: List[dict]
):
if backend != "vllm":
pytest.skip(
"Usage reporting is currently available only for the vLLM backend."
)

seed = 0
temperature = 0.0
max_tokens = 16
Expand Down Expand Up @@ -397,13 +380,8 @@ async def test_chat_streaming_usage_option(

@pytest.mark.asyncio
async def test_completion_streaming_usage_option(
self, client: openai.AsyncOpenAI, model: str, prompt: str, backend: str
self, client: openai.AsyncOpenAI, model: str, prompt: str
):
if backend != "vllm":
pytest.skip(
"Usage reporting is currently available only for the vLLM backend."
)

seed = 0
temperature = 0.0
max_tokens = 16
Expand Down Expand Up @@ -509,36 +487,3 @@ async def test_stream_options_without_streaming(
stream_options={"include_usage": True},
)
assert "`stream_options` can only be used when `stream` is True" in str(e.value)

@pytest.mark.asyncio
async def test_streaming_usage_unsupported_backend(
self, client: openai.AsyncOpenAI, model: str, messages: List[dict], backend: str
):
if backend == "vllm":
pytest.skip(
"This test is for backends that do not support usage reporting."
)

with pytest.raises(openai.BadRequestError) as e:
await client.completions.create(
model=model,
prompt="Test prompt",
stream=True,
stream_options={"include_usage": True},
)
assert (
"`stream_options.include_usage` is currently only supported for the vLLM backend"
in str(e.value)
)

with pytest.raises(openai.BadRequestError) as e:
await client.chat.completions.create(
model=model,
messages=messages,
stream=True,
stream_options={"include_usage": True},
)
assert (
"`stream_options.include_usage` is currently only supported for the vLLM backend"
in str(e.value)
)
3 changes: 2 additions & 1 deletion qa/L0_openai/generate_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from argparse import ArgumentParser

from tensorrt_llm import LLM, BuildConfig
from tensorrt_llm import BuildConfig
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm.plugin import PluginConfig


Expand Down
Loading