Skip to content

Commit 7a3d2a5

Browse files
authored
[Frontend] Support for chat completions input in the tokenize endpoint (#5923)
1 parent d970115 commit 7a3d2a5

File tree

9 files changed

+386
-244
lines changed

9 files changed

+386
-244
lines changed

tests/async_engine/test_chat_template.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
import pytest
66

7+
from vllm.entrypoints.openai.chat_utils import load_chat_template
78
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
8-
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
99
from vllm.transformers_utils.tokenizer import get_tokenizer
1010

1111
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
@@ -64,8 +64,7 @@ def test_load_chat_template():
6464
# Testing chatml template
6565
tokenizer = MockTokenizer()
6666
mock_serving_chat = MockServingChat(tokenizer)
67-
OpenAIServingChat._load_chat_template(mock_serving_chat,
68-
chat_template=chatml_jinja_path)
67+
load_chat_template(mock_serving_chat, chat_template=chatml_jinja_path)
6968

7069
template_content = tokenizer.chat_template
7170

@@ -84,8 +83,7 @@ def test_no_load_chat_template_filelike():
8483
mock_serving_chat = MockServingChat(tokenizer)
8584

8685
with pytest.raises(ValueError, match="looks like a file path"):
87-
OpenAIServingChat._load_chat_template(mock_serving_chat,
88-
chat_template=template)
86+
load_chat_template(mock_serving_chat, chat_template=template)
8987

9088

9189
def test_no_load_chat_template_literallike():
@@ -94,8 +92,7 @@ def test_no_load_chat_template_literallike():
9492
tokenizer = MockTokenizer()
9593

9694
mock_serving_chat = MockServingChat(tokenizer)
97-
OpenAIServingChat._load_chat_template(mock_serving_chat,
98-
chat_template=template)
95+
load_chat_template(mock_serving_chat, chat_template=template)
9996
template_content = tokenizer.chat_template
10097

10198
assert template_content == template
@@ -109,8 +106,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
109106
# Initialize the tokenizer
110107
tokenizer = get_tokenizer(tokenizer_name=model)
111108
mock_serving_chat = MockServingChat(tokenizer)
112-
OpenAIServingChat._load_chat_template(mock_serving_chat,
113-
chat_template=template)
109+
load_chat_template(mock_serving_chat, chat_template=template)
114110

115111
# Create a mock request object using keyword arguments
116112
mock_request = ChatCompletionRequest(

tests/entrypoints/openai/test_completion.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import jsonschema
77
import openai # use the official client for correctness check
88
import pytest
9-
import requests
109
# downloading lora to test lora requests
1110
from huggingface_hub import snapshot_download
1211
from openai import BadRequestError
@@ -636,51 +635,3 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
636635
prompt="Give an example string that fits this regex",
637636
extra_body=dict(guided_regex=sample_regex,
638637
guided_json=sample_json_schema))
639-
640-
641-
@pytest.mark.asyncio
642-
@pytest.mark.parametrize(
643-
"model_name",
644-
[MODEL_NAME],
645-
)
646-
async def test_tokenize(client: openai.AsyncOpenAI, model_name: str):
647-
base_url = str(client.base_url)[:-3].strip("/")
648-
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
649-
650-
for add_special in [False, True]:
651-
prompt = "This is a test prompt."
652-
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
653-
654-
response = requests.post(base_url + "/tokenize",
655-
json={
656-
"add_special_tokens": add_special,
657-
"model": model_name,
658-
"prompt": prompt
659-
})
660-
response.raise_for_status()
661-
assert response.json() == {
662-
"tokens": tokens,
663-
"count": len(tokens),
664-
"max_model_len": 8192
665-
}
666-
667-
668-
@pytest.mark.asyncio
669-
@pytest.mark.parametrize(
670-
"model_name",
671-
[MODEL_NAME],
672-
)
673-
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
674-
base_url = str(client.base_url)[:-3]
675-
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
676-
677-
prompt = "This is a test prompt."
678-
tokens = tokenizer.encode(prompt, add_special_tokens=False)
679-
680-
response = requests.post(base_url + "detokenize",
681-
json={
682-
"model": model_name,
683-
"tokens": tokens
684-
})
685-
response.raise_for_status()
686-
assert response.json() == {"prompt": prompt}
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import openai # use the official client for correctness check
2+
import pytest
3+
import requests
4+
5+
from vllm.transformers_utils.tokenizer import get_tokenizer
6+
7+
from ...utils import RemoteOpenAIServer
8+
9+
# any model with a chat template should work here
10+
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
11+
12+
13+
@pytest.fixture(scope="module")
14+
def server():
15+
with RemoteOpenAIServer([
16+
"--model",
17+
MODEL_NAME,
18+
# use half precision for speed and memory savings in CI environment
19+
"--dtype",
20+
"bfloat16",
21+
"--max-model-len",
22+
"8192",
23+
"--enforce-eager",
24+
"--max-num-seqs",
25+
"128",
26+
]) as remote_server:
27+
yield remote_server
28+
29+
30+
@pytest.fixture(scope="module")
31+
def client(server):
32+
return server.get_async_client()
33+
34+
35+
@pytest.mark.asyncio
36+
@pytest.mark.parametrize(
37+
"model_name",
38+
[MODEL_NAME],
39+
)
40+
async def test_tokenize_completions(client: openai.AsyncOpenAI,
41+
model_name: str):
42+
base_url = str(client.base_url)[:-3].strip("/")
43+
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
44+
45+
for add_special in [False, True]:
46+
prompt = "This is a test prompt."
47+
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
48+
49+
response = requests.post(base_url + "/tokenize",
50+
json={
51+
"add_special_tokens": add_special,
52+
"model": model_name,
53+
"prompt": prompt
54+
})
55+
response.raise_for_status()
56+
57+
assert response.json() == {
58+
"tokens": tokens,
59+
"count": len(tokens),
60+
"max_model_len": 8192
61+
}
62+
63+
64+
@pytest.mark.asyncio
65+
@pytest.mark.parametrize(
66+
"model_name",
67+
[MODEL_NAME],
68+
)
69+
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
70+
base_url = str(client.base_url)[:-3].strip("/")
71+
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
72+
73+
for add_generation in [False, True]:
74+
for add_special in [False, True]:
75+
conversation = [{
76+
"role": "user",
77+
"content": "Hi there!"
78+
}, {
79+
"role": "assistant",
80+
"content": "Nice to meet you!"
81+
}, {
82+
"role": "user",
83+
"content": "Can I ask a question?"
84+
}]
85+
86+
prompt = tokenizer.apply_chat_template(
87+
add_generation_prompt=add_generation,
88+
conversation=conversation,
89+
tokenize=False)
90+
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
91+
92+
response = requests.post(base_url + "/tokenize",
93+
json={
94+
"add_generation_prompt":
95+
add_generation,
96+
"add_special_tokens": add_special,
97+
"messages": conversation,
98+
"model": model_name
99+
})
100+
response.raise_for_status()
101+
102+
assert response.json() == {
103+
"tokens": tokens,
104+
"count": len(tokens),
105+
"max_model_len": 8192
106+
}
107+
108+
109+
@pytest.mark.asyncio
110+
@pytest.mark.parametrize(
111+
"model_name",
112+
[MODEL_NAME],
113+
)
114+
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
115+
base_url = str(client.base_url)[:-3].strip("/")
116+
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
117+
118+
prompt = "This is a test prompt."
119+
tokens = tokenizer.encode(prompt, add_special_tokens=False)
120+
121+
response = requests.post(base_url + "/detokenize",
122+
json={
123+
"model": model_name,
124+
"tokens": tokens
125+
})
126+
response.raise_for_status()
127+
128+
assert response.json() == {"prompt": prompt}

vllm/entrypoints/openai/api_server.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
3434
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
3535
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
36+
from vllm.entrypoints.openai.serving_tokenization import (
37+
OpenAIServingTokenization)
3638
from vllm.logger import init_logger
3739
from vllm.usage.usage_lib import UsageContext
3840
from vllm.utils import FlexibleArgumentParser
@@ -46,6 +48,7 @@
4648
openai_serving_chat: OpenAIServingChat
4749
openai_serving_completion: OpenAIServingCompletion
4850
openai_serving_embedding: OpenAIServingEmbedding
51+
openai_serving_tokenization: OpenAIServingTokenization
4952

5053
logger = init_logger('vllm.entrypoints.openai.api_server')
5154

@@ -86,7 +89,7 @@ async def health() -> Response:
8689

8790
@router.post("/tokenize")
8891
async def tokenize(request: TokenizeRequest):
89-
generator = await openai_serving_completion.create_tokenize(request)
92+
generator = await openai_serving_tokenization.create_tokenize(request)
9093
if isinstance(generator, ErrorResponse):
9194
return JSONResponse(content=generator.model_dump(),
9295
status_code=generator.code)
@@ -97,7 +100,7 @@ async def tokenize(request: TokenizeRequest):
97100

98101
@router.post("/detokenize")
99102
async def detokenize(request: DetokenizeRequest):
100-
generator = await openai_serving_completion.create_detokenize(request)
103+
generator = await openai_serving_tokenization.create_detokenize(request)
101104
if isinstance(generator, ErrorResponse):
102105
return JSONResponse(content=generator.model_dump(),
103106
status_code=generator.code)
@@ -241,6 +244,7 @@ def run_server(args, llm_engine=None):
241244
global openai_serving_chat
242245
global openai_serving_completion
243246
global openai_serving_embedding
247+
global openai_serving_tokenization
244248

245249
openai_serving_chat = OpenAIServingChat(engine, model_config,
246250
served_model_names,
@@ -252,6 +256,8 @@ def run_server(args, llm_engine=None):
252256
args.prompt_adapters)
253257
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
254258
served_model_names)
259+
openai_serving_tokenization = OpenAIServingTokenization(
260+
engine, model_config, served_model_names, args.chat_template)
255261
app.root_path = args.root_path
256262

257263
logger.info("Available routes are:")

0 commit comments

Comments
 (0)