Skip to content

Commit f941124

Browse files
[Feature] Support include_stop_str_in_output (#2930)
Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
1 parent b89f083 commit f941124

File tree

4 files changed

+74
-8
lines changed

4 files changed

+74
-8
lines changed

fastdeploy/entrypoints/openai/serving_chat.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ async def chat_completion_stream_generator(
119119
num_choices = 1
120120
max_streaming_response_tokens = 1
121121
enable_thinking = None
122+
include_stop_str_in_output = False
122123
if request.metadata is not None and request.metadata.get("max_streaming_response_tokens", 1) > 1:
123124
max_streaming_response_tokens = request.metadata["max_streaming_response_tokens"]
124125

@@ -146,6 +147,7 @@ async def chat_completion_stream_generator(
146147
current_waiting_time = 0
147148
if request.metadata is not None:
148149
enable_thinking = request.metadata.get("enable_thinking")
150+
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
149151
while num_choices > 0:
150152
try:
151153
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
@@ -169,7 +171,7 @@ async def chat_completion_stream_generator(
169171
raise ValueError("{}".format(res["error_msg"]))
170172

171173
self.engine_client.data_processor.process_response_dict(
172-
res, stream=True, enable_thinking=enable_thinking)
174+
res, stream=True, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
173175

174176
if res['metrics']['first_token_time'] is not None:
175177
arrival_time = res['metrics']['first_token_time']
@@ -303,6 +305,7 @@ async def chat_completion_full_generator(
303305
created_time = int(time.time())
304306
final_res = None
305307
enable_thinking = None
308+
include_stop_str_in_output = False
306309
try:
307310
dealer = await aiozmq.create_zmq_stream(
308311
zmq.DEALER,
@@ -335,8 +338,9 @@ async def chat_completion_full_generator(
335338
raise ValueError("{}".format(data["error_msg"]))
336339
if request.metadata is not None:
337340
enable_thinking = request.metadata.get("enable_thinking")
341+
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
338342
data = self.engine_client.data_processor.process_response_dict(
339-
data, stream=False, enable_thinking=enable_thinking)
343+
data, stream=False, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
340344
# api_server_logger.debug(f"Client {request_id} received: {data}")
341345
previous_num_tokens += len(data["outputs"]["token_ids"])
342346
# The logprob for handling the response

fastdeploy/input/ernie_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def process_response_dict_normal(self, response_dict, **kwargs):
248248
token_ids = response_dict["outputs"]["token_ids"]
249249
is_end = response_dict["finished"]
250250
req_id = response_dict["request_id"]
251-
if is_end and len(token_ids) > 0:
251+
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
252252
if token_ids[-1] == self.tokenizer.eos_token_id:
253253
token_ids = token_ids[:-1]
254254
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
@@ -283,7 +283,7 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
283283
req_id = response_dict["request_id"]
284284
token_ids = response_dict["outputs"]["token_ids"]
285285

286-
if is_end and len(token_ids) > 0:
286+
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
287287
if token_ids[-1] == self.tokenizer.eos_token_id:
288288
token_ids = token_ids[:-1]
289289
delta_text, previous_token_ids, previous_texts = self.ids2tokens(

fastdeploy/input/text_processor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def process_response_dict_normal(self, response_dict, **kwargs):
355355
token_ids = response_dict["outputs"]["token_ids"]
356356
is_end = response_dict["finished"]
357357
req_id = response_dict["request_id"]
358-
if is_end and len(token_ids) > 0:
358+
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
359359
if token_ids[-1] == self.tokenizer.eos_token_id:
360360
token_ids = token_ids[:-1]
361361
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
@@ -390,7 +390,7 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
390390
req_id = response_dict["request_id"]
391391
token_ids = response_dict["outputs"]["token_ids"]
392392

393-
if is_end and len(token_ids) > 0:
393+
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
394394
if token_ids[-1] == self.tokenizer.eos_token_id:
395395
token_ids = token_ids[:-1]
396396
delta_text, previous_token_ids, previous_texts = self.ids2tokens(
@@ -430,7 +430,7 @@ def process_response_dict(self, response_dict, **kwargs):
430430
response_dict, enable_thinking=enable_thinking, **kwargs)
431431
else:
432432
return self.process_response_dict_normal(
433-
response_dict=response_dict, enable_thinking=enable_thinking)
433+
response_dict=response_dict, enable_thinking=enable_thinking, **kwargs)
434434

435435
def text2ids(self, text, max_model_len, raw_request=True):
436436
"""

test/ci_use/EB_Lite/test_EB_Lite_serving.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,4 +313,66 @@ def test_streaming(openai_client, capsys):
313313
output = []
314314
for chunk in response:
315315
output.append(chunk.choices[0].text)
316-
assert len(output) > 0
316+
assert len(output) > 0
317+
318+
def test_non_streaming_with_stop_str(openai_client):
319+
"""
320+
Test non-streaming chat functionality with the local service
321+
"""
322+
response = openai_client.chat.completions.create(
323+
model="default",
324+
messages=[{"role": "user", "content": "Hello, how are you?"}],
325+
temperature=1,
326+
max_tokens=5,
327+
metadata={"include_stop_str_in_output": True},
328+
stream=False,
329+
)
330+
# Assertions to check the response structure
331+
assert hasattr(response, 'choices')
332+
assert len(response.choices) > 0
333+
assert response.choices[0].message.content.endswith("</s>")
334+
335+
response = openai_client.chat.completions.create(
336+
model="default",
337+
messages=[{"role": "user", "content": "Hello, how are you?"}],
338+
temperature=1,
339+
max_tokens=5,
340+
metadata={"include_stop_str_in_output": False},
341+
stream=False,
342+
)
343+
# Assertions to check the response structure
344+
assert hasattr(response, 'choices')
345+
assert len(response.choices) > 0
346+
assert not response.choices[0].message.content.endswith("</s>")
347+
348+
def test_streaming_with_stop_str(openai_client):
349+
"""
350+
Test non-streaming chat functionality with the local service
351+
"""
352+
response = openai_client.chat.completions.create(
353+
model="default",
354+
messages=[{"role": "user", "content": "Hello, how are you?"}],
355+
temperature=1,
356+
max_tokens=5,
357+
metadata={"include_stop_str_in_output": True},
358+
stream=True,
359+
)
360+
# Assertions to check the response structure
361+
last_token = ""
362+
for chunk in response:
363+
last_token = chunk.choices[0].delta.content
364+
assert last_token == "</s>"
365+
366+
response = openai_client.chat.completions.create(
367+
model="default",
368+
messages=[{"role": "user", "content": "Hello, how are you?"}],
369+
temperature=1,
370+
max_tokens=5,
371+
metadata={"include_stop_str_in_output": False},
372+
stream=True,
373+
)
374+
# Assertions to check the response structure
375+
last_token = ""
376+
for chunk in response:
377+
last_token = chunk.choices[0].delta.content
378+
assert last_token != "</s>"

0 commit comments

Comments
 (0)