Skip to content

Commit da8e03d

Browse files
authored
ImageGen models support for --pull (#3302)
CVS-166467
1 parent d885cf3 commit da8e03d

File tree

12 files changed

+649
-5
lines changed

12 files changed

+649
-5
lines changed

demos/common/export_models/README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@ python export_model.py --help
1313
```
1414
Expected Output:
1515
```console
16-
usage: export_model.py [-h] {text_generation,embeddings,rerank} ...
16+
usage: export_model.py [-h] {text_generation,embeddings,rerank,image_generation} ...
1717

1818
Export Hugging face models to OVMS models repository including all configuration for deployments
1919

2020
positional arguments:
21-
{text_generation,embeddings,rerank}
21+
{text_generation,embeddings,rerank,image_generation}
2222
subcommand help
2323
text_generation export model for chat and completion endpoints
2424
embeddings export model for embeddings endpoint
2525
rerank export model for rerank endpoint
26+
image_generation export model for image generation endpoint
2627
```
2728
For every use case subcommand there is adjusted list of parameters:
2829

@@ -134,6 +135,15 @@ python export_model.py rerank \
134135
--num_streams 2
135136
```
136137

138+
### Image Generation Models
139+
```console
140+
python export_model.py image_generation \
141+
--source_model dreamlike-art/dreamlike-anime-1.0 \
142+
--weight-format int8 \
143+
--config_file_path models/config_all.json \
144+
--max_resolution 2048x2048
145+
```
146+
137147
## Deployment example
138148

139149
After exporting models using the commands above (which use `--model_repository_path models` and `--config_file_path models/config_all.json`), you can deploy them with either with Docker or on Baremetal.

src/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ cc_library(
346346
"//src/graph_export:graph_cli_parser",
347347
"//src/graph_export:rerank_graph_cli_parser",
348348
"//src/graph_export:embeddings_graph_cli_parser",
349+
"//src/graph_export:image_generation_graph_cli_parser",
349350
],
350351
visibility = ["//visibility:public",],
351352
local_defines = COMMON_LOCAL_DEFINES,

src/capi_frontend/server_settings.hpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,25 @@ struct RerankGraphSettingsImpl {
7070
uint32_t version = 1; // FIXME: export_rerank_tokenizer python method - not supported currently?
7171
};
7272

73+
struct ImageGenerationGraphSettingsImpl {
74+
std::string modelName = "";
75+
std::string modelPath = "./";
76+
std::string targetDevice = "CPU";
77+
std::string maxResolution = "";
78+
std::string defaultResolution = "";
79+
std::optional<uint32_t> maxNumberImagesPerPrompt;
80+
std::optional<uint32_t> defaultNumInferenceSteps;
81+
std::optional<uint32_t> maxNumInferenceSteps;
82+
std::string pluginConfig;
83+
};
84+
7385
struct HFSettingsImpl {
7486
std::string targetDevice = "CPU";
7587
std::string sourceModel = "";
7688
std::string downloadPath = "";
7789
bool overwriteModels = false;
7890
GraphExportType task = TEXT_GENERATION_GRAPH;
79-
std::variant<TextGenGraphSettingsImpl, RerankGraphSettingsImpl, EmbeddingsGraphSettingsImpl> graphSettings;
91+
std::variant<TextGenGraphSettingsImpl, RerankGraphSettingsImpl, EmbeddingsGraphSettingsImpl, ImageGenerationGraphSettingsImpl> graphSettings;
8092
};
8193

8294
struct ServerSettingsImpl {

src/cli_parser.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "graph_export/graph_cli_parser.hpp"
2828
#include "graph_export/rerank_graph_cli_parser.hpp"
2929
#include "graph_export/embeddings_graph_cli_parser.hpp"
30+
#include "graph_export/image_generation_graph_cli_parser.hpp"
3031
#include "ovms_exit_codes.hpp"
3132
#include "filesystem.hpp"
3233
#include "version.hpp"
@@ -153,7 +154,7 @@ void CLIParser::parse(int argc, char** argv) {
153154
cxxopts::value<std::string>(),
154155
"MODEL_REPOSITORY_PATH")
155156
("task",
156-
"Choose type of model export: text_generation - chat and completion endpoints, embeddings - embeddings endpoint, rerank - rerank endpoint.",
157+
"Choose type of model export: text_generation - chat and completion endpoints, embeddings - embeddings endpoint, rerank - rerank endpoint, image_generation - image generation/edit/inpainting endpoints.",
157158
cxxopts::value<std::string>()->default_value("text_generation"),
158159
"TASK")
159160
("list_models",
@@ -249,6 +250,12 @@ void CLIParser::parse(int argc, char** argv) {
249250
this->graphOptionsParser = std::move(cliParser);
250251
break;
251252
}
253+
case IMAGE_GENERATION_GRAPH: {
254+
ImageGenerationGraphCLIParser cliParser;
255+
unmatchedOptions = cliParser.parse(result->unmatched());
256+
this->graphOptionsParser = std::move(cliParser);
257+
break;
258+
}
252259
case UNKNOWN_GRAPH: {
253260
std::cerr << "error parsing options - --task parameter unsupported value: " + result->operator[]("task").as<std::string>();
254261
exit(OVMS_EX_USAGE);
@@ -501,6 +508,10 @@ void CLIParser::prepareGraph(ServerSettingsImpl& serverSettings, HFSettingsImpl&
501508
}
502509
break;
503510
}
511+
case IMAGE_GENERATION_GRAPH: {
512+
std::get<ImageGenerationGraphCLIParser>(this->graphOptionsParser).prepare(serverSettings, hfSettings, modelName);
513+
break;
514+
}
504515
case UNKNOWN_GRAPH: {
505516
throw std::logic_error("Error: --task parameter unsupported value: " + result->operator[]("task").as<std::string>());
506517
break;

src/cli_parser.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "graph_export/graph_cli_parser.hpp"
2525
#include "graph_export/rerank_graph_cli_parser.hpp"
2626
#include "graph_export/embeddings_graph_cli_parser.hpp"
27+
#include "graph_export/image_generation_graph_cli_parser.hpp"
2728

2829
namespace ovms {
2930

@@ -33,7 +34,7 @@ struct ModelsSettingsImpl;
3334
class CLIParser {
3435
std::unique_ptr<cxxopts::Options> options;
3536
std::unique_ptr<cxxopts::ParseResult> result;
36-
std::variant<GraphCLIParser, RerankGraphCLIParser, EmbeddingsGraphCLIParser> graphOptionsParser;
37+
std::variant<GraphCLIParser, RerankGraphCLIParser, EmbeddingsGraphCLIParser, ImageGenerationGraphCLIParser> graphOptionsParser;
3738

3839
public:
3940
CLIParser() = default;

src/graph_export/BUILD

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,22 @@ cc_library(
6666
visibility = ["//visibility:public"],
6767
)
6868

69+
cc_library(
70+
name = "image_generation_graph_cli_parser",
71+
srcs = ["image_generation_graph_cli_parser.cpp"],
72+
hdrs = ["image_generation_graph_cli_parser.hpp"],
73+
deps = [
74+
"@ovms//src/graph_export:graph_export_types",
75+
"@ovms//src/graph_export:graph_cli_parser",
76+
"@ovms//src:cpp_headers",
77+
"@ovms//src:libovms_server_settings",
78+
"@ovms//src:ovms_exit_codes",
79+
"@com_github_jarro2783_cxxopts//:cxxopts",
80+
"@com_github_tencent_rapidjson//:rapidjson",
81+
],
82+
visibility = ["//visibility:public"],
83+
)
84+
6985
cc_library(
7086
name = "embeddings_graph_cli_parser",
7187
srcs = ["embeddings_graph_cli_parser.cpp"],

src/graph_export/graph_export.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,65 @@ static Status createEmbeddingsGraphTemplate(const std::string& directoryPath, co
299299
return createEmbeddingsSubconfigTemplate(directoryPath, graphSettings);
300300
}
301301

302+
static Status createImageGenerationGraphTemplate(const std::string& directoryPath, const ImageGenerationGraphSettingsImpl& graphSettings) {
303+
std::ostringstream oss;
304+
// clang-format off
305+
oss << R"(
306+
input_stream: "HTTP_REQUEST_PAYLOAD:input"
307+
output_stream: "HTTP_RESPONSE_PAYLOAD:output"
308+
309+
node: {
310+
name: "ImageGenExecutor"
311+
calculator: "ImageGenCalculator"
312+
input_stream: "HTTP_REQUEST_PAYLOAD:input"
313+
input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes"
314+
output_stream: "HTTP_RESPONSE_PAYLOAD:output"
315+
node_options: {
316+
[type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: {
317+
models_path: ")" << graphSettings.modelPath << R"("
318+
target_device: ")" << graphSettings.targetDevice << R"(")";
319+
320+
if (graphSettings.pluginConfig.size()) {
321+
oss << R"(
322+
plugin_config: ')" << graphSettings.pluginConfig << R"(')";
323+
}
324+
325+
if (graphSettings.maxResolution.size()) {
326+
oss << R"(
327+
max_resolution: ")" << graphSettings.maxResolution << R"(")";
328+
}
329+
330+
if (graphSettings.defaultResolution.size()) {
331+
oss << R"(
332+
default_resolution: ")" << graphSettings.defaultResolution << R"(")";
333+
}
334+
335+
if (graphSettings.maxNumberImagesPerPrompt.has_value()) {
336+
oss << R"(
337+
max_number_images_per_prompt: )" << graphSettings.maxNumberImagesPerPrompt.value();
338+
}
339+
340+
if (graphSettings.defaultNumInferenceSteps.has_value()) {
341+
oss << R"(
342+
default_num_inference_steps: )" << graphSettings.defaultNumInferenceSteps.value();
343+
}
344+
345+
if (graphSettings.maxNumInferenceSteps.has_value()) {
346+
oss << R"(
347+
max_num_inference_steps: )" << graphSettings.maxNumInferenceSteps.value();
348+
}
349+
350+
oss << R"(
351+
}
352+
}
353+
}
354+
)";
355+
356+
// clang-format on
357+
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
358+
return FileSystem::createFileOverwrite(fullPath, oss.str());
359+
}
360+
302361
GraphExport::GraphExport() {
303362
}
304363

@@ -344,6 +403,13 @@ Status GraphExport::createServableConfig(const std::string& directoryPath, const
344403
SPDLOG_ERROR("Graph options not initialized for rerank.");
345404
return StatusCode::INTERNAL_ERROR;
346405
}
406+
} else if (hfSettings.task == IMAGE_GENERATION_GRAPH) {
407+
if (std::holds_alternative<ImageGenerationGraphSettingsImpl>(hfSettings.graphSettings)) {
408+
return createImageGenerationGraphTemplate(directoryPath, std::get<ImageGenerationGraphSettingsImpl>(hfSettings.graphSettings));
409+
} else {
410+
SPDLOG_ERROR("Graph options not initialized for image generation.");
411+
return StatusCode::INTERNAL_ERROR;
412+
}
347413
} else if (hfSettings.task == UNKNOWN_GRAPH) {
348414
SPDLOG_ERROR("Graph options not initialized.");
349415
return StatusCode::INTERNAL_ERROR;

src/graph_export/graph_export_types.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,27 @@
1818
#include <string>
1919
#pragma once
2020
namespace ovms {
21+
2122
enum GraphExportType {
2223
TEXT_GENERATION_GRAPH,
2324
RERANK_GRAPH,
2425
EMBEDDINGS_GRAPH,
26+
IMAGE_GENERATION_GRAPH,
2527
UNKNOWN_GRAPH
2628
};
2729

2830
const std::map<GraphExportType, std::string> typeToString = {
2931
{TEXT_GENERATION_GRAPH, "text_generation"},
3032
{RERANK_GRAPH, "rerank"},
3133
{EMBEDDINGS_GRAPH, "embeddings"},
34+
{IMAGE_GENERATION_GRAPH, "image_generation"},
3235
{UNKNOWN_GRAPH, "unknown_graph"}};
3336

3437
const std::map<std::string, GraphExportType> stringToType = {
3538
{"text_generation", TEXT_GENERATION_GRAPH},
3639
{"rerank", RERANK_GRAPH},
3740
{"embeddings", EMBEDDINGS_GRAPH},
41+
{"image_generation", IMAGE_GENERATION_GRAPH},
3842
{"unknown_graph", UNKNOWN_GRAPH}};
3943

4044
std::string enumToString(GraphExportType type);

0 commit comments

Comments
 (0)