Skip to content

ImageGen models support for --pull #3302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions demos/common/export_models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ python export_model.py --help
```
Expected Output:
```console
usage: export_model.py [-h] {text_generation,embeddings,rerank} ...
usage: export_model.py [-h] {text_generation,embeddings,rerank,image_generation} ...

Export Hugging face models to OVMS models repository including all configuration for deployments

positional arguments:
{text_generation,embeddings,rerank}
{text_generation,embeddings,rerank,image_generation}
subcommand help
text_generation export model for chat and completion endpoints
embeddings export model for embeddings endpoint
rerank export model for rerank endpoint
image_generation export model for image generation endpoint
```
For every use case subcommand there is adjusted list of parameters:

Expand Down Expand Up @@ -134,6 +135,15 @@ python export_model.py rerank \
--num_streams 2
```

### Image Generation Models
```console
python export_model.py image_generation \
--source_model dreamlike-art/dreamlike-anime-1.0 \
--weight-format int8 \
--config_file_path models/config_all.json \
--max_resolution 2048x2048
```

## Deployment example

After exporting models using the commands above (which use `--model_repository_path models` and `--config_file_path models/config_all.json`), you can deploy them with either with Docker or on Baremetal.
Expand Down
1 change: 1 addition & 0 deletions src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ cc_library(
"//src/graph_export:graph_cli_parser",
"//src/graph_export:rerank_graph_cli_parser",
"//src/graph_export:embeddings_graph_cli_parser",
"//src/graph_export:image_generation_graph_cli_parser",
],
visibility = ["//visibility:public",],
local_defines = COMMON_LOCAL_DEFINES,
Expand Down
14 changes: 13 additions & 1 deletion src/capi_frontend/server_settings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,25 @@ struct RerankGraphSettingsImpl {
uint32_t version = 1; // FIXME: export_rerank_tokenizer python method - not supported currently?
};

struct ImageGenerationGraphSettingsImpl {
std::string modelName = "";
std::string modelPath = "./";
std::string targetDevice = "CPU";
std::string maxResolution = "";
std::string defaultResolution = "";
std::optional<uint32_t> maxNumberImagesPerPrompt;
Comment on lines +77 to +79
Copy link
Collaborator

@atobiszei atobiszei May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default = optional not empty?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default it won't be present in graphpbtxt. And by default optional will be nullopt

std::optional<uint32_t> defaultNumInferenceSteps;
std::optional<uint32_t> maxNumInferenceSteps;
std::string pluginConfig;
};

struct HFSettingsImpl {
std::string targetDevice = "CPU";
std::string sourceModel = "";
std::string downloadPath = "";
bool overwriteModels = false;
GraphExportType task = TEXT_GENERATION_GRAPH;
std::variant<TextGenGraphSettingsImpl, RerankGraphSettingsImpl, EmbeddingsGraphSettingsImpl> graphSettings;
std::variant<TextGenGraphSettingsImpl, RerankGraphSettingsImpl, EmbeddingsGraphSettingsImpl, ImageGenerationGraphSettingsImpl> graphSettings;
};

struct ServerSettingsImpl {
Expand Down
13 changes: 12 additions & 1 deletion src/cli_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "graph_export/graph_cli_parser.hpp"
#include "graph_export/rerank_graph_cli_parser.hpp"
#include "graph_export/embeddings_graph_cli_parser.hpp"
#include "graph_export/image_generation_graph_cli_parser.hpp"
#include "ovms_exit_codes.hpp"
#include "filesystem.hpp"
#include "version.hpp"
Expand Down Expand Up @@ -153,7 +154,7 @@ void CLIParser::parse(int argc, char** argv) {
cxxopts::value<std::string>(),
"MODEL_REPOSITORY_PATH")
("task",
"Choose type of model export: text_generation - chat and completion endpoints, embeddings - embeddings endpoint, rerank - rerank endpoint.",
"Choose type of model export: text_generation - chat and completion endpoints, embeddings - embeddings endpoint, rerank - rerank endpoint, image_generation - image generation/edit/inpainting endpoints.",
cxxopts::value<std::string>()->default_value("text_generation"),
"TASK")
("list_models",
Expand Down Expand Up @@ -249,6 +250,12 @@ void CLIParser::parse(int argc, char** argv) {
this->graphOptionsParser = std::move(cliParser);
break;
}
case IMAGE_GENERATION_GRAPH: {
ImageGenerationGraphCLIParser cliParser;
unmatchedOptions = cliParser.parse(result->unmatched());
this->graphOptionsParser = std::move(cliParser);
break;
}
case UNKNOWN_GRAPH: {
std::cerr << "error parsing options - --task parameter unsupported value: " + result->operator[]("task").as<std::string>();
exit(OVMS_EX_USAGE);
Expand Down Expand Up @@ -501,6 +508,10 @@ void CLIParser::prepareGraph(ServerSettingsImpl& serverSettings, HFSettingsImpl&
}
break;
}
case IMAGE_GENERATION_GRAPH: {
std::get<ImageGenerationGraphCLIParser>(this->graphOptionsParser).prepare(serverSettings, hfSettings, modelName);
break;
}
case UNKNOWN_GRAPH: {
throw std::logic_error("Error: --task parameter unsupported value: " + result->operator[]("task").as<std::string>());
break;
Expand Down
3 changes: 2 additions & 1 deletion src/cli_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "graph_export/graph_cli_parser.hpp"
#include "graph_export/rerank_graph_cli_parser.hpp"
#include "graph_export/embeddings_graph_cli_parser.hpp"
#include "graph_export/image_generation_graph_cli_parser.hpp"

namespace ovms {

Expand All @@ -33,7 +34,7 @@ struct ModelsSettingsImpl;
class CLIParser {
std::unique_ptr<cxxopts::Options> options;
std::unique_ptr<cxxopts::ParseResult> result;
std::variant<GraphCLIParser, RerankGraphCLIParser, EmbeddingsGraphCLIParser> graphOptionsParser;
std::variant<GraphCLIParser, RerankGraphCLIParser, EmbeddingsGraphCLIParser, ImageGenerationGraphCLIParser> graphOptionsParser;

public:
CLIParser() = default;
Expand Down
16 changes: 16 additions & 0 deletions src/graph_export/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,22 @@ cc_library(
visibility = ["//visibility:public"],
)

cc_library(
name = "image_generation_graph_cli_parser",
srcs = ["image_generation_graph_cli_parser.cpp"],
hdrs = ["image_generation_graph_cli_parser.hpp"],
deps = [
"@ovms//src/graph_export:graph_export_types",
"@ovms//src/graph_export:graph_cli_parser",
"@ovms//src:cpp_headers",
"@ovms//src:libovms_server_settings",
"@ovms//src:ovms_exit_codes",
"@com_github_jarro2783_cxxopts//:cxxopts",
"@com_github_tencent_rapidjson//:rapidjson",
],
visibility = ["//visibility:public"],
)

cc_library(
name = "embeddings_graph_cli_parser",
srcs = ["embeddings_graph_cli_parser.cpp"],
Expand Down
66 changes: 66 additions & 0 deletions src/graph_export/graph_export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,65 @@ static Status createEmbeddingsGraphTemplate(const std::string& directoryPath, co
return createEmbeddingsSubconfigTemplate(directoryPath, graphSettings);
}

static Status createImageGenerationGraphTemplate(const std::string& directoryPath, const ImageGenerationGraphSettingsImpl& graphSettings) {
std::ostringstream oss;
// clang-format off
oss << R"(
input_stream: "HTTP_REQUEST_PAYLOAD:input"
output_stream: "HTTP_RESPONSE_PAYLOAD:output"

node: {
name: "ImageGenExecutor"
calculator: "ImageGenCalculator"
input_stream: "HTTP_REQUEST_PAYLOAD:input"
input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes"
output_stream: "HTTP_RESPONSE_PAYLOAD:output"
node_options: {
[type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: {
models_path: ")" << graphSettings.modelPath << R"("
target_device: ")" << graphSettings.targetDevice << R"(")";

if (graphSettings.pluginConfig.size()) {
oss << R"(
plugin_config: ')" << graphSettings.pluginConfig << R"(')";
}

if (graphSettings.maxResolution.size()) {
oss << R"(
max_resolution: ")" << graphSettings.maxResolution << R"(")";
}

if (graphSettings.defaultResolution.size()) {
oss << R"(
default_resolution: ")" << graphSettings.defaultResolution << R"(")";
}

if (graphSettings.maxNumberImagesPerPrompt.has_value()) {
oss << R"(
max_number_images_per_prompt: )" << graphSettings.maxNumberImagesPerPrompt.value();
}

if (graphSettings.defaultNumInferenceSteps.has_value()) {
oss << R"(
default_num_inference_steps: )" << graphSettings.defaultNumInferenceSteps.value();
}

if (graphSettings.maxNumInferenceSteps.has_value()) {
oss << R"(
max_num_inference_steps: )" << graphSettings.maxNumInferenceSteps.value();
}

oss << R"(
}
}
}
)";

// clang-format on
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
return FileSystem::createFileOverwrite(fullPath, oss.str());
}

GraphExport::GraphExport() {
}

Expand Down Expand Up @@ -344,6 +403,13 @@ Status GraphExport::createServableConfig(const std::string& directoryPath, const
SPDLOG_ERROR("Graph options not initialized for rerank.");
return StatusCode::INTERNAL_ERROR;
}
} else if (hfSettings.task == IMAGE_GENERATION_GRAPH) {
if (std::holds_alternative<ImageGenerationGraphSettingsImpl>(hfSettings.graphSettings)) {
return createImageGenerationGraphTemplate(directoryPath, std::get<ImageGenerationGraphSettingsImpl>(hfSettings.graphSettings));
} else {
SPDLOG_ERROR("Graph options not initialized for image generation.");
return StatusCode::INTERNAL_ERROR;
}
} else if (hfSettings.task == UNKNOWN_GRAPH) {
SPDLOG_ERROR("Graph options not initialized.");
return StatusCode::INTERNAL_ERROR;
Expand Down
4 changes: 4 additions & 0 deletions src/graph_export/graph_export_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,27 @@
#include <string>
#pragma once
namespace ovms {

enum GraphExportType {
TEXT_GENERATION_GRAPH,
RERANK_GRAPH,
EMBEDDINGS_GRAPH,
IMAGE_GENERATION_GRAPH,
UNKNOWN_GRAPH
};

const std::map<GraphExportType, std::string> typeToString = {
{TEXT_GENERATION_GRAPH, "text_generation"},
{RERANK_GRAPH, "rerank"},
{EMBEDDINGS_GRAPH, "embeddings"},
{IMAGE_GENERATION_GRAPH, "image_generation"},
{UNKNOWN_GRAPH, "unknown_graph"}};

const std::map<std::string, GraphExportType> stringToType = {
{"text_generation", TEXT_GENERATION_GRAPH},
{"rerank", RERANK_GRAPH},
{"embeddings", EMBEDDINGS_GRAPH},
{"image_generation", IMAGE_GENERATION_GRAPH},
{"unknown_graph", UNKNOWN_GRAPH}};

std::string enumToString(GraphExportType type);
Expand Down
Loading