diff --git a/demos/common/export_models/README.md b/demos/common/export_models/README.md index 04744a672e..209d7df67b 100644 --- a/demos/common/export_models/README.md +++ b/demos/common/export_models/README.md @@ -13,16 +13,17 @@ python export_model.py --help ``` Expected Output: ```console -usage: export_model.py [-h] {text_generation,embeddings,rerank} ... +usage: export_model.py [-h] {text_generation,embeddings,rerank,image_generation} ... Export Hugging face models to OVMS models repository including all configuration for deployments positional arguments: - {text_generation,embeddings,rerank} + {text_generation,embeddings,rerank,image_generation} subcommand help text_generation export model for chat and completion endpoints embeddings export model for embeddings endpoint rerank export model for rerank endpoint + image_generation export model for image generation endpoint ``` For every use case subcommand there is adjusted list of parameters: @@ -134,6 +135,15 @@ python export_model.py rerank \ --num_streams 2 ``` +### Image Generation Models +```console +python export_model.py image_generation \ + --source_model dreamlike-art/dreamlike-anime-1.0 \ + --weight-format int8 \ + --config_file_path models/config_all.json \ + --max_resolution 2048x2048 +``` + ## Deployment example After exporting models using the commands above (which use `--model_repository_path models` and `--config_file_path models/config_all.json`), you can deploy them with either with Docker or on Baremetal. diff --git a/src/BUILD b/src/BUILD index cf7f64ae0a..c8f358c4ab 100644 --- a/src/BUILD +++ b/src/BUILD @@ -346,6 +346,7 @@ cc_library( "//src/graph_export:graph_cli_parser", "//src/graph_export:rerank_graph_cli_parser", "//src/graph_export:embeddings_graph_cli_parser", + "//src/graph_export:image_generation_graph_cli_parser", ], visibility = ["//visibility:public",], local_defines = COMMON_LOCAL_DEFINES, diff --git a/src/capi_frontend/server_settings.hpp b/src/capi_frontend/server_settings.hpp index eb0ffe9ee9..48f2edb66a 100644 --- a/src/capi_frontend/server_settings.hpp +++ b/src/capi_frontend/server_settings.hpp @@ -70,13 +70,25 @@ struct RerankGraphSettingsImpl { uint32_t version = 1; // FIXME: export_rerank_tokenizer python method - not supported currently? }; +struct ImageGenerationGraphSettingsImpl { + std::string modelName = ""; + std::string modelPath = "./"; + std::string targetDevice = "CPU"; + std::string maxResolution = ""; + std::string defaultResolution = ""; + std::optional maxNumberImagesPerPrompt; + std::optional defaultNumInferenceSteps; + std::optional maxNumInferenceSteps; + std::string pluginConfig; +}; + struct HFSettingsImpl { std::string targetDevice = "CPU"; std::string sourceModel = ""; std::string downloadPath = ""; bool overwriteModels = false; GraphExportType task = TEXT_GENERATION_GRAPH; - std::variant graphSettings; + std::variant graphSettings; }; struct ServerSettingsImpl { diff --git a/src/cli_parser.cpp b/src/cli_parser.cpp index c9df7c0120..9bdc0e1e95 100644 --- a/src/cli_parser.cpp +++ b/src/cli_parser.cpp @@ -27,6 +27,7 @@ #include "graph_export/graph_cli_parser.hpp" #include "graph_export/rerank_graph_cli_parser.hpp" #include "graph_export/embeddings_graph_cli_parser.hpp" +#include "graph_export/image_generation_graph_cli_parser.hpp" #include "ovms_exit_codes.hpp" #include "filesystem.hpp" #include "version.hpp" @@ -153,7 +154,7 @@ void CLIParser::parse(int argc, char** argv) { cxxopts::value(), "MODEL_REPOSITORY_PATH") ("task", - "Choose type of model export: text_generation - chat and completion endpoints, embeddings - embeddings endpoint, rerank - rerank endpoint.", + "Choose type of model export: text_generation - chat and completion endpoints, embeddings - embeddings endpoint, rerank - rerank endpoint, image_generation - image generation/edit/inpainting endpoints.", cxxopts::value()->default_value("text_generation"), "TASK") ("list_models", @@ -249,6 +250,12 @@ void CLIParser::parse(int argc, char** argv) { this->graphOptionsParser = std::move(cliParser); break; } + case IMAGE_GENERATION_GRAPH: { + ImageGenerationGraphCLIParser cliParser; + unmatchedOptions = cliParser.parse(result->unmatched()); + this->graphOptionsParser = std::move(cliParser); + break; + } case UNKNOWN_GRAPH: { std::cerr << "error parsing options - --task parameter unsupported value: " + result->operator[]("task").as(); exit(OVMS_EX_USAGE); @@ -501,6 +508,10 @@ void CLIParser::prepareGraph(ServerSettingsImpl& serverSettings, HFSettingsImpl& } break; } + case IMAGE_GENERATION_GRAPH: { + std::get(this->graphOptionsParser).prepare(serverSettings, hfSettings, modelName); + break; + } case UNKNOWN_GRAPH: { throw std::logic_error("Error: --task parameter unsupported value: " + result->operator[]("task").as()); break; diff --git a/src/cli_parser.hpp b/src/cli_parser.hpp index baa9759b77..bb0d338ace 100644 --- a/src/cli_parser.hpp +++ b/src/cli_parser.hpp @@ -24,6 +24,7 @@ #include "graph_export/graph_cli_parser.hpp" #include "graph_export/rerank_graph_cli_parser.hpp" #include "graph_export/embeddings_graph_cli_parser.hpp" +#include "graph_export/image_generation_graph_cli_parser.hpp" namespace ovms { @@ -33,7 +34,7 @@ struct ModelsSettingsImpl; class CLIParser { std::unique_ptr options; std::unique_ptr result; - std::variant graphOptionsParser; + std::variant graphOptionsParser; public: CLIParser() = default; diff --git a/src/graph_export/BUILD b/src/graph_export/BUILD index a6cb6a844d..e5cff96216 100644 --- a/src/graph_export/BUILD +++ b/src/graph_export/BUILD @@ -66,6 +66,22 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "image_generation_graph_cli_parser", + srcs = ["image_generation_graph_cli_parser.cpp"], + hdrs = ["image_generation_graph_cli_parser.hpp"], + deps = [ + "@ovms//src/graph_export:graph_export_types", + "@ovms//src/graph_export:graph_cli_parser", + "@ovms//src:cpp_headers", + "@ovms//src:libovms_server_settings", + "@ovms//src:ovms_exit_codes", + "@com_github_jarro2783_cxxopts//:cxxopts", + "@com_github_tencent_rapidjson//:rapidjson", + ], + visibility = ["//visibility:public"], +) + cc_library( name = "embeddings_graph_cli_parser", srcs = ["embeddings_graph_cli_parser.cpp"], diff --git a/src/graph_export/graph_export.cpp b/src/graph_export/graph_export.cpp index 046ce19689..185ef93604 100644 --- a/src/graph_export/graph_export.cpp +++ b/src/graph_export/graph_export.cpp @@ -299,6 +299,65 @@ static Status createEmbeddingsGraphTemplate(const std::string& directoryPath, co return createEmbeddingsSubconfigTemplate(directoryPath, graphSettings); } +static Status createImageGenerationGraphTemplate(const std::string& directoryPath, const ImageGenerationGraphSettingsImpl& graphSettings) { + std::ostringstream oss; + // clang-format off + oss << R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: ")" << graphSettings.modelPath << R"(" + target_device: ")" << graphSettings.targetDevice << R"(")"; + + if (graphSettings.pluginConfig.size()) { + oss << R"( + plugin_config: ')" << graphSettings.pluginConfig << R"(')"; + } + + if (graphSettings.maxResolution.size()) { + oss << R"( + max_resolution: ")" << graphSettings.maxResolution << R"(")"; + } + + if (graphSettings.defaultResolution.size()) { + oss << R"( + default_resolution: ")" << graphSettings.defaultResolution << R"(")"; + } + + if (graphSettings.maxNumberImagesPerPrompt.has_value()) { + oss << R"( + max_number_images_per_prompt: )" << graphSettings.maxNumberImagesPerPrompt.value(); + } + + if (graphSettings.defaultNumInferenceSteps.has_value()) { + oss << R"( + default_num_inference_steps: )" << graphSettings.defaultNumInferenceSteps.value(); + } + + if (graphSettings.maxNumInferenceSteps.has_value()) { + oss << R"( + max_num_inference_steps: )" << graphSettings.maxNumInferenceSteps.value(); + } + + oss << R"( + } + } +} +)"; + + // clang-format on + std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"}); + return FileSystem::createFileOverwrite(fullPath, oss.str()); +} + GraphExport::GraphExport() { } @@ -344,6 +403,13 @@ Status GraphExport::createServableConfig(const std::string& directoryPath, const SPDLOG_ERROR("Graph options not initialized for rerank."); return StatusCode::INTERNAL_ERROR; } + } else if (hfSettings.task == IMAGE_GENERATION_GRAPH) { + if (std::holds_alternative(hfSettings.graphSettings)) { + return createImageGenerationGraphTemplate(directoryPath, std::get(hfSettings.graphSettings)); + } else { + SPDLOG_ERROR("Graph options not initialized for image generation."); + return StatusCode::INTERNAL_ERROR; + } } else if (hfSettings.task == UNKNOWN_GRAPH) { SPDLOG_ERROR("Graph options not initialized."); return StatusCode::INTERNAL_ERROR; diff --git a/src/graph_export/graph_export_types.hpp b/src/graph_export/graph_export_types.hpp index 355c52147e..e10e095631 100644 --- a/src/graph_export/graph_export_types.hpp +++ b/src/graph_export/graph_export_types.hpp @@ -18,10 +18,12 @@ #include #pragma once namespace ovms { + enum GraphExportType { TEXT_GENERATION_GRAPH, RERANK_GRAPH, EMBEDDINGS_GRAPH, + IMAGE_GENERATION_GRAPH, UNKNOWN_GRAPH }; @@ -29,12 +31,14 @@ const std::map typeToString = { {TEXT_GENERATION_GRAPH, "text_generation"}, {RERANK_GRAPH, "rerank"}, {EMBEDDINGS_GRAPH, "embeddings"}, + {IMAGE_GENERATION_GRAPH, "image_generation"}, {UNKNOWN_GRAPH, "unknown_graph"}}; const std::map stringToType = { {"text_generation", TEXT_GENERATION_GRAPH}, {"rerank", RERANK_GRAPH}, {"embeddings", EMBEDDINGS_GRAPH}, + {"image_generation", IMAGE_GENERATION_GRAPH}, {"unknown_graph", UNKNOWN_GRAPH}}; std::string enumToString(GraphExportType type); diff --git a/src/graph_export/image_generation_graph_cli_parser.cpp b/src/graph_export/image_generation_graph_cli_parser.cpp new file mode 100644 index 0000000000..a266e8fda7 --- /dev/null +++ b/src/graph_export/image_generation_graph_cli_parser.cpp @@ -0,0 +1,172 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include "image_generation_graph_cli_parser.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#pragma warning(push) +#pragma warning(disable : 6313) +#include +#include +#include +#include +#pragma warning(pop) + +#include "../capi_frontend/server_settings.hpp" +#include "../ovms_exit_codes.hpp" +#include "../status.hpp" + +namespace ovms { + +static bool isValidResolution(const std::string& resolution) { + static const std::regex pattern(R"(\d+x\d+)"); + return std::regex_match(resolution, pattern); +} + +ImageGenerationGraphSettingsImpl& ImageGenerationGraphCLIParser::defaultGraphSettings() { + static ImageGenerationGraphSettingsImpl instance; + return instance; +} + +void ImageGenerationGraphCLIParser::createOptions() { + this->options = std::make_unique("ovms --pull [PULL OPTIONS ... ]", "--pull --task image generation/edit/inpainting graph options"); + options->allow_unrecognised_options(); + + // clang-format off + options->add_options("image_generation") + ("max_resolution", + "Max allowed resolution in a format of WxH; W=width H=height. If not specified, inherited from model.", + cxxopts::value(), + "MAX_RESOLUTION") + ("default_resolution", + "Default resolution when not specified by client in a format of WxH; W=width H=height. If not specified, inherited from model.", + cxxopts::value(), + "DEFAULT_RESOLUTION") + ("max_number_images_per_prompt", + "Max allowed number of images client is allowed to request for a given prompt.", + cxxopts::value(), + "MAX_NUMBER_IMAGES_PER_PROMPT") + ("default_num_inference_steps", + "Default number of inference steps when not specified by client.", + cxxopts::value(), + "DEFAULT_NUM_INFERENCE_STEPS") + ("max_num_inference_steps", + "Max allowed number of inference steps client is allowed to request for a given prompt.", + cxxopts::value(), + "MAX_NUM_INFERENCE_STEPS") + ("num_streams", + "The number of parallel execution streams to use for the image generation models. Use at least 2 on 2 socket CPU systems.", + cxxopts::value(), + "NUM_STREAMS"); +} + +void ImageGenerationGraphCLIParser::printHelp() { + if (!this->options) { + this->createOptions(); + } + std::cout << options->help({"image_generation"}) << std::endl; +} + +std::vector ImageGenerationGraphCLIParser::parse(const std::vector& unmatchedOptions) { + if (!this->options) { + this->createOptions(); + } + std::vector cStrArray; + cStrArray.reserve(unmatchedOptions.size() + 1); + cStrArray.push_back("ovms graph"); + std::transform(unmatchedOptions.begin(), unmatchedOptions.end(), std::back_inserter(cStrArray), [](const std::string& str) { return str.c_str(); }); + const char* const* args = cStrArray.data(); + result = std::make_unique(options->parse(cStrArray.size(), args)); + + return result->unmatched(); +} + +void ImageGenerationGraphCLIParser::prepare(ServerSettingsImpl& serverSettings, HFSettingsImpl& hfSettings, const std::string& modelName) { + ImageGenerationGraphSettingsImpl imageGenerationGraphSettings = ImageGenerationGraphCLIParser::defaultGraphSettings(); + imageGenerationGraphSettings.targetDevice = hfSettings.targetDevice; + // Deduct model name + if (modelName != "") { + imageGenerationGraphSettings.modelName = modelName; + } else { + imageGenerationGraphSettings.modelName = hfSettings.sourceModel; + } + if (nullptr == result) { + // Pull with default arguments - no arguments from user + if (serverSettings.serverMode != HF_PULL_MODE && serverSettings.serverMode != HF_PULL_AND_START_MODE) { + throw std::logic_error("Tried to prepare server and model settings without graph parse result"); + } + } else { + imageGenerationGraphSettings.maxResolution = result->count("max_resolution") ? result->operator[]("max_resolution").as() : ""; + if (!imageGenerationGraphSettings.maxResolution.empty() && !isValidResolution(imageGenerationGraphSettings.maxResolution)) { + throw std::invalid_argument("Invalid max_resolution format. Expected WxH, e.g., 1024x1024"); + } + imageGenerationGraphSettings.defaultResolution = result->count("default_resolution") ? result->operator[]("default_resolution").as() : ""; + if (!imageGenerationGraphSettings.defaultResolution.empty() && !isValidResolution(imageGenerationGraphSettings.defaultResolution)) { + throw std::invalid_argument("Invalid default_resolution format. Expected WxH, e.g., 1024x1024"); + } + if (result->count("max_number_images_per_prompt")) { + imageGenerationGraphSettings.maxNumberImagesPerPrompt = result->operator[]("max_number_images_per_prompt").as(); + if (imageGenerationGraphSettings.maxNumberImagesPerPrompt == 0) { + throw std::invalid_argument("max_number_images_per_prompt must be greater than 0"); + } + } + if (result->count("default_num_inference_steps")) { + imageGenerationGraphSettings.defaultNumInferenceSteps = result->operator[]("default_num_inference_steps").as(); + if (imageGenerationGraphSettings.defaultNumInferenceSteps == 0) { + throw std::invalid_argument("default_num_inference_steps must be greater than 0"); + } + } + if (result->count("max_num_inference_steps")) { + imageGenerationGraphSettings.maxNumInferenceSteps = result->operator[]("max_num_inference_steps").as(); + if (imageGenerationGraphSettings.maxNumInferenceSteps == 0) { + throw std::invalid_argument("max_num_inference_steps must be greater than 0"); + } + } + + if (result->count("num_streams") || serverSettings.cacheDir != "") { + rapidjson::Document pluginConfigDoc; + pluginConfigDoc.SetObject(); + rapidjson::Document::AllocatorType& allocator = pluginConfigDoc.GetAllocator(); + if (result->count("num_streams")) { + uint32_t numStreams = result->operator[]("num_streams").as(); + if (numStreams == 0) { + throw std::invalid_argument("num_streams must be greater than 0"); + } + pluginConfigDoc.AddMember("NUM_STREAMS", numStreams, allocator); + } + + if (!serverSettings.cacheDir.empty()) { + pluginConfigDoc.AddMember("CACHE_DIR", rapidjson::Value(serverSettings.cacheDir.c_str(), allocator), allocator); + } + + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + pluginConfigDoc.Accept(writer); + imageGenerationGraphSettings.pluginConfig = buffer.GetString(); + } + } + + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); +} + +} // namespace ovms diff --git a/src/graph_export/image_generation_graph_cli_parser.hpp b/src/graph_export/image_generation_graph_cli_parser.hpp new file mode 100644 index 0000000000..2207b39f90 --- /dev/null +++ b/src/graph_export/image_generation_graph_cli_parser.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +#include +#include +#include + +#include + +#include "graph_cli_parser.hpp" + +namespace ovms { + +struct HFSettingsImpl; +struct ImageGenerationGraphSettingsImpl; +struct ServerSettingsImpl; +class Status; + +class ImageGenerationGraphCLIParser : public GraphCLIParser { +public: + ImageGenerationGraphCLIParser() = default; + std::vector parse(const std::vector& unmatchedOptions); + void prepare(ServerSettingsImpl& serverMode, HFSettingsImpl& hfSettings, const std::string& modelName); + + void printHelp(); + void createOptions(); + +private: + static ImageGenerationGraphSettingsImpl& defaultGraphSettings(); +}; + +} // namespace ovms diff --git a/src/test/graph_export_test.cpp b/src/test/graph_export_test.cpp index 162f1c76f7..f4543d7c4c 100644 --- a/src/test/graph_export_test.cpp +++ b/src/test/graph_export_test.cpp @@ -246,6 +246,52 @@ const std::string expectedEmbeddingsGraphContents = R"( } )"; +const std::string expectedImageGenerationGraphContents = R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: "./" + target_device: "GPU" + plugin_config: '{"NUM_STREAMS":14,"CACHE_DIR":"/cache"}' + max_resolution: "3000x4000" + default_resolution: "300x400" + max_number_images_per_prompt: 7 + default_num_inference_steps: 2 + max_num_inference_steps: 3 + } + } +} + +)"; + +const std::string expectedImageGenerationGraphContentsDefault = R"( +input_stream: "HTTP_REQUEST_PAYLOAD:input" +output_stream: "HTTP_RESPONSE_PAYLOAD:output" + +node: { + name: "ImageGenExecutor" + calculator: "ImageGenCalculator" + input_stream: "HTTP_REQUEST_PAYLOAD:input" + input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes" + output_stream: "HTTP_RESPONSE_PAYLOAD:output" + node_options: { + [type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: { + models_path: "./" + target_device: "CPU" + } + } +} + +)"; + class GraphCreationTest : public TestWithTempDir { protected: void TearDown() { @@ -435,6 +481,10 @@ TEST_F(GraphCreationTest, negativeGraphOptionsNotInitialized) { status = graphExporter->createServableConfig(this->directoryPath, hfSettings); ASSERT_EQ(status, ovms::StatusCode::INTERNAL_ERROR); + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::INTERNAL_ERROR); + hfSettings.task = ovms::UNKNOWN_GRAPH; status = graphExporter->createServableConfig(this->directoryPath, hfSettings); ASSERT_EQ(status, ovms::StatusCode::INTERNAL_ERROR); @@ -457,3 +507,38 @@ TEST_F(GraphCreationTest, negativeCreatedPbtxtInvalid) { auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID); } + +TEST_F(GraphCreationTest, imageGenerationPositiveDefault) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenerationGraphContentsDefault, graphContents) << graphContents; +} + +TEST_F(GraphCreationTest, imageGenerationPositiveFull) { + ovms::HFSettingsImpl hfSettings; + hfSettings.task = ovms::IMAGE_GENERATION_GRAPH; + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings; + imageGenerationGraphSettings.pluginConfig = "{\"NUM_STREAMS\":14,\"CACHE_DIR\":\"/cache\"}"; + imageGenerationGraphSettings.targetDevice = "GPU"; + imageGenerationGraphSettings.defaultResolution = "300x400"; + imageGenerationGraphSettings.maxResolution = "3000x4000"; + imageGenerationGraphSettings.maxNumberImagesPerPrompt = 7; + imageGenerationGraphSettings.defaultNumInferenceSteps = 2; + imageGenerationGraphSettings.maxNumInferenceSteps = 3; + hfSettings.graphSettings = std::move(imageGenerationGraphSettings); + std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt"; + std::unique_ptr graphExporter = std::make_unique(); + auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings); + ASSERT_EQ(status, ovms::StatusCode::OK); + + std::string graphContents = GetFileContents(graphPath); + ASSERT_EQ(expectedImageGenerationGraphContents, graphContents) << graphContents; +} diff --git a/src/test/ovmsconfig_test.cpp b/src/test/ovmsconfig_test.cpp index fc4ae04f2a..c1f5782c5f 100644 --- a/src/test/ovmsconfig_test.cpp +++ b/src/test/ovmsconfig_test.cpp @@ -380,6 +380,126 @@ TEST_F(OvmsConfigDeathTest, hfBadRerankGraphParameter) { EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "task: rerank - error parsing options - unmatched arguments : --normalize, true,"); } +TEST_F(OvmsConfigDeathTest, notSupportedImageGenerationGraphParameter) { + char* n_argv[] = { + "ovms", + "--pull", + "--source_model", + "some/model", + "--model_repository_path", + "/some/path", + "--task", + "image_generation", + "--unsupported_param", + "true", + }; + int arg_count = 10; + EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), + "task: image_generation - error parsing options - unmatched arguments : --unsupported_param, true,"); +} + +TEST_F(OvmsConfigDeathTest, negativeImageGenerationGraph_NumStreamsZero) { + char* n_argv[] = { + "ovms", + "--pull", + "--source_model", + "some/model", + "--model_repository_path", + "/some/path", + "--task", + "image_generation", + "--num_streams", + "0", + }; + int arg_count = 10; + EXPECT_THROW(ovms::Config::instance().parse(arg_count, n_argv), std::invalid_argument); +} + +TEST_F(OvmsConfigDeathTest, negativeImageGenerationGraph_MaxResolutionWrongFormat) { + char* n_argv[] = { + "ovms", + "--pull", + "--source_model", + "some/model", + "--model_repository_path", + "/some/path", + "--task", + "image_generation", + "--max_resolution", + "hello", + }; + int arg_count = 10; + EXPECT_THROW(ovms::Config::instance().parse(arg_count, n_argv), std::invalid_argument); +} + +TEST_F(OvmsConfigDeathTest, negativeImageGenerationGraph_DefaultResolutionWrongFormat) { + char* n_argv[] = { + "ovms", + "--pull", + "--source_model", + "some/model", + "--model_repository_path", + "/some/path", + "--task", + "image_generation", + "--default_resolution", + "hello", + }; + int arg_count = 10; + EXPECT_THROW(ovms::Config::instance().parse(arg_count, n_argv), std::invalid_argument); +} + +TEST_F(OvmsConfigDeathTest, negativeImageGenerationGraph_MaxNumberImagesPerPromptZero) { + char* n_argv[] = { + "ovms", + "--pull", + "--source_model", + "some/model", + "--model_repository_path", + "/some/path", + "--task", + "image_generation", + "--max_number_images_per_prompt", + "0", + }; + int arg_count = 10; + EXPECT_THROW(ovms::Config::instance().parse(arg_count, n_argv), std::invalid_argument); +} + +TEST_F(OvmsConfigDeathTest, negativeImageGenerationGraph_DefaultNumInferenceStepsZero) { + char* n_argv[] = { + "ovms", + "--pull", + "--source_model", + "some/model", + "--model_repository_path", + "/some/path", + "--task", + "image_generation", + "--default_num_inference_steps", + "0", + }; + int arg_count = 10; + EXPECT_THROW(ovms::Config::instance().parse(arg_count, n_argv), std::invalid_argument); +} + +TEST_F(OvmsConfigDeathTest, negativeImageGenerationGraph_MaxNumInferenceStepsZero) { + char* n_argv[] = { + "ovms", + "--pull", + "--source_model", + "some/model", + "--model_repository_path", + "/some/path", + "--task", + "image_generation", + "--max_num_inference_steps", + "0", + }; + int arg_count = 10; + EXPECT_THROW(ovms::Config::instance().parse(arg_count, n_argv), std::invalid_argument); +} + TEST_F(OvmsConfigDeathTest, hfBadEmbeddingsGraphParameter) { char* n_argv[] = { "ovms", @@ -600,6 +720,21 @@ TEST_F(OvmsConfigDeathTest, modifyModelConfigDisableMissingModelNameWithPath) { int arg_count = 7; EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "Set model_name with add_to_config, remove_from_config"); } +TEST_F(OvmsConfigDeathTest, hfBadImageGenerationGraphNoPull) { + char* n_argv[] = { + "ovms", + "--source_model", + "some/model", + "--model_repository_path", + "/some/path", + "--task", + "image_generation", + "--unsupported_param", + "true", + }; + int arg_count = 9; + EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "task: image_generation - error parsing options - unmatched arguments : --unsupported_param, true,"); +} TEST_F(OvmsConfigDeathTest, hfPullNoSourceModel) { char* n_argv[] = { @@ -1108,6 +1243,91 @@ TEST(OvmsGraphConfigTest, positiveSomeChangedRerank) { ASSERT_EQ(rerankGraphSettings.modelName, servingName); } +TEST(OvmsGraphConfigTest, positiveAllChangedImageGeneration) { + std::string modelName = "OpenVINO/Phi-3-mini-FastDraft-50M-int8-ov"; + std::string downloadPath = "test/repository"; + char* n_argv[] = { + (char*)"ovms", + (char*)"--pull", + (char*)"--source_model", + (char*)modelName.c_str(), + (char*)"--model_repository_path", + (char*)downloadPath.c_str(), + (char*)"--task", + (char*)"image_generation", + (char*)"--cache_dir", + (char*)"/cache", + (char*)"--target_device", + (char*)"GPU", + (char*)"--num_streams", + (char*)"14", + (char*)"--max_resolution", + (char*)"3000x4000", + (char*)"--default_resolution", + (char*)"300x400", + (char*)"--max_number_images_per_prompt", + (char*)"7", + (char*)"--default_num_inference_steps", + (char*)"2", + (char*)"--max_num_inference_steps", + (char*)"3", + }; + + int arg_count = 24; + ConstructorEnabledConfig config; + config.parse(arg_count, n_argv); + + auto& hfSettings = config.getServerSettings().hfSettings; + ASSERT_EQ(hfSettings.sourceModel, modelName); + ASSERT_EQ(hfSettings.downloadPath, downloadPath); + ASSERT_EQ(config.getServerSettings().serverMode, ovms::HF_PULL_MODE); + ASSERT_EQ(hfSettings.task, ovms::IMAGE_GENERATION_GRAPH); + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(imageGenerationGraphSettings.targetDevice, "GPU"); + ASSERT_EQ(imageGenerationGraphSettings.maxResolution, "3000x4000"); + ASSERT_EQ(imageGenerationGraphSettings.defaultResolution, "300x400"); + ASSERT_TRUE(imageGenerationGraphSettings.maxNumberImagesPerPrompt.has_value()); + ASSERT_EQ(imageGenerationGraphSettings.maxNumberImagesPerPrompt.value(), 7); + ASSERT_TRUE(imageGenerationGraphSettings.defaultNumInferenceSteps.has_value()); + ASSERT_EQ(imageGenerationGraphSettings.defaultNumInferenceSteps.value(), 2); + ASSERT_TRUE(imageGenerationGraphSettings.maxNumInferenceSteps.has_value()); + ASSERT_EQ(imageGenerationGraphSettings.maxNumInferenceSteps.value(), 3); + ASSERT_EQ(imageGenerationGraphSettings.pluginConfig, "{\"NUM_STREAMS\":14,\"CACHE_DIR\":\"/cache\"}"); +} + +TEST(OvmsGraphConfigTest, positiveDefaultImageGeneration) { + std::string modelName = "OpenVINO/Phi-3-mini-FastDraft-50M-int8-ov"; + std::string downloadPath = "test/repository"; + char* n_argv[] = { + (char*)"ovms", + (char*)"--pull", + (char*)"--source_model", + (char*)modelName.c_str(), + (char*)"--model_repository_path", + (char*)downloadPath.c_str(), + (char*)"--task", + (char*)"image_generation", + }; + + int arg_count = 8; + ConstructorEnabledConfig config; + config.parse(arg_count, n_argv); + + auto& hfSettings = config.getServerSettings().hfSettings; + ASSERT_EQ(hfSettings.sourceModel, modelName); + ASSERT_EQ(hfSettings.downloadPath, downloadPath); + ASSERT_EQ(config.getServerSettings().serverMode, ovms::HF_PULL_MODE); + ASSERT_EQ(hfSettings.task, ovms::IMAGE_GENERATION_GRAPH); + ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings = std::get(hfSettings.graphSettings); + ASSERT_EQ(imageGenerationGraphSettings.targetDevice, "CPU"); + ASSERT_TRUE(imageGenerationGraphSettings.maxResolution.empty()); + ASSERT_TRUE(imageGenerationGraphSettings.defaultResolution.empty()); + ASSERT_FALSE(imageGenerationGraphSettings.maxNumberImagesPerPrompt.has_value()); + ASSERT_FALSE(imageGenerationGraphSettings.defaultNumInferenceSteps.has_value()); + ASSERT_FALSE(imageGenerationGraphSettings.maxNumInferenceSteps.has_value()); + ASSERT_TRUE(imageGenerationGraphSettings.pluginConfig.empty()); +} + TEST(OvmsGraphConfigTest, positiveAllChangedEmbeddings) { std::string modelName = "OpenVINO/Phi-3-mini-FastDraft-50M-int8-ov"; std::string downloadPath = "test/repository";