Skip to content

Commit d885cf3

Browse files
Validate create graph parameters (#3290)
1 parent f80f038 commit d885cf3

File tree

7 files changed

+373
-19
lines changed

7 files changed

+373
-19
lines changed

src/capi_frontend/server_settings.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ struct EmbeddingsGraphSettingsImpl {
6060
uint32_t numStreams = 1;
6161
uint32_t version = 1; // FIXME: export_embeddings_tokenizer python method - not supported currently?
6262
std::string normalize = "false";
63-
std::string truncate = "false"; // FIXME: export_embeddings_tokenizer python method - not supported currently?
6463
};
6564

6665
struct RerankGraphSettingsImpl {

src/config.cpp

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,73 @@ bool Config::check_hostname_or_ip(const std::string& input) {
8686
}
8787

8888
bool Config::validate() {
89-
// TODO: CVS-166727 Add validation of all parameters once the CLI model export flags will be implemented
9089
if (this->serverSettings.serverMode == HF_PULL_MODE) {
90+
if (!serverSettings.hfSettings.sourceModel.size()) {
91+
std::cerr << "source_model parameter is required for pull mode";
92+
return false;
93+
}
94+
if (!serverSettings.hfSettings.downloadPath.size()) {
95+
std::cerr << "model_repository_path parameter is required for pull mode";
96+
return false;
97+
}
9198
if (this->serverSettings.hfSettings.task == UNKNOWN_GRAPH) {
9299
std::cerr << "Error: --task parameter not set." << std::endl;
93100
return false;
94101
}
102+
if (serverSettings.hfSettings.sourceModel.rfind("OpenVINO/", 0) != 0) {
103+
std::cerr << "For now only OpenVINO models are supported in pulling mode";
104+
return false;
105+
}
106+
if (this->serverSettings.hfSettings.task == TEXT_GENERATION_GRAPH) {
107+
if (!std::holds_alternative<TextGenGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings)) {
108+
std::cerr << "Graph options not initialized for text generation.";
109+
return false;
110+
}
111+
auto settings = std::get<TextGenGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings);
112+
std::vector allowedPipelineTypes = {"LM", "LM_CB", "VLM", "VLM_CB", "AUTO"};
113+
if (settings.pipelineType.has_value() && std::find(allowedPipelineTypes.begin(), allowedPipelineTypes.end(), settings.pipelineType) == allowedPipelineTypes.end()) {
114+
std::cerr << "pipeline_type: " << settings.pipelineType.value() << " is not allowed. Supported types: LM, LM_CB, VLM, VLM_CB, AUTO" << std::endl;
115+
return false;
116+
}
117+
118+
std::vector allowedTargetDevices = {"CPU", "GPU", "NPU", "AUTO"};
119+
if (std::find(allowedTargetDevices.begin(), allowedTargetDevices.end(), settings.targetDevice) == allowedTargetDevices.end() && settings.targetDevice.rfind("HETERO", 0) != 0) {
120+
std::cerr << "target_device: " << settings.targetDevice << " is not allowed. Supported devices: CPU, GPU, NPU, HETERO, AUTO" << std::endl;
121+
return false;
122+
}
123+
124+
std::vector allowedBoolValues = {"false", "true"};
125+
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.enablePrefixCaching) == allowedBoolValues.end()) {
126+
std::cerr << "enable_prefix_caching: " << settings.enablePrefixCaching << " is not allowed. Supported values: true, false" << std::endl;
127+
return false;
128+
}
129+
130+
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.dynamicSplitFuse) == allowedBoolValues.end()) {
131+
std::cerr << "dynamic_split_fuse: " << settings.dynamicSplitFuse << " is not allowed. Supported values: true, false" << std::endl;
132+
return false;
133+
}
134+
135+
if (settings.targetDevice != "NPU") {
136+
if (settings.pluginConfig.maxPromptLength.has_value()) {
137+
std::cerr << "max_prompt_len is only supported for NPU target device";
138+
return false;
139+
}
140+
}
141+
}
142+
143+
if (this->serverSettings.hfSettings.task == EMBEDDINGS_GRAPH) {
144+
if (!std::holds_alternative<EmbeddingsGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings)) {
145+
std::cerr << "Graph options not initialized for embeddings.";
146+
return false;
147+
}
148+
auto settings = std::get<EmbeddingsGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings);
149+
150+
std::vector allowedBoolValues = {"false", "true"};
151+
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.normalize) == allowedBoolValues.end()) {
152+
std::cerr << "normalize: " << settings.normalize << " is not allowed. Supported values: true, false" << std::endl;
153+
return false;
154+
}
155+
}
95156
return true;
96157
}
97158
if (this->serverSettings.serverMode == LIST_MODELS_MODE) {

src/graph_export/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ cc_library(
2929
"@ovms//src:libovmsfilesystem",
3030
"@ovms//src:libovmslocalfilesystem",
3131
"@com_github_tencent_rapidjson//:rapidjson",
32+
"@mediapipe//mediapipe/framework/port:parse_text_proto",
33+
"@mediapipe//mediapipe/framework:calculator_graph",
34+
"@ovms//src:libovmsschema",
3235
],
3336
visibility = ["//visibility:public"],
3437
)

src/graph_export/embeddings_graph_cli_parser.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@ void EmbeddingsGraphCLIParser::createOptions() {
4444
"The number of parallel execution streams to use for the model. Use at least 2 on 2 socket CPU systems.",
4545
cxxopts::value<uint32_t>()->default_value("1"),
4646
"NUM_STREAMS")
47-
("truncate",
48-
"Truncate the prompts to fit to the embeddings model.",
49-
cxxopts::value<std::string>()->default_value("false"),
50-
"TRUNCATE")
5147
("normalize",
5248
"Normalize the embeddings.",
5349
cxxopts::value<std::string>()->default_value("false"),
@@ -95,7 +91,6 @@ void EmbeddingsGraphCLIParser::prepare(OvmsServerMode serverMode, HFSettingsImpl
9591
} else {
9692
embeddingsGraphSettings.numStreams = result->operator[]("num_streams").as<uint32_t>();
9793
embeddingsGraphSettings.normalize = result->operator[]("normalize").as<std::string>();
98-
embeddingsGraphSettings.truncate = result->operator[]("truncate").as<std::string>();
9994
embeddingsGraphSettings.version = result->operator[]("model_version").as<std::uint32_t>();
10095
}
10196
hfSettings.graphSettings = std::move(embeddingsGraphSettings);

src/graph_export/graph_export.cpp

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include <rapidjson/document.h>
2626
#include <rapidjson/stringbuffer.h>
2727
#include <rapidjson/writer.h>
28+
#include <rapidjson/istreamwrapper.h>
29+
#include <rapidjson/prettywriter.h>
2830
#pragma warning(pop)
2931

3032
#include "../capi_frontend/server_settings.hpp"
@@ -34,8 +36,12 @@
3436
#include "../logging.hpp"
3537
#include "../status.hpp"
3638
#include "../stringutils.hpp"
39+
#include "../schema.hpp"
3740
#include "graph_export_types.hpp"
3841

42+
#include "mediapipe/framework/port/parse_text_proto.h"
43+
#include "mediapipe/framework/calculator_graph.h"
44+
3945
namespace ovms {
4046

4147
static Status createTextGenerationGraphTemplate(const std::string& directoryPath, const TextGenGraphSettingsImpl& graphSettings) {
@@ -102,12 +108,31 @@ static Status createTextGenerationGraphTemplate(const std::string& directoryPath
102108
}
103109
}
104110
})";
105-
111+
::mediapipe::CalculatorGraphConfig config;
112+
bool success = ::google::protobuf::TextFormat::ParseFromString(oss.str(), &config);
113+
if (!success) {
114+
SPDLOG_ERROR("Created graph config file couldn't be parsed - check used task parameters values.");
115+
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID;
116+
}
106117
// clang-format on
107118
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
108119
return FileSystem::createFileOverwrite(fullPath, oss.str());
109120
}
110121

122+
static Status validateSubconfigSchema(const std::string& subconfig, const std::string& type) {
123+
rapidjson::Document subconfigJson;
124+
rapidjson::ParseResult parseResult = subconfigJson.Parse(subconfig.c_str());
125+
if (parseResult.Code()) {
126+
SPDLOG_LOGGER_ERROR(modelmanager_logger, "Created {} subconfig file is not a valid JSON file. Error: {}", type, rapidjson::GetParseError_En(parseResult.Code()));
127+
return StatusCode::JSON_INVALID;
128+
}
129+
if (validateJsonAgainstSchema(subconfigJson, MEDIAPIPE_SUBCONFIG_SCHEMA.c_str()) != StatusCode::OK) {
130+
SPDLOG_ERROR("Created {} subconfig file is not in valid configuration format", type);
131+
return StatusCode::JSON_INVALID;
132+
}
133+
return StatusCode::OK;
134+
}
135+
111136
static Status createRerankSubconfigTemplate(const std::string& directoryPath, const RerankGraphSettingsImpl& graphSettings) {
112137
std::ostringstream oss;
113138
// clang-format off
@@ -130,6 +155,10 @@ static Status createRerankSubconfigTemplate(const std::string& directoryPath, co
130155
}
131156
]
132157
})";
158+
auto status = validateSubconfigSchema(oss.str(), "rerank");
159+
if (!status.ok()){
160+
return status;
161+
}
133162
// clang-format on
134163
std::string fullPath = FileSystem::joinPath({directoryPath, "subconfig.json"});
135164
return FileSystem::createFileOverwrite(fullPath, oss.str());
@@ -157,6 +186,10 @@ static Status createEmbeddingsSubconfigTemplate(const std::string& directoryPath
157186
}
158187
]
159188
})";
189+
auto status = validateSubconfigSchema(oss.str(), "embeddings");
190+
if (!status.ok()){
191+
return status;
192+
}
160193
// clang-format on
161194
std::string fullPath = FileSystem::joinPath({directoryPath, "subconfig.json"});
162195
return FileSystem::createFileOverwrite(fullPath, oss.str());
@@ -196,6 +229,12 @@ static Status createRerankGraphTemplate(const std::string& directoryPath, const
196229
output_stream: "RESPONSE_PAYLOAD:output"
197230
})";
198231

232+
::mediapipe::CalculatorGraphConfig config;
233+
bool success = ::google::protobuf::TextFormat::ParseFromString(oss.str(), &config);
234+
if (!success) {
235+
SPDLOG_ERROR("Created rerank graph config couldn't be parsed.");
236+
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID;
237+
}
199238
// clang-format on
200239
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
201240
auto status = FileSystem::createFileOverwrite(fullPath, oss.str());
@@ -245,6 +284,12 @@ static Status createEmbeddingsGraphTemplate(const std::string& directoryPath, co
245284
}
246285
})";
247286

287+
::mediapipe::CalculatorGraphConfig config;
288+
bool success = ::google::protobuf::TextFormat::ParseFromString(oss.str(), &config);
289+
if (!success) {
290+
SPDLOG_ERROR("Created embeddings graph config couldn't be parsed.");
291+
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID;
292+
}
248293
// clang-format on
249294
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
250295
auto status = FileSystem::createFileOverwrite(fullPath, oss.str());

src/test/graph_export_test.cpp

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,14 +288,47 @@ TEST_F(GraphCreationTest, rerankPositiveDefault) {
288288
ASSERT_EQ(expectedRerankJsonContents, jsonContents) << jsonContents;
289289
}
290290

291+
TEST_F(GraphCreationTest, rerankCreatedJsonInvalid) {
292+
ovms::HFSettingsImpl hfSettings;
293+
hfSettings.task = ovms::RERANK_GRAPH;
294+
ovms::RerankGraphSettingsImpl rerankGraphSettings;
295+
rerankGraphSettings.targetDevice = "GPU";
296+
rerankGraphSettings.modelName = "myModel\t";
297+
rerankGraphSettings.numStreams = 2;
298+
rerankGraphSettings.maxDocLength = 18;
299+
rerankGraphSettings.version = 2;
300+
hfSettings.graphSettings = std::move(rerankGraphSettings);
301+
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
302+
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
303+
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
304+
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
305+
ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID);
306+
}
307+
308+
TEST_F(GraphCreationTest, rerankCreatedPbtxtInvalid) {
309+
ovms::HFSettingsImpl hfSettings;
310+
hfSettings.task = ovms::RERANK_GRAPH;
311+
ovms::RerankGraphSettingsImpl rerankGraphSettings;
312+
rerankGraphSettings.targetDevice = "GPU";
313+
rerankGraphSettings.modelName = "myModel\"";
314+
rerankGraphSettings.numStreams = 2;
315+
rerankGraphSettings.maxDocLength = 18;
316+
rerankGraphSettings.version = 2;
317+
hfSettings.graphSettings = std::move(rerankGraphSettings);
318+
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
319+
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
320+
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
321+
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
322+
ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID);
323+
}
324+
291325
TEST_F(GraphCreationTest, embeddingsPositiveDefault) {
292326
ovms::HFSettingsImpl hfSettings;
293327
hfSettings.task = ovms::EMBEDDINGS_GRAPH;
294328
ovms::EmbeddingsGraphSettingsImpl embeddingsGraphSettings;
295329
embeddingsGraphSettings.targetDevice = "GPU";
296330
embeddingsGraphSettings.modelName = "myModel";
297331
embeddingsGraphSettings.numStreams = 2;
298-
embeddingsGraphSettings.truncate = "true";
299332
embeddingsGraphSettings.normalize = "true";
300333
embeddingsGraphSettings.version = 2;
301334
hfSettings.graphSettings = std::move(embeddingsGraphSettings);
@@ -312,6 +345,40 @@ TEST_F(GraphCreationTest, embeddingsPositiveDefault) {
312345
ASSERT_EQ(expectedEmbeddingsJsonContents, jsonContents) << jsonContents;
313346
}
314347

348+
TEST_F(GraphCreationTest, embeddingsCreatedJsonInvalid) {
349+
ovms::HFSettingsImpl hfSettings;
350+
hfSettings.task = ovms::EMBEDDINGS_GRAPH;
351+
ovms::EmbeddingsGraphSettingsImpl embeddingsGraphSettings;
352+
embeddingsGraphSettings.targetDevice = "GPU";
353+
embeddingsGraphSettings.modelName = "myModel\t";
354+
embeddingsGraphSettings.numStreams = 2;
355+
embeddingsGraphSettings.normalize = "true";
356+
embeddingsGraphSettings.version = 2;
357+
hfSettings.graphSettings = std::move(embeddingsGraphSettings);
358+
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
359+
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
360+
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
361+
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
362+
ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID);
363+
}
364+
365+
TEST_F(GraphCreationTest, embeddingsCreatedPbtxtInvalid) {
366+
ovms::HFSettingsImpl hfSettings;
367+
hfSettings.task = ovms::EMBEDDINGS_GRAPH;
368+
ovms::EmbeddingsGraphSettingsImpl embeddingsGraphSettings;
369+
embeddingsGraphSettings.targetDevice = "GPU";
370+
embeddingsGraphSettings.modelName = "myModel\"";
371+
embeddingsGraphSettings.numStreams = 2;
372+
embeddingsGraphSettings.normalize = "true";
373+
embeddingsGraphSettings.version = 2;
374+
hfSettings.graphSettings = std::move(embeddingsGraphSettings);
375+
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
376+
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
377+
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
378+
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
379+
ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID);
380+
}
381+
315382
TEST_F(GraphCreationTest, positivePluginConfigAll) {
316383
ovms::HFSettingsImpl hfSettings;
317384
ovms::TextGenGraphSettingsImpl graphSettings;
@@ -377,3 +444,16 @@ TEST_F(GraphCreationTest, negativeGraphOptionsNotInitialized) {
377444
status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
378445
ASSERT_EQ(status, ovms::StatusCode::OK) << status.string();
379446
}
447+
448+
TEST_F(GraphCreationTest, negativeCreatedPbtxtInvalid) {
449+
ovms::HFSettingsImpl hfSettings;
450+
hfSettings.task = ovms::TEXT_GENERATION_GRAPH;
451+
ovms::TextGenGraphSettingsImpl graphSettings;
452+
graphSettings.modelPath = "invalid\"";
453+
hfSettings.graphSettings = std::move(graphSettings);
454+
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
455+
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
456+
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
457+
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
458+
ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID);
459+
}

0 commit comments

Comments
 (0)