-
Notifications
You must be signed in to change notification settings - Fork 217
Validate create graph parameters #3290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7277f7c
bf1cbe6
9119690
5197ce8
c66955c
51923cd
29df3d8
142ca0e
e63a6f2
2ee56f0
18ba9c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,12 +86,73 @@ bool Config::check_hostname_or_ip(const std::string& input) { | |
} | ||
|
||
bool Config::validate() { | ||
// TODO: CVS-166727 Add validation of all parameters once the CLI model export flags will be implemented | ||
if (this->serverSettings.serverMode == HF_PULL_MODE) { | ||
if (!serverSettings.hfSettings.sourceModel.size()) { | ||
std::cerr << "source_model parameter is required for pull mode"; | ||
return false; | ||
} | ||
if (!serverSettings.hfSettings.downloadPath.size()) { | ||
std::cerr << "model_repository_path parameter is required for pull mode"; | ||
return false; | ||
} | ||
if (this->serverSettings.hfSettings.task == UNKNOWN_GRAPH) { | ||
std::cerr << "Error: --task parameter not set." << std::endl; | ||
return false; | ||
} | ||
if (serverSettings.hfSettings.sourceModel.rfind("OpenVINO/", 0) != 0) { | ||
std::cerr << "For now only OpenVINO models are supported in pulling mode"; | ||
return false; | ||
} | ||
if (this->serverSettings.hfSettings.task == TEXT_GENERATION_GRAPH) { | ||
if (!std::holds_alternative<TextGenGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings)) { | ||
std::cerr << "Graph options not initialized for text generation."; | ||
return false; | ||
} | ||
auto settings = std::get<TextGenGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings); | ||
std::vector allowedPipelineTypes = {"LM", "LM_CB", "VLM", "VLM_CB", "AUTO"}; | ||
if (settings.pipelineType.has_value() && std::find(allowedPipelineTypes.begin(), allowedPipelineTypes.end(), settings.pipelineType) == allowedPipelineTypes.end()) { | ||
std::cerr << "pipeline_type: " << settings.pipelineType.value() << " is not allowed. Supported types: LM, LM_CB, VLM, VLM_CB, AUTO" << std::endl; | ||
return false; | ||
} | ||
|
||
std::vector allowedTargetDevices = {"CPU", "GPU", "NPU", "AUTO"}; | ||
if (std::find(allowedTargetDevices.begin(), allowedTargetDevices.end(), settings.targetDevice) == allowedTargetDevices.end() && settings.targetDevice.rfind("HETERO", 0) != 0) { | ||
std::cerr << "target_device: " << settings.targetDevice << " is not allowed. Supported devices: CPU, GPU, NPU, HETERO, AUTO" << std::endl; | ||
return false; | ||
} | ||
|
||
std::vector allowedBoolValues = {"false", "true"}; | ||
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.enablePrefixCaching) == allowedBoolValues.end()) { | ||
std::cerr << "enable_prefix_caching: " << settings.enablePrefixCaching << " is not allowed. Supported values: true, false" << std::endl; | ||
return false; | ||
} | ||
|
||
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.dynamicSplitFuse) == allowedBoolValues.end()) { | ||
std::cerr << "dynamic_split_fuse: " << settings.dynamicSplitFuse << " is not allowed. Supported values: true, false" << std::endl; | ||
return false; | ||
} | ||
|
||
if (settings.targetDevice != "NPU") { | ||
if (settings.pluginConfig.maxPromptLength.has_value()) { | ||
std::cerr << "max_prompt_len is only supported for NPU target device"; | ||
return false; | ||
} | ||
} | ||
} | ||
|
||
if (this->serverSettings.hfSettings.task == EMBEDDINGS_GRAPH) { | ||
if (!std::holds_alternative<EmbeddingsGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings)) { | ||
std::cerr << "Graph options not initialized for embeddings."; | ||
return false; | ||
} | ||
auto settings = std::get<EmbeddingsGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings); | ||
|
||
std::vector allowedBoolValues = {"false", "true"}; | ||
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.normalize) == allowedBoolValues.end()) { | ||
std::cerr << "normalize: " << settings.normalize << " is not allowed. Supported values: true, false" << std::endl; | ||
return false; | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add check for source_model and model_repository_path if they are set and add unit tests. |
||
return true; | ||
} | ||
if (this->serverSettings.serverMode == LIST_MODELS_MODE) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,8 @@ | |
#include <rapidjson/document.h> | ||
#include <rapidjson/stringbuffer.h> | ||
#include <rapidjson/writer.h> | ||
#include <rapidjson/istreamwrapper.h> | ||
#include <rapidjson/prettywriter.h> | ||
#pragma warning(pop) | ||
|
||
#include "../capi_frontend/server_settings.hpp" | ||
|
@@ -34,8 +36,12 @@ | |
#include "../logging.hpp" | ||
#include "../status.hpp" | ||
#include "../stringutils.hpp" | ||
#include "../schema.hpp" | ||
#include "graph_export_types.hpp" | ||
|
||
#include "mediapipe/framework/port/parse_text_proto.h" | ||
#include "mediapipe/framework/calculator_graph.h" | ||
|
||
namespace ovms { | ||
|
||
static Status createTextGenerationGraphTemplate(const std::string& directoryPath, const TextGenGraphSettingsImpl& graphSettings) { | ||
|
@@ -102,12 +108,31 @@ static Status createTextGenerationGraphTemplate(const std::string& directoryPath | |
} | ||
} | ||
})"; | ||
|
||
::mediapipe::CalculatorGraphConfig config; | ||
bool success = ::google::protobuf::TextFormat::ParseFromString(oss.str(), &config); | ||
if (!success) { | ||
SPDLOG_ERROR("Created graph config file couldn't be parsed - check used task parameters values."); | ||
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID; | ||
} | ||
// clang-format on | ||
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"}); | ||
return FileSystem::createFileOverwrite(fullPath, oss.str()); | ||
} | ||
|
||
static Status validateSubconfigSchema(const std::string& subconfig, const std::string& type) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not needed |
||
rapidjson::Document subconfigJson; | ||
rapidjson::ParseResult parseResult = subconfigJson.Parse(subconfig.c_str()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am for parsing subconfig and mediapipe graph once it is created. @dtrawins want to remove it. Lets decide now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rasapala subconfig will not be used then new embeddings and rerank calculators are merged. |
||
if (parseResult.Code()) { | ||
SPDLOG_LOGGER_ERROR(modelmanager_logger, "Created {} subconfig file is not a valid JSON file. Error: {}", type, rapidjson::GetParseError_En(parseResult.Code())); | ||
return StatusCode::JSON_INVALID; | ||
} | ||
if (validateJsonAgainstSchema(subconfigJson, MEDIAPIPE_SUBCONFIG_SCHEMA.c_str()) != StatusCode::OK) { | ||
SPDLOG_ERROR("Created {} subconfig file is not in valid configuration format", type); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Created graph config couldn't be parsed - check used task parameters values. |
||
return StatusCode::JSON_INVALID; | ||
} | ||
return StatusCode::OK; | ||
} | ||
|
||
static Status createRerankSubconfigTemplate(const std::string& directoryPath, const RerankGraphSettingsImpl& graphSettings) { | ||
std::ostringstream oss; | ||
// clang-format off | ||
|
@@ -130,6 +155,10 @@ static Status createRerankSubconfigTemplate(const std::string& directoryPath, co | |
} | ||
] | ||
})"; | ||
auto status = validateSubconfigSchema(oss.str(), "rerank"); | ||
if (!status.ok()){ | ||
return status; | ||
} | ||
// clang-format on | ||
std::string fullPath = FileSystem::joinPath({directoryPath, "subconfig.json"}); | ||
return FileSystem::createFileOverwrite(fullPath, oss.str()); | ||
|
@@ -157,6 +186,10 @@ static Status createEmbeddingsSubconfigTemplate(const std::string& directoryPath | |
} | ||
] | ||
})"; | ||
auto status = validateSubconfigSchema(oss.str(), "embeddings"); | ||
if (!status.ok()){ | ||
return status; | ||
} | ||
// clang-format on | ||
std::string fullPath = FileSystem::joinPath({directoryPath, "subconfig.json"}); | ||
return FileSystem::createFileOverwrite(fullPath, oss.str()); | ||
|
@@ -196,6 +229,12 @@ static Status createRerankGraphTemplate(const std::string& directoryPath, const | |
output_stream: "RESPONSE_PAYLOAD:output" | ||
})"; | ||
|
||
::mediapipe::CalculatorGraphConfig config; | ||
bool success = ::google::protobuf::TextFormat::ParseFromString(oss.str(), &config); | ||
if (!success) { | ||
SPDLOG_ERROR("Created rerank graph config couldn't be parsed."); | ||
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID; | ||
} | ||
// clang-format on | ||
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"}); | ||
auto status = FileSystem::createFileOverwrite(fullPath, oss.str()); | ||
|
@@ -245,6 +284,12 @@ static Status createEmbeddingsGraphTemplate(const std::string& directoryPath, co | |
} | ||
})"; | ||
|
||
::mediapipe::CalculatorGraphConfig config; | ||
bool success = ::google::protobuf::TextFormat::ParseFromString(oss.str(), &config); | ||
if (!success) { | ||
SPDLOG_ERROR("Created embeddings graph config couldn't be parsed."); | ||
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID; | ||
} | ||
// clang-format on | ||
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"}); | ||
auto status = FileSystem::createFileOverwrite(fullPath, oss.str()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove // TODO: CVS-1667
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not removed.