Skip to content

ImageGen models support for --pull #3302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ cc_library(
"//src/graph_export:graph_cli_parser",
"//src/graph_export:rerank_graph_cli_parser",
"//src/graph_export:embeddings_graph_cli_parser",
"//src/graph_export:image_generation_graph_cli_parser",
],
visibility = ["//visibility:public",],
local_defines = COMMON_LOCAL_DEFINES,
Expand Down
9 changes: 8 additions & 1 deletion src/capi_frontend/server_settings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ struct RerankGraphSettingsImpl {
uint32_t version = 1; // FIXME: export_rerank_tokenizer python method - not supported currently?
};

struct ImageGenerationGraphSettingsImpl {
std::string modelName = "";
std::string modelPath = "./";
std::string targetDevice = "CPU";
std::string defaultResolution = "512x512";
};

struct HFSettingsImpl {
std::string targetDevice = "CPU";
std::string sourceModel = "";
Expand All @@ -69,7 +76,7 @@ struct HFSettingsImpl {
bool pullHfAndStartModelMode = false;
bool overwriteModels = false;
ExportType task = text_generation;
std::variant<TextGenGraphSettingsImpl, RerankGraphSettingsImpl, EmbeddingsGraphSettingsImpl> graphSettings;
std::variant<TextGenGraphSettingsImpl, RerankGraphSettingsImpl, EmbeddingsGraphSettingsImpl, ImageGenerationGraphSettingsImpl> graphSettings;
};

struct ServerSettingsImpl {
Expand Down
13 changes: 12 additions & 1 deletion src/cli_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "graph_export/graph_cli_parser.hpp"
#include "graph_export/rerank_graph_cli_parser.hpp"
#include "graph_export/embeddings_graph_cli_parser.hpp"
#include "graph_export/image_generation_graph_cli_parser.hpp"
#include "ovms_exit_codes.hpp"
#include "filesystem.hpp"
#include "version.hpp"
Expand Down Expand Up @@ -152,7 +153,7 @@ void CLIParser::parse(int argc, char** argv) {
cxxopts::value<std::string>(),
"MODEL_REPOSITORY_PATH")
("task",
"Choose type of model export: text_generation - chat and completion endpoints, embeddings - embeddings endpoint, rerank - rerank endpoint.",
"Choose type of model export: text_generation - chat and completion endpoints, embeddings - embeddings endpoint, rerank - rerank endpoint, image_generation - image generation/edit/inpainting endpoints.",
cxxopts::value<std::string>()->default_value("text_generation"),
"TASK")
("list_models",
Expand Down Expand Up @@ -240,6 +241,12 @@ void CLIParser::parse(int argc, char** argv) {
this->graphOptionsParser = std::move(cliParser);
break;
}
case image_generation: {
ImageGenerationGraphCLIParser cliParser;
unmatchedOptions = cliParser.parse(result->unmatched());
this->graphOptionsParser = std::move(cliParser);
break;
}
case unknown: {
std::cerr << "error parsing options - --task parameter unsupported value: " + result->operator[]("task").as<std::string>();
exit(OVMS_EX_USAGE);
Expand Down Expand Up @@ -477,6 +484,10 @@ void CLIParser::prepareGraph(HFSettingsImpl& hfSettings, const std::string& mode
}
break;
}
case image_generation: {
std::get<ImageGenerationGraphCLIParser>(this->graphOptionsParser).prepare(hfSettings, modelName);
break;
}
case unknown: {
throw std::logic_error("Error: --task parameter unsupported value: " + result->operator[]("task").as<std::string>());
break;
Expand Down
3 changes: 2 additions & 1 deletion src/cli_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "graph_export/graph_cli_parser.hpp"
#include "graph_export/rerank_graph_cli_parser.hpp"
#include "graph_export/embeddings_graph_cli_parser.hpp"
#include "graph_export/image_generation_graph_cli_parser.hpp"

namespace ovms {

Expand All @@ -33,7 +34,7 @@ struct ModelsSettingsImpl;
class CLIParser {
std::unique_ptr<cxxopts::Options> options;
std::unique_ptr<cxxopts::ParseResult> result;
std::variant<GraphCLIParser, RerankGraphCLIParser, EmbeddingsGraphCLIParser> graphOptionsParser;
std::variant<GraphCLIParser, RerankGraphCLIParser, EmbeddingsGraphCLIParser, ImageGenerationGraphCLIParser> graphOptionsParser;

public:
CLIParser() = default;
Expand Down
15 changes: 15 additions & 0 deletions src/graph_export/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ cc_library(
visibility = ["//visibility:public"],
)

cc_library(
name = "image_generation_graph_cli_parser",
srcs = ["image_generation_graph_cli_parser.cpp"],
hdrs = ["image_generation_graph_cli_parser.hpp"],
deps = [
"@ovms//src/graph_export:graph_export_types",
"@ovms//src/graph_export:graph_cli_parser",
"@ovms//src:cpp_headers",
"@ovms//src:libovms_server_settings",
"@ovms//src:ovms_exit_codes",
"@com_github_jarro2783_cxxopts//:cxxopts",
],
visibility = ["//visibility:public"],
)

cc_library(
name = "embeddings_graph_cli_parser",
srcs = ["embeddings_graph_cli_parser.cpp"],
Expand Down
37 changes: 37 additions & 0 deletions src/graph_export/graph_export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,36 @@ static Status createEmbeddingsGraphTemplate(const std::string& directoryPath, co
return createEmbeddingsSubconfigTemplate(directoryPath, graphSettings);
}

static Status createImageGenerationGraphTemplate(const std::string& directoryPath, const ImageGenerationGraphSettingsImpl& graphSettings) {
std::ostringstream oss;
// clang-format off
oss << R"(
input_stream: "HTTP_REQUEST_PAYLOAD:input"
output_stream: "HTTP_RESPONSE_PAYLOAD:output"

node: {
name: "ImageGenExecutor"
calculator: "ImageGenCalculator"
input_stream: "HTTP_REQUEST_PAYLOAD:input"
input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes"
output_stream: "HTTP_RESPONSE_PAYLOAD:output"
node_options: {
[type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: {
models_path: ")" << graphSettings.modelPath << R"(",
target_device: ")" << graphSettings.targetDevice << R"(",
default_resolution: ")" << graphSettings.defaultResolution << R"("
}
}
}
)";

// TODO: Remaining params

// clang-format on
std::string fullPath = FileSystem::joinPath({directoryPath, "graph.pbtxt"});
return createFile(fullPath, oss.str());
}

GraphExport::GraphExport() {
}

Expand Down Expand Up @@ -298,6 +328,13 @@ Status GraphExport::createServableConfig(const std::string& directoryPath, const
SPDLOG_ERROR("Graph options not initialized for rerank.");
return StatusCode::INTERNAL_ERROR;
}
} else if (hfSettings.task == image_generation) {
if (std::holds_alternative<ImageGenerationGraphSettingsImpl>(hfSettings.graphSettings)) {
return createImageGenerationGraphTemplate(directoryPath, std::get<ImageGenerationGraphSettingsImpl>(hfSettings.graphSettings));
} else {
SPDLOG_ERROR("Graph options not initialized for image generation.");
return StatusCode::INTERNAL_ERROR;
}
} else if (hfSettings.task == unknown) {
SPDLOG_ERROR("Graph options not initialized.");
return StatusCode::INTERNAL_ERROR;
Expand Down
3 changes: 3 additions & 0 deletions src/graph_export/graph_export_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,22 @@ enum ExportType {
text_generation,
rerank,
embeddings,
image_generation,
unknown
};

const std::map<ExportType, std::string> typeToString = {
{text_generation, "text_generation"},
{rerank, "rerank"},
{embeddings, "embeddings"},
{image_generation, "image_generation"},
{unknown, "unknown"}};

const std::map<std::string, ExportType> stringToType = {
{"text_generation", text_generation},
{"rerank", rerank},
{"embeddings", embeddings},
{"image_generation", image_generation},
{"unknown", unknown}};

std::string enumToString(ExportType type);
Expand Down
95 changes: 95 additions & 0 deletions src/graph_export/image_generation_graph_cli_parser.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//*****************************************************************************
// Copyright 2025 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "image_generation_graph_cli_parser.hpp"

#include <algorithm>
#include <iostream>
#include <optional>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>

#include "../capi_frontend/server_settings.hpp"
#include "../ovms_exit_codes.hpp"
#include "../status.hpp"

namespace ovms {

ImageGenerationGraphSettingsImpl& ImageGenerationGraphCLIParser::defaultGraphSettings() {
static ImageGenerationGraphSettingsImpl instance;
return instance;
}

void ImageGenerationGraphCLIParser::createOptions() {
this->options = std::make_unique<cxxopts::Options>("ovms --pull [PULL OPTIONS ... ]", "-pull --task image generation/edit/inpainting graph options");
options->allow_unrecognised_options();

// clang-format off
options->add_options("image_generation")
("graph_target_device",
"CPU, GPU, NPU or HETERO, default is CPU.",
cxxopts::value<std::string>()->default_value("CPU"),
"GRAPH_TARGET_DEVICE")
("default_resolution",
"Default width and height of requested image in case user does not provide it.",
cxxopts::value<std::string>()->default_value("512x512"),
"DEFAULT_RESOLUTION");
}

void ImageGenerationGraphCLIParser::printHelp() {
if (!this->options) {
this->createOptions();
}
std::cout << options->help({"image_generation"}) << std::endl;
}

std::vector<std::string> ImageGenerationGraphCLIParser::parse(const std::vector<std::string>& unmatchedOptions) {
if (!this->options) {
this->createOptions();
}
std::vector<const char*> cStrArray;
cStrArray.reserve(unmatchedOptions.size() + 1);
cStrArray.push_back("ovms graph");
std::transform(unmatchedOptions.begin(), unmatchedOptions.end(), std::back_inserter(cStrArray), [](const std::string& str) { return str.c_str(); });
const char* const* args = cStrArray.data();
result = std::make_unique<cxxopts::ParseResult>(options->parse(cStrArray.size(), args));

return result->unmatched();
}

void ImageGenerationGraphCLIParser::prepare(HFSettingsImpl& hfSettings, const std::string& modelName) {
ImageGenerationGraphSettingsImpl imageGenerationGraphSettings = ImageGenerationGraphCLIParser::defaultGraphSettings();
// Deduct model name
if (modelName != "") {
imageGenerationGraphSettings.modelName = modelName;
} else {
imageGenerationGraphSettings.modelName = hfSettings.sourceModel;
}
if (nullptr == result) {
// Pull with default arguments - no arguments from user
if (!hfSettings.pullHfModelMode || !hfSettings.pullHfAndStartModelMode) {
throw std::logic_error("Tried to prepare server and model settings without graph parse result");
}
} else {
imageGenerationGraphSettings.targetDevice = result->operator[]("graph_target_device").as<std::string>();
imageGenerationGraphSettings.defaultResolution = result->operator[]("default_resolution").as<std::string>();
}

hfSettings.graphSettings = std::move(imageGenerationGraphSettings);
}

} // namespace ovms
45 changes: 45 additions & 0 deletions src/graph_export/image_generation_graph_cli_parser.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//*****************************************************************************
// Copyright 2025 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once

#include <memory>
#include <string>
#include <vector>

#include <cxxopts.hpp>

#include "graph_cli_parser.hpp"

namespace ovms {

struct HFSettingsImpl;
struct ImageGenerationGraphSettingsImpl;
class Status;

class ImageGenerationGraphCLIParser : public GraphCLIParser {
public:
ImageGenerationGraphCLIParser() = default;
std::vector<std::string> parse(const std::vector<std::string>& unmatchedOptions);
void prepare(HFSettingsImpl& hfSettings, const std::string& modelName);

void printHelp();
void createOptions();

private:
static ImageGenerationGraphSettingsImpl& defaultGraphSettings();
};

} // namespace ovms
41 changes: 41 additions & 0 deletions src/test/graph_export_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,27 @@ const std::string expectedEmbeddingsGraphContents = R"(
}
)";

const std::string expectedImageGenerationGraphContents = R"(
input_stream: "HTTP_REQUEST_PAYLOAD:input"
output_stream: "HTTP_RESPONSE_PAYLOAD:output"

node: {
name: "ImageGenExecutor"
calculator: "ImageGenCalculator"
input_stream: "HTTP_REQUEST_PAYLOAD:input"
input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes"
output_stream: "HTTP_RESPONSE_PAYLOAD:output"
node_options: {
[type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: {
models_path: "./",
target_device: "GPU",
default_resolution: "800x800"
}
}
}

)";

class GraphCreationTest : public TestWithTempDir {
protected:
void TearDown() {
Expand Down Expand Up @@ -368,6 +389,10 @@ TEST_F(GraphCreationTest, negativeGraphOptionsNotInitialized) {
status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::INTERNAL_ERROR);

hfSettings.task = ovms::image_generation;
status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::INTERNAL_ERROR);

hfSettings.task = ovms::unknown;
status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::INTERNAL_ERROR);
Expand All @@ -377,3 +402,19 @@ TEST_F(GraphCreationTest, negativeGraphOptionsNotInitialized) {
status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::OK) << status.string();
}

TEST_F(GraphCreationTest, imageGenerationPositiveDefault) {
ovms::HFSettingsImpl hfSettings;
hfSettings.task = ovms::image_generation;
ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings;
imageGenerationGraphSettings.targetDevice = "GPU";
imageGenerationGraphSettings.defaultResolution = "800x800";
hfSettings.graphSettings = std::move(imageGenerationGraphSettings);
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
auto status = graphExporter->createServableConfig(this->directoryPath, hfSettings);
ASSERT_EQ(status, ovms::StatusCode::OK);

std::string graphContents = GetFileContents(graphPath);
ASSERT_EQ(expectedImageGenerationGraphContents, graphContents) << graphContents;
}
Loading