Skip to content

Commit ccef99f

Browse files
committed
Input parsing text2image p1
1 parent 0d3aab5 commit ccef99f

File tree

8 files changed

+401
-37
lines changed

8 files changed

+401
-37
lines changed

src/BUILD

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2812,12 +2812,12 @@ cc_test(
28122812
+ select({
28132813
"//:not_disable_mediapipe": [
28142814
"test/embeddingsnode_test.cpp",
2815-
"test/mediapipeflow_test.cpp",
2815+
#"test/mediapipeflow_test.cpp", FIXME
28162816
"test/mediapipe/inputsidepacketusertestcalc.cc",
28172817
"test/reranknode_test.cpp",
28182818
"test/rerank_handler_test.cpp",
28192819
"test/rerank_chunking_test.cpp",
2820-
"test/streaming_test.cpp", # Mediapipe enabled
2820+
# "test/streaming_test.cpp", # Mediapipe enabled FIXME
28212821
"test/mediapipe_validation_test.cpp", # Mediapipe enabled
28222822
"test/get_mediapipe_graph_metadata_response_test.cpp",
28232823
"test/mediapipe_framework_test.cpp",
@@ -2836,7 +2836,7 @@ cc_test(
28362836
"//:not_disable_python": [
28372837
# OvmsPyTensor is currently not used in OVMS core and is just a base for the binding.
28382838
# "test/python/ovms_py_tensor_test.cpp",
2839-
"test/pythonnode_test.cpp",
2839+
#"test/pythonnode_test.cpp", FIXME
28402840
# LLM logic uses Python for processing Jinja templates when built with Python enabled
28412841
"test/llm/llmtemplate_test.cpp",
28422842
],
@@ -3004,6 +3004,7 @@ cc_test(
30043004
"//src/test/mediapipe/calculators:dependency_free_http_test_calculators",
30053005
"@mediapipe//mediapipe/calculators/ovms:ovms_calculator",
30063006
"@mediapipe//mediapipe/framework:calculator_runner",
3007+
":text2image_test",
30073008
],
30083009
"//:disable_mediapipe" :
30093010
[
@@ -3174,6 +3175,19 @@ cc_library(
31743175
copts = COPTS_TESTS,
31753176
)
31763177

3178+
cc_library(
3179+
name = "text2image_test",
3180+
linkstatic = 1,
3181+
alwayslink = True,
3182+
srcs = ["test/text2image_test.cpp"],
3183+
deps = [
3184+
"//src:test_utils",
3185+
"//src/image_gen:imagegenutils",
3186+
],
3187+
local_defines = COMMON_LOCAL_DEFINES,
3188+
copts = COPTS_TESTS,
3189+
)
3190+
31773191
filegroup(
31783192
name = "release_custom_nodes",
31793193
srcs = [

src/image_conversion.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include "profiler.hpp"
2424
#pragma warning(push)
2525
#pragma warning(disable : 6262)
26-
#include "stb_image.h" // NOLINT
26+
#include "stb_image.h" // NOLINT
2727
#include "stb_image_write.h" // NOLINT
2828
#pragma warning(default : 6262)
2929
#pragma warning(disable : 6001 4324 6385 6386)
@@ -107,13 +107,13 @@ std::string save_image_stbi(ov::Tensor tensor) {
107107

108108
// Write PNG to memory using our buffer
109109
int success = stbi_write_png_to_func(
110-
write_func, // Our write function
111-
&png_buffer, // Context (our buffer)
112-
width, // Image width
113-
height, // Image height
114-
channels, // Number of channels
115-
image_data, // Image data
116-
width * channels); // Stride (bytes per row)
110+
write_func, // Our write function
111+
&png_buffer, // Context (our buffer)
112+
width, // Image width
113+
height, // Image height
114+
channels, // Number of channels
115+
image_data, // Image data
116+
width * channels); // Stride (bytes per row)
117117

118118
if (!success) {
119119
throw std::runtime_error{"Failed to encode image to PNG format"};

src/image_gen/BUILD

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,27 @@ cc_library(
5353
alwayslink = 1,
5454
)
5555

56+
cc_library(
57+
name = "imagegenutils",
58+
srcs = ["imagegenutils.cpp"],
59+
hdrs = ["imagegenutils.hpp"],
60+
deps = [
61+
"@com_google_absl//absl/strings",
62+
"@com_google_absl//absl/status",
63+
"//src:httppayload",
64+
"//src:libovmslogging",
65+
"//src:libimage_conversion",
66+
"//src:libovmsstring_utils",
67+
] + select({
68+
"//conditions:default": ["//third_party:genai", ":llm_engine"],
69+
"//:not_genai_bin" : [":llm_engine"],
70+
}),
71+
visibility = ["//visibility:public"],
72+
local_defines = COMMON_LOCAL_DEFINES,
73+
copts = COPTS_ADJUSTED,
74+
linkopts = LINKOPTS_ADJUSTED,
75+
)
76+
5677
cc_library(
5778
name = "image_gen_calculator",
5879
srcs = ["http_image_gen_calculator.cc"],
@@ -63,6 +84,7 @@ cc_library(
6384
"image_gen_calculator_cc_proto",
6485
":pipelines",
6586
"//src:libimage_conversion",
87+
":imagegenutils",
6688
]+ select({
6789
"//conditions:default": ["//third_party:genai", ":llm_engine"],
6890
"//:not_genai_bin" : [":llm_engine"],

src/image_gen/http_image_gen_calculator.cc

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "../image_conversion.hpp"
3030

3131
#include "pipelines.hpp"
32+
#include "imagegenutils.hpp" // FIXME split
3233

3334
#pragma warning(push)
3435
#pragma warning(disable : 6001 4324 6385 6386)
@@ -83,39 +84,24 @@ class ImageGenCalculator : public CalculatorBase {
8384
if (!payload.parsedJson->IsObject()) {
8485
return absl::InvalidArgumentError("JSON body must be an object");
8586
}
86-
87-
// get prompt field as string
88-
auto promptIt = payload.parsedJson->FindMember("prompt");
89-
if (promptIt == payload.parsedJson->MemberEnd()) {
90-
return absl::InvalidArgumentError("prompt field is missing in JSON body");
91-
}
92-
if (!promptIt->value.IsString()) {
93-
return absl::InvalidArgumentError("prompt field is not a string");
94-
}
95-
std::string prompt = promptIt->value.GetString();
87+
SET_OR_RETURN(std::string, prompt, getPromptField(payload));
9688

9789
// TODO: Support more pipeline types
9890
// Depending on URI, select text2ImagePipeline/image2ImagePipeline/inpaintingPipeline
9991

10092
// curl -X POST localhost:11338/v3/images/generations -H "Content-Type: application/json" -d '{ "model": "endpoint", "prompt": "A cute baby sea otter", "n": 1, "size": "1024x1024" }'
93+
// FIXME routing request to different pipelines (enum?)
10194
ov::genai::Text2ImagePipeline request = pipe->text2ImagePipeline.clone();
102-
ov::Tensor image = request.generate(prompt,
103-
ov::AnyMap{
104-
ov::genai::width(512), // todo: get from req
105-
ov::genai::height(512), // todo: get from req
106-
ov::genai::num_inference_steps(20), // todo: get from req
107-
ov::genai::num_images_per_prompt(1)}); // todo: get from req
95+
SET_OR_RETURN(ov::AnyMap, requestOptions, getImageGenerationRequestOptions(payload));
10896

109-
std::string res = save_image_stbi(image);
97+
ov::Tensor image = request.generate(prompt, requestOptions);
98+
std::string imageAsString = save_image_stbi(image);
11099

111100
// Convert the image to a base64 string
112-
std::string base64_image;
113-
absl::Base64Escape(res, &base64_image);
114-
101+
std::string base64image;
102+
absl::Base64Escape(imageAsString, &base64image);
115103
// Create the JSON response
116-
std::string json_response = absl::StrCat("{\"data\":[{\"b64_json\":\"", base64_image, "\"}]}");
117-
// Produce std::string packet
118-
auto output = absl::make_unique<std::string>(json_response);
104+
auto output = generateJSONResponseFromB64Image(base64image);
119105
cc->Outputs().Tag(OUTPUT_TAG_NAME).Add(output.release(), cc->InputTimestamp());
120106

121107
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "ImageGenCalculator [Node: {}] Process end", cc->NodeName());

0 commit comments

Comments
 (0)