Skip to content

Commit 035abff

Browse files
authored
Bring back stop string unit test (#3286)
CVS-165001
1 parent 9e1fb15 commit 035abff

File tree

1 file changed

+70
-10
lines changed

1 file changed

+70
-10
lines changed

src/test/llm/llmnode_test.cpp

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,9 +1588,7 @@ TEST_P(LLMFlowHttpTestParameterized, streamChatCompletionsFinishReasonLength) {
15881588
}
15891589
}
15901590

1591-
// Potential sporadic - move to functional if problematic
15921591
TEST_P(LLMFlowHttpTestParameterized, streamChatCompletionsSingleStopString) {
1593-
GTEST_SKIP() << "Real sporadic, either fix or move to functional";
15941592
auto params = GetParam();
15951593
std::string requestBody = R"(
15961594
{
@@ -1605,33 +1603,95 @@ TEST_P(LLMFlowHttpTestParameterized, streamChatCompletionsSingleStopString) {
16051603
"messages": [
16061604
{
16071605
"role": "user",
1608-
"content": "What is OpenVINO?"
1606+
"content": "What is OpenVINO? In short"
16091607
}
16101608
]
16111609
}
16121610
)";
16131611

16141612
std::vector<std::string> responses;
16151613

1614+
// Gather responses
16161615
EXPECT_CALL(*writer, PartialReply(::testing::_))
16171616
.WillRepeatedly([this, &responses](std::string response) {
16181617
responses.push_back(response);
16191618
});
1619+
1620+
// dispatchToProcessor is blocking because of mocked PartialReplyBegin in fixture:
1621+
// ON_CALL(*writer, PartialReplyBegin(::testing::_)).WillByDefault(testing::Invoke([](std::function<void()> fn) { fn(); }));
16201622
EXPECT_CALL(*writer, PartialReplyEnd()).Times(1);
16211623
ASSERT_EQ(
16221624
handler->dispatchToProcessor(endpointChatCompletions, requestBody, &response, comp, responseComponents, writer, multiPartParser),
16231625
ovms::StatusCode::PARTIAL_END);
1626+
1627+
// Check if there is at least one response
1628+
ASSERT_GT(responses.size(), 0);
1629+
16241630
if (params.checkFinishReason) {
16251631
ASSERT_TRUE(responses.back().find("\"finish_reason\":\"stop\"") != std::string::npos);
16261632
}
1627-
std::regex content_regex("\"content\":\".*\\.[ ]{0,1}\"");
1628-
if (params.modelName.find("legacy") != std::string::npos) {
1629-
// In legacy streaming we don't know if the callback is the last one, so we rely on entire generation call finish.
1630-
// Because of that, we might get additional response with empty content at the end of the stream.
1631-
ASSERT_TRUE(std::regex_search(responses[responses.size() - 2], content_regex) || std::regex_search(responses.back(), content_regex));
1632-
} else {
1633-
ASSERT_TRUE(std::regex_search(responses.back(), content_regex));
1633+
1634+
// In legacy streaming we don't know if the callback is the last one, so we rely on entire generation call finish.
1635+
// Because of that, we might get additional response with empty content at the end of the stream.
1636+
const size_t numberOfLastResponsesToCheckForStopString = params.modelName.find("legacy") != std::string::npos ? 2 : 1;
1637+
1638+
// The stop string (.) does not need to be at the end of the message.
1639+
// There are cases when the last generation contains dot and a new lines, or generated token is "e.g",
1640+
// or simply any token (or group of tokens) that has dot in a middle.
1641+
1642+
// Check for no existence of a dot:
1643+
for (size_t i = 0; i < responses.size() - numberOfLastResponsesToCheckForStopString; ++i) {
1644+
// Assert there is no dot '.' in the response
1645+
1646+
// Cut "data: " prefix
1647+
std::string dataPrefix = "data:";
1648+
std::string resp = responses[i].substr(dataPrefix.size());
1649+
1650+
rapidjson::Document d;
1651+
rapidjson::ParseResult ok = d.Parse(resp.c_str());
1652+
ASSERT_EQ(ok.Code(), 0) << d.GetParseError() << "\n"
1653+
<< resp;
1654+
1655+
ASSERT_TRUE(d["choices"].IsArray());
1656+
ASSERT_EQ(d["choices"].Size(), 1);
1657+
ASSERT_TRUE(d["choices"][0].IsObject());
1658+
ASSERT_TRUE(d["choices"][0]["delta"].IsObject());
1659+
ASSERT_TRUE(d["choices"][0]["delta"]["content"].IsString());
1660+
resp = d["choices"][0]["delta"]["content"].GetString();
1661+
ASSERT_EQ(resp.find('.'), std::string::npos) << "found dot in response: " << responses[i] << " at index: " << i << " out of: " << responses.size();
1662+
}
1663+
1664+
bool foundDotInLastResponse = false;
1665+
// Check for existence of a dot:
1666+
for (size_t i = responses.size() - numberOfLastResponsesToCheckForStopString; i < responses.size(); ++i) {
1667+
// Assert there is a dot '.' in the response
1668+
1669+
// Cut "data: " prefix
1670+
std::string dataPrefix = "data:";
1671+
std::string resp = responses[i].substr(dataPrefix.size());
1672+
1673+
// remove from resp: "data: [DONE]" (not only in the beginning)
1674+
size_t pos = resp.find("data: [DONE]");
1675+
if (pos != std::string::npos) {
1676+
resp.erase(pos, std::string("data: [DONE]").length());
1677+
}
1678+
1679+
rapidjson::Document d;
1680+
rapidjson::ParseResult ok = d.Parse(resp.c_str());
1681+
ASSERT_EQ(ok.Code(), 0) << d.GetParseError() << "\n"
1682+
<< resp;
1683+
1684+
ASSERT_TRUE(d["choices"].IsArray());
1685+
ASSERT_EQ(d["choices"].Size(), 1);
1686+
ASSERT_TRUE(d["choices"][0].IsObject());
1687+
ASSERT_TRUE(d["choices"][0]["delta"].IsObject());
1688+
ASSERT_TRUE(d["choices"][0]["delta"]["content"].IsString());
1689+
resp = d["choices"][0]["delta"]["content"].GetString();
1690+
if (resp.find('.') != std::string::npos) {
1691+
foundDotInLastResponse = true;
1692+
}
16341693
}
1694+
ASSERT_TRUE(foundDotInLastResponse) << "cannot find dot last responses";
16351695
}
16361696

16371697
TEST_P(LLMFlowHttpTestParameterized, streamCompletionsFinishReasonLength) {

0 commit comments

Comments
 (0)