Skip to content

Validate create graph parameters #3290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,66 @@ bool Config::validate() {
std::cerr << "Error: --task parameter not set." << std::endl;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove // TODO: CVS-1667

return false;
}
if (this->serverSettings.hfSettings.task == text_generation) {
if (!std::holds_alternative<TextGenGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings)) {
std::cerr << "Graph options not initialized for text generation.";
return false;
}
auto settings = std::get<TextGenGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings);
std::vector allowedPipelineTypes = {"LM", "LM_CB", "VLM", "VLM_CB", "AUTO"};
if (settings.pipelineType.has_value() && std::find(allowedPipelineTypes.begin(), allowedPipelineTypes.end(), settings.pipelineType) == allowedPipelineTypes.end()) {
std::cerr << "pipeline_type: " << settings.pipelineType.value() << " is not allowed. Supported types: LM, LM_CB, VLM, VLM_CB, AUTO" << std::endl;
return false;
}

std::vector allowedTargetDevices = {"CPU", "GPU", "NPU", "AUTO"};
if (std::find(allowedTargetDevices.begin(), allowedTargetDevices.end(), settings.targetDevice) == allowedTargetDevices.end() && settings.targetDevice.rfind("HETERO", 0) != 0) {
std::cerr << "target_device: " << settings.targetDevice << " is not allowed. Supported devices: CPU, GPU, NPU, HETERO, AUTO" << std::endl;
return false;
}

std::vector allowedBoolValues = {"false", "true"};
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.enablePrefixCaching) == allowedBoolValues.end()) {
std::cerr << "enable_prefix_caching: " << settings.enablePrefixCaching << " is not allowed. Supported values: true, false" << std::endl;
return false;
}

if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.dynamicSplitFuse) == allowedBoolValues.end()) {
std::cerr << "dynamic_split_fuse: " << settings.dynamicSplitFuse << " is not allowed. Supported values: true, false" << std::endl;
return false;
}

if (settings.targetDevice != "NPU") {
if (settings.pluginConfig.maxPromptLength.has_value()) {
std::cerr << "max_prompt_len is only supported for NPU target device";
return false;
}
}

if (serverSettings.hfSettings.sourceModel.rfind("OpenVINO/", 0) != 0) {
std::cerr << "For now only OpenVINO models are supported";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be generic check not only for text_generation. @dtrawins

return false;
}
}

if (this->serverSettings.hfSettings.task == embeddings) {
if (!std::holds_alternative<EmbeddingsGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings)) {
std::cerr << "Graph options not initialized for embeddings.";
return false;
}
auto settings = std::get<EmbeddingsGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings);

std::vector allowedBoolValues = {"false", "true"};
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.normalize) == allowedBoolValues.end()) {
std::cerr << "normalize: " << settings.normalize << " is not allowed. Supported values: true, false" << std::endl;
return false;
}

if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.truncate) == allowedBoolValues.end()) {
std::cerr << "truncate: " << settings.truncate << " is not allowed. Supported values: true, false" << std::endl;
return false;
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add check for source_model and model_repository_path if they are set and add unit tests.

return true;
}
if (this->serverSettings.listServables) {
Expand Down
3 changes: 3 additions & 0 deletions src/graph_export/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ cc_library(
"@ovms//src:libovms_module",
"@ovms//src:libovmsfilesystem",
"@com_github_tencent_rapidjson//:rapidjson",
"@mediapipe//mediapipe/framework/port:parse_text_proto",
"@mediapipe//mediapipe/framework:calculator_graph",
"@ovms//src:libovmsschema",
],
visibility = ["//visibility:public"],
)
Expand Down
47 changes: 46 additions & 1 deletion src/graph_export/graph_export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include <rapidjson/document.h>
#include <rapidjson/stringbuffer.h>
#include <rapidjson/writer.h>
#include <rapidjson/istreamwrapper.h>
#include <rapidjson/prettywriter.h>
#pragma warning(pop)

#include "../capi_frontend/server_settings.hpp"
Expand All @@ -33,8 +35,12 @@
#include "../logging.hpp"
#include "../status.hpp"
#include "../stringutils.hpp"
#include "../schema.hpp"
#include "graph_export_types.hpp"

#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/calculator_graph.h"

namespace ovms {

Status createFile(const std::string& filePath, const std::string& contents) {
Expand Down Expand Up @@ -116,12 +122,31 @@ static Status createTextGenerationGraphTemplate(const std::string& directoryPath
}
}
})";

::mediapipe::CalculatorGraphConfig config;
bool success = ::google::protobuf::TextFormat::ParseFromString(oss.str(), &config);
if (!success) {
SPDLOG_ERROR("Created graph config couldn't be parsed.");
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID;
}
// clang-format on
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
return createFile(fullPath, oss.str());
}

static Status validateSubconfigSchema(const std::string& subconfig, const std::string& type) {
rapidjson::Document subconfigJson;
rapidjson::ParseResult parseResult = subconfigJson.Parse(subconfig.c_str());
if (parseResult.Code()) {
SPDLOG_LOGGER_ERROR(modelmanager_logger, "Created {} subconfig file is not a valid JSON file. Error: {}", type, rapidjson::GetParseError_En(parseResult.Code()));
return StatusCode::JSON_INVALID;
}
if (validateJsonAgainstSchema(subconfigJson, MEDIAPIPE_SUBCONFIG_SCHEMA.c_str()) != StatusCode::OK) {
SPDLOG_ERROR("Created {} subconfig file is not in valid configuration format", type);
return StatusCode::JSON_INVALID;
}
return StatusCode::OK;
}

static Status createRerankSubconfigTemplate(const std::string& directoryPath, const RerankGraphSettingsImpl& graphSettings) {
std::ostringstream oss;
// clang-format off
Expand All @@ -144,6 +169,10 @@ static Status createRerankSubconfigTemplate(const std::string& directoryPath, co
}
]
})";
auto status = validateSubconfigSchema(oss.str(), "rerank");
if (!status.ok()){
return status;
}
// clang-format on
std::string fullPath = FileSystem::joinPath({directoryPath, "subconfig.json"});
return createFile(fullPath, oss.str());
Expand Down Expand Up @@ -171,6 +200,10 @@ static Status createEmbeddingsSubconfigTemplate(const std::string& directoryPath
}
]
})";
auto status = validateSubconfigSchema(oss.str(), "embeddings");
if (!status.ok()){
return status;
}
// clang-format on
std::string fullPath = FileSystem::joinPath({directoryPath, "subconfig.json"});
return createFile(fullPath, oss.str());
Expand Down Expand Up @@ -210,6 +243,12 @@ static Status createRerankGraphTemplate(const std::string& directoryPath, const
output_stream: "RESPONSE_PAYLOAD:output"
})";

::mediapipe::CalculatorGraphConfig config;
bool success = ::google::protobuf::TextFormat::ParseFromString(oss.str(), &config);
if (!success) {
SPDLOG_ERROR("Created rerank graph config couldn't be parsed.");
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID;
}
// clang-format on
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
auto status = createFile(fullPath, oss.str());
Expand Down Expand Up @@ -259,6 +298,12 @@ static Status createEmbeddingsGraphTemplate(const std::string& directoryPath, co
}
})";

::mediapipe::CalculatorGraphConfig config;
bool success = ::google::protobuf::TextFormat::ParseFromString(oss.str(), &config);
if (!success) {
SPDLOG_ERROR("Created embeddings graph config couldn't be parsed.");
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID;
}
// clang-format on
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
auto status = createFile(fullPath, oss.str());
Expand Down
83 changes: 83 additions & 0 deletions src/test/graph_export_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,40 @@ TEST_F(GraphCreationTest, rerankPositiveDefault) {
ASSERT_EQ(expectedRerankJsonContents, jsonContents) << jsonContents;
}

TEST_F(GraphCreationTest, rerankCreatedJsonInvalid) {
ovms::HFSettingsImpl hfSettings;
hfSettings.task = ovms::rerank;
ovms::RerankGraphSettingsImpl rerankGraphSettings;
rerankGraphSettings.targetDevice = "GPU";
rerankGraphSettings.modelName = "myModel\t";
rerankGraphSettings.numStreams = 2;
rerankGraphSettings.maxDocLength = 18;
rerankGraphSettings.version = 2;
hfSettings.graphSettings = std::move(rerankGraphSettings);
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID);
}

TEST_F(GraphCreationTest, rerankCreatedPbtxtInvalid) {
ovms::HFSettingsImpl hfSettings;
hfSettings.task = ovms::rerank;
ovms::RerankGraphSettingsImpl rerankGraphSettings;
rerankGraphSettings.targetDevice = "GPU";
rerankGraphSettings.modelName = "myModel\"";
rerankGraphSettings.numStreams = 2;
rerankGraphSettings.maxDocLength = 18;
rerankGraphSettings.version = 2;
hfSettings.graphSettings = std::move(rerankGraphSettings);
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID);
}

TEST_F(GraphCreationTest, embeddingsPositiveDefault) {
ovms::HFSettingsImpl hfSettings;
hfSettings.task = ovms::embeddings;
Expand All @@ -312,6 +346,42 @@ TEST_F(GraphCreationTest, embeddingsPositiveDefault) {
ASSERT_EQ(expectedEmbeddingsJsonContents, jsonContents) << jsonContents;
}

TEST_F(GraphCreationTest, embeddingsCreatedJsonInvalid) {
ovms::HFSettingsImpl hfSettings;
hfSettings.task = ovms::embeddings;
ovms::EmbeddingsGraphSettingsImpl embeddingsGraphSettings;
embeddingsGraphSettings.targetDevice = "GPU";
embeddingsGraphSettings.modelName = "myModel\t";
embeddingsGraphSettings.numStreams = 2;
embeddingsGraphSettings.truncate = "true";
embeddingsGraphSettings.normalize = "true";
embeddingsGraphSettings.version = 2;
hfSettings.graphSettings = std::move(embeddingsGraphSettings);
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::JSON_INVALID);
}

TEST_F(GraphCreationTest, embeddingsCreatedPbtxtInvalid) {
ovms::HFSettingsImpl hfSettings;
hfSettings.task = ovms::embeddings;
ovms::EmbeddingsGraphSettingsImpl embeddingsGraphSettings;
embeddingsGraphSettings.targetDevice = "GPU";
embeddingsGraphSettings.modelName = "myModel\"";
embeddingsGraphSettings.numStreams = 2;
embeddingsGraphSettings.truncate = "true";
embeddingsGraphSettings.normalize = "true";
embeddingsGraphSettings.version = 2;
hfSettings.graphSettings = std::move(embeddingsGraphSettings);
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID);
}

TEST_F(GraphCreationTest, positivePluginConfigAll) {
ovms::HFSettingsImpl hfSettings;
ovms::TextGenGraphSettingsImpl graphSettings;
Expand Down Expand Up @@ -377,3 +447,16 @@ TEST_F(GraphCreationTest, negativeGraphOptionsNotInitialized) {
status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::OK) << status.string();
}

TEST_F(GraphCreationTest, negativeCreatedPbtxtInvalid) {
ovms::HFSettingsImpl hfSettings;
hfSettings.task = ovms::text_generation;
ovms::TextGenGraphSettingsImpl graphSettings;
graphSettings.modelPath = "invalid\"";
hfSettings.graphSettings = std::move(graphSettings);
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
std::string subconfigPath = ovms::FileSystem::appendSlash(this->directoryPath) + "subconfig.json";
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID);
}
Loading