Skip to content

Commit 8c9921f

Browse files
Validate create graph parameters
1 parent ca84e6a commit 8c9921f

File tree

5 files changed

+267
-1
lines changed

5 files changed

+267
-1
lines changed

src/config.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,40 @@ bool Config::validate() {
9292
std::cerr << "Error: --task parameter not set." << std::endl;
9393
return false;
9494
}
95+
std::vector allowedPipelineTypes = {"LM", "LM_CB", "VLM", "VLM_CB", "AUTO"};
96+
if (serverSettings.hfSettings.graphSettings.pipelineType.has_value() && std::find(allowedPipelineTypes.begin(), allowedPipelineTypes.end(), serverSettings.hfSettings.graphSettings.pipelineType) == allowedPipelineTypes.end()) {
97+
std::cerr << "pipeline_type: " << serverSettings.hfSettings.graphSettings.pipelineType.value() << " is not allowed. Supported types: LM, LM_CB, VLM, VLM_CB, AUTO" << std::endl;
98+
return false;
99+
}
100+
101+
std::vector allowedTargeDevices = {"CPU", "GPU", "NPU", "HETERO"};
102+
if (std::find(allowedTargeDevices.begin(), allowedTargeDevices.end(), serverSettings.hfSettings.graphSettings.targetDevice) == allowedTargeDevices.end()) {
103+
std::cerr << "target_device: " << serverSettings.hfSettings.graphSettings.targetDevice << " is not allowed. Supported devices: CPU, GPU, NPU, HETERO" << std::endl;
104+
return false;
105+
}
106+
107+
std::vector allowedBoolValues = {"false", "true"};
108+
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), serverSettings.hfSettings.graphSettings.enablePrefixCaching) == allowedBoolValues.end()) {
109+
std::cerr << "enable_prefix_caching: " << serverSettings.hfSettings.graphSettings.enablePrefixCaching << " is not allowed. Supported values: true, false" << std::endl;
110+
return false;
111+
}
112+
113+
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), serverSettings.hfSettings.graphSettings.dynamicSplitFuse) == allowedBoolValues.end()) {
114+
std::cerr << "dynamic_split_fuse: " << serverSettings.hfSettings.graphSettings.dynamicSplitFuse << " is not allowed. Supported values: true, false" << std::endl;
115+
return false;
116+
}
117+
118+
if (serverSettings.hfSettings.graphSettings.targetDevice != "NPU") {
119+
if (serverSettings.hfSettings.graphSettings.pluginConfig.maxPromptLength.has_value()) {
120+
std::cerr << "max_prompt_len is only supported for NPU target device";
121+
return false;
122+
}
123+
}
124+
125+
if (serverSettings.hfSettings.sourceModel.rfind("OpenVINO/", 0) != 0) {
126+
std::cerr << "For now only OpenVINO models are supported";
127+
return false;
128+
}
95129
return true;
96130
}
97131
if (this->serverSettings.listServables) {

src/graph_export/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ cc_library(
2727
"@ovms//src:libovms_module",
2828
"@ovms//src:libovmsfilesystem",
2929
"@com_github_tencent_rapidjson//:rapidjson",
30+
"@mediapipe//mediapipe/framework/port:parse_text_proto",
31+
"@mediapipe//mediapipe/framework:calculator_graph",
32+
"@ovms//src:libovmsschema",
3033
],
3134
visibility = ["//visibility:public"],
3235
)

src/graph_export/graph_export.cpp

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include <rapidjson/document.h>
2626
#include <rapidjson/stringbuffer.h>
2727
#include <rapidjson/writer.h>
28+
#include <rapidjson/istreamwrapper.h>
29+
#include <rapidjson/prettywriter.h>
2830
#pragma warning(pop)
2931

3032
#include "../capi_frontend/server_settings.hpp"
@@ -33,8 +35,12 @@
3335
#include "../logging.hpp"
3436
#include "../status.hpp"
3537
#include "../stringutils.hpp"
38+
#include "../schema.hpp"
3639
#include "graph_export_types.hpp"
3740

41+
#include "mediapipe/framework/port/parse_text_proto.h"
42+
#include "mediapipe/framework/calculator_graph.h"
43+
3844
namespace ovms {
3945

4046
Status createFile(const std::string& filePath, const std::string& contents) {
@@ -116,12 +122,31 @@ static Status createTextGenerationGraphTemplate(const std::string& directoryPath
116122
}
117123
}
118124
})";
119-
125+
::mediapipe::CalculatorGraphConfig config;
126+
bool success = ::google::protobuf::TextFormat::ParseFromString(oss.str(), &config);
127+
if (!success) {
128+
SPDLOG_ERROR("Created graph config couldn't be parsed.");
129+
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID;
130+
}
120131
// clang-format on
121132
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
122133
return createFile(fullPath, oss.str());
123134
}
124135

136+
static Status validateSubconfigSchema(const std::string& subconfig, const std::string& type) {
137+
rapidjson::Document subconfigJson;
138+
rapidjson::ParseResult parseResult = subconfigJson.Parse(subconfig.c_str());
139+
if (parseResult.Code()) {
140+
SPDLOG_LOGGER_ERROR(modelmanager_logger, "Created {} subconfig file is not a valid JSON file. Error: {}", type, rapidjson::GetParseError_En(parseResult.Code()));
141+
return StatusCode::JSON_INVALID;
142+
}
143+
if (validateJsonAgainstSchema(subconfigJson, MEDIAPIPE_SUBCONFIG_SCHEMA.c_str()) != StatusCode::OK) {
144+
SPDLOG_ERROR("Created {} subconfig file is not in valid configuration format", type);
145+
return StatusCode::JSON_INVALID;
146+
}
147+
return StatusCode::OK;
148+
}
149+
125150
static Status createRerankSubconfigTemplate(const std::string& directoryPath, const RerankGraphSettingsImpl& graphSettings) {
126151
std::ostringstream oss;
127152
// clang-format off
@@ -144,6 +169,10 @@ static Status createRerankSubconfigTemplate(const std::string& directoryPath, co
144169
}
145170
]
146171
})";
172+
auto status = validateSubconfigSchema(oss.str(), "rerank");
173+
if (!status.ok()){
174+
return status;
175+
}
147176
// clang-format on
148177
std::string fullPath = FileSystem::joinPath({directoryPath, "subconfig.json"});
149178
return createFile(fullPath, oss.str());
@@ -171,6 +200,10 @@ static Status createEmbeddingsSubconfigTemplate(const std::string& directoryPath
171200
}
172201
]
173202
})";
203+
auto status = validateSubconfigSchema(oss.str(), "embeddings");
204+
if (!status.ok()){
205+
return status;
206+
}
174207
// clang-format on
175208
std::string fullPath = FileSystem::joinPath({directoryPath, "subconfig.json"});
176209
return createFile(fullPath, oss.str());
@@ -210,6 +243,12 @@ static Status createRerankGraphTemplate(const std::string& directoryPath, const
210243
output_stream: "RESPONSE_PAYLOAD:output"
211244
})";
212245

246+
::mediapipe::CalculatorGraphConfig config;
247+
bool success = ::google::protobuf::TextFormat::ParseFromString(oss.str(), &config);
248+
if (!success) {
249+
SPDLOG_ERROR("Created rerank graph config couldn't be parsed.");
250+
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID;
251+
}
213252
// clang-format on
214253
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
215254
auto status = createFile(fullPath, oss.str());
@@ -259,6 +298,12 @@ static Status createEmbeddingsGraphTemplate(const std::string& directoryPath, co
259298
}
260299
})";
261300

301+
::mediapipe::CalculatorGraphConfig config;
302+
bool success = ::google::protobuf::TextFormat::ParseFromString(oss.str(), &config);
303+
if (!success) {
304+
SPDLOG_ERROR("Created embeddings graph config couldn't be parsed.");
305+
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID;
306+
}
262307
// clang-format on
263308
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
264309
auto status = createFile(fullPath, oss.str());

src/test/graph_export_test.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,36 @@ TEST_F(GraphCreationTest, rerankPositiveDefault) {
288288
ASSERT_EQ(expectedRerankJsonContents, jsonContents) << jsonContents;
289289
}
290290

291+
TEST_F(GraphCreationTest, rerankCreatedJsonInvalid) {
292+
ovms::HFSettingsImpl hfSettings;
293+
hfSettings.task = ovms::rerank;
294+
hfSettings.rerankGraphSettings.targetDevice = "GPU";
295+
hfSettings.rerankGraphSettings.modelName = "myModel\t";
296+
hfSettings.rerankGraphSettings.numStreams = 2;
297+
hfSettings.rerankGraphSettings.maxDocLength = 18;
298+
hfSettings.rerankGraphSettings.version = 2;
299+
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
300+
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
301+
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
302+
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
303+
ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID);
304+
}
305+
306+
TEST_F(GraphCreationTest, rerankCreatedPbtxtInvalid) {
307+
ovms::HFSettingsImpl hfSettings;
308+
hfSettings.task = ovms::rerank;
309+
hfSettings.rerankGraphSettings.targetDevice = "GPU";
310+
hfSettings.rerankGraphSettings.modelName = "myModel\"";
311+
hfSettings.rerankGraphSettings.numStreams = 2;
312+
hfSettings.rerankGraphSettings.maxDocLength = 18;
313+
hfSettings.rerankGraphSettings.version = 2;
314+
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
315+
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
316+
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
317+
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
318+
ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID);
319+
}
320+
291321
TEST_F(GraphCreationTest, embeddingsPositiveDefault) {
292322
ovms::HFSettingsImpl hfSettings;
293323
hfSettings.task = ovms::embeddings;
@@ -312,6 +342,38 @@ TEST_F(GraphCreationTest, embeddingsPositiveDefault) {
312342
ASSERT_EQ(expectedEmbeddingsJsonContents, jsonContents) << jsonContents;
313343
}
314344

345+
TEST_F(GraphCreationTest, embeddingsCreatedJsonInvalid) {
346+
ovms::HFSettingsImpl hfSettings;
347+
hfSettings.task = ovms::embeddings;
348+
hfSettings.embeddingsGraphSettings.targetDevice = "GPU";
349+
hfSettings.embeddingsGraphSettings.modelName = "myModel\t";
350+
hfSettings.embeddingsGraphSettings.numStreams = 2;
351+
hfSettings.embeddingsGraphSettings.truncate = "true";
352+
hfSettings.embeddingsGraphSettings.normalize = "true";
353+
hfSettings.embeddingsGraphSettings.version = 2;
354+
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
355+
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
356+
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
357+
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
358+
ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID);
359+
}
360+
361+
TEST_F(GraphCreationTest, embeddingsCreatedPbtxtInvalid) {
362+
ovms::HFSettingsImpl hfSettings;
363+
hfSettings.task = ovms::embeddings;
364+
hfSettings.embeddingsGraphSettings.targetDevice = "GPU";
365+
hfSettings.embeddingsGraphSettings.modelName = "myModel\"";
366+
hfSettings.embeddingsGraphSettings.numStreams = 2;
367+
hfSettings.embeddingsGraphSettings.truncate = "true";
368+
hfSettings.embeddingsGraphSettings.normalize = "true";
369+
hfSettings.embeddingsGraphSettings.version = 2;
370+
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
371+
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
372+
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
373+
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
374+
ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID);
375+
}
376+
315377
TEST_F(GraphCreationTest, positivePluginConfigAll) {
316378
ovms::HFSettingsImpl hfSettings;
317379
ovms::TextGenGraphSettingsImpl graphSettings;
@@ -377,3 +439,13 @@ TEST_F(GraphCreationTest, negativeGraphOptionsNotInitialized) {
377439
status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
378440
ASSERT_EQ(status, ovms::StatusCode::OK) << status.string();
379441
}
442+
443+
TEST_F(GraphCreationTest, negativeCreatedPbtxtInvalid) {
444+
ovms::HFSettingsImpl hfSettings;
445+
hfSettings.sourceModel = "\"";
446+
447+
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
448+
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
449+
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
450+
ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID);
451+
}

src/test/ovmsconfig_test.cpp

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,118 @@ TEST(OvmsGraphConfigTest, positiveDefaultStart) {
739739
ASSERT_EQ(graphSettings.draftModelDirName.has_value(), false);
740740
}
741741

742+
TEST(OvmsGraphConfigTest, negativePipelineType) {
743+
std::string modelName = "OpenVINO/Phi-3-mini-FastDraft-50M-int8-ov";
744+
std::string downloadPath = "test/repository";
745+
char* n_argv[] = {
746+
(char*)"ovms",
747+
(char*)"--pull",
748+
(char*)"--source_model",
749+
(char*)modelName.c_str(),
750+
(char*)"--model_repository_path",
751+
(char*)downloadPath.c_str(),
752+
(char*)"--pipeline_type",
753+
(char*)"INVALID",
754+
};
755+
756+
int arg_count = 8;
757+
ConstructorEnabledConfig config;
758+
EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "pipeline_type: INVALID is not allowed. Supported types: LM, LM_CB, VLM, VLM_CB, AUTO");
759+
}
760+
761+
TEST(OvmsGraphConfigTest, negativeTargetDevice) {
762+
std::string modelName = "OpenVINO/Phi-3-mini-FastDraft-50M-int8-ov";
763+
std::string downloadPath = "test/repository";
764+
char* n_argv[] = {
765+
(char*)"ovms",
766+
(char*)"--pull",
767+
(char*)"--source_model",
768+
(char*)modelName.c_str(),
769+
(char*)"--model_repository_path",
770+
(char*)downloadPath.c_str(),
771+
(char*)"--graph_target_device",
772+
(char*)"INVALID",
773+
};
774+
775+
int arg_count = 8;
776+
ConstructorEnabledConfig config;
777+
EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "target_device: INVALID is not allowed. Supported devices: CPU, GPU, NPU, HETERO");
778+
}
779+
780+
TEST(OvmsGraphConfigTest, negativeEnablePrefixCaching) {
781+
std::string modelName = "OpenVINO/Phi-3-mini-FastDraft-50M-int8-ov";
782+
std::string downloadPath = "test/repository";
783+
char* n_argv[] = {
784+
(char*)"ovms",
785+
(char*)"--pull",
786+
(char*)"--source_model",
787+
(char*)modelName.c_str(),
788+
(char*)"--model_repository_path",
789+
(char*)downloadPath.c_str(),
790+
(char*)"--enable_prefix_caching",
791+
(char*)"INVALID",
792+
};
793+
794+
int arg_count = 8;
795+
ConstructorEnabledConfig config;
796+
EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "enable_prefix_caching: INVALID is not allowed. Supported values: true, false");
797+
}
798+
799+
TEST(OvmsGraphConfigTest, negativeDynamicSplitFuse) {
800+
std::string modelName = "OpenVINO/Phi-3-mini-FastDraft-50M-int8-ov";
801+
std::string downloadPath = "test/repository";
802+
char* n_argv[] = {
803+
(char*)"ovms",
804+
(char*)"--pull",
805+
(char*)"--source_model",
806+
(char*)modelName.c_str(),
807+
(char*)"--model_repository_path",
808+
(char*)downloadPath.c_str(),
809+
(char*)"--dynamic_split_fuse",
810+
(char*)"INVALID",
811+
};
812+
813+
int arg_count = 8;
814+
ConstructorEnabledConfig config;
815+
EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "dynamic_split_fuse: INVALID is not allowed. Supported values: true, false");
816+
}
817+
818+
TEST(OvmsGraphConfigTest, negativeMaxPromptLength) {
819+
std::string modelName = "OpenVINO/Phi-3-mini-FastDraft-50M-int8-ov";
820+
std::string downloadPath = "test/repository";
821+
char* n_argv[] = {
822+
(char*)"ovms",
823+
(char*)"--pull",
824+
(char*)"--source_model",
825+
(char*)modelName.c_str(),
826+
(char*)"--model_repository_path",
827+
(char*)downloadPath.c_str(),
828+
(char*)"--max_prompt_len",
829+
(char*)"10",
830+
};
831+
832+
int arg_count = 8;
833+
ConstructorEnabledConfig config;
834+
EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "max_prompt_len is only supported for NPU target device");
835+
}
836+
837+
TEST(OvmsGraphConfigTest, negativeSourceModel) {
838+
std::string modelName = "NonOpenVINO/Phi-3-mini-FastDraft-50M-int8-ov";
839+
std::string downloadPath = "test/repository";
840+
char* n_argv[] = {
841+
(char*)"ovms",
842+
(char*)"--pull",
843+
(char*)"--source_model",
844+
(char*)modelName.c_str(),
845+
(char*)"--model_repository_path",
846+
(char*)downloadPath.c_str(),
847+
};
848+
849+
int arg_count = 6;
850+
ConstructorEnabledConfig config;
851+
EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "For now only OpenVINO models are supported");
852+
}
853+
742854
TEST(OvmsGraphConfigTest, positiveAllChangedRerank) {
743855
std::string modelName = "OpenVINO/Phi-3-mini-FastDraft-50M-int8-ov";
744856
std::string downloadPath = "test/repository";

0 commit comments

Comments
 (0)