Skip to content

Validate create graph parameters #3290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/capi_frontend/server_settings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ struct EmbeddingsGraphSettingsImpl {
uint32_t numStreams = 1;
uint32_t version = 1; // FIXME: export_embeddings_tokenizer python method - not supported currently?
std::string normalize = "false";
std::string truncate = "false"; // FIXME: export_embeddings_tokenizer python method - not supported currently?
};

struct RerankGraphSettingsImpl {
Expand Down
63 changes: 62 additions & 1 deletion src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,73 @@ bool Config::check_hostname_or_ip(const std::string& input) {
}

bool Config::validate() {
// TODO: CVS-166727 Add validation of all parameters once the CLI model export flags will be implemented
if (this->serverSettings.serverMode == HF_PULL_MODE) {
if (!serverSettings.hfSettings.sourceModel.size()) {
std::cerr << "source_model parameter is required for pull mode";
return false;
}
if (!serverSettings.hfSettings.downloadPath.size()) {
std::cerr << "model_repository_path parameter is required for pull mode";
return false;
}
if (this->serverSettings.hfSettings.task == UNKNOWN_GRAPH) {
std::cerr << "Error: --task parameter not set." << std::endl;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove // TODO: CVS-1667

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not removed.

return false;
}
if (serverSettings.hfSettings.sourceModel.rfind("OpenVINO/", 0) != 0) {
std::cerr << "For now only OpenVINO models are supported in pulling mode";
return false;
}
if (this->serverSettings.hfSettings.task == TEXT_GENERATION_GRAPH) {
if (!std::holds_alternative<TextGenGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings)) {
std::cerr << "Graph options not initialized for text generation.";
return false;
}
auto settings = std::get<TextGenGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings);
std::vector allowedPipelineTypes = {"LM", "LM_CB", "VLM", "VLM_CB", "AUTO"};
if (settings.pipelineType.has_value() && std::find(allowedPipelineTypes.begin(), allowedPipelineTypes.end(), settings.pipelineType) == allowedPipelineTypes.end()) {
std::cerr << "pipeline_type: " << settings.pipelineType.value() << " is not allowed. Supported types: LM, LM_CB, VLM, VLM_CB, AUTO" << std::endl;
return false;
}

std::vector allowedTargetDevices = {"CPU", "GPU", "NPU", "AUTO"};
if (std::find(allowedTargetDevices.begin(), allowedTargetDevices.end(), settings.targetDevice) == allowedTargetDevices.end() && settings.targetDevice.rfind("HETERO", 0) != 0) {
std::cerr << "target_device: " << settings.targetDevice << " is not allowed. Supported devices: CPU, GPU, NPU, HETERO, AUTO" << std::endl;
return false;
}

std::vector allowedBoolValues = {"false", "true"};
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.enablePrefixCaching) == allowedBoolValues.end()) {
std::cerr << "enable_prefix_caching: " << settings.enablePrefixCaching << " is not allowed. Supported values: true, false" << std::endl;
return false;
}

if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.dynamicSplitFuse) == allowedBoolValues.end()) {
std::cerr << "dynamic_split_fuse: " << settings.dynamicSplitFuse << " is not allowed. Supported values: true, false" << std::endl;
return false;
}

if (settings.targetDevice != "NPU") {
if (settings.pluginConfig.maxPromptLength.has_value()) {
std::cerr << "max_prompt_len is only supported for NPU target device";
return false;
}
}
}

if (this->serverSettings.hfSettings.task == EMBEDDINGS_GRAPH) {
if (!std::holds_alternative<EmbeddingsGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings)) {
std::cerr << "Graph options not initialized for embeddings.";
return false;
}
auto settings = std::get<EmbeddingsGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings);

std::vector allowedBoolValues = {"false", "true"};
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.normalize) == allowedBoolValues.end()) {
std::cerr << "normalize: " << settings.normalize << " is not allowed. Supported values: true, false" << std::endl;
return false;
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add check for source_model and model_repository_path if they are set and add unit tests.

return true;
}
if (this->serverSettings.serverMode == LIST_MODELS_MODE) {
Expand Down
3 changes: 3 additions & 0 deletions src/graph_export/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ cc_library(
"@ovms//src:libovmsfilesystem",
"@ovms//src:libovmslocalfilesystem",
"@com_github_tencent_rapidjson//:rapidjson",
"@mediapipe//mediapipe/framework/port:parse_text_proto",
"@mediapipe//mediapipe/framework:calculator_graph",
"@ovms//src:libovmsschema",
],
visibility = ["//visibility:public"],
)
Expand Down
5 changes: 0 additions & 5 deletions src/graph_export/embeddings_graph_cli_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ void EmbeddingsGraphCLIParser::createOptions() {
"The number of parallel execution streams to use for the model. Use at least 2 on 2 socket CPU systems.",
cxxopts::value<uint32_t>()->default_value("1"),
"NUM_STREAMS")
("truncate",
"Truncate the prompts to fit to the embeddings model.",
cxxopts::value<std::string>()->default_value("false"),
"TRUNCATE")
("normalize",
"Normalize the embeddings.",
cxxopts::value<std::string>()->default_value("false"),
Expand Down Expand Up @@ -95,7 +91,6 @@ void EmbeddingsGraphCLIParser::prepare(OvmsServerMode serverMode, HFSettingsImpl
} else {
embeddingsGraphSettings.numStreams = result->operator[]("num_streams").as<uint32_t>();
embeddingsGraphSettings.normalize = result->operator[]("normalize").as<std::string>();
embeddingsGraphSettings.truncate = result->operator[]("truncate").as<std::string>();
embeddingsGraphSettings.version = result->operator[]("model_version").as<std::uint32_t>();
}
hfSettings.graphSettings = std::move(embeddingsGraphSettings);
Expand Down
47 changes: 46 additions & 1 deletion src/graph_export/graph_export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include <rapidjson/document.h>
#include <rapidjson/stringbuffer.h>
#include <rapidjson/writer.h>
#include <rapidjson/istreamwrapper.h>
#include <rapidjson/prettywriter.h>
#pragma warning(pop)

#include "../capi_frontend/server_settings.hpp"
Expand All @@ -34,8 +36,12 @@
#include "../logging.hpp"
#include "../status.hpp"
#include "../stringutils.hpp"
#include "../schema.hpp"
#include "graph_export_types.hpp"

#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/calculator_graph.h"

namespace ovms {

static Status createTextGenerationGraphTemplate(const std::string& directoryPath, const TextGenGraphSettingsImpl& graphSettings) {
Expand Down Expand Up @@ -102,12 +108,31 @@ static Status createTextGenerationGraphTemplate(const std::string& directoryPath
}
}
})";

::mediapipe::CalculatorGraphConfig config;
bool success = ::google::protobuf::TextFormat::ParseFromString(oss.str(), &config);
if (!success) {
SPDLOG_ERROR("Created graph config file couldn't be parsed - check used task parameters values.");
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID;
}
// clang-format on
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
return FileSystem::createFileOverwrite(fullPath, oss.str());
}

static Status validateSubconfigSchema(const std::string& subconfig, const std::string& type) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not needed

rapidjson::Document subconfigJson;
rapidjson::ParseResult parseResult = subconfigJson.Parse(subconfig.c_str());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am for parsing subconfig and mediapipe graph once it is created. @dtrawins want to remove it. Lets decide now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rasapala subconfig will not be used then new embeddings and rerank calculators are merged.

if (parseResult.Code()) {
SPDLOG_LOGGER_ERROR(modelmanager_logger, "Created {} subconfig file is not a valid JSON file. Error: {}", type, rapidjson::GetParseError_En(parseResult.Code()));
return StatusCode::JSON_INVALID;
}
if (validateJsonAgainstSchema(subconfigJson, MEDIAPIPE_SUBCONFIG_SCHEMA.c_str()) != StatusCode::OK) {
SPDLOG_ERROR("Created {} subconfig file is not in valid configuration format", type);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created graph config couldn't be parsed - check used task parameters values.

return StatusCode::JSON_INVALID;
}
return StatusCode::OK;
}

static Status createRerankSubconfigTemplate(const std::string& directoryPath, const RerankGraphSettingsImpl& graphSettings) {
std::ostringstream oss;
// clang-format off
Expand All @@ -130,6 +155,10 @@ static Status createRerankSubconfigTemplate(const std::string& directoryPath, co
}
]
})";
auto status = validateSubconfigSchema(oss.str(), "rerank");
if (!status.ok()){
return status;
}
// clang-format on
std::string fullPath = FileSystem::joinPath({directoryPath, "subconfig.json"});
return FileSystem::createFileOverwrite(fullPath, oss.str());
Expand Down Expand Up @@ -157,6 +186,10 @@ static Status createEmbeddingsSubconfigTemplate(const std::string& directoryPath
}
]
})";
auto status = validateSubconfigSchema(oss.str(), "embeddings");
if (!status.ok()){
return status;
}
// clang-format on
std::string fullPath = FileSystem::joinPath({directoryPath, "subconfig.json"});
return FileSystem::createFileOverwrite(fullPath, oss.str());
Expand Down Expand Up @@ -196,6 +229,12 @@ static Status createRerankGraphTemplate(const std::string& directoryPath, const
output_stream: "RESPONSE_PAYLOAD:output"
})";

::mediapipe::CalculatorGraphConfig config;
bool success = ::google::protobuf::TextFormat::ParseFromString(oss.str(), &config);
if (!success) {
SPDLOG_ERROR("Created rerank graph config couldn't be parsed.");
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID;
}
// clang-format on
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
auto status = FileSystem::createFileOverwrite(fullPath, oss.str());
Expand Down Expand Up @@ -245,6 +284,12 @@ static Status createEmbeddingsGraphTemplate(const std::string& directoryPath, co
}
})";

::mediapipe::CalculatorGraphConfig config;
bool success = ::google::protobuf::TextFormat::ParseFromString(oss.str(), &config);
if (!success) {
SPDLOG_ERROR("Created embeddings graph config couldn't be parsed.");
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID;
}
// clang-format on
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
auto status = FileSystem::createFileOverwrite(fullPath, oss.str());
Expand Down
82 changes: 81 additions & 1 deletion src/test/graph_export_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,14 +288,47 @@ TEST_F(GraphCreationTest, rerankPositiveDefault) {
ASSERT_EQ(expectedRerankJsonContents, jsonContents) << jsonContents;
}

TEST_F(GraphCreationTest, rerankCreatedJsonInvalid) {
ovms::HFSettingsImpl hfSettings;
hfSettings.task = ovms::RERANK_GRAPH;
ovms::RerankGraphSettingsImpl rerankGraphSettings;
rerankGraphSettings.targetDevice = "GPU";
rerankGraphSettings.modelName = "myModel\t";
rerankGraphSettings.numStreams = 2;
rerankGraphSettings.maxDocLength = 18;
rerankGraphSettings.version = 2;
hfSettings.graphSettings = std::move(rerankGraphSettings);
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID);
}

TEST_F(GraphCreationTest, rerankCreatedPbtxtInvalid) {
ovms::HFSettingsImpl hfSettings;
hfSettings.task = ovms::RERANK_GRAPH;
ovms::RerankGraphSettingsImpl rerankGraphSettings;
rerankGraphSettings.targetDevice = "GPU";
rerankGraphSettings.modelName = "myModel\"";
rerankGraphSettings.numStreams = 2;
rerankGraphSettings.maxDocLength = 18;
rerankGraphSettings.version = 2;
hfSettings.graphSettings = std::move(rerankGraphSettings);
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID);
}

TEST_F(GraphCreationTest, embeddingsPositiveDefault) {
ovms::HFSettingsImpl hfSettings;
hfSettings.task = ovms::EMBEDDINGS_GRAPH;
ovms::EmbeddingsGraphSettingsImpl embeddingsGraphSettings;
embeddingsGraphSettings.targetDevice = "GPU";
embeddingsGraphSettings.modelName = "myModel";
embeddingsGraphSettings.numStreams = 2;
embeddingsGraphSettings.truncate = "true";
embeddingsGraphSettings.normalize = "true";
embeddingsGraphSettings.version = 2;
hfSettings.graphSettings = std::move(embeddingsGraphSettings);
Expand All @@ -312,6 +345,40 @@ TEST_F(GraphCreationTest, embeddingsPositiveDefault) {
ASSERT_EQ(expectedEmbeddingsJsonContents, jsonContents) << jsonContents;
}

TEST_F(GraphCreationTest, embeddingsCreatedJsonInvalid) {
ovms::HFSettingsImpl hfSettings;
hfSettings.task = ovms::EMBEDDINGS_GRAPH;
ovms::EmbeddingsGraphSettingsImpl embeddingsGraphSettings;
embeddingsGraphSettings.targetDevice = "GPU";
embeddingsGraphSettings.modelName = "myModel\t";
embeddingsGraphSettings.numStreams = 2;
embeddingsGraphSettings.normalize = "true";
embeddingsGraphSettings.version = 2;
hfSettings.graphSettings = std::move(embeddingsGraphSettings);
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID);
}

TEST_F(GraphCreationTest, embeddingsCreatedPbtxtInvalid) {
ovms::HFSettingsImpl hfSettings;
hfSettings.task = ovms::EMBEDDINGS_GRAPH;
ovms::EmbeddingsGraphSettingsImpl embeddingsGraphSettings;
embeddingsGraphSettings.targetDevice = "GPU";
embeddingsGraphSettings.modelName = "myModel\"";
embeddingsGraphSettings.numStreams = 2;
embeddingsGraphSettings.normalize = "true";
embeddingsGraphSettings.version = 2;
hfSettings.graphSettings = std::move(embeddingsGraphSettings);
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID);
}

TEST_F(GraphCreationTest, positivePluginConfigAll) {
ovms::HFSettingsImpl hfSettings;
ovms::TextGenGraphSettingsImpl graphSettings;
Expand Down Expand Up @@ -377,3 +444,16 @@ TEST_F(GraphCreationTest, negativeGraphOptionsNotInitialized) {
status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::OK) << status.string();
}

TEST_F(GraphCreationTest, negativeCreatedPbtxtInvalid) {
ovms::HFSettingsImpl hfSettings;
hfSettings.task = ovms::TEXT_GENERATION_GRAPH;
ovms::TextGenGraphSettingsImpl graphSettings;
graphSettings.modelPath = "invalid\"";
hfSettings.graphSettings = std::move(graphSettings);
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID);
}
Loading