-
Notifications
You must be signed in to change notification settings - Fork 217
Validate create graph parameters #3290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -92,6 +92,66 @@ bool Config::validate() { | |
std::cerr << "Error: --task parameter not set." << std::endl; | ||
return false; | ||
} | ||
if (this->serverSettings.hfSettings.task == text_generation) { | ||
if (!std::holds_alternative<TextGenGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings)) { | ||
std::cerr << "Graph options not initialized for text generation."; | ||
return false; | ||
} | ||
auto settings = std::get<TextGenGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings); | ||
std::vector allowedPipelineTypes = {"LM", "LM_CB", "VLM", "VLM_CB", "AUTO"}; | ||
if (settings.pipelineType.has_value() && std::find(allowedPipelineTypes.begin(), allowedPipelineTypes.end(), settings.pipelineType) == allowedPipelineTypes.end()) { | ||
std::cerr << "pipeline_type: " << settings.pipelineType.value() << " is not allowed. Supported types: LM, LM_CB, VLM, VLM_CB, AUTO" << std::endl; | ||
return false; | ||
} | ||
|
||
std::vector allowedTargetDevices = {"CPU", "GPU", "NPU", "AUTO"}; | ||
if (std::find(allowedTargetDevices.begin(), allowedTargetDevices.end(), settings.targetDevice) == allowedTargetDevices.end() && settings.targetDevice.rfind("HETERO", 0) != 0) { | ||
std::cerr << "target_device: " << settings.targetDevice << " is not allowed. Supported devices: CPU, GPU, NPU, HETERO, AUTO" << std::endl; | ||
return false; | ||
} | ||
|
||
std::vector allowedBoolValues = {"false", "true"}; | ||
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.enablePrefixCaching) == allowedBoolValues.end()) { | ||
std::cerr << "enable_prefix_caching: " << settings.enablePrefixCaching << " is not allowed. Supported values: true, false" << std::endl; | ||
return false; | ||
} | ||
|
||
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.dynamicSplitFuse) == allowedBoolValues.end()) { | ||
std::cerr << "dynamic_split_fuse: " << settings.dynamicSplitFuse << " is not allowed. Supported values: true, false" << std::endl; | ||
return false; | ||
} | ||
|
||
if (settings.targetDevice != "NPU") { | ||
if (settings.pluginConfig.maxPromptLength.has_value()) { | ||
std::cerr << "max_prompt_len is only supported for NPU target device"; | ||
return false; | ||
} | ||
} | ||
|
||
if (serverSettings.hfSettings.sourceModel.rfind("OpenVINO/", 0) != 0) { | ||
std::cerr << "For now only OpenVINO models are supported"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should be generic check not only for text_generation. @dtrawins |
||
return false; | ||
} | ||
} | ||
|
||
if (this->serverSettings.hfSettings.task == embeddings) { | ||
if (!std::holds_alternative<EmbeddingsGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings)) { | ||
std::cerr << "Graph options not initialized for embeddings."; | ||
return false; | ||
} | ||
auto settings = std::get<EmbeddingsGraphSettingsImpl>(this->serverSettings.hfSettings.graphSettings); | ||
|
||
std::vector allowedBoolValues = {"false", "true"}; | ||
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.normalize) == allowedBoolValues.end()) { | ||
std::cerr << "normalize: " << settings.normalize << " is not allowed. Supported values: true, false" << std::endl; | ||
return false; | ||
} | ||
|
||
if (std::find(allowedBoolValues.begin(), allowedBoolValues.end(), settings.truncate) == allowedBoolValues.end()) { | ||
std::cerr << "truncate: " << settings.truncate << " is not allowed. Supported values: true, false" << std::endl; | ||
return false; | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add check for source_model and model_repository_path if they are set and add unit tests. |
||
return true; | ||
} | ||
if (this->serverSettings.listServables) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove // TODO: CVS-1667