Skip to content

Commit 05784bf

Browse files
committed
working
1 parent 58d6739 commit 05784bf

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

demos/common/export_models/export_model.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def add_common_arguments(parser):
6464
parser_rerank.add_argument('--num_streams', default="1", help='The number of parallel execution streams to use for the model. Use at least 2 on 2 socket CPU systems.', dest='num_streams')
6565
parser_rerank.add_argument('--max_doc_length', default=16000, type=int, help='Maximum length of input documents in tokens', dest='max_doc_length')
6666
parser_rerank.add_argument('--version', default="1", help='version of the model', dest='version')
67+
68+
parser_image_generation = subparsers.add_parser('image_generation', help='export model for image generation endpoint')
69+
add_common_arguments(parser_image_generation)
70+
parser_image_generation.add_argument('--resolution', default="512x512", help='Resolution of generated images if not specified by the reques', dest='resolution') # unused for now, param as an example
6771
args = vars(parser.parse_args())
6872

6973
embedding_graph_template = """input_stream: "REQUEST_PAYLOAD:input"
@@ -213,6 +217,23 @@ def add_common_arguments(parser):
213217
]
214218
}"""
215219

220+
image_generation_graph_template = """input_stream: "HTTP_REQUEST_PAYLOAD:input"
221+
output_stream: "HTTP_RESPONSE_PAYLOAD:output"
222+
223+
node: {
224+
name: "ImageGenExecutor"
225+
calculator: "ImageGenCalculator"
226+
input_stream: "HTTP_REQUEST_PAYLOAD:input"
227+
input_side_packet: "IMAGE_GEN_NODE_RESOURCES:pipes"
228+
output_stream: "HTTP_RESPONSE_PAYLOAD:output"
229+
node_options: {
230+
[type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: {
231+
models_path: "{{model_path}}",
232+
#resolution: "{{resolution}}", # unused for now
233+
}
234+
}
235+
}"""
236+
216237
def export_rerank_tokenizer(source_model, destination_path, max_length):
217238
import openvino as ov
218239
from openvino_tokenizers import convert_tokenizer
@@ -448,6 +469,27 @@ def export_rerank_model(model_repository_path, source_model, model_name, precisi
448469
add_servable_to_config(config_file_path, model_name, os.path.relpath( os.path.join(model_repository_path, model_name), os.path.dirname(config_file_path)))
449470

450471

472+
def export_image_generation_model(model_repository_path, source_model, model_name, precision, task_parameters, config_file_path, resolution):
473+
model_path = "./"
474+
model_index_path = os.path.join(target_path, 'model_index.json')
475+
476+
if os.path.isfile(model_index_path):
477+
print("Model index file already exists. Skipping conversion.")
478+
return
479+
480+
optimum_command = "optimum-cli export openvino --model {} --weight-format {} {}".format(source_model, precision, target_path)
481+
482+
if os.system(optimum_command):
483+
raise ValueError("Failed to export image generation model model", source_model)
484+
485+
gtemplate = jinja2.Environment(loader=jinja2.BaseLoader).from_string(image_generation_graph_template)
486+
graph_content = gtemplate.render(model_path=model_path, resolution=resolution)
487+
with open(os.path.join(model_repository_path, model_name, 'graph.pbtxt'), 'w') as f:
488+
f.write(graph_content)
489+
print("Created graph {}".format(os.path.join(model_repository_path, model_name, 'graph.pbtxt')))
490+
add_servable_to_config(config_file_path, model_name, os.path.relpath( os.path.join(model_repository_path, model_name), os.path.dirname(config_file_path)))
491+
492+
451493
if not os.path.isdir(args['model_repository_path']):
452494
raise ValueError(f"The model repository path '{args['model_repository_path']}' is not a valid directory.")
453495
if args['source_model'] is None:
@@ -477,4 +519,6 @@ def export_rerank_model(model_repository_path, source_model, model_name, precisi
477519
elif args['task'] == 'rerank':
478520
export_rerank_model(args['model_repository_path'], args['source_model'], args['model_name'] ,args['precision'], template_parameters, str(args['version']), args['config_file_path'], args['max_doc_length'])
479521

522+
elif args['task'] == 'image_generation':
523+
export_image_generation_model(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, args['config_file_path'], args['resolution'])
480524

demos/common/export_models/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ einops
1313
torchvision==0.21.0
1414
timm==1.0.15
1515
auto-gptq==0.7.1
16+
diffusers==0.33.1 # for image generation

0 commit comments

Comments
 (0)